Instructions to use microsoft/colipri with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- COLIPRI
How to use microsoft/colipri with COLIPRI:
pip install colipri
from colipri import get_model from colipri import get_processor from colipri import load_sample_ct from colipri import ZeroShotImageClassificationPipeline model = get_model().cuda() processor = get_processor() pipeline = ZeroShotImageClassificationPipeline("microsoft/colipri", processor) image = load_sample_ct() pipeline(image, ["No lung nodules", "Lung nodules"]) - Notebooks
- Google Colab
- Kaggle
Reproducing CT-RATE retrieval numbers
Hi, thanks a lot for releasing COLIPRI.
I am currently trying to reproduce the CT-RATE retrieval numbers:
The CT-RATE validation split has 1,564 unique reports, but the paper evaluates on 1,493 report-scan pairs. Could you release the list of the 1,493
.nii.gzvolume names you used?If possible, could you also release the CT-RATE retrieval evaluation script? I implemented one but it is currently ~10 percentage points below the paper on R@10 and I suspect preprocessing/report-formatting/other differences that are difficult to figure out from the current repo and the paper.
Thanks for considering and best regards,
Hey @gings .
- If you exclude the studies that are not chest CTs (
no_chest_valid.txt), you go from 1564 to 1551. - If you the drop duplicates (by
Findings_EN), you go from 1551 to 1493.
I'll see if we can help with the code for retrieval, but the implementation was based on this code from CXR-CLIP.
Found it. Hope this helps:
def evaluate_retrieval(self, overwrite: bool = False) -> dict[str, float]:
vision_embeds = self.calculate_or_load_vision_embeddings(overwrite)
vision_embeddings = vision_embeds["image_embeddings"].squeeze(dim=-1)
normed_vision_embeddings = vision_embeddings / torch.norm(vision_embeddings, dim=-1, keepdim=True) # [N, 768]
findings = vision_embeds["findings"]
language_embeddings = embed_text_prompts(
findings,
text_encoder=self.text_encoder,
text_pooler=self.text_pooler,
text_tokenizer=self.text_tokenizer,
)
language_embeddings = torch.cat(language_embeddings, dim=0).squeeze(dim=-1)
normed_language_embeddings = (
language_embeddings / torch.norm(language_embeddings, dim=-1, keepdim=True)
).cpu()
temp_scaled_vision_embeddings = normed_vision_embeddings / self.temperature
temp_scaled_language_embeddings = normed_language_embeddings / self.temperature
test_retrieval_metrics = retrieval_image_text(
image_embeddings=temp_scaled_vision_embeddings.cpu(),
text_embeddings=temp_scaled_language_embeddings.cpu(),
text_list=findings,
)
test_retrieval_report_image = retrieval_image_text(
image_embeddings=temp_scaled_vision_embeddings.cpu(),
text_embeddings=temp_scaled_language_embeddings.cpu(),
text_list=findings,
do_image_text_retrieval=False,
)
test_retrieval_metrics.update(test_retrieval_report_image)
save_json(test_retrieval_metrics, os.path.join(self.output_folder, "test_retrieval_CT-RATE.json"))
import numpy as np
from sklearn import metrics
def retrieval_image_text(
image_embeddings: np.ndarray,
text_embeddings: np.ndarray,
text_list: list = [],
do_image_text_retrieval: bool = True,
):
image_embeddings = np.array(image_embeddings)
text_embeddings = np.array(text_embeddings)
identical_text_set = []
idx2label = {}
identical_indexes = []
for i, text in enumerate(text_list):
if text not in identical_text_set:
identical_text_set.append(text)
identical_indexes.append(i)
idx2label[i] = len(identical_text_set) - 1
else:
idx2label[i] = identical_text_set.index(text)
identical_text_embedding = text_embeddings[identical_indexes]
num_samples = image_embeddings.shape[0]
n_text = len(identical_text_set)
result = {}
try:
if do_image_text_retrieval:
similarities = metrics.pairwise.cosine_similarity(image_embeddings, identical_text_embedding) # n x m
recall_dict: dict[int, float] = {1: 0, 5: 0, 10: 0}
map_dict: dict[int, float] = {1: 0, 5: 0, 10: 0}
ranks = []
for idx in range(num_samples):
label = idx2label[idx]
similarity = similarities[idx]
similarity_args = similarity.argsort()
# rank of the paired text
rank = n_text - np.argwhere(similarity_args == label).ravel()[0]
ranks.append(rank)
reciprocal_rank = 1 / float(rank)
for k in recall_dict:
if rank <= k:
recall_dict[k] += 1
map_dict[k] += reciprocal_rank
else:
# Text-to-image retrieval: for each unique text, find matching images
similarities = metrics.pairwise.cosine_similarity(
identical_text_embedding, image_embeddings
) # n_unique_texts x n_images
recall_dict: dict[int, float] = {1: 0, 5: 0, 10: 0}
map_dict: dict[int, float] = {1: 0, 5: 0, 10: 0}
ranks = []
# Create reverse mapping: unique_text_idx -> list of image indices with that text
text_to_image_indices = {}
for img_idx, unique_text_idx in idx2label.items():
if unique_text_idx not in text_to_image_indices:
text_to_image_indices[unique_text_idx] = []
text_to_image_indices[unique_text_idx].append(img_idx)
for text_idx in range(n_text):
similarity = similarities[text_idx]
similarity_args = similarity.argsort()[::-1] # Sort in descending order for text-to-image
# Find the best rank among all images with this text
matching_image_indices = text_to_image_indices[text_idx]
best_rank = float("inf")
for img_idx in matching_image_indices:
rank = np.argwhere(similarity_args == img_idx).ravel()[0] + 1 # 1-indexed
best_rank = min(best_rank, rank)
ranks.append(best_rank)
reciprocal_rank = 1 / float(best_rank)
for k in recall_dict:
if best_rank <= k:
recall_dict[k] += 1
map_dict[k] += reciprocal_rank
num_samples = n_text
ranks = np.array(ranks)
result.update({f"RecallAt{k}": v / num_samples for k, v in recall_dict.items()})
result.update({f"MeanAveragePrecisionAt{k}": v / num_samples for k, v in map_dict.items()})
result.update({"MeanReciprocalRank": float(np.mean(1 / ranks))})
result.update({"MeanRank": float(np.mean(ranks))})
result.update({"MedianRank": float(np.median(ranks))})
result.update({"NumSamples": num_samples})
except ValueError:
recall_dict = {1: np.nan, 5: np.nan, 10: np.nan}
map_dict = {1: np.nan, 5: np.nan, 10: np.nan}
result.update({f"RecallAt{k}": v / num_samples for k, v in recall_dict.items()})
result.update({f"MeanAveragePrecisionAt{k}": v / num_samples for k, v in map_dict.items()})
result.update({"MeanReciprocalRank": np.nan})
result.update({"MeanRank": float(np.nan)})
result.update({"MedianRank": float(np.nan)})
result.update({"NumSamples": np.nan})
if do_image_text_retrieval:
result = {f"ImageText_{k}": v for k, v in result.items()}
else:
result = {f"TextImage_{k}": v for k, v in result.items()}
return result