Skip to content

segformermodule

SegformerModule

Bases: SegmentationModule

Source code in src/tcd_pipeline/models/segformermodule.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class SegformerModule(SegmentationModule):
    model: SegformerForSemanticSegmentation

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # For older checkpoints
        if self.hparams.get("backbone"):
            self.hparams["model_name"] = self.hparams["backbone"]

        # User (or load_from_pretrained) must provide an id2label so that
        # we know how many classes we're supporting, etc.
        if "id2label" not in self.hparams:
            raise ValueError("You must provide an ID->Label mapping")

        id2label = json.load(open(self.hparams["id2label"], "r"))
        self.id2label = {int(k): v for k, v in id2label.items()}
        self.label2id = {v: k for k, v in id2label.items()}
        self.num_classes = len(id2label)
        self.processor = None
        self.model_name = self.hparams["model_name"]
        logger.info(f"Initialising model with {self.num_classes} classes.")

        self.save_hyperparameters()

    def configure_models(self, init_pretrained=False):
        self.use_local = os.getenv("HF_FORCE_LOCAL") is not None

        config = SegformerConfig.from_pretrained(
            self.model_name,
            num_labels=self.num_classes,
            id2label=self.id2label,
            label2id=self.label2id,
            local_files_only=self.use_local,
        )

        if not init_pretrained:
            self.model = SegformerForSemanticSegmentation(config)
            logger.info(f"{self.model_name} initialised")
        else:
            self.model = SegformerForSemanticSegmentation.from_pretrained(
                pretrained_model_name_or_path=self.model_name,
                num_labels=self.num_classes,
                id2label=self.id2label,
                label2id=self.label2id,
                local_files_only=self.use_local,
            )
            logger.info(f"{self.model_name} initialised with weights")

        self.processor = SegformerImageProcessor.from_pretrained(
            self.model_name,
            do_resize=False,
            do_reduce_labels=False,
            local_files_only=self.use_local,
        )

        assert self.model.config.num_labels == self.num_classes

    def _predict_batch(self, batch):
        """Predict on a batch of data. This function is subclassed to handle
        specific details of the transformers library since we need to

        (a) Pre-process data into the correct format (this could also be done
            at the data loader stage)

        (b) Post-process data so that the predicted masks are the correct shape
            with respect to the input. This could also be done in the dataloader
            by passing a (h, w) tuple so we know how to resize the image. However
            we should really to compute loss with respect to the original mask
            and not a downscaled one.

        Returns:
            loss (torch.Tensor): Loss for the batch
            y_hat (torch.Tensor): Logits from the model
            y_hat_hard (torch.Tensor): Argmax output from the model (i.e. predictions)
        """

        encoded_inputs = self.processor(
            batch["image"], batch["mask"], return_tensors="pt"
        )

        # TODO Move device checking and data pre-processing to the dataloader/datamodule
        # For some reason, the processor doesn't respect device and moves everything back
        # to CPU.
        outputs = self.model(
            pixel_values=encoded_inputs.pixel_values.to(self.device),
            labels=encoded_inputs.labels.to(self.device),
        )

        # We need to reshape according to the input mask, not the encoded version
        # as the sizes are likely different. We want to keep hold of the probabilities
        # and not just the segmentation so we don't use the built-in converter:
        # y_hat_hard = self.processor.post_process_semantic_segmentation(outputs, target_sizes=[m.shape[-2] for m in batch['mask']]))

        # Somewhat stupidly, if labels are provided, upsampled logits are used to form the loss
        # but only downsampled logits are returned.
        y_hat = nn.functional.interpolate(
            outputs.logits,
            size=batch["mask"].shape[-2:],
            mode="bilinear",
            align_corners=False,
        )
        y_hat_hard = y_hat.argmax(dim=1)

        return outputs.loss, y_hat, y_hat_hard

    def predict_step(self, batch: dict):
        """
        Run prediction on a batch of data. Functionally the same as _predict_batch
        but returns only softmax probabilities rather than intermediate outputs

        """
        encoded_inputs = self.processor(batch["image"], return_tensors="pt")

        with torch.no_grad():
            encoded_inputs.to(self.model.device)
            logits = self.model(pixel_values=encoded_inputs.pixel_values).logits

        pred = nn.functional.interpolate(
            logits,
            size=encoded_inputs.pixel_values.shape[-2:],
            mode="bilinear",
            align_corners=False,
        )

        return pred.softmax(dim=1)

    def forward(self, x: torch.tensor) -> torch.tensor:
        """Forward pass of the model.

        Args:
            x: Image array

        Returns:
            Interpolated semantic segmentation probabilities
        """
        return self.predict_step({"image": x})

forward(x)

Forward pass of the model.

Parameters:

Name Type Description Default
x tensor

Image array

required

Returns:

Type Description
tensor

Interpolated semantic segmentation probabilities

Source code in src/tcd_pipeline/models/segformermodule.py
150
151
152
153
154
155
156
157
158
159
def forward(self, x: torch.tensor) -> torch.tensor:
    """Forward pass of the model.

    Args:
        x: Image array

    Returns:
        Interpolated semantic segmentation probabilities
    """
    return self.predict_step({"image": x})

predict_step(batch)

Run prediction on a batch of data. Functionally the same as _predict_batch but returns only softmax probabilities rather than intermediate outputs

Source code in src/tcd_pipeline/models/segformermodule.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def predict_step(self, batch: dict):
    """
    Run prediction on a batch of data. Functionally the same as _predict_batch
    but returns only softmax probabilities rather than intermediate outputs

    """
    encoded_inputs = self.processor(batch["image"], return_tensors="pt")

    with torch.no_grad():
        encoded_inputs.to(self.model.device)
        logits = self.model(pixel_values=encoded_inputs.pixel_values).logits

    pred = nn.functional.interpolate(
        logits,
        size=encoded_inputs.pixel_values.shape[-2:],
        mode="bilinear",
        align_corners=False,
    )

    return pred.softmax(dim=1)