Skip to content

imagedataset

logger = logging.getLogger(__name__) module-attribute

Dataloaders for tiling orthomosaic imagery.

dataloader_from_image(image, tile_size_px=1024, overlap_px=256, gsd_m=0.1, batch_size=1, pad_if_needed=True)

Yields a Pytorch dataloader from a single (potentially large) image.

This function is a convenience utility that creates a dataloader for tiled inference.

The provided tile size [px] is the square dimension of the input to the model, chosen by available VRAM typically. The gsd should similarly be selected as appropriate for the model. Together these are used to define what size tile to sample from the input image, e.g. tile_size * gsd. We assume that the image is in a metric CRS!

Parameters:

Name Type Description Default
image str or DatasetReader

Path to image

required
tile_size_px int

Tile size in pixels.

1024
overlap_px int

Minimum tile overlap

256
gsd_m float

Assumed GSD, defaults to 0.1

0.1
batch_size int

Batch size, defaults to 1

1
pad_if_needed bool

Pad to the specified tile size, defaults to True

True

Returns: DataLoader: torch dataloader for this image

Source code in src/tcd_pipeline/data/imagedataset.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def dataloader_from_image(
    image: Union[str, rasterio.DatasetReader, rasterio.io.DatasetWriter],
    tile_size_px: int = 1024,
    overlap_px: int = 256,
    gsd_m: float = 0.1,
    batch_size: int = 1,
    pad_if_needed: bool = True,
) -> DataLoader:
    """Yields a Pytorch dataloader from a single (potentially large) image.

    This function is a convenience utility that creates a dataloader for tiled
    inference.

    The provided tile size [px] is the square dimension of the input to the model,
    chosen by available VRAM typically. The gsd should similarly be selected as
    appropriate for the model. Together these are used to define what size tile to
    sample from the input image, e.g. tile_size * gsd. We assume that the image
    is in a metric CRS!

    Args:
        image (str or DatasetReader): Path to image
        tile_size_px (int): Tile size in pixels.
        overlap_px (int): Minimum tile overlap
        gsd_m (float): Assumed GSD, defaults to 0.1
        batch_size (int): Batch size, defaults to 1
        pad_if_needed (bool): Pad to the specified tile size, defaults to True
    Returns:
        DataLoader: torch dataloader for this image
    """
    assert tile_size_px % 32 == 0

    if isinstance(image, str):
        image = rasterio.open(image)

    if image.res[0] != 0:
        logger.info("Geographic information present, loading as a geo dataset")
        dataset = SingleImageGeoDataset(
            image,
            target_gsd=gsd_m,
            tile_size=tile_size_px,
            overlap=overlap_px,
            pad_if_needed=pad_if_needed,
        )
    else:
        logger.warn(
            "Unable to determine GSD/resolution, loading as a plain image dataset"
        )
        dataset = SingleImageDataset(
            image,
            tile_size=tile_size_px,
            overlap=overlap_px,
            pad_if_needed=pad_if_needed,
        )

    logger.info(
        f"Dataset has {len(dataset)} tiles of size {tile_size_px}x{tile_size_px} px."
    )

    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_dicts)

    return dataloader