| | """ |
| | Segmentation Metrics |
| | |
| | This module provides comprehensive metrics for evaluating segmentation performance |
| | in few-shot and zero-shot learning scenarios. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | from typing import Dict, List, Tuple, Optional |
| | from sklearn.metrics import precision_recall_curve, average_precision_score |
| | import cv2 |
| |
|
| |
|
| | class SegmentationMetrics: |
| | """Comprehensive segmentation metrics calculator.""" |
| | |
| | def __init__(self, threshold: float = 0.5): |
| | self.threshold = threshold |
| | |
| | def compute_metrics( |
| | self, |
| | pred_mask: torch.Tensor, |
| | gt_mask: torch.Tensor |
| | ) -> Dict[str, float]: |
| | """ |
| | Compute comprehensive segmentation metrics. |
| | |
| | Args: |
| | pred_mask: Predicted mask tensor [H, W] or [1, H, W] |
| | gt_mask: Ground truth mask tensor [H, W] or [1, H, W] |
| | |
| | Returns: |
| | Dictionary containing various metrics |
| | """ |
| | |
| | if pred_mask.dim() == 3: |
| | pred_mask = pred_mask.squeeze(0) |
| | if gt_mask.dim() == 3: |
| | gt_mask = gt_mask.squeeze(0) |
| | |
| | |
| | pred_binary = (pred_mask > self.threshold).float() |
| | gt_binary = (gt_mask > self.threshold).float() |
| | |
| | |
| | metrics = {} |
| | |
| | |
| | metrics['iou'] = self.compute_iou(pred_binary, gt_binary) |
| | |
| | |
| | metrics['dice'] = self.compute_dice(pred_binary, gt_binary) |
| | |
| | |
| | metrics['precision'] = self.compute_precision(pred_binary, gt_binary) |
| | metrics['recall'] = self.compute_recall(pred_binary, gt_binary) |
| | |
| | |
| | metrics['f1'] = self.compute_f1_score(pred_binary, gt_binary) |
| | |
| | |
| | metrics['accuracy'] = self.compute_accuracy(pred_binary, gt_binary) |
| | |
| | |
| | metrics['boundary_iou'] = self.compute_boundary_iou(pred_binary, gt_binary) |
| | metrics['hausdorff_distance'] = self.compute_hausdorff_distance(pred_binary, gt_binary) |
| | |
| | |
| | metrics['area_ratio'] = self.compute_area_ratio(pred_binary, gt_binary) |
| | |
| | return metrics |
| | |
| | def compute_iou(self, pred: torch.Tensor, gt: torch.Tensor) -> float: |
| | """Compute Intersection over Union.""" |
| | intersection = (pred & gt).sum() |
| | union = (pred | gt).sum() |
| | return (intersection / union).item() if union > 0 else 0.0 |
| | |
| | def compute_dice(self, pred: torch.Tensor, gt: torch.Tensor) -> float: |
| | """Compute Dice coefficient.""" |
| | intersection = (pred & gt).sum() |
| | total = pred.sum() + gt.sum() |
| | return (2 * intersection / total).item() if total > 0 else 0.0 |
| | |
| | def compute_precision(self, pred: torch.Tensor, gt: torch.Tensor) -> float: |
| | """Compute precision.""" |
| | intersection = (pred & gt).sum() |
| | return (intersection / pred.sum()).item() if pred.sum() > 0 else 0.0 |
| | |
| | def compute_recall(self, pred: torch.Tensor, gt: torch.Tensor) -> float: |
| | """Compute recall.""" |
| | intersection = (pred & gt).sum() |
| | return (intersection / gt.sum()).item() if gt.sum() > 0 else 0.0 |
| | |
| | def compute_f1_score(self, pred: torch.Tensor, gt: torch.Tensor) -> float: |
| | """Compute F1 score.""" |
| | precision = self.compute_precision(pred, gt) |
| | recall = self.compute_recall(pred, gt) |
| | return (2 * precision * recall / (precision + recall)).item() if (precision + recall) > 0 else 0.0 |
| | |
| | def compute_accuracy(self, pred: torch.Tensor, gt: torch.Tensor) -> float: |
| | """Compute pixel accuracy.""" |
| | correct = (pred == gt).sum() |
| | total = pred.numel() |
| | return (correct / total).item() |
| | |
| | def compute_boundary_iou(self, pred: torch.Tensor, gt: torch.Tensor) -> float: |
| | """Compute boundary IoU.""" |
| | |
| | pred_boundary = self.extract_boundary(pred) |
| | gt_boundary = self.extract_boundary(gt) |
| | |
| | |
| | return self.compute_iou(pred_boundary, gt_boundary) |
| | |
| | def extract_boundary(self, mask: torch.Tensor) -> torch.Tensor: |
| | """Extract boundary from binary mask.""" |
| | mask_np = mask.cpu().numpy().astype(np.uint8) |
| | |
| | |
| | kernel = np.ones((3, 3), np.uint8) |
| | dilated = cv2.dilate(mask_np, kernel, iterations=1) |
| | eroded = cv2.erode(mask_np, kernel, iterations=1) |
| | boundary = dilated - eroded |
| | |
| | return torch.from_numpy(boundary).float() |
| | |
| | def compute_hausdorff_distance(self, pred: torch.Tensor, gt: torch.Tensor) -> float: |
| | """Compute Hausdorff distance between boundaries.""" |
| | pred_boundary = self.extract_boundary(pred) |
| | gt_boundary = self.extract_boundary(gt) |
| | |
| | |
| | pred_np = pred_boundary.cpu().numpy() |
| | gt_np = gt_boundary.cpu().numpy() |
| | |
| | |
| | pred_points = np.column_stack(np.where(pred_np > 0)) |
| | gt_points = np.column_stack(np.where(gt_np > 0)) |
| | |
| | if len(pred_points) == 0 or len(gt_points) == 0: |
| | return float('inf') |
| | |
| | |
| | hausdorff_dist = self._hausdorff_distance(pred_points, gt_points) |
| | return hausdorff_dist |
| | |
| | def _hausdorff_distance(self, set1: np.ndarray, set2: np.ndarray) -> float: |
| | """Compute Hausdorff distance between two point sets.""" |
| | def directed_hausdorff(set_a, set_b): |
| | min_distances = [] |
| | for point_a in set_a: |
| | distances = np.linalg.norm(set_b - point_a, axis=1) |
| | min_distances.append(np.min(distances)) |
| | return np.max(min_distances) |
| | |
| | d1 = directed_hausdorff(set1, set2) |
| | d2 = directed_hausdorff(set2, set1) |
| | return max(d1, d2) |
| | |
| | def compute_area_ratio(self, pred: torch.Tensor, gt: torch.Tensor) -> float: |
| | """Compute ratio of predicted area to ground truth area.""" |
| | pred_area = pred.sum() |
| | gt_area = gt.sum() |
| | return (pred_area / gt_area).item() if gt_area > 0 else 0.0 |
| | |
| | def compute_class_metrics( |
| | self, |
| | predictions: Dict[str, torch.Tensor], |
| | ground_truth: Dict[str, torch.Tensor] |
| | ) -> Dict[str, Dict[str, float]]: |
| | """Compute metrics for multiple classes.""" |
| | class_metrics = {} |
| | |
| | for class_name in ground_truth.keys(): |
| | if class_name in predictions: |
| | metrics = self.compute_metrics(predictions[class_name], ground_truth[class_name]) |
| | class_metrics[class_name] = metrics |
| | else: |
| | |
| | class_metrics[class_name] = { |
| | 'iou': 0.0, |
| | 'dice': 0.0, |
| | 'precision': 0.0, |
| | 'recall': 0.0, |
| | 'f1': 0.0, |
| | 'accuracy': 0.0, |
| | 'boundary_iou': 0.0, |
| | 'hausdorff_distance': float('inf'), |
| | 'area_ratio': 0.0 |
| | } |
| | |
| | return class_metrics |
| | |
| | def compute_average_metrics( |
| | self, |
| | class_metrics: Dict[str, Dict[str, float]] |
| | ) -> Dict[str, float]: |
| | """Compute average metrics across all classes.""" |
| | if not class_metrics: |
| | return {} |
| | |
| | |
| | metric_names = list(class_metrics[list(class_metrics.keys())[0]].keys()) |
| | |
| | |
| | averages = {} |
| | for metric_name in metric_names: |
| | values = [class_metrics[cls][metric_name] for cls in class_metrics.keys()] |
| | |
| | |
| | if metric_name == 'hausdorff_distance': |
| | finite_values = [v for v in values if v != float('inf')] |
| | if finite_values: |
| | averages[metric_name] = np.mean(finite_values) |
| | else: |
| | averages[metric_name] = float('inf') |
| | else: |
| | averages[metric_name] = np.mean(values) |
| | |
| | return averages |
| |
|
| |
|
| | class FewShotMetrics: |
| | """Specialized metrics for few-shot learning evaluation.""" |
| | |
| | def __init__(self): |
| | self.segmentation_metrics = SegmentationMetrics() |
| | |
| | def compute_episode_metrics( |
| | self, |
| | episode_results: List[Dict] |
| | ) -> Dict[str, float]: |
| | """Compute metrics across multiple episodes.""" |
| | all_metrics = [] |
| | |
| | for episode in episode_results: |
| | if 'metrics' in episode: |
| | all_metrics.append(episode['metrics']) |
| | |
| | if not all_metrics: |
| | return {} |
| | |
| | |
| | episode_stats = {} |
| | metric_names = all_metrics[0].keys() |
| | |
| | for metric_name in metric_names: |
| | values = [ep[metric_name] for ep in all_metrics if metric_name in ep] |
| | if values: |
| | episode_stats[f'mean_{metric_name}'] = np.mean(values) |
| | episode_stats[f'std_{metric_name}'] = np.std(values) |
| | episode_stats[f'min_{metric_name}'] = np.min(values) |
| | episode_stats[f'max_{metric_name}'] = np.max(values) |
| | |
| | return episode_stats |
| | |
| | def compute_shot_analysis( |
| | self, |
| | results_by_shots: Dict[int, List[Dict]] |
| | ) -> Dict[str, Dict[str, float]]: |
| | """Analyze performance across different numbers of shots.""" |
| | shot_analysis = {} |
| | |
| | for num_shots, results in results_by_shots.items(): |
| | episode_metrics = self.compute_episode_metrics(results) |
| | shot_analysis[f'{num_shots}_shots'] = episode_metrics |
| | |
| | return shot_analysis |
| |
|
| |
|
| | class ZeroShotMetrics: |
| | """Specialized metrics for zero-shot learning evaluation.""" |
| | |
| | def __init__(self): |
| | self.segmentation_metrics = SegmentationMetrics() |
| | |
| | def compute_prompt_strategy_comparison( |
| | self, |
| | strategy_results: Dict[str, List[Dict]] |
| | ) -> Dict[str, Dict[str, float]]: |
| | """Compare different prompt strategies.""" |
| | strategy_comparison = {} |
| | |
| | for strategy_name, results in strategy_results.items(): |
| | |
| | avg_metrics = {} |
| | if results: |
| | metric_names = results[0].keys() |
| | for metric_name in metric_names: |
| | values = [r[metric_name] for r in results if metric_name in r] |
| | if values: |
| | avg_metrics[f'mean_{metric_name}'] = np.mean(values) |
| | avg_metrics[f'std_{metric_name}'] = np.std(values) |
| | |
| | strategy_comparison[strategy_name] = avg_metrics |
| | |
| | return strategy_comparison |
| | |
| | def compute_attention_analysis( |
| | self, |
| | with_attention: List[Dict], |
| | without_attention: List[Dict] |
| | ) -> Dict[str, float]: |
| | """Analyze the impact of attention mechanisms.""" |
| | if not with_attention or not without_attention: |
| | return {} |
| | |
| | |
| | with_attention_avg = {} |
| | without_attention_avg = {} |
| | |
| | metric_names = with_attention[0].keys() |
| | for metric_name in metric_names: |
| | with_values = [r[metric_name] for r in with_attention if metric_name in r] |
| | without_values = [r[metric_name] for r in without_attention if metric_name in r] |
| | |
| | if with_values: |
| | with_attention_avg[metric_name] = np.mean(with_values) |
| | if without_values: |
| | without_attention_avg[metric_name] = np.mean(without_values) |
| | |
| | |
| | improvements = {} |
| | for metric_name in with_attention_avg.keys(): |
| | if metric_name in without_attention_avg: |
| | improvement = with_attention_avg[metric_name] - without_attention_avg[metric_name] |
| | improvements[f'{metric_name}_improvement'] = improvement |
| | |
| | return { |
| | 'with_attention': with_attention_avg, |
| | 'without_attention': without_attention_avg, |
| | 'improvements': improvements |
| | } |