defineMobileUNet {geodl} | R Documentation |
defineMobileUNet
Description
Define a UNet architecture for geospatial semantic segmentation with a MobileNet-v2 backbone.
Usage
defineMobileUNet(
nCls = 3,
pretrainedEncoder = TRUE,
freezeEncoder = TRUE,
actFunc = "relu",
useAttn = FALSE,
useDS = FALSE,
dcChn = c(256, 128, 64, 32, 16),
negative_slope = 0.01
)
Arguments
nCls |
Number of classes being differentiated. For a binary classification, this can be either 1 or 2. If 2, the problem is treated as a multiclass problem, and a multiclass loss metric should be used. Default is 3. |
pretrainedEncoder |
TRUE or FALSE. Whether or not to initialized using pre-trained ImageNet weights for the MobileNet-v2 encoder. Default is TRUE. |
freezeEncoder |
TRUE or FALSE. Whether or not to freeze the encoder during training. The default is TRUE. If TRUE, only the decoder component is trained. |
actFunc |
Defines activation function to use throughout the network (note that MobileNet-v2 layers are not impacted). "relu" = rectified linear unit (ReLU); "lrelu" = leaky ReLU; "swish" = swish. Default is "relu". |
useAttn |
TRUE or FALSE. Whether to add attention gates along the skip connections. Default is FALSE or no attention gates are added. |
useDS |
TRUE or FALSE. Whether or not to use deep supervision. If TRUE, four predictions are made, one at each of the four largest decoder block resolutions, and the predictions are returned as a list object containing the 4 predictions. If FALSE, only the final prediction at the original resolution is returned. Default is FALSE or deep supervision is not implemented. |
dcChn |
Vector of 4 integers defining the number of output feature maps for each of the 4 decoder blocks. Default is 128, 64, 32, and 16. |
negative_slope |
If actFunc = "lrelu", specifies the negative slope term to use. Default is 0.01. |
Details
Define a UNet architecture with a MobileNet-v2 backbone or encoder. This UNet implementation was inspired by a blog post by Sigrid Keydana available here. This architecture has 6 blocks in the encoder (including the bottleneck) and 5 blocks in the decoder. The user is able to implement deep supervision (useDS = TRUE) and attention gates along the skip connections (useAttn = TRUE). This model requires three input bands or channels.
Value
ModileUNet model instance as torch nn_module
Examples
require(torch)
#Generate example data as torch tensor
tensorIn <- torch::torch_rand(c(12,3,128,128))
#Instantiate model
model <- defineMobileUNet(nCls = 3,
pretrainedEncoder = FALSE,
freezeEncoder = FALSE,
actFunc = "relu",
useAttn = TRUE,
useDS = TRUE,
dcChn = c(256,128,64,32,16),
negative_slope = 0.01)
pred <- model(tensorIn)