| | """ |
| | Data Loader Utilities |
| | |
| | This module provides data loading utilities for different domains |
| | (satellite, fashion, robotics) with support for few-shot and zero-shot learning. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | from PIL import Image |
| | import os |
| | import json |
| | from typing import List, Dict, Tuple, Optional |
| | import random |
| | from torch.utils.data import Dataset, DataLoader |
| | import torchvision.transforms as transforms |
| | from torchvision.transforms import functional as F |
| | import cv2 |
| |
|
| |
|
| | class BaseDataLoader: |
| | """Base class for domain-specific data loaders.""" |
| | |
| | def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)): |
| | self.data_dir = data_dir |
| | self.image_size = image_size |
| | |
| | |
| | self.transform = transforms.Compose([ |
| | transforms.Resize(image_size), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| | |
| | self.mask_transform = transforms.Compose([ |
| | transforms.Resize(image_size, interpolation=transforms.InterpolationMode.NEAREST), |
| | transforms.ToTensor() |
| | ]) |
| | |
| | def load_image(self, image_path: str) -> torch.Tensor: |
| | """Load and preprocess image.""" |
| | image = Image.open(image_path).convert('RGB') |
| | return self.transform(image) |
| | |
| | def load_mask(self, mask_path: str) -> torch.Tensor: |
| | """Load and preprocess mask.""" |
| | mask = Image.open(mask_path).convert('L') |
| | return self.mask_transform(mask) |
| | |
| | def get_random_sample(self) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| | """Get a random sample from the dataset.""" |
| | raise NotImplementedError |
| | |
| | def get_class_examples(self, class_name: str, num_examples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: |
| | """Get examples for a specific class.""" |
| | raise NotImplementedError |
| |
|
| |
|
| | class SatelliteDataLoader(BaseDataLoader): |
| | """Data loader for satellite imagery segmentation.""" |
| | |
| | def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)): |
| | super().__init__(data_dir, image_size) |
| | |
| | |
| | self.classes = ["building", "road", "vegetation", "water"] |
| | self.class_to_id = {cls: i for i, cls in enumerate(self.classes)} |
| | |
| | |
| | self.load_dataset_structure() |
| | |
| | def load_dataset_structure(self): |
| | """Load dataset structure and file paths.""" |
| | self.images = [] |
| | self.masks = [] |
| | self.class_samples = {cls: [] for cls in self.classes} |
| | |
| | |
| | images_dir = os.path.join(self.data_dir, "images") |
| | masks_dir = os.path.join(self.data_dir, "masks") |
| | |
| | if not os.path.exists(images_dir) or not os.path.exists(masks_dir): |
| | |
| | self.create_dummy_data() |
| | return |
| | |
| | |
| | for filename in os.listdir(images_dir): |
| | if filename.endswith(('.jpg', '.png', '.tif')): |
| | image_path = os.path.join(images_dir, filename) |
| | mask_path = os.path.join(masks_dir, filename.replace('.jpg', '_mask.png')) |
| | |
| | if os.path.exists(mask_path): |
| | self.images.append(image_path) |
| | self.masks.append(mask_path) |
| | |
| | |
| | self.categorize_sample(image_path, mask_path) |
| | |
| | def create_dummy_data(self): |
| | """Create dummy satellite data for demonstration.""" |
| | print("Creating dummy satellite data...") |
| | |
| | |
| | os.makedirs(os.path.join(self.data_dir, "images"), exist_ok=True) |
| | os.makedirs(os.path.join(self.data_dir, "masks"), exist_ok=True) |
| | |
| | |
| | for i in range(100): |
| | |
| | image = np.random.randint(50, 200, (512, 512, 3), dtype=np.uint8) |
| | |
| | |
| | |
| | for _ in range(5): |
| | x, y = np.random.randint(0, 400), np.random.randint(0, 400) |
| | w, h = np.random.randint(20, 80), np.random.randint(20, 80) |
| | image[y:y+h, x:x+w] = np.random.randint(100, 150, 3) |
| | |
| | |
| | for _ in range(3): |
| | x, y = np.random.randint(0, 512), np.random.randint(0, 512) |
| | length = np.random.randint(50, 150) |
| | angle = np.random.uniform(0, 2*np.pi) |
| | for j in range(length): |
| | px = int(x + j * np.cos(angle)) |
| | py = int(y + j * np.sin(angle)) |
| | if 0 <= px < 512 and 0 <= py < 512: |
| | image[py, px] = [80, 80, 80] |
| | |
| | |
| | image_path = os.path.join(self.data_dir, "images", f"satellite_{i:03d}.jpg") |
| | Image.fromarray(image).save(image_path) |
| | |
| | |
| | mask = np.zeros((512, 512), dtype=np.uint8) |
| | |
| | |
| | for _ in range(3): |
| | x, y = np.random.randint(0, 400), np.random.randint(0, 400) |
| | w, h = np.random.randint(20, 80), np.random.randint(20, 80) |
| | mask[y:y+h, x:x+w] = 1 |
| | |
| | |
| | for _ in range(2): |
| | x, y = np.random.randint(0, 512), np.random.randint(0, 512) |
| | length = np.random.randint(50, 150) |
| | angle = np.random.uniform(0, 2*np.pi) |
| | for j in range(length): |
| | px = int(x + j * np.cos(angle)) |
| | py = int(y + j * np.sin(angle)) |
| | if 0 <= px < 512 and 0 <= py < 512: |
| | mask[py, px] = 2 |
| | |
| | |
| | mask_path = os.path.join(self.data_dir, "masks", f"satellite_{i:03d}_mask.png") |
| | Image.fromarray(mask * 85).save(mask_path) |
| | |
| | |
| | self.images.append(image_path) |
| | self.masks.append(mask_path) |
| | |
| | |
| | self.categorize_sample(image_path, mask_path) |
| | |
| | def categorize_sample(self, image_path: str, mask_path: str): |
| | """Categorize sample by dominant class.""" |
| | mask = np.array(Image.open(mask_path)) |
| | |
| | |
| | class_counts = {} |
| | for i, class_name in enumerate(self.classes): |
| | class_counts[class_name] = np.sum(mask == i) |
| | |
| | |
| | dominant_class = max(class_counts.items(), key=lambda x: x[1])[0] |
| | self.class_samples[dominant_class].append((image_path, mask_path)) |
| | |
| | def get_random_query(self, class_name: str) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Get a random query image and mask for a specific class.""" |
| | if class_name not in self.class_samples or not self.class_samples[class_name]: |
| | |
| | idx = random.randint(0, len(self.images) - 1) |
| | image = self.load_image(self.images[idx]) |
| | mask = self.load_mask(self.masks[idx]) |
| | return image, mask |
| | |
| | |
| | image_path, mask_path = random.choice(self.class_samples[class_name]) |
| | image = self.load_image(image_path) |
| | mask = self.load_mask(mask_path) |
| | |
| | return image, mask |
| | |
| | def get_class_examples(self, class_name: str, num_examples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: |
| | """Get examples for a specific class.""" |
| | examples = [] |
| | |
| | if class_name in self.class_samples: |
| | available_samples = self.class_samples[class_name] |
| | selected_samples = random.sample(available_samples, min(num_examples, len(available_samples))) |
| | |
| | for image_path, mask_path in selected_samples: |
| | image = self.load_image(image_path) |
| | mask = self.load_mask(mask_path) |
| | examples.append((image, mask)) |
| | |
| | return examples |
| |
|
| |
|
| | class FashionDataLoader(BaseDataLoader): |
| | """Data loader for fashion segmentation.""" |
| | |
| | def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)): |
| | super().__init__(data_dir, image_size) |
| | |
| | |
| | self.classes = ["shirt", "pants", "dress", "shoes"] |
| | self.class_to_id = {cls: i for i, cls in enumerate(self.classes)} |
| | |
| | |
| | self.load_dataset_structure() |
| | |
| | def load_dataset_structure(self): |
| | """Load dataset structure and file paths.""" |
| | self.images = [] |
| | self.masks = [] |
| | self.class_samples = {cls: [] for cls in self.classes} |
| | |
| | |
| | images_dir = os.path.join(self.data_dir, "images") |
| | masks_dir = os.path.join(self.data_dir, "masks") |
| | |
| | if not os.path.exists(images_dir) or not os.path.exists(masks_dir): |
| | |
| | self.create_dummy_data() |
| | return |
| | |
| | |
| | for filename in os.listdir(images_dir): |
| | if filename.endswith(('.jpg', '.png')): |
| | image_path = os.path.join(images_dir, filename) |
| | mask_path = os.path.join(masks_dir, filename.replace('.jpg', '_mask.png')) |
| | |
| | if os.path.exists(mask_path): |
| | self.images.append(image_path) |
| | self.masks.append(mask_path) |
| | |
| | |
| | self.categorize_sample(image_path, mask_path) |
| | |
| | def create_dummy_data(self): |
| | """Create dummy fashion data for demonstration.""" |
| | print("Creating dummy fashion data...") |
| | |
| | |
| | os.makedirs(os.path.join(self.data_dir, "images"), exist_ok=True) |
| | os.makedirs(os.path.join(self.data_dir, "masks"), exist_ok=True) |
| | |
| | |
| | for i in range(100): |
| | |
| | image = np.random.randint(200, 255, (512, 512, 3), dtype=np.uint8) |
| | |
| | |
| | class_id = i % len(self.classes) |
| | |
| | if class_id == 0: |
| | |
| | center_x, center_y = 256, 256 |
| | width, height = 150, 200 |
| | image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [100, 150, 200] |
| | |
| | elif class_id == 1: |
| | |
| | center_x, center_y = 256, 300 |
| | width, height = 120, 180 |
| | image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [50, 100, 150] |
| | |
| | elif class_id == 2: |
| | |
| | center_x, center_y = 256, 250 |
| | width, height = 140, 220 |
| | image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [200, 100, 150] |
| | |
| | else: |
| | |
| | center_x, center_y = 256, 400 |
| | width, height = 100, 60 |
| | image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [80, 80, 80] |
| | |
| | |
| | image_path = os.path.join(self.data_dir, "images", f"fashion_{i:03d}.jpg") |
| | Image.fromarray(image).save(image_path) |
| | |
| | |
| | mask = np.zeros((512, 512), dtype=np.uint8) |
| | |
| | |
| | if class_id == 0: |
| | center_x, center_y = 256, 256 |
| | width, height = 150, 200 |
| | mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 1 |
| | |
| | elif class_id == 1: |
| | center_x, center_y = 256, 300 |
| | width, height = 120, 180 |
| | mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 2 |
| | |
| | elif class_id == 2: |
| | center_x, center_y = 256, 250 |
| | width, height = 140, 220 |
| | mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 3 |
| | |
| | else: |
| | center_x, center_y = 256, 400 |
| | width, height = 100, 60 |
| | mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 4 |
| | |
| | |
| | mask_path = os.path.join(self.data_dir, "masks", f"fashion_{i:03d}_mask.png") |
| | Image.fromarray(mask * 51).save(mask_path) |
| | |
| | |
| | self.images.append(image_path) |
| | self.masks.append(mask_path) |
| | |
| | |
| | self.categorize_sample(image_path, mask_path) |
| | |
| | def categorize_sample(self, image_path: str, mask_path: str): |
| | """Categorize sample by dominant class.""" |
| | mask = np.array(Image.open(mask_path)) |
| | |
| | |
| | class_counts = {} |
| | for i, class_name in enumerate(self.classes): |
| | class_counts[class_name] = np.sum(mask == (i + 1)) |
| | |
| | |
| | dominant_class = max(class_counts.items(), key=lambda x: x[1])[0] |
| | self.class_samples[dominant_class].append((image_path, mask_path)) |
| | |
| | def get_test_sample(self) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| | """Get a random test sample with ground truth masks.""" |
| | idx = random.randint(0, len(self.images) - 1) |
| | image = self.load_image(self.images[idx]) |
| | mask = self.load_mask(self.masks[idx]) |
| | |
| | |
| | ground_truth = {} |
| | for i, class_name in enumerate(self.classes): |
| | class_mask = (mask == (i + 1)).float() |
| | ground_truth[class_name] = class_mask |
| | |
| | return image, ground_truth |
| |
|
| |
|
| | class RoboticsDataLoader(BaseDataLoader): |
| | """Data loader for robotics segmentation.""" |
| | |
| | def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)): |
| | super().__init__(data_dir, image_size) |
| | |
| | |
| | self.classes = ["robot", "tool", "safety"] |
| | self.class_to_id = {cls: i for i, cls in enumerate(self.classes)} |
| | |
| | |
| | self.load_dataset_structure() |
| | |
| | def load_dataset_structure(self): |
| | """Load dataset structure and file paths.""" |
| | self.images = [] |
| | self.masks = [] |
| | self.class_samples = {cls: [] for cls in self.classes} |
| | |
| | |
| | images_dir = os.path.join(self.data_dir, "images") |
| | masks_dir = os.path.join(self.data_dir, "masks") |
| | |
| | if not os.path.exists(images_dir) or not os.path.exists(masks_dir): |
| | |
| | self.create_dummy_data() |
| | return |
| | |
| | |
| | for filename in os.listdir(images_dir): |
| | if filename.endswith(('.jpg', '.png')): |
| | image_path = os.path.join(images_dir, filename) |
| | mask_path = os.path.join(masks_dir, filename.replace('.jpg', '_mask.png')) |
| | |
| | if os.path.exists(mask_path): |
| | self.images.append(image_path) |
| | self.masks.append(mask_path) |
| | |
| | |
| | self.categorize_sample(image_path, mask_path) |
| | |
| | def create_dummy_data(self): |
| | """Create dummy robotics data for demonstration.""" |
| | print("Creating dummy robotics data...") |
| | |
| | |
| | os.makedirs(os.path.join(self.data_dir, "images"), exist_ok=True) |
| | os.makedirs(os.path.join(self.data_dir, "masks"), exist_ok=True) |
| | |
| | |
| | for i in range(100): |
| | |
| | image = np.random.randint(50, 150, (512, 512, 3), dtype=np.uint8) |
| | |
| | |
| | class_id = i % len(self.classes) |
| | |
| | if class_id == 0: |
| | |
| | center_x, center_y = 256, 256 |
| | width, height = 120, 160 |
| | image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [100, 100, 100] |
| | |
| | elif class_id == 1: |
| | |
| | center_x, center_y = 256, 256 |
| | width, height = 80, 120 |
| | image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [150, 100, 50] |
| | |
| | else: |
| | |
| | center_x, center_y = 256, 256 |
| | width, height = 100, 100 |
| | image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [200, 200, 50] |
| | |
| | |
| | image_path = os.path.join(self.data_dir, "images", f"robotics_{i:03d}.jpg") |
| | Image.fromarray(image).save(image_path) |
| | |
| | |
| | mask = np.zeros((512, 512), dtype=np.uint8) |
| | |
| | |
| | if class_id == 0: |
| | center_x, center_y = 256, 256 |
| | width, height = 120, 160 |
| | mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 1 |
| | |
| | elif class_id == 1: |
| | center_x, center_y = 256, 256 |
| | width, height = 80, 120 |
| | mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 2 |
| | |
| | else: |
| | center_x, center_y = 256, 256 |
| | width, height = 100, 100 |
| | mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 3 |
| | |
| | |
| | mask_path = os.path.join(self.data_dir, "masks", f"robotics_{i:03d}_mask.png") |
| | Image.fromarray(mask * 85).save(mask_path) |
| | |
| | |
| | self.images.append(image_path) |
| | self.masks.append(mask_path) |
| | |
| | |
| | self.categorize_sample(image_path, mask_path) |
| | |
| | def categorize_sample(self, image_path: str, mask_path: str): |
| | """Categorize sample by dominant class.""" |
| | mask = np.array(Image.open(mask_path)) |
| | |
| | |
| | class_counts = {} |
| | for i, class_name in enumerate(self.classes): |
| | class_counts[class_name] = np.sum(mask == (i + 1)) |
| | |
| | |
| | dominant_class = max(class_counts.items(), key=lambda x: x[1])[0] |
| | self.class_samples[dominant_class].append((image_path, mask_path)) |
| | |
| | def get_test_sample(self) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| | """Get a random test sample with ground truth masks.""" |
| | idx = random.randint(0, len(self.images) - 1) |
| | image = self.load_image(self.images[idx]) |
| | mask = self.load_mask(self.masks[idx]) |
| | |
| | |
| | ground_truth = {} |
| | for i, class_name in enumerate(self.classes): |
| | class_mask = (mask == (i + 1)).float() |
| | ground_truth[class_name] = class_mask |
| | |
| | return image, ground_truth |