Skip to content

smp

SMPModel

Bases: SemanticSegmentationModel

Wrapper around Segmentation Models Pytorch (SMP) family of semantic segmentation models. Currently supports unet, deeplabv3+ and unet++.

In theory any model variant is supported, but it becomes very verbose specifiying all the classes. If you need to add more, it should be close to copy-paste in the load_model function.

Source code in src/tcd_pipeline/models/smp.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class SMPModel(SemanticSegmentationModel):
    """
    Wrapper around Segmentation Models Pytorch (SMP) family of semantic
    segmentation models. Currently supports unet, deeplabv3+ and unet++.

    In theory any model variant is supported, but it becomes very verbose
    specifiying all the classes. If you need to add more, it should be close
    to copy-paste in the load_model function.
    """

    model: SegmentationModel

    def setup(self):
        pass

    def load_model(self) -> None:
        """
        Load model weights. First, instantiates a model object from the
        segmentation models library. Then, if the model path exists on the
        system, those weights will be loaded. Otherwise, the model will be
        downloaded from HuggingFace Hub (assuming it exists and is accessible).

        The model weights should be a state dictionary for the specified
        architecture. Other parameters like the backbone type (e.g. resnet)
        and input/output channels should be specified via config.
        """

        logger.info("Loading SMP model")

        if self.config.model.name == "unet":
            model = smp.Unet(
                encoder_name=self.config.model.backbone,
                classes=self.config.model.num_classes,
                in_channels=self.config.model.in_channels,
            )
        elif self.config.model.name == "deeplabv3+":
            model = smp.DeepLabV3Plus(
                encoder_name=self.config.model.backbone,
                classes=self.config.model.num_classes,
                in_channels=self.config.model.in_channels,
            )
        elif self.config.model.name == "unet++":
            model = smp.UnetPlusPlus(
                encoder_name=self.config.model.backbone,
                classes=self.config.model.num_classes,
                in_channels=self.config.model.in_channels,
            )
        else:
            raise ValueError(
                f"Model type '{self.config.model.name}' is not valid. "
                f"Currently, only supports 'unet', 'deeplabv3+' and 'unet++'."
            )

        if not os.path.exists(self.config.model.weights):
            from huggingface_hub import HfApi

            api = HfApi()
            self.config.model.weights = api.hf_hub_download(
                repo_id=self.config.model.weights,
                filename="model.pt",
                revision=self.config.model.revision,
            )

        assert os.path.exists(self.config.model.weights)

        model.load_state_dict(
            torch.load(self.config.model.weights, map_location=self.device), strict=True
        )

        self.model = model.to(self.device)
        self.model.eval()

    def forward(self, x: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
        """Forward pass of the model.

        Args:
            x: List[torch.Tensor]

        Returns:
            Interpolated semantic segmentation predictions with softmax applied
        """

        if isinstance(x, list):
            x = torch.stack(x)

        with torch.no_grad():
            logits = self.model(x)
            preds = logits.softmax(dim=1)

        return preds

forward(x)

Forward pass of the model.

Parameters:

Name Type Description Default
x Union[Tensor, List[Tensor]]

List[torch.Tensor]

required

Returns:

Type Description
Tensor

Interpolated semantic segmentation predictions with softmax applied

Source code in src/tcd_pipeline/models/smp.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def forward(self, x: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
    """Forward pass of the model.

    Args:
        x: List[torch.Tensor]

    Returns:
        Interpolated semantic segmentation predictions with softmax applied
    """

    if isinstance(x, list):
        x = torch.stack(x)

    with torch.no_grad():
        logits = self.model(x)
        preds = logits.softmax(dim=1)

    return preds

load_model()

Load model weights. First, instantiates a model object from the segmentation models library. Then, if the model path exists on the system, those weights will be loaded. Otherwise, the model will be downloaded from HuggingFace Hub (assuming it exists and is accessible).

The model weights should be a state dictionary for the specified architecture. Other parameters like the backbone type (e.g. resnet) and input/output channels should be specified via config.

Source code in src/tcd_pipeline/models/smp.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def load_model(self) -> None:
    """
    Load model weights. First, instantiates a model object from the
    segmentation models library. Then, if the model path exists on the
    system, those weights will be loaded. Otherwise, the model will be
    downloaded from HuggingFace Hub (assuming it exists and is accessible).

    The model weights should be a state dictionary for the specified
    architecture. Other parameters like the backbone type (e.g. resnet)
    and input/output channels should be specified via config.
    """

    logger.info("Loading SMP model")

    if self.config.model.name == "unet":
        model = smp.Unet(
            encoder_name=self.config.model.backbone,
            classes=self.config.model.num_classes,
            in_channels=self.config.model.in_channels,
        )
    elif self.config.model.name == "deeplabv3+":
        model = smp.DeepLabV3Plus(
            encoder_name=self.config.model.backbone,
            classes=self.config.model.num_classes,
            in_channels=self.config.model.in_channels,
        )
    elif self.config.model.name == "unet++":
        model = smp.UnetPlusPlus(
            encoder_name=self.config.model.backbone,
            classes=self.config.model.num_classes,
            in_channels=self.config.model.in_channels,
        )
    else:
        raise ValueError(
            f"Model type '{self.config.model.name}' is not valid. "
            f"Currently, only supports 'unet', 'deeplabv3+' and 'unet++'."
        )

    if not os.path.exists(self.config.model.weights):
        from huggingface_hub import HfApi

        api = HfApi()
        self.config.model.weights = api.hf_hub_download(
            repo_id=self.config.model.weights,
            filename="model.pt",
            revision=self.config.model.revision,
        )

    assert os.path.exists(self.config.model.weights)

    model.load_state_dict(
        torch.load(self.config.model.weights, map_location=self.device), strict=True
    )

    self.model = model.to(self.device)
    self.model.eval()