smp
SMPModel
Bases: SemanticSegmentationModel
Wrapper around Segmentation Models Pytorch (SMP) family of semantic segmentation models. Currently supports unet, deeplabv3+ and unet++.
In theory any model variant is supported, but it becomes very verbose specifiying all the classes. If you need to add more, it should be close to copy-paste in the load_model function.
Source code in src/tcd_pipeline/models/smp.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
|
forward(x)
Forward pass of the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Union[Tensor, List[Tensor]]
|
List[torch.Tensor] |
required |
Returns:
Type | Description |
---|---|
Tensor
|
Interpolated semantic segmentation predictions with softmax applied |
Source code in src/tcd_pipeline/models/smp.py
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
|
load_model()
Load model weights. First, instantiates a model object from the segmentation models library. Then, if the model path exists on the system, those weights will be loaded. Otherwise, the model will be downloaded from HuggingFace Hub (assuming it exists and is accessible).
The model weights should be a state dictionary for the specified architecture. Other parameters like the backbone type (e.g. resnet) and input/output channels should be specified via config.
Source code in src/tcd_pipeline/models/smp.py
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
|