Skip to content

semantic

GeotiffSemanticCache

Bases: SemanticSegmentationCache

Source code in src/tcd_pipeline/cache/semantic.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
class GeotiffSemanticCache(SemanticSegmentationCache):
    def __init__(
        self,
        cache_folder,
        cache_tile_size: int = 10000,
        compress: bool = False,
        *args,
        **kwargs,
    ):
        super().__init__(cache_folder=cache_folder, *args, **kwargs)

        with rasterio.open(self.image_path) as src:
            self.image_width = src.width
            self.image_height = src.height
            self.src_transform = src.transform
            self.src_meta = src.meta

        # Must be larger/equal to tile size
        self.cache_tile_size = min(
            self.image_height, min(self.image_width, cache_tile_size)
        )
        self.compress = compress

        if self.classes is not None:
            self.band_count = max(1, len(self.classes) - 1)
        else:
            self.band_count = 1

        self.tile_names = []
        self.tile_prediction_count = None
        self.intersections = None
        self._generate_tiles()

    def _generate_tiles(self):
        """
        Internal function that generates non-overlapping tile extents covering the source image.

        Tiles are stored as bounding boxes in an R-Tree so that when predictions are saved to the
        cache, we can efficiently look up which cache tiles overlap and where the tile data should
        be saved. The tiler here is not fancy - it figures out the number of tiles required to cover
        each axis and generates bounding boxes for each.
        """

        n_tiles_x = int(math.ceil(self.image_width / self.cache_tile_size))
        n_tiles_y = int(math.ceil(self.image_height / self.cache_tile_size))

        self.index = rtree.Index()

        idx = 0
        for tx in range(n_tiles_x):
            for ty in range(n_tiles_y):
                minx = tx * self.cache_tile_size
                miny = ty * self.cache_tile_size

                maxx = minx + self.cache_tile_size
                maxy = miny + self.cache_tile_size

                tile_box = box(
                    minx,
                    miny,
                    min(maxx, self.image_width),
                    min(maxy, self.image_height),
                )
                self.index.add(id=idx, coordinates=tile_box.bounds, obj=tile_box)

                idx += 1

        logger.info(
            f"Caching to {idx} tiles, approximately {(idx * self.cache_tile_size ** 2 )/1e9 :1.2f}GB needed for temporary storage during inference"
        )

    def set_prediction_tiles(self, dataset_tiles):
        """
        Pre-compute intersections for dataset tiles which allows for compression
        operations to happen during prediction. This can keep working storage space
        much lower than waiting for the prediction to complete before compressing.

        By calculating how many hits each cache tile is expected to have, we can
        run compression when the number of hits (or in this case a counter reaching
        zero) has happened.
        """

        self.tile_prediction_count = defaultdict(int)
        self.intersections = defaultdict(list)
        for tile in dataset_tiles:
            for result in self.index.intersection(tile.bounds, objects=True):
                self.tile_prediction_count[result.id] += 1
                self.intersections[tile.bounds].append(result)

    def _create_cache_tile(self, tile_bbox: box, path: str) -> None:
        """
        Create an empty tile with a given bounding box. The tile is generated with reference
        to the source image being used for prediction and it will have the same CRS and a transform
        derived from the source and the desired bounding box.

        By default, no compression is enabled to speed up cache time during inference.
        """
        meta = self.src_meta
        meta.update(
            {
                "driver": "GTiff",
                "width": self.cache_tile_size,
                "height": self.cache_tile_size,
                "count": self.band_count,
                "dtype": "uint8",
                "nodata": 0,
                "compress": "deflate" if self.compress else None,
                "transform": rasterio.windows.transform(
                    Window(*tile_bbox.bounds), self.src_transform
                ),
            }
        )

        self.tile_names.append(os.path.basename(path))

        with rasterio.open(path, "w+", **meta) as dst:
            # Touch a very small window to create the tile
            dst.write(
                np.zeros((1, 1), dtype=np.uint8), window=Window(0, 0, 1, 1), indexes=1
            )

    def save_tile(self, mask: npt.NDArray, bbox: box) -> list[str]:
        """
        Save a model prediction to the tile cache. This function will determine which
        tiles in the cache overlap with the provided bounding box and the predictions
        will be split up appropriately. Cache tiles are created lazily - i.e. as needed
        so it is possible that not all the tiles in the index will be created if there are
        large regions of empty data in the input image.

        Args:
            mask: prediction result
            bbox: tile bounding box in global image pixel coordinates

        """

        if self.intersections is not None:
            intersecting_cache_tiles = self.intersections[bbox.bounds]
        else:
            intersecting_cache_tiles = [
                hit for hit in self.index.intersection(bbox.bounds, objects=True)
            ]

        output_paths = []
        for tile in intersecting_cache_tiles:
            tile_idx = tile.id
            tile = tile.object
            minx, miny, maxx, maxy = [int(i) for i in bbox.intersection(tile).bounds]

            # Crop size
            width = int(maxx - minx)
            height = int(maxy - miny)

            # Coordinates from mask
            mask_offset_x = int(minx - bbox.bounds[0])
            mask_offset_y = int(miny - bbox.bounds[1])
            mask_crop = mask[
                :,
                mask_offset_y : mask_offset_y + height,
                mask_offset_x : mask_offset_x + width,
            ]

            # Coordinates within tile
            tile_offset_x = int(minx - tile.bounds[0])
            tile_offset_y = int(miny - tile.bounds[1])
            window = Window(tile_offset_x, tile_offset_y, width, height)

            # Tile shape for filename
            tile_width = int(tile.bounds[2] - tile.bounds[0])
            tile_height = int(tile.bounds[3] - tile.bounds[1])

            file_name = f"{int(tile.bounds[0])}_{int(tile.bounds[1])}_{tile_width}_{tile_height}_{tile_idx}{self.cache_suffix}.tif"
            output_path = os.path.join(self.cache_folder, file_name)

            if not os.path.exists(output_path):
                self._create_cache_tile(tile, output_path)

            with rasterio.open(output_path, "r+", nodata=0, dtype="uint8") as dst:
                for band_idx, band in enumerate(mask_crop[1:, :, :]):
                    dst.write(band, indexes=band_idx + 1, window=window)

            if self.tile_prediction_count is not None:
                self.tile_prediction_count[tile_idx] -= 1
                if self.tile_prediction_count[tile_idx] == 0:
                    self.compress_tile(output_path)

            output_paths.append(output_path)

        return output_paths

    def save(self, mask: npt.NDArray, bbox: box):
        """
        Save a prediction mask into the cache. See `save_tile` for
        more information on the internal details.

        The provided mask should be an unsigned 8-bit array containing prediction
        values scaled from 0 to 255. Nominally, 0 is used as the nodata
        value in the cache tile. If the maximum value of the array is
        greater than 1, then the array is multiplied by 255 and cast
        to uint8.

        The mask can have multiple bands, corresponding to multiple class
        predictions, but the first channel is assumed to be background and is
        not stored (as it can be reconstructed from the remaining bands).

        Args:
            mask: prediction result
            bbox: tile bounding box in global image coordinates
        """

        if len(mask.shape) == 2:
            mask = np.expand_dims(mask, 0)

        # Maybe more efficient to just assume float results should be normed
        if mask.max() <= 1:
            mask = np.round(mask * 255).astype(np.uint8)

        output_paths = self.save_tile(mask, bbox)
        self.tile_count += 1
        self.write_tile_meta(self.tile_count, bbox, output_paths)

    def _find_cache_files(self) -> list[str]:
        """
        Locate cache files. This file is overwritten because for tiled
        semantic segmentation caches, each prediction tile can be split
        over several cache tiles (so the result for each entry is
        a list).
        """
        with open(self.meta_path, "r") as fp:
            reader = jsonlines.Reader(fp)
            lines = [l for l in reader.iter()]

        cache_files = []

        if len(lines) > 1:
            tiles = lines[1:]

            # Minus one for the header
            self.tile_count = len(tiles)

            for tile in tiles:
                cache_files.extend(tile["cache_file"])

        return list(set(cache_files))

    def compress_tile(self, path):
        with rasterio.open(path) as src:
            meta = src.meta
            meta.update({"driver": "GTiff", "compress": "deflate"})

            with rasterio.open(path + ".temp", "w", **meta) as dst:
                dst.write(src.read())

        import shutil

        logger.debug("Compressing", path)
        shutil.move(path + ".temp", path)

    def compress_tiles(self):
        """
        Iterate over the tiles in the cache and re-write them as compressed
        GeoTIFFs. This usually results in a significant reduction in file
        size. The `deflate` compression method is used as `packbits` can sometimes
        fail, and `lzw` is often not supported and is slow.

        A temporary file is created before being moved to overwrite the source image.
        """
        for path in self.cache_files:
            self.compress_tile(path)

    def generate_vrt(self, filename="overview.vrt", files=None, root=None):
        """
        Generate a virtual raster from the tiles in the cache. This should be called
        at the end of inference to create an "overview" file that can be used to
        read all the tiles as a single image.
        """
        if root is None:
            root = self.cache_folder

        if files is None:
            files = self.cache_files

        import subprocess

        logger.info(f"Saving vrt to: {root}")

        _ = subprocess.check_output(
            [
                "gdalbuildvrt",
                "-srcnodata",
                "0",
                "-vrtnodata",
                "0",
                filename,
                *files,
            ],
            cwd=root,
        )

    def _load_file(self, cache_file: str) -> rasterio.DatasetReader:
        # Recall image filename:
        # {tile_offset_x}_{tile_offset_y}_{tile_width}_{tile_height}_{tile_idx}{self.cache_suffix}.tif"

        offset_x, offset_y, width, height = [
            int(float(i)) for i in os.path.basename(cache_file).split("_")[:4]
        ]

        return {
            "bbox": box(offset_x, offset_y, width, height),
            "mask": rasterio.open(cache_file),
        }

    # TODO: track predicted tile count and return this on load
    def __len__(self):
        self._find_cache_files()
        return self.tile_count

compress_tiles()

Iterate over the tiles in the cache and re-write them as compressed GeoTIFFs. This usually results in a significant reduction in file size. The deflate compression method is used as packbits can sometimes fail, and lzw is often not supported and is slow.

A temporary file is created before being moved to overwrite the source image.

Source code in src/tcd_pipeline/cache/semantic.py
325
326
327
328
329
330
331
332
333
334
335
def compress_tiles(self):
    """
    Iterate over the tiles in the cache and re-write them as compressed
    GeoTIFFs. This usually results in a significant reduction in file
    size. The `deflate` compression method is used as `packbits` can sometimes
    fail, and `lzw` is often not supported and is slow.

    A temporary file is created before being moved to overwrite the source image.
    """
    for path in self.cache_files:
        self.compress_tile(path)

generate_vrt(filename='overview.vrt', files=None, root=None)

Generate a virtual raster from the tiles in the cache. This should be called at the end of inference to create an "overview" file that can be used to read all the tiles as a single image.

Source code in src/tcd_pipeline/cache/semantic.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
def generate_vrt(self, filename="overview.vrt", files=None, root=None):
    """
    Generate a virtual raster from the tiles in the cache. This should be called
    at the end of inference to create an "overview" file that can be used to
    read all the tiles as a single image.
    """
    if root is None:
        root = self.cache_folder

    if files is None:
        files = self.cache_files

    import subprocess

    logger.info(f"Saving vrt to: {root}")

    _ = subprocess.check_output(
        [
            "gdalbuildvrt",
            "-srcnodata",
            "0",
            "-vrtnodata",
            "0",
            filename,
            *files,
        ],
        cwd=root,
    )

save(mask, bbox)

Save a prediction mask into the cache. See save_tile for more information on the internal details.

The provided mask should be an unsigned 8-bit array containing prediction values scaled from 0 to 255. Nominally, 0 is used as the nodata value in the cache tile. If the maximum value of the array is greater than 1, then the array is multiplied by 255 and cast to uint8.

The mask can have multiple bands, corresponding to multiple class predictions, but the first channel is assumed to be background and is not stored (as it can be reconstructed from the remaining bands).

Parameters:

Name Type Description Default
mask NDArray

prediction result

required
bbox box

tile bounding box in global image coordinates

required
Source code in src/tcd_pipeline/cache/semantic.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def save(self, mask: npt.NDArray, bbox: box):
    """
    Save a prediction mask into the cache. See `save_tile` for
    more information on the internal details.

    The provided mask should be an unsigned 8-bit array containing prediction
    values scaled from 0 to 255. Nominally, 0 is used as the nodata
    value in the cache tile. If the maximum value of the array is
    greater than 1, then the array is multiplied by 255 and cast
    to uint8.

    The mask can have multiple bands, corresponding to multiple class
    predictions, but the first channel is assumed to be background and is
    not stored (as it can be reconstructed from the remaining bands).

    Args:
        mask: prediction result
        bbox: tile bounding box in global image coordinates
    """

    if len(mask.shape) == 2:
        mask = np.expand_dims(mask, 0)

    # Maybe more efficient to just assume float results should be normed
    if mask.max() <= 1:
        mask = np.round(mask * 255).astype(np.uint8)

    output_paths = self.save_tile(mask, bbox)
    self.tile_count += 1
    self.write_tile_meta(self.tile_count, bbox, output_paths)

save_tile(mask, bbox)

Save a model prediction to the tile cache. This function will determine which tiles in the cache overlap with the provided bounding box and the predictions will be split up appropriately. Cache tiles are created lazily - i.e. as needed so it is possible that not all the tiles in the index will be created if there are large regions of empty data in the input image.

Parameters:

Name Type Description Default
mask NDArray

prediction result

required
bbox box

tile bounding box in global image pixel coordinates

required
Source code in src/tcd_pipeline/cache/semantic.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def save_tile(self, mask: npt.NDArray, bbox: box) -> list[str]:
    """
    Save a model prediction to the tile cache. This function will determine which
    tiles in the cache overlap with the provided bounding box and the predictions
    will be split up appropriately. Cache tiles are created lazily - i.e. as needed
    so it is possible that not all the tiles in the index will be created if there are
    large regions of empty data in the input image.

    Args:
        mask: prediction result
        bbox: tile bounding box in global image pixel coordinates

    """

    if self.intersections is not None:
        intersecting_cache_tiles = self.intersections[bbox.bounds]
    else:
        intersecting_cache_tiles = [
            hit for hit in self.index.intersection(bbox.bounds, objects=True)
        ]

    output_paths = []
    for tile in intersecting_cache_tiles:
        tile_idx = tile.id
        tile = tile.object
        minx, miny, maxx, maxy = [int(i) for i in bbox.intersection(tile).bounds]

        # Crop size
        width = int(maxx - minx)
        height = int(maxy - miny)

        # Coordinates from mask
        mask_offset_x = int(minx - bbox.bounds[0])
        mask_offset_y = int(miny - bbox.bounds[1])
        mask_crop = mask[
            :,
            mask_offset_y : mask_offset_y + height,
            mask_offset_x : mask_offset_x + width,
        ]

        # Coordinates within tile
        tile_offset_x = int(minx - tile.bounds[0])
        tile_offset_y = int(miny - tile.bounds[1])
        window = Window(tile_offset_x, tile_offset_y, width, height)

        # Tile shape for filename
        tile_width = int(tile.bounds[2] - tile.bounds[0])
        tile_height = int(tile.bounds[3] - tile.bounds[1])

        file_name = f"{int(tile.bounds[0])}_{int(tile.bounds[1])}_{tile_width}_{tile_height}_{tile_idx}{self.cache_suffix}.tif"
        output_path = os.path.join(self.cache_folder, file_name)

        if not os.path.exists(output_path):
            self._create_cache_tile(tile, output_path)

        with rasterio.open(output_path, "r+", nodata=0, dtype="uint8") as dst:
            for band_idx, band in enumerate(mask_crop[1:, :, :]):
                dst.write(band, indexes=band_idx + 1, window=window)

        if self.tile_prediction_count is not None:
            self.tile_prediction_count[tile_idx] -= 1
            if self.tile_prediction_count[tile_idx] == 0:
                self.compress_tile(output_path)

        output_paths.append(output_path)

    return output_paths

set_prediction_tiles(dataset_tiles)

Pre-compute intersections for dataset tiles which allows for compression operations to happen during prediction. This can keep working storage space much lower than waiting for the prediction to complete before compressing.

By calculating how many hits each cache tile is expected to have, we can run compression when the number of hits (or in this case a counter reaching zero) has happened.

Source code in src/tcd_pipeline/cache/semantic.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def set_prediction_tiles(self, dataset_tiles):
    """
    Pre-compute intersections for dataset tiles which allows for compression
    operations to happen during prediction. This can keep working storage space
    much lower than waiting for the prediction to complete before compressing.

    By calculating how many hits each cache tile is expected to have, we can
    run compression when the number of hits (or in this case a counter reaching
    zero) has happened.
    """

    self.tile_prediction_count = defaultdict(int)
    self.intersections = defaultdict(list)
    for tile in dataset_tiles:
        for result in self.index.intersection(tile.bounds, objects=True):
            self.tile_prediction_count[result.id] += 1
            self.intersections[tile.bounds].append(result)

PickleSemanticCache

Bases: SemanticSegmentationCache

Source code in src/tcd_pipeline/cache/semantic.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class PickleSemanticCache(SemanticSegmentationCache):
    def save(self, mask, bbox: box):
        output = {"mask": mask, "bbox": bbox, "image": self.image_path}

        file_name = f"{self.tile_count}_{self.cache_suffix}.pkl"
        output_path = os.path.join(self.cache_folder, file_name)

        with open(output_path, "wb") as fp:
            pickle.dump(output, fp)

        self.tile_count += 1
        self.write_tile_meta(self.tile_count, bbox, output_path)

    def _load_file(self, cache_file: str) -> dict:
        """Load pickled cache results

        Args:
            cache_file (str): Cache filename

        Returns:
            dict: dictionary containing "instances" and "bbox"

        """

        with open(cache_file, "rb") as fp:
            annotations = pickle.load(fp)

        return annotations

SemanticSegmentationCache

Bases: ResultsCache

Source code in src/tcd_pipeline/cache/semantic.py
23
24
25
26
27
28
29
30
31
32
33
34
35
class SemanticSegmentationCache(ResultsCache):
    @property
    def results(self) -> list[dict]:
        """
        Should return a list of dictionaries with the keys:

        - mask
        - bbox
        - image
        - tile_id

        """
        return self._results

results: list[dict] property

Should return a list of dictionaries with the keys:

  • mask
  • bbox
  • image
  • tile_id