Skip to content

pvt

Pvt2

Bases: SemanticSegmentationModel

Wrapper around PvT2 above

Source code in src/tcd_pipeline/models/pvt.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
class Pvt2(SemanticSegmentationModel):
    """
    Wrapper around PvT2 above
    """

    model: PvtV2ForSemanticSegmentation

    def setup(self):
        """
        Performs any setup action - in the case of SegFormer,
        just checks whether to force HuggingFace to use local
        files only.
        """
        pass

    def load_model(self):
        """
        Load model weights from HuggingFace Hub or local storage. The
        config key model.weights is used. If you want to force using
        local files only, you can set the environment variable:

        HF_FORCE_LOCAL

        this can be useful for testing in offline environments.

        It is assumed that the image processor has the same name as the
        model; if you're providing a local checkpoint then the 'weight'
        path should be a directory containing the state dictionary of the
        model (saved using save_pretrained) and a `preprocessor_config.json`
        file.
        """
        self.processor = PvtImageProcessor.from_pretrained(
            pretrained_model_name_or_path=self.config.model.weights,
            do_resize=False,
            revision=self.config.model.revision,
        )
        self.model = PvtV2ForSemanticSegmentation()

    def forward(self, x: Union[torch.tensor, list[torch.tensor]]) -> torch.Tensor:
        """Forward pass of the model. The batch is first run
        through the processor which constructs a dictionary of inputs
        for the model. This processor handles varying types of input, for
        example tensors, numpy arrays, PIL images. Within the pipeline
        this function is normally called with images pre-converted to tensors
        as they are tiles sampled from a source (geo) image.

        Args:
            x: List[torch.Tensor] or torch.Tensor

        Returns:
            Interpolated semantic segmentation predictions with softmax applied
        """

        encoded_inputs = self.processor(images=x, return_tensors="pt")

        with torch.no_grad():
            encoded_inputs.to(self.model.device)
            logits = self.model(pixel_values=encoded_inputs.pixel_values)
            pred = logits.softmax(dim=1)

        return pred

forward(x)

Forward pass of the model. The batch is first run through the processor which constructs a dictionary of inputs for the model. This processor handles varying types of input, for example tensors, numpy arrays, PIL images. Within the pipeline this function is normally called with images pre-converted to tensors as they are tiles sampled from a source (geo) image.

Parameters:

Name Type Description Default
x Union[tensor, list[tensor]]

List[torch.Tensor] or torch.Tensor

required

Returns:

Type Description
Tensor

Interpolated semantic segmentation predictions with softmax applied

Source code in src/tcd_pipeline/models/pvt.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def forward(self, x: Union[torch.tensor, list[torch.tensor]]) -> torch.Tensor:
    """Forward pass of the model. The batch is first run
    through the processor which constructs a dictionary of inputs
    for the model. This processor handles varying types of input, for
    example tensors, numpy arrays, PIL images. Within the pipeline
    this function is normally called with images pre-converted to tensors
    as they are tiles sampled from a source (geo) image.

    Args:
        x: List[torch.Tensor] or torch.Tensor

    Returns:
        Interpolated semantic segmentation predictions with softmax applied
    """

    encoded_inputs = self.processor(images=x, return_tensors="pt")

    with torch.no_grad():
        encoded_inputs.to(self.model.device)
        logits = self.model(pixel_values=encoded_inputs.pixel_values)
        pred = logits.softmax(dim=1)

    return pred

load_model()

Load model weights from HuggingFace Hub or local storage. The config key model.weights is used. If you want to force using local files only, you can set the environment variable:

HF_FORCE_LOCAL

this can be useful for testing in offline environments.

It is assumed that the image processor has the same name as the model; if you're providing a local checkpoint then the 'weight' path should be a directory containing the state dictionary of the model (saved using save_pretrained) and a preprocessor_config.json file.

Source code in src/tcd_pipeline/models/pvt.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def load_model(self):
    """
    Load model weights from HuggingFace Hub or local storage. The
    config key model.weights is used. If you want to force using
    local files only, you can set the environment variable:

    HF_FORCE_LOCAL

    this can be useful for testing in offline environments.

    It is assumed that the image processor has the same name as the
    model; if you're providing a local checkpoint then the 'weight'
    path should be a directory containing the state dictionary of the
    model (saved using save_pretrained) and a `preprocessor_config.json`
    file.
    """
    self.processor = PvtImageProcessor.from_pretrained(
        pretrained_model_name_or_path=self.config.model.weights,
        do_resize=False,
        revision=self.config.model.revision,
    )
    self.model = PvtV2ForSemanticSegmentation()

setup()

Performs any setup action - in the case of SegFormer, just checks whether to force HuggingFace to use local files only.

Source code in src/tcd_pipeline/models/pvt.py
135
136
137
138
139
140
141
def setup(self):
    """
    Performs any setup action - in the case of SegFormer,
    just checks whether to force HuggingFace to use local
    files only.
    """
    pass

PvtV2ForSemanticSegmentation

Bases: Module

FPN segmentation decoder per https://arxiv.org/pdf/1901.02446, see Figure 3.

Backbone features are passed to a feature pyramid network. Each feature output is smoothed before passed into an upscaling sequence:

(conv, gn, relu, 2x upscale)

until it is 1/4 the input size. The scaled feature maps are then summed and convolved with a 1x kernel to form the segmentation output, and the final result is upscaled to the input size (i.e. 4x).

Source code in src/tcd_pipeline/models/pvt.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class PvtV2ForSemanticSegmentation(torch.nn.Module):
    """
    FPN segmentation decoder per https://arxiv.org/pdf/1901.02446,
    see Figure 3.

    Backbone features are passed to a feature pyramid network. Each feature
    output is smoothed before passed into an upscaling sequence:

      (conv, gn, relu, 2x upscale)

    until it is 1/4 the input size. The scaled feature maps are then summed
    and convolved with a 1x kernel to form the segmentation output, and the
    final result is upscaled to the input size (i.e. 4x).

    """

    def __init__(self, model_name, num_classes=2):
        super().__init__()
        config = PvtV2Config.from_pretrained(
            model_name,
            image_size=512,
            out_features=["stage1", "stage2", "stage3", "stage4"],
        )
        self.backbone = PvtV2Backbone(config)

        fpn_channels = config.hidden_sizes
        self.fpn = FeaturePyramidNetwork(fpn_channels, 256)

        self.antialias = []
        for i in range(len(fpn_channels)):
            self.antialias.append(
                torch.nn.Conv2d(
                    in_channels=256,
                    out_channels=256,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                )
            )

        self.segmentation_head = torch.nn.Conv2d(
            in_channels=256,
            out_channels=num_classes,
            kernel_size=1,
            stride=1,
            padding=1,
        )

        self.upscale_2x_256 = UpsampleStage(256, 256)
        self.upscale_2x_128 = UpsampleStage(256, 128)
        self.upscale_p2 = UpsampleStage(256, 128, scale=1)

    def forward(self, x):
        # Generate FPN features
        features = self.backbone(x)["feature_maps"]
        fpn_features = self.fpn({f"p{i+2}": p for (i, p) in enumerate(features)})

        # Anti-alias FPN features
        p2 = self.antialias[0](fpn_features["p2"])
        p3 = self.antialias[1](fpn_features["p3"])
        p4 = self.antialias[2](fpn_features["p4"])
        p5 = self.antialias[3](fpn_features["p5"])

        # Start upscaling from smallest feature (p2)
        p5 = self.upscale_2x_256(p5)
        p5 = self.upscale_2x_256(p5)
        p5 = self.upscale_2x_128(p5)

        p4 = self.upscale_2x_256(p4)
        p4 = self.upscale_2x_128(p4)

        p3 = self.upscale_2x_128(p3)

        p2 = self.upscale_p2(p2)

        # Merge features, convolve and upscale
        output = interpolate(
            self.segmentation_head(p2 + p3 + p4 + p5),
            x.shape[-2:],
            mode="bilinear",
            align_corners=False,
        )

        return output

UpsampleStage

Bases: Module

FPN Upsample stage which performs a 3x3 convolution followed by a group normalisation, GELU activation and optional upscale.

Source code in src/tcd_pipeline/models/pvt.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class UpsampleStage(torch.nn.Module):
    """
    FPN Upsample stage which performs a 3x3 convolution
    followed by a group normalisation, GELU activation and
    optional upscale.
    """

    def __init__(self, channels, groups, scale=2):
        super().__init__()
        self.conv = torch.nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.gn = torch.nn.GroupNorm(groups, channels)
        self.act = torch.nn.GELU()
        self.scale = scale

    def forward(self, x):
        x = self.conv(x)
        x = self.gn(x)
        x = self.act(x)
        _, _, h, w = x.size()

        if self.scale != 1:
            return torch.nn.functional.interpolate(
                x,
                (self.scale * h, self.scale * w),
                mode="bilinear",
                align_corners=False,
            )
        else:
            return x