Skip to content

postprocessor

PostProcessor

Processes results from a model, provides support for caching model results and keeping track of tile results in the context of the "source" image

Source code in src/tcd_pipeline/postprocess/postprocessor.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class PostProcessor:
    """Processes results from a model, provides support for caching model results
    and keeping track of tile results in the context of the "source" image
    """

    def __init__(self, config: dict, image: Optional[rasterio.DatasetReader] = None):
        """Initializes the PostProcessor

        Args:
            config (DotMap): the configuration
            image (DatasetReader): input rasterio image
        """
        self.config = config
        self.threshold = config.postprocess.confidence_threshold

        self.cache_root = config.postprocess.cache_folder
        self.cache_folder = None
        self.cache_suffix = None
        self.cache: ResultsCache = None

        if image is not None:
            self.initialise(image)

    @abstractmethod
    def setup_cache(self):
        """
        Initialise the cache. Abstract method, depends on the type of postprocessor (instance, semantic, etc)
        """
        raise NotImplementedError

    def initialise(self, image, warm_start=False) -> None:
        """Initialise the processor for a new image and creates cache
        folders if required.

        Args:
            image (DatasetReader): input rasterio image
            warm_start (bool, option): Whether or not to continue from where one left off. Defaults to False
                                        to avoid unexpected behaviour.
        """
        self.results = []
        self.image = image
        self.tile_count = 0

        # Break early if we aren't doing stateful (cached) post-processing
        if not self.config.postprocess.stateful:
            return

        self.warm_start = warm_start
        self.cache_folder = os.path.abspath(
            os.path.join(
                self.cache_root,
                os.path.splitext(os.path.basename(self.image.name))[0] + "_cache",
            )
        )

        self.setup_cache()

        # Always clear the cache directory if we're doing a cold start
        if warm_start:
            logger.info(f"Attempting to use cached result from {self.cache_folder}")
            # Check to see if we have a bounding box file
            # this stores how many tiles we've processed

            if self.image is not None and os.path.exists(self.cache_folder):
                self.cache.load()

                # We should probably have a strict mode that will error out
                # if there's a cache mismatch
                self.tile_count = len(self.cache)

                if self.tile_count > 0:
                    logger.info(f"Starting from tile {self.tile_count + 1}.")
                    return
        else:
            # Otherwise we should clear the cache
            logger.debug(f"Attempting to clear existing cache")
            self.cache.clear()
            self.cache.initialise()

    def add(self, results: List[dict]):
        """
        Add results to the post processor
        """
        for result in results:
            self.tile_count += 1

            # We always want to keep a list of bounding boxes
            new_result = {"tile_id": self.tile_count, "bbox": result["bbox"]}

            # Either cache results, or add to in-memory list
            if self.config.postprocess.stateful:
                logger.debug(f"Saving cache for tile {self.tile_count}")
                self.cache_result(result)

                if self.config.postprocess.debug_images:
                    self.cache.cache_image(self.image, result["window"])

            else:
                new_result |= self._transform(result)

            self.results.append(new_result)

    @abstractmethod
    def merge(self):
        """
        Merge results from overlapping tiles
        """
        raise NotImplementedError

    @abstractmethod
    def process(self) -> ProcessedResult:
        """
        Processes the stored results into a ProcessedResult object that represents
        the complete prediction over the tiled input
        """
        raise NotImplementedError

    @abstractmethod
    def cache_result(self) -> None:
        """
        Store a prediction in the cache
        """
        raise NotImplementedError

__init__(config, image=None)

Initializes the PostProcessor

Parameters:

Name Type Description Default
config DotMap

the configuration

required
image DatasetReader

input rasterio image

None
Source code in src/tcd_pipeline/postprocess/postprocessor.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(self, config: dict, image: Optional[rasterio.DatasetReader] = None):
    """Initializes the PostProcessor

    Args:
        config (DotMap): the configuration
        image (DatasetReader): input rasterio image
    """
    self.config = config
    self.threshold = config.postprocess.confidence_threshold

    self.cache_root = config.postprocess.cache_folder
    self.cache_folder = None
    self.cache_suffix = None
    self.cache: ResultsCache = None

    if image is not None:
        self.initialise(image)

add(results)

Add results to the post processor

Source code in src/tcd_pipeline/postprocess/postprocessor.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def add(self, results: List[dict]):
    """
    Add results to the post processor
    """
    for result in results:
        self.tile_count += 1

        # We always want to keep a list of bounding boxes
        new_result = {"tile_id": self.tile_count, "bbox": result["bbox"]}

        # Either cache results, or add to in-memory list
        if self.config.postprocess.stateful:
            logger.debug(f"Saving cache for tile {self.tile_count}")
            self.cache_result(result)

            if self.config.postprocess.debug_images:
                self.cache.cache_image(self.image, result["window"])

        else:
            new_result |= self._transform(result)

        self.results.append(new_result)

cache_result() abstractmethod

Store a prediction in the cache

Source code in src/tcd_pipeline/postprocess/postprocessor.py
131
132
133
134
135
136
@abstractmethod
def cache_result(self) -> None:
    """
    Store a prediction in the cache
    """
    raise NotImplementedError

initialise(image, warm_start=False)

Initialise the processor for a new image and creates cache folders if required.

Parameters:

Name Type Description Default
image DatasetReader

input rasterio image

required
warm_start (bool, option)

Whether or not to continue from where one left off. Defaults to False to avoid unexpected behaviour.

False
Source code in src/tcd_pipeline/postprocess/postprocessor.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def initialise(self, image, warm_start=False) -> None:
    """Initialise the processor for a new image and creates cache
    folders if required.

    Args:
        image (DatasetReader): input rasterio image
        warm_start (bool, option): Whether or not to continue from where one left off. Defaults to False
                                    to avoid unexpected behaviour.
    """
    self.results = []
    self.image = image
    self.tile_count = 0

    # Break early if we aren't doing stateful (cached) post-processing
    if not self.config.postprocess.stateful:
        return

    self.warm_start = warm_start
    self.cache_folder = os.path.abspath(
        os.path.join(
            self.cache_root,
            os.path.splitext(os.path.basename(self.image.name))[0] + "_cache",
        )
    )

    self.setup_cache()

    # Always clear the cache directory if we're doing a cold start
    if warm_start:
        logger.info(f"Attempting to use cached result from {self.cache_folder}")
        # Check to see if we have a bounding box file
        # this stores how many tiles we've processed

        if self.image is not None and os.path.exists(self.cache_folder):
            self.cache.load()

            # We should probably have a strict mode that will error out
            # if there's a cache mismatch
            self.tile_count = len(self.cache)

            if self.tile_count > 0:
                logger.info(f"Starting from tile {self.tile_count + 1}.")
                return
    else:
        # Otherwise we should clear the cache
        logger.debug(f"Attempting to clear existing cache")
        self.cache.clear()
        self.cache.initialise()

merge() abstractmethod

Merge results from overlapping tiles

Source code in src/tcd_pipeline/postprocess/postprocessor.py
116
117
118
119
120
121
@abstractmethod
def merge(self):
    """
    Merge results from overlapping tiles
    """
    raise NotImplementedError

process() abstractmethod

Processes the stored results into a ProcessedResult object that represents the complete prediction over the tiled input

Source code in src/tcd_pipeline/postprocess/postprocessor.py
123
124
125
126
127
128
129
@abstractmethod
def process(self) -> ProcessedResult:
    """
    Processes the stored results into a ProcessedResult object that represents
    the complete prediction over the tiled input
    """
    raise NotImplementedError

setup_cache() abstractmethod

Initialise the cache. Abstract method, depends on the type of postprocessor (instance, semantic, etc)

Source code in src/tcd_pipeline/postprocess/postprocessor.py
37
38
39
40
41
42
@abstractmethod
def setup_cache(self):
    """
    Initialise the cache. Abstract method, depends on the type of postprocessor (instance, semantic, etc)
    """
    raise NotImplementedError