Skip to content

instance_segmentation

Instance segmentation model framework, using Detectron2 as the backend.

DetectronModel

Bases: Model

Tiled model subclass for Detectron2 models.

Source code in src/tcd_pipeline/models/instance_segmentation.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
class DetectronModel(Model):
    """Tiled model subclass for Detectron2 models."""

    def __init__(self, config: dict):
        """Initialize the model.

        Args:
            config (dict): The global configuration dictionary
        """
        super().__init__(config)
        self.post_processor = InstanceSegmentationPostProcessor(config)
        self.predictor = None
        self.should_reload = False
        self._cfg = None

        self.load_config()

    def setup(self):
        pass

    def load_config(self) -> None:
        cfg = get_cfg()
        cfg.merge_from_file(model_zoo.get_config_file(self.config.model.architecture))

        if isinstance(self.config.model.config, str):
            cfg.merge_from_file(self.config.model.config)
        elif isinstance(self.config.model.config, DictConfig):
            cfg.merge_from_other_cfg(
                CfgNode(OmegaConf.to_container(self.config.model.config))
            )
        else:
            raise NotImplementedError

        cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(self.config.data.classes)

        cfg.MODEL.DEVICE = self.device

        self._cfg = cfg.clone()

        MetadataCatalog.get(
            self.config.data.name
        ).thing_classes = self.config.data.classes
        self.num_classes = len(self.config.data.classes)
        self.max_detections = self._cfg.TEST.DETECTIONS_PER_IMAGE

    def load_model(self) -> None:
        """Load a detectron2 model from the provided config."""
        gc.collect()

        if torch.cuda.is_available():
            with torch.no_grad():
                torch.cuda.empty_cache()

        if not os.path.exists(self.config.model.weights):
            try:
                from huggingface_hub import HfApi

                api = HfApi()
                self._cfg.MODEL.WEIGHTS = api.hf_hub_download(
                    self.config.model.weights, filename="model.pth"
                )
            except Exception as e:
                logger.warning("Failed to download checkpoint from HF hub")
        else:
            self._cfg.MODEL.WEIGHTS = self.config.model.weights

        self.predictor = DefaultPredictor(self._cfg)

        if self._cfg.TEST.AUG.ENABLED:
            logger.info("Using Test-Time Augmentation")
            self.model = GeneralizedRCNNWithTTA(
                self._cfg,
                self.predictor.model,
                batch_size=self.config.model.tta_batch_size,
            )
        else:
            logger.info("Test-Time Augmentation is disabled")
            self.model = self.predictor.model

    def evaluate(
        self,
        annotation_file=None,
        image_folder=None,
        output_folder=None,
        prediction_file=None,
        evaluate=True,
    ) -> None:
        """Evaluate the model.

        If no inputs are provided, then the evaluation is run on the test dataset
        as per the config file. Normally you should explicitly provide an
        annotation file and image folder to test against.

        If you're running this after training a model then you can directly provide
        a prediction file to avoid running inference twice. In this case, the
        predictions must come from the dataset that the evaluator was set up with
        or you'll get nonsense results.
        """

        if self.model is None:
            self.load_model()
            assert self.model is not None

        if annotation_file is None:
            annotation_file = self.config.data.test
        elif image_folder is None:
            raise ValueError(
                "Please provide an image folder if using a custom annotation file."
            )

        if image_folder is None:
            image_folder = self.config.data.images

        logger.info("Image folder: %s", image_folder)
        logger.info("Annotation file: %s", annotation_file)

        # Setup the "test" dataset with the provided annotation file
        if "eval_test" not in DatasetCatalog.list():
            register_coco_instances("eval_test", {}, annotation_file, image_folder)
        else:
            logger.warning("Skipping test dataset registration, already registered.")

        assert self._cfg is not None

        test_loader = (
            build_detection_test_loader(  # pylint: disable=too-many-function-args
                self._cfg, "eval_test", batch_size=1
            )
        )

        if output_folder is None:
            output_folder = self.config.data.output
        os.makedirs(output_folder, exist_ok=True)

        # Use the segm task since we're doing instance segmentation
        if evaluate:
            evaluator = COCOEvaluator(
                dataset_name="eval_test",
                tasks=["segm"],
                distributed=False,
                output_dir=output_folder,
                max_dets_per_image=self._cfg.TEST.DETECTIONS_PER_IMAGE,
                allow_cached_coco=False,
            )
        else:
            evaluator = None

        if not prediction_file:
            inference_on_dataset(self.model, test_loader, evaluator)
        else:
            # Detectron2 has a nice mod to cocoeval that supports more
            # unlimited detections
            from detectron2.evaluation.coco_evaluation import COCOevalMaxDets
            from pycocotools.coco import COCO

            gt = COCO(annotation_file)
            dt = gt.loadRes(prediction_file)

            coco_eval = COCOevalMaxDets(gt, dt)
            coco_eval.params.maxDets[2] = self._cfg.TEST.DETECTIONS_PER_IMAGE

            coco_eval.evaluate()
            coco_eval.accumulate()
            coco_eval.summarize()

            return coco_eval.stats

    def predict_batch(
        self, image_tensor: Union[torch.Tensor, List[torch.Tensor]]
    ) -> List[Dict]:
        self.model.eval()
        self.should_reload = False
        predictions = None

        t_start_s = time.time()

        with torch.no_grad():
            inputs = []

            if isinstance(image_tensor, list):
                for image in image_tensor:
                    height, width = image.shape[1:]
                    inputs.append(
                        {"image": image[:3, :, :], "height": height, "width": width}
                    )
            else:
                height, width = image_tensor.shape[1:]
                inputs.append(
                    {"image": image_tensor[:3, :, :], "height": height, "width": width}
                )

            try:
                predictions = [p["instances"] for p in self.model(inputs)]

                for prediction in predictions:
                    if len(prediction) >= self.max_detections:
                        logger.warning(
                            "Maximum detections reached (%s), possibly re-run with a higher threshold.",
                            self.max_detections,
                        )

            except RuntimeError as e:
                logger.error("Runtime error: %s", e)
                self.should_reload = True
            except Exception as e:  # pylint: disable=broad-except
                logger.error(
                    "Failed to run inference: %s. Attempting to reload model.", e
                )
                self.should_reload = True

        t_elapsed_s = time.time() - t_start_s
        logger.debug("Predicted tile in %1.2fs", t_elapsed_s)

        return predictions

    def visualise(
        self,
        image: npt.NDArray,
        results: Instances,
        confidence_thresh: float = 0.5,
        **kwargs: Any
    ) -> None:
        """Visualise model results using Detectron's provided utils

        Args:
            image (array): Numpy array for image (HWC)
            results (Instances): Instances from predictions
            confidence_thresh (float, optional): Confidence threshold to plot. Defaults to 0.5.
            **kwargs (Any): Passed to matplotlib figure
        """

        mask = results.scores > confidence_thresh
        viz = Visualizer(
            image,
            MetadataCatalog.get(self.config.data.name),
            scale=1.2,
            instance_mode=ColorMode.SEGMENTATION,
        )
        out = viz.draw_instance_predictions(results[mask].to("cpu"))

        plt.figure(**kwargs)
        plt.imshow(out.get_image())

__init__(config)

Initialize the model.

Parameters:

Name Type Description Default
config dict

The global configuration dictionary

required
Source code in src/tcd_pipeline/models/instance_segmentation.py
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(self, config: dict):
    """Initialize the model.

    Args:
        config (dict): The global configuration dictionary
    """
    super().__init__(config)
    self.post_processor = InstanceSegmentationPostProcessor(config)
    self.predictor = None
    self.should_reload = False
    self._cfg = None

    self.load_config()

evaluate(annotation_file=None, image_folder=None, output_folder=None, prediction_file=None, evaluate=True)

Evaluate the model.

If no inputs are provided, then the evaluation is run on the test dataset as per the config file. Normally you should explicitly provide an annotation file and image folder to test against.

If you're running this after training a model then you can directly provide a prediction file to avoid running inference twice. In this case, the predictions must come from the dataset that the evaluator was set up with or you'll get nonsense results.

Source code in src/tcd_pipeline/models/instance_segmentation.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def evaluate(
    self,
    annotation_file=None,
    image_folder=None,
    output_folder=None,
    prediction_file=None,
    evaluate=True,
) -> None:
    """Evaluate the model.

    If no inputs are provided, then the evaluation is run on the test dataset
    as per the config file. Normally you should explicitly provide an
    annotation file and image folder to test against.

    If you're running this after training a model then you can directly provide
    a prediction file to avoid running inference twice. In this case, the
    predictions must come from the dataset that the evaluator was set up with
    or you'll get nonsense results.
    """

    if self.model is None:
        self.load_model()
        assert self.model is not None

    if annotation_file is None:
        annotation_file = self.config.data.test
    elif image_folder is None:
        raise ValueError(
            "Please provide an image folder if using a custom annotation file."
        )

    if image_folder is None:
        image_folder = self.config.data.images

    logger.info("Image folder: %s", image_folder)
    logger.info("Annotation file: %s", annotation_file)

    # Setup the "test" dataset with the provided annotation file
    if "eval_test" not in DatasetCatalog.list():
        register_coco_instances("eval_test", {}, annotation_file, image_folder)
    else:
        logger.warning("Skipping test dataset registration, already registered.")

    assert self._cfg is not None

    test_loader = (
        build_detection_test_loader(  # pylint: disable=too-many-function-args
            self._cfg, "eval_test", batch_size=1
        )
    )

    if output_folder is None:
        output_folder = self.config.data.output
    os.makedirs(output_folder, exist_ok=True)

    # Use the segm task since we're doing instance segmentation
    if evaluate:
        evaluator = COCOEvaluator(
            dataset_name="eval_test",
            tasks=["segm"],
            distributed=False,
            output_dir=output_folder,
            max_dets_per_image=self._cfg.TEST.DETECTIONS_PER_IMAGE,
            allow_cached_coco=False,
        )
    else:
        evaluator = None

    if not prediction_file:
        inference_on_dataset(self.model, test_loader, evaluator)
    else:
        # Detectron2 has a nice mod to cocoeval that supports more
        # unlimited detections
        from detectron2.evaluation.coco_evaluation import COCOevalMaxDets
        from pycocotools.coco import COCO

        gt = COCO(annotation_file)
        dt = gt.loadRes(prediction_file)

        coco_eval = COCOevalMaxDets(gt, dt)
        coco_eval.params.maxDets[2] = self._cfg.TEST.DETECTIONS_PER_IMAGE

        coco_eval.evaluate()
        coco_eval.accumulate()
        coco_eval.summarize()

        return coco_eval.stats

load_model()

Load a detectron2 model from the provided config.

Source code in src/tcd_pipeline/models/instance_segmentation.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def load_model(self) -> None:
    """Load a detectron2 model from the provided config."""
    gc.collect()

    if torch.cuda.is_available():
        with torch.no_grad():
            torch.cuda.empty_cache()

    if not os.path.exists(self.config.model.weights):
        try:
            from huggingface_hub import HfApi

            api = HfApi()
            self._cfg.MODEL.WEIGHTS = api.hf_hub_download(
                self.config.model.weights, filename="model.pth"
            )
        except Exception as e:
            logger.warning("Failed to download checkpoint from HF hub")
    else:
        self._cfg.MODEL.WEIGHTS = self.config.model.weights

    self.predictor = DefaultPredictor(self._cfg)

    if self._cfg.TEST.AUG.ENABLED:
        logger.info("Using Test-Time Augmentation")
        self.model = GeneralizedRCNNWithTTA(
            self._cfg,
            self.predictor.model,
            batch_size=self.config.model.tta_batch_size,
        )
    else:
        logger.info("Test-Time Augmentation is disabled")
        self.model = self.predictor.model

visualise(image, results, confidence_thresh=0.5, **kwargs)

Visualise model results using Detectron's provided utils

Parameters:

Name Type Description Default
image array

Numpy array for image (HWC)

required
results Instances

Instances from predictions

required
confidence_thresh float

Confidence threshold to plot. Defaults to 0.5.

0.5
**kwargs Any

Passed to matplotlib figure

{}
Source code in src/tcd_pipeline/models/instance_segmentation.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def visualise(
    self,
    image: npt.NDArray,
    results: Instances,
    confidence_thresh: float = 0.5,
    **kwargs: Any
) -> None:
    """Visualise model results using Detectron's provided utils

    Args:
        image (array): Numpy array for image (HWC)
        results (Instances): Instances from predictions
        confidence_thresh (float, optional): Confidence threshold to plot. Defaults to 0.5.
        **kwargs (Any): Passed to matplotlib figure
    """

    mask = results.scores > confidence_thresh
    viz = Visualizer(
        image,
        MetadataCatalog.get(self.config.data.name),
        scale=1.2,
        instance_mode=ColorMode.SEGMENTATION,
    )
    out = viz.draw_instance_predictions(results[mask].to("cpu"))

    plt.figure(**kwargs)
    plt.imshow(out.get_image())