Reproducing CT-RATE retrieval numbers

#6
by gings - opened

Hi, thanks a lot for releasing COLIPRI.

I am currently trying to reproduce the CT-RATE retrieval numbers:

  1. The CT-RATE validation split has 1,564 unique reports, but the paper evaluates on 1,493 report-scan pairs. Could you release the list of the 1,493 .nii.gz volume names you used?

  2. If possible, could you also release the CT-RATE retrieval evaluation script? I implemented one but it is currently ~10 percentage points below the paper on R@10 and I suspect preprocessing/report-formatting/other differences that are difficult to figure out from the current repo and the paper.

Thanks for considering and best regards,

Hey @gings .

  1. If you exclude the studies that are not chest CTs (no_chest_valid.txt), you go from 1564 to 1551.
  2. If you the drop duplicates (by Findings_EN), you go from 1551 to 1493.

I'll see if we can help with the code for retrieval, but the implementation was based on this code from CXR-CLIP.

Found it. Hope this helps:

    def evaluate_retrieval(self, overwrite: bool = False) -> dict[str, float]:
        vision_embeds = self.calculate_or_load_vision_embeddings(overwrite)
        vision_embeddings = vision_embeds["image_embeddings"].squeeze(dim=-1)
        normed_vision_embeddings = vision_embeddings / torch.norm(vision_embeddings, dim=-1, keepdim=True)  # [N, 768]

        findings = vision_embeds["findings"]

        language_embeddings = embed_text_prompts(
            findings,
            text_encoder=self.text_encoder,
            text_pooler=self.text_pooler,
            text_tokenizer=self.text_tokenizer,
        )
        language_embeddings = torch.cat(language_embeddings, dim=0).squeeze(dim=-1)
        normed_language_embeddings = (
            language_embeddings / torch.norm(language_embeddings, dim=-1, keepdim=True)
        ).cpu()

        temp_scaled_vision_embeddings = normed_vision_embeddings / self.temperature
        temp_scaled_language_embeddings = normed_language_embeddings / self.temperature
        test_retrieval_metrics = retrieval_image_text(
            image_embeddings=temp_scaled_vision_embeddings.cpu(),
            text_embeddings=temp_scaled_language_embeddings.cpu(),
            text_list=findings,
        )
        test_retrieval_report_image = retrieval_image_text(
            image_embeddings=temp_scaled_vision_embeddings.cpu(),
            text_embeddings=temp_scaled_language_embeddings.cpu(),
            text_list=findings,
            do_image_text_retrieval=False,
        )
        test_retrieval_metrics.update(test_retrieval_report_image)
        save_json(test_retrieval_metrics, os.path.join(self.output_folder, "test_retrieval_CT-RATE.json"))
import numpy as np
from sklearn import metrics

def retrieval_image_text(
    image_embeddings: np.ndarray,
    text_embeddings: np.ndarray,
    text_list: list = [],
    do_image_text_retrieval: bool = True,
):
    image_embeddings = np.array(image_embeddings)
    text_embeddings = np.array(text_embeddings)
    identical_text_set = []

    idx2label = {}
    identical_indexes = []
    for i, text in enumerate(text_list):
        if text not in identical_text_set:
            identical_text_set.append(text)
            identical_indexes.append(i)
            idx2label[i] = len(identical_text_set) - 1
        else:
            idx2label[i] = identical_text_set.index(text)

    identical_text_embedding = text_embeddings[identical_indexes]

    num_samples = image_embeddings.shape[0]
    n_text = len(identical_text_set)

    result = {}

    try:
        if do_image_text_retrieval:
            similarities = metrics.pairwise.cosine_similarity(image_embeddings, identical_text_embedding)  # n x m

            recall_dict: dict[int, float] = {1: 0, 5: 0, 10: 0}
            map_dict: dict[int, float] = {1: 0, 5: 0, 10: 0}
            ranks = []
            for idx in range(num_samples):
                label = idx2label[idx]
                similarity = similarities[idx]
                similarity_args = similarity.argsort()

                # rank of the paired text
                rank = n_text - np.argwhere(similarity_args == label).ravel()[0]
                ranks.append(rank)

                reciprocal_rank = 1 / float(rank)
                for k in recall_dict:
                    if rank <= k:
                        recall_dict[k] += 1
                        map_dict[k] += reciprocal_rank
        else:
            # Text-to-image retrieval: for each unique text, find matching images
            similarities = metrics.pairwise.cosine_similarity(
                identical_text_embedding, image_embeddings
            )  # n_unique_texts x n_images

            recall_dict: dict[int, float] = {1: 0, 5: 0, 10: 0}
            map_dict: dict[int, float] = {1: 0, 5: 0, 10: 0}
            ranks = []

            # Create reverse mapping: unique_text_idx -> list of image indices with that text
            text_to_image_indices = {}
            for img_idx, unique_text_idx in idx2label.items():
                if unique_text_idx not in text_to_image_indices:
                    text_to_image_indices[unique_text_idx] = []
                text_to_image_indices[unique_text_idx].append(img_idx)

            for text_idx in range(n_text):
                similarity = similarities[text_idx]
                similarity_args = similarity.argsort()[::-1]  # Sort in descending order for text-to-image

                # Find the best rank among all images with this text
                matching_image_indices = text_to_image_indices[text_idx]
                best_rank = float("inf")

                for img_idx in matching_image_indices:
                    rank = np.argwhere(similarity_args == img_idx).ravel()[0] + 1  # 1-indexed
                    best_rank = min(best_rank, rank)

                ranks.append(best_rank)
                reciprocal_rank = 1 / float(best_rank)

                for k in recall_dict:
                    if best_rank <= k:
                        recall_dict[k] += 1
                        map_dict[k] += reciprocal_rank
            num_samples = n_text

        ranks = np.array(ranks)
        result.update({f"RecallAt{k}": v / num_samples for k, v in recall_dict.items()})
        result.update({f"MeanAveragePrecisionAt{k}": v / num_samples for k, v in map_dict.items()})
        result.update({"MeanReciprocalRank": float(np.mean(1 / ranks))})
        result.update({"MeanRank": float(np.mean(ranks))})
        result.update({"MedianRank": float(np.median(ranks))})
        result.update({"NumSamples": num_samples})
    except ValueError:
        recall_dict = {1: np.nan, 5: np.nan, 10: np.nan}
        map_dict = {1: np.nan, 5: np.nan, 10: np.nan}
        result.update({f"RecallAt{k}": v / num_samples for k, v in recall_dict.items()})
        result.update({f"MeanAveragePrecisionAt{k}": v / num_samples for k, v in map_dict.items()})
        result.update({"MeanReciprocalRank": np.nan})
        result.update({"MeanRank": float(np.nan)})
        result.update({"MedianRank": float(np.nan)})
        result.update({"NumSamples": np.nan})

    if do_image_text_retrieval:
        result = {f"ImageText_{k}": v for k, v in result.items()}
    else:
        result = {f"TextImage_{k}": v for k, v in result.items()}

    return result

Sign up or log in to comment