23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159 | class SegformerModule(SegmentationModule):
model: SegformerForSemanticSegmentation
def __init__(self, **kwargs):
super().__init__(**kwargs)
# For older checkpoints
if self.hparams.get("backbone"):
self.hparams["model_name"] = self.hparams["backbone"]
# User (or load_from_pretrained) must provide an id2label so that
# we know how many classes we're supporting, etc.
if "id2label" not in self.hparams:
raise ValueError("You must provide an ID->Label mapping")
id2label = json.load(open(self.hparams["id2label"], "r"))
self.id2label = {int(k): v for k, v in id2label.items()}
self.label2id = {v: k for k, v in id2label.items()}
self.num_classes = len(id2label)
self.processor = None
self.model_name = self.hparams["model_name"]
logger.info(f"Initialising model with {self.num_classes} classes.")
self.save_hyperparameters()
def configure_models(self, init_pretrained=False):
self.use_local = os.getenv("HF_FORCE_LOCAL") is not None
config = SegformerConfig.from_pretrained(
self.model_name,
num_labels=self.num_classes,
id2label=self.id2label,
label2id=self.label2id,
local_files_only=self.use_local,
)
if not init_pretrained:
self.model = SegformerForSemanticSegmentation(config)
logger.info(f"{self.model_name} initialised")
else:
self.model = SegformerForSemanticSegmentation.from_pretrained(
pretrained_model_name_or_path=self.model_name,
num_labels=self.num_classes,
id2label=self.id2label,
label2id=self.label2id,
local_files_only=self.use_local,
)
logger.info(f"{self.model_name} initialised with weights")
self.processor = SegformerImageProcessor.from_pretrained(
self.model_name,
do_resize=False,
do_reduce_labels=False,
local_files_only=self.use_local,
)
assert self.model.config.num_labels == self.num_classes
def _predict_batch(self, batch):
"""Predict on a batch of data. This function is subclassed to handle
specific details of the transformers library since we need to
(a) Pre-process data into the correct format (this could also be done
at the data loader stage)
(b) Post-process data so that the predicted masks are the correct shape
with respect to the input. This could also be done in the dataloader
by passing a (h, w) tuple so we know how to resize the image. However
we should really to compute loss with respect to the original mask
and not a downscaled one.
Returns:
loss (torch.Tensor): Loss for the batch
y_hat (torch.Tensor): Logits from the model
y_hat_hard (torch.Tensor): Argmax output from the model (i.e. predictions)
"""
encoded_inputs = self.processor(
batch["image"], batch["mask"], return_tensors="pt"
)
# TODO Move device checking and data pre-processing to the dataloader/datamodule
# For some reason, the processor doesn't respect device and moves everything back
# to CPU.
outputs = self.model(
pixel_values=encoded_inputs.pixel_values.to(self.device),
labels=encoded_inputs.labels.to(self.device),
)
# We need to reshape according to the input mask, not the encoded version
# as the sizes are likely different. We want to keep hold of the probabilities
# and not just the segmentation so we don't use the built-in converter:
# y_hat_hard = self.processor.post_process_semantic_segmentation(outputs, target_sizes=[m.shape[-2] for m in batch['mask']]))
# Somewhat stupidly, if labels are provided, upsampled logits are used to form the loss
# but only downsampled logits are returned.
y_hat = nn.functional.interpolate(
outputs.logits,
size=batch["mask"].shape[-2:],
mode="bilinear",
align_corners=False,
)
y_hat_hard = y_hat.argmax(dim=1)
return outputs.loss, y_hat, y_hat_hard
def predict_step(self, batch: dict):
"""
Run prediction on a batch of data. Functionally the same as _predict_batch
but returns only softmax probabilities rather than intermediate outputs
"""
encoded_inputs = self.processor(batch["image"], return_tensors="pt")
with torch.no_grad():
encoded_inputs.to(self.model.device)
logits = self.model(pixel_values=encoded_inputs.pixel_values).logits
pred = nn.functional.interpolate(
logits,
size=encoded_inputs.pixel_values.shape[-2:],
mode="bilinear",
align_corners=False,
)
return pred.softmax(dim=1)
def forward(self, x: torch.tensor) -> torch.tensor:
"""Forward pass of the model.
Args:
x: Image array
Returns:
Interpolated semantic segmentation probabilities
"""
return self.predict_step({"image": x})
|