Skip to content

datamodule

Semantic segmentation model framework, using smp models

COCODataModule

Bases: LightningDataModule

Datamodule for TCD

Source code in src/tcd_pipeline/data/datamodule.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
class COCODataModule(pl.LightningDataModule):
    """Datamodule for TCD"""

    def __init__(
        self,
        data_root,
        train_path="train.json",
        val_path="val.json",
        test_path="test.json",
        num_workers=8,
        data_frac=1.0,
        batch_size=1,
        tile_size=1024,
        augment=True,
    ):
        """
        Initialise the datamodule

        Args:
            data_root (str): Path to the data directory
            num_workers (int, optional): Number of workers to use. Defaults to 8.
            data_frac (float, optional): Fraction of the data to use. Defaults to 1.0.
            batch_size (int, optional): Batch size. Defaults to 1.
            tile_size (int, optional): Tile size to return. Defaults to 1024.
            augment (bool, optional): Whether to apply data augmentation. Defaults to True.

        """
        super().__init__()
        self.data_frac = data_frac
        self.augment = augment
        self.batch_size = batch_size
        self.data_root = data_root
        self.train_path = os.path.join(self.data_root, train_path)
        self.val_path = os.path.join(self.data_root, val_path)
        self.test_path = os.path.join(self.data_root, test_path)
        self.num_workers = num_workers
        self.tile_size = tile_size

        logger.info("Data root: %s", self.data_root)

    def prepare_data(self) -> None:
        """
        Construct train/val/test datasets.

        Test datasets do not use data augmentation and simply
        return a tensor. This is to avoid stochastic results
        during evaluation.

        Tensors are returned **not** normalised, as this is
        handled by the forward functions in SMP and transformers.
        """
        logger.info("Preparing datasets")
        if self.augment:
            transform = A.Compose(
                [
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.5),
                    A.Rotate(),
                    A.RandomBrightnessContrast(),
                    A.OneOf([A.Blur(p=0.2), A.Sharpen(p=0.2)]),
                    A.HueSaturationValue(
                        hue_shift_limit=5, sat_shift_limit=4, val_shift_limit=5
                    ),
                    A.RandomCrop(width=1024, height=1024),
                    ToTensorV2(),
                ]
            )
            logger.debug("Train-time augmentation is enabled.")
        else:
            transform = None

        self.train_data = COCOSegmentationDataset(
            self.data_root,
            self.train_path,
            transform=transform,
            tile_size=self.tile_size,
        )

        self.test_data = COCOSegmentationDataset(
            self.data_root, self.test_path, transform=A.Compose(ToTensorV2())
        )

        if os.path.exists(self.val_path):
            self.val_data = COCOSegmentationDataset(
                self.data_root, self.val_path, transform=None, tile_size=self.tile_size
            )
        else:
            self.val_data = self.test_data

    def train_dataloader(self) -> DataLoader:
        """Get training dataloaders:

        Returns:
            List[DataLoader]: List of training dataloaders
        """
        return get_dataloaders(
            self.train_data,
            data_frac=self.data_frac,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )[0]

    def val_dataloader(self) -> DataLoader:
        """Get validation dataloaders:

        Returns:
            List[DataLoader]: List of validation dataloaders
        """
        return get_dataloaders(
            self.val_data,
            data_frac=self.data_frac,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )[0]

    def test_dataloader(self) -> DataLoader:
        """Get test dataloaders:

        Returns:
            List[DataLoader]: List of test dataloaders
        """
        # Don't shuffle the test loader so we can
        # more easily compare runs on wandb
        return get_dataloaders(
            self.test_data,
            data_frac=self.data_frac,
            batch_size=1,
            shuffle=False,
            num_workers=1,
        )[0]

__init__(data_root, train_path='train.json', val_path='val.json', test_path='test.json', num_workers=8, data_frac=1.0, batch_size=1, tile_size=1024, augment=True)

Initialise the datamodule

Parameters:

Name Type Description Default
data_root str

Path to the data directory

required
num_workers int

Number of workers to use. Defaults to 8.

8
data_frac float

Fraction of the data to use. Defaults to 1.0.

1.0
batch_size int

Batch size. Defaults to 1.

1
tile_size int

Tile size to return. Defaults to 1024.

1024
augment bool

Whether to apply data augmentation. Defaults to True.

True
Source code in src/tcd_pipeline/data/datamodule.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def __init__(
    self,
    data_root,
    train_path="train.json",
    val_path="val.json",
    test_path="test.json",
    num_workers=8,
    data_frac=1.0,
    batch_size=1,
    tile_size=1024,
    augment=True,
):
    """
    Initialise the datamodule

    Args:
        data_root (str): Path to the data directory
        num_workers (int, optional): Number of workers to use. Defaults to 8.
        data_frac (float, optional): Fraction of the data to use. Defaults to 1.0.
        batch_size (int, optional): Batch size. Defaults to 1.
        tile_size (int, optional): Tile size to return. Defaults to 1024.
        augment (bool, optional): Whether to apply data augmentation. Defaults to True.

    """
    super().__init__()
    self.data_frac = data_frac
    self.augment = augment
    self.batch_size = batch_size
    self.data_root = data_root
    self.train_path = os.path.join(self.data_root, train_path)
    self.val_path = os.path.join(self.data_root, val_path)
    self.test_path = os.path.join(self.data_root, test_path)
    self.num_workers = num_workers
    self.tile_size = tile_size

    logger.info("Data root: %s", self.data_root)

prepare_data()

Construct train/val/test datasets.

Test datasets do not use data augmentation and simply return a tensor. This is to avoid stochastic results during evaluation.

Tensors are returned not normalised, as this is handled by the forward functions in SMP and transformers.

Source code in src/tcd_pipeline/data/datamodule.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def prepare_data(self) -> None:
    """
    Construct train/val/test datasets.

    Test datasets do not use data augmentation and simply
    return a tensor. This is to avoid stochastic results
    during evaluation.

    Tensors are returned **not** normalised, as this is
    handled by the forward functions in SMP and transformers.
    """
    logger.info("Preparing datasets")
    if self.augment:
        transform = A.Compose(
            [
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.Rotate(),
                A.RandomBrightnessContrast(),
                A.OneOf([A.Blur(p=0.2), A.Sharpen(p=0.2)]),
                A.HueSaturationValue(
                    hue_shift_limit=5, sat_shift_limit=4, val_shift_limit=5
                ),
                A.RandomCrop(width=1024, height=1024),
                ToTensorV2(),
            ]
        )
        logger.debug("Train-time augmentation is enabled.")
    else:
        transform = None

    self.train_data = COCOSegmentationDataset(
        self.data_root,
        self.train_path,
        transform=transform,
        tile_size=self.tile_size,
    )

    self.test_data = COCOSegmentationDataset(
        self.data_root, self.test_path, transform=A.Compose(ToTensorV2())
    )

    if os.path.exists(self.val_path):
        self.val_data = COCOSegmentationDataset(
            self.data_root, self.val_path, transform=None, tile_size=self.tile_size
        )
    else:
        self.val_data = self.test_data

test_dataloader()

Get test dataloaders:

Returns:

Type Description
DataLoader

List[DataLoader]: List of test dataloaders

Source code in src/tcd_pipeline/data/datamodule.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def test_dataloader(self) -> DataLoader:
    """Get test dataloaders:

    Returns:
        List[DataLoader]: List of test dataloaders
    """
    # Don't shuffle the test loader so we can
    # more easily compare runs on wandb
    return get_dataloaders(
        self.test_data,
        data_frac=self.data_frac,
        batch_size=1,
        shuffle=False,
        num_workers=1,
    )[0]

train_dataloader()

Get training dataloaders:

Returns:

Type Description
DataLoader

List[DataLoader]: List of training dataloaders

Source code in src/tcd_pipeline/data/datamodule.py
170
171
172
173
174
175
176
177
178
179
180
181
def train_dataloader(self) -> DataLoader:
    """Get training dataloaders:

    Returns:
        List[DataLoader]: List of training dataloaders
    """
    return get_dataloaders(
        self.train_data,
        data_frac=self.data_frac,
        batch_size=self.batch_size,
        num_workers=self.num_workers,
    )[0]

val_dataloader()

Get validation dataloaders:

Returns:

Type Description
DataLoader

List[DataLoader]: List of validation dataloaders

Source code in src/tcd_pipeline/data/datamodule.py
183
184
185
186
187
188
189
190
191
192
193
194
def val_dataloader(self) -> DataLoader:
    """Get validation dataloaders:

    Returns:
        List[DataLoader]: List of validation dataloaders
    """
    return get_dataloaders(
        self.val_data,
        data_frac=self.data_frac,
        batch_size=self.batch_size,
        num_workers=self.num_workers,
    )[0]

collate_fn(batch)

Collate function for dataloader

Default collation function, filtering out empty values in the batch.

Parameters:

Name Type Description Default
batch Any

data batch

required

Returns:

Name Type Description
Any Any

Collated batch

Source code in src/tcd_pipeline/data/datamodule.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def collate_fn(batch: Any) -> Any:
    """Collate function for dataloader

    Default collation function, filtering out empty
    values in the batch.

    Args:
        batch (Any): data batch

    Returns:
        Any: Collated batch
    """
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

get_dataloaders(*datasets, num_workers=8, data_frac=1, batch_size=1, shuffle=True)

Construct dataloaders from a list of datasets

Parameters:

Name Type Description Default
*datasets Dataset

List of datasets to use

()
num_workers int

Number of workers to use. Defaults to 8.

8
data_frac float

Fraction of the data to use. Defaults to 1.0.

1
batch_size int

Batch size. Defaults to 1.

1
shuffle bool

Whether to shuffle the data. Defaults to True.

True

Returns:

Type Description
List[DataLoader]

List[DataLoader]: List of dataloaders

Source code in src/tcd_pipeline/data/datamodule.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def get_dataloaders(
    *datasets: List[Dataset],
    num_workers: int = 8,
    data_frac: float = 1,
    batch_size: int = 1,
    shuffle: bool = True
) -> List[DataLoader]:
    """Construct dataloaders from a list of datasets

    Args:
        *datasets (Dataset): List of datasets to use
        num_workers (int, optional): Number of workers to use. Defaults to 8.
        data_frac (float, optional): Fraction of the data to use. Defaults to 1.0.
        batch_size (int, optional): Batch size. Defaults to 1.
        shuffle (bool, optional): Whether to shuffle the data. Defaults to True.

    Returns:
        List[DataLoader]: List of dataloaders

    """
    if data_frac != 1.0:
        datasets = [
            torch.utils.data.Subset(
                dataset,
                np.random.choice(
                    len(dataset), int(len(dataset) * data_frac), replace=False
                ),
            )
            for dataset in datasets
        ]

    return [
        DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=int(num_workers),
            collate_fn=collate_fn,
        )
        for dataset in datasets
    ]