| | """ |
| | SAM 2 Zero-Shot Segmentation Model |
| | |
| | This module implements zero-shot segmentation using SAM 2 with advanced |
| | text prompting, visual grounding, and attention-based prompt generation. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Dict, List, Optional, Tuple, Union |
| | import numpy as np |
| | from PIL import Image |
| | import clip |
| | from segment_anything_2 import sam_model_registry, SamPredictor |
| | from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel |
| | import cv2 |
| |
|
| |
|
| | class SAM2ZeroShot(nn.Module): |
| | """ |
| | SAM 2 Zero-Shot Segmentation Model |
| | |
| | Performs zero-shot segmentation using SAM 2 with advanced text prompting |
| | and visual grounding techniques. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | sam2_checkpoint: str, |
| | clip_model_name: str = "ViT-B/32", |
| | device: str = "cuda", |
| | use_attention_maps: bool = True, |
| | use_grounding_dino: bool = False, |
| | temperature: float = 0.1 |
| | ): |
| | super().__init__() |
| | self.device = device |
| | self.temperature = temperature |
| | self.use_attention_maps = use_attention_maps |
| | self.use_grounding_dino = use_grounding_dino |
| | |
| | |
| | self.sam2 = sam_model_registry["vit_h"](checkpoint=sam2_checkpoint) |
| | self.sam2.to(device) |
| | self.sam2_predictor = SamPredictor(self.sam2) |
| | |
| | |
| | self.clip_model, self.clip_preprocess = clip.load(clip_model_name, device=device) |
| | self.clip_model.eval() |
| | |
| | |
| | if self.use_attention_maps: |
| | self.clip_text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") |
| | self.clip_vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") |
| | self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
| | self.clip_text_model.to(device) |
| | self.clip_vision_model.to(device) |
| | |
| | |
| | self.advanced_prompts = { |
| | "satellite": { |
| | "building": [ |
| | "satellite view of buildings", "aerial photograph of structures", |
| | "overhead view of houses", "urban development from above", |
| | "rooftop structures", "architectural features from space" |
| | ], |
| | "road": [ |
| | "satellite view of roads", "aerial photograph of streets", |
| | "overhead view of highways", "transportation network from above", |
| | "paved surfaces", "road infrastructure from space" |
| | ], |
| | "vegetation": [ |
| | "satellite view of vegetation", "aerial photograph of forests", |
| | "overhead view of trees", "green areas from above", |
| | "natural landscape", "plant life from space" |
| | ], |
| | "water": [ |
| | "satellite view of water", "aerial photograph of lakes", |
| | "overhead view of rivers", "water bodies from above", |
| | "aquatic features", "water resources from space" |
| | ] |
| | }, |
| | "fashion": { |
| | "shirt": [ |
| | "fashion photography of shirts", "clothing item top", |
| | "apparel garment", "upper body clothing", |
| | "casual wear", "formal attire top" |
| | ], |
| | "pants": [ |
| | "fashion photography of pants", "lower body clothing", |
| | "trousers garment", "leg wear", |
| | "casual pants", "formal trousers" |
| | ], |
| | "dress": [ |
| | "fashion photography of dresses", "full body garment", |
| | "formal dress", "evening wear", |
| | "casual dress", "party dress" |
| | ], |
| | "shoes": [ |
| | "fashion photography of shoes", "footwear item", |
| | "foot covering", "walking shoes", |
| | "casual footwear", "formal shoes" |
| | ] |
| | }, |
| | "robotics": { |
| | "robot": [ |
| | "robotics environment with robot", "automation equipment", |
| | "mechanical arm", "industrial robot", |
| | "automated system", "robotic device" |
| | ], |
| | "tool": [ |
| | "robotics environment with tools", "industrial equipment", |
| | "mechanical tools", "work equipment", |
| | "hand tools", "power tools" |
| | ], |
| | "safety": [ |
| | "robotics environment with safety equipment", "protective gear", |
| | "safety helmet", "safety vest", |
| | "protective clothing", "safety equipment" |
| | ] |
| | } |
| | } |
| | |
| | |
| | self.prompt_strategies = { |
| | "descriptive": lambda x: f"a clear image showing {x}", |
| | "contextual": lambda x: f"in a typical environment, {x}", |
| | "detailed": lambda x: f"high quality photograph of {x} with clear details", |
| | "contrastive": lambda x: f"{x} standing out from the background" |
| | } |
| | |
| | def generate_attention_maps( |
| | self, |
| | image: torch.Tensor, |
| | text_prompts: List[str] |
| | ) -> torch.Tensor: |
| | """Generate attention maps using CLIP's cross-attention.""" |
| | if not self.use_attention_maps: |
| | return None |
| | |
| | |
| | text_inputs = self.clip_tokenizer( |
| | text_prompts, |
| | padding=True, |
| | return_tensors="pt" |
| | ).to(self.device) |
| | |
| | |
| | image_inputs = self.clip_preprocess(image).unsqueeze(0).to(self.device) |
| | |
| | |
| | with torch.no_grad(): |
| | text_outputs = self.clip_text_model(**text_inputs, output_attentions=True) |
| | vision_outputs = self.clip_vision_model(image_inputs, output_attentions=True) |
| | |
| | |
| | cross_attention = text_outputs.cross_attentions[-1] |
| | attention_maps = cross_attention.mean(dim=1) |
| | |
| | return attention_maps |
| | |
| | def extract_attention_points( |
| | self, |
| | attention_maps: torch.Tensor, |
| | num_points: int = 5 |
| | ) -> List[Tuple[int, int]]: |
| | """Extract points from attention maps for SAM 2 prompting.""" |
| | if attention_maps is None: |
| | return [] |
| | |
| | |
| | h, w = attention_maps.shape[-2:] |
| | attention_maps = F.interpolate( |
| | attention_maps.unsqueeze(0), |
| | size=(h, w), |
| | mode='bilinear' |
| | ).squeeze(0) |
| | |
| | |
| | points = [] |
| | for i in range(min(num_points, attention_maps.shape[0])): |
| | attention_map = attention_maps[i] |
| | max_idx = torch.argmax(attention_map) |
| | y, x = max_idx // w, max_idx % w |
| | points.append((int(x), int(y))) |
| | |
| | return points |
| | |
| | def generate_enhanced_prompts( |
| | self, |
| | domain: str, |
| | class_names: List[str] |
| | ) -> List[str]: |
| | """Generate enhanced prompts using multiple strategies.""" |
| | enhanced_prompts = [] |
| | |
| | for class_name in class_names: |
| | if domain in self.advanced_prompts and class_name in self.advanced_prompts[domain]: |
| | base_prompts = self.advanced_prompts[domain][class_name] |
| | |
| | |
| | enhanced_prompts.extend(base_prompts) |
| | |
| | |
| | for strategy_name, strategy_func in self.prompt_strategies.items(): |
| | for base_prompt in base_prompts[:2]: |
| | enhanced_prompt = strategy_func(base_prompt) |
| | enhanced_prompts.append(enhanced_prompt) |
| | else: |
| | |
| | enhanced_prompts.append(class_name) |
| | enhanced_prompts.append(f"object: {class_name}") |
| | |
| | return enhanced_prompts |
| | |
| | def compute_text_image_similarity( |
| | self, |
| | image: torch.Tensor, |
| | text_prompts: List[str] |
| | ) -> torch.Tensor: |
| | """Compute similarity between image and text prompts.""" |
| | |
| | text_tokens = clip.tokenize(text_prompts).to(self.device) |
| | |
| | with torch.no_grad(): |
| | text_features = self.clip_model.encode_text(text_tokens) |
| | text_features = F.normalize(text_features, dim=-1) |
| | |
| | |
| | image_input = self.clip_preprocess(image).unsqueeze(0).to(self.device) |
| | image_features = self.clip_model.encode_image(image_input) |
| | image_features = F.normalize(image_features, dim=-1) |
| | |
| | |
| | similarity = torch.matmul(image_features, text_features.T) / self.temperature |
| | |
| | return similarity |
| | |
| | def generate_sam2_prompts( |
| | self, |
| | image: torch.Tensor, |
| | domain: str, |
| | class_names: List[str] |
| | ) -> List[Dict]: |
| | """Generate comprehensive SAM 2 prompts for zero-shot segmentation.""" |
| | prompts = [] |
| | |
| | |
| | text_prompts = self.generate_enhanced_prompts(domain, class_names) |
| | |
| | |
| | similarities = self.compute_text_image_similarity(image, text_prompts) |
| | |
| | |
| | attention_maps = self.generate_attention_maps(image, text_prompts) |
| | attention_points = self.extract_attention_points(attention_maps) |
| | |
| | |
| | for i, class_name in enumerate(class_names): |
| | class_prompts = [] |
| | |
| | |
| | class_text_indices = [] |
| | for j, prompt in enumerate(text_prompts): |
| | if class_name.lower() in prompt.lower(): |
| | class_text_indices.append(j) |
| | |
| | if class_text_indices: |
| | |
| | class_similarities = similarities[0, class_text_indices] |
| | best_idx = torch.argmax(class_similarities) |
| | best_similarity = class_similarities[best_idx] |
| | |
| | if best_similarity > 0.2: |
| | |
| | if attention_points: |
| | for point in attention_points[:3]: |
| | prompts.append({ |
| | 'type': 'point', |
| | 'data': point, |
| | 'label': 1, |
| | 'class': class_name, |
| | 'confidence': best_similarity.item(), |
| | 'source': 'attention' |
| | }) |
| | |
| | |
| | h, w = image.shape[-2:] |
| | center_point = [w // 2, h // 2] |
| | prompts.append({ |
| | 'type': 'point', |
| | 'data': center_point, |
| | 'label': 1, |
| | 'class': class_name, |
| | 'confidence': best_similarity.item(), |
| | 'source': 'center' |
| | }) |
| | |
| | |
| | if best_similarity > 0.4: |
| | box = [w // 4, h // 4, 3 * w // 4, 3 * h // 4] |
| | prompts.append({ |
| | 'type': 'box', |
| | 'data': box, |
| | 'class': class_name, |
| | 'confidence': best_similarity.item(), |
| | 'source': 'similarity' |
| | }) |
| | |
| | return prompts |
| | |
| | def segment( |
| | self, |
| | image: torch.Tensor, |
| | domain: str, |
| | class_names: List[str] |
| | ) -> Dict[str, torch.Tensor]: |
| | """ |
| | Perform zero-shot segmentation. |
| | |
| | Args: |
| | image: Input image tensor [C, H, W] |
| | domain: Domain name (satellite, fashion, robotics) |
| | class_names: List of class names to segment |
| | |
| | Returns: |
| | Dictionary with masks for each class |
| | """ |
| | |
| | if isinstance(image, torch.Tensor): |
| | image_np = image.permute(1, 2, 0).cpu().numpy() |
| | else: |
| | image_np = image |
| | |
| | |
| | self.sam2_predictor.set_image(image_np) |
| | |
| | |
| | prompts = self.generate_sam2_prompts(image, domain, class_names) |
| | |
| | results = {} |
| | |
| | for prompt in prompts: |
| | class_name = prompt['class'] |
| | |
| | if prompt['type'] == 'point': |
| | point = prompt['data'] |
| | label = prompt['label'] |
| | |
| | |
| | masks, scores, logits = self.sam2_predictor.predict( |
| | point_coords=np.array([point]), |
| | point_labels=np.array([label]), |
| | multimask_output=True |
| | ) |
| | |
| | |
| | best_mask_idx = np.argmax(scores) |
| | mask = torch.from_numpy(masks[best_mask_idx]).float() |
| | |
| | |
| | if prompt['confidence'] > 0.2: |
| | if class_name not in results: |
| | results[class_name] = mask |
| | else: |
| | |
| | results[class_name] = torch.max(results[class_name], mask) |
| | |
| | elif prompt['type'] == 'box': |
| | box = prompt['data'] |
| | |
| | |
| | masks, scores, logits = self.sam2_predictor.predict( |
| | box=np.array(box), |
| | multimask_output=True |
| | ) |
| | |
| | |
| | best_mask_idx = np.argmax(scores) |
| | mask = torch.from_numpy(masks[best_mask_idx]).float() |
| | |
| | |
| | if prompt['confidence'] > 0.3: |
| | if class_name not in results: |
| | results[class_name] = mask |
| | else: |
| | |
| | results[class_name] = torch.max(results[class_name], mask) |
| | |
| | return results |
| | |
| | def forward( |
| | self, |
| | image: torch.Tensor, |
| | domain: str, |
| | class_names: List[str] |
| | ) -> Dict[str, torch.Tensor]: |
| | """Forward pass.""" |
| | return self.segment(image, domain, class_names) |
| |
|
| |
|
| | class ZeroShotEvaluator: |
| | """Evaluator for zero-shot segmentation.""" |
| | |
| | def __init__(self): |
| | self.metrics = {} |
| | |
| | def compute_iou(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> float: |
| | """Compute Intersection over Union.""" |
| | intersection = (pred_mask & gt_mask).sum() |
| | union = (pred_mask | gt_mask).sum() |
| | return (intersection / union).item() if union > 0 else 0.0 |
| | |
| | def compute_dice(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> float: |
| | """Compute Dice coefficient.""" |
| | intersection = (pred_mask & gt_mask).sum() |
| | total = pred_mask.sum() + gt_mask.sum() |
| | return (2 * intersection / total).item() if total > 0 else 0.0 |
| | |
| | def evaluate( |
| | self, |
| | predictions: Dict[str, torch.Tensor], |
| | ground_truth: Dict[str, torch.Tensor] |
| | ) -> Dict[str, float]: |
| | """Evaluate zero-shot segmentation results.""" |
| | results = {} |
| | |
| | for class_name in ground_truth.keys(): |
| | if class_name in predictions: |
| | pred_mask = predictions[class_name] > 0.5 |
| | gt_mask = ground_truth[class_name] > 0.5 |
| | |
| | iou = self.compute_iou(pred_mask, gt_mask) |
| | dice = self.compute_dice(pred_mask, gt_mask) |
| | |
| | results[f"{class_name}_iou"] = iou |
| | results[f"{class_name}_dice"] = dice |
| | |
| | |
| | if results: |
| | results['mean_iou'] = np.mean([v for k, v in results.items() if 'iou' in k]) |
| | results['mean_dice'] = np.mean([v for k, v in results.items() if 'dice' in k]) |
| | |
| | return results |