Skip to content

cluster

generate_clip_embeddings(model, processor, instances, device)

Generates representations for all images in the dataloader with the given model

Source code in src/tcd_pipeline/scripts/cluster.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def generate_clip_embeddings(
    model: torch.nn.Module, processor: Any, instances: list[Instance], device: str
):
    """Generates representations for all images in the dataloader with
    the given model
    """

    embeddings = []

    with torch.no_grad():
        try:
            for _, instance in tqdm(enumerate(instances), total=len(instances)):
                input = processor(Image.fromarray(instance.raster))

                if len(input.shape) == 3:
                    input = input.unsqueeze(0)

                emb = model(input.to(device))[0].cpu()
                embeddings.append(emb)
        except KeyboardInterrupt:
            pass

    embeddings = torch.cat(embeddings, 0)
    embeddings = normalize(embeddings)
    return embeddings