| from typing import Dict, Tuple, Union |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| from PIL import Image |
| from PIL.Image import Image as PilImage |
| from torchvision import transforms |
| from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
| from transformers.image_utils import ImageInput |
|
|
|
|
| class RescaleT(object): |
| def __init__(self, output_size: Union[int, Tuple[int, int]]) -> None: |
| super().__init__() |
| assert isinstance(output_size, (int, tuple)) |
| self.output_size = output_size |
|
|
| def __call__(self, sample) -> Dict[str, np.ndarray]: |
| image, label = sample["image"], sample["label"] |
|
|
| h, w = image.shape[:2] |
|
|
| if isinstance(self.output_size, int): |
| if h > w: |
| new_h, new_w = self.output_size * h / w, self.output_size |
| else: |
| new_h, new_w = self.output_size, self.output_size * w / h |
| else: |
| new_h, new_w = self.output_size |
|
|
| new_h, new_w = int(new_h), int(new_w) |
|
|
| |
| |
| |
|
|
| |
| img = ( |
| cv2.resize( |
| image, |
| (self.output_size, self.output_size), |
| interpolation=cv2.INTER_AREA, |
| ) |
| / 255.0 |
| ) |
| |
| |
| |
| |
| lbl = cv2.resize( |
| label, (self.output_size, self.output_size), interpolation=cv2.INTER_NEAREST |
| ) |
| lbl = np.expand_dims(lbl, axis=-1) |
| lbl = np.clip(lbl, np.min(label), np.max(label)) |
|
|
| return {"image": img, "label": lbl} |
|
|
|
|
| class ToTensorLab(object): |
| """Convert ndarrays in sample to Tensors.""" |
|
|
| def __init__(self, flag: int = 0) -> None: |
| self.flag = flag |
|
|
| def __call__(self, sample): |
| image, label = sample["image"], sample["label"] |
|
|
| tmpLbl = np.zeros(label.shape) |
|
|
| if np.max(label) < 1e-6: |
| label = label |
| else: |
| label = label / np.max(label) |
|
|
| |
| if self.flag == 2: |
| tmpImg = np.zeros((image.shape[0], image.shape[1], 6)) |
| tmpImgt = np.zeros((image.shape[0], image.shape[1], 3)) |
| if image.shape[2] == 1: |
| tmpImgt[:, :, 0] = image[:, :, 0] |
| tmpImgt[:, :, 1] = image[:, :, 0] |
| tmpImgt[:, :, 2] = image[:, :, 0] |
| else: |
| tmpImgt = image |
| |
| tmpImgtl = cv2.cvtColor(tmpImgt, cv2.COLOR_RGB2LAB) |
|
|
| |
| tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / ( |
| np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0]) |
| ) |
| tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / ( |
| np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1]) |
| ) |
| tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / ( |
| np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2]) |
| ) |
| tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / ( |
| np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0]) |
| ) |
| tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / ( |
| np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1]) |
| ) |
| tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / ( |
| np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2]) |
| ) |
|
|
| |
|
|
| tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std( |
| tmpImg[:, :, 0] |
| ) |
| tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std( |
| tmpImg[:, :, 1] |
| ) |
| tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std( |
| tmpImg[:, :, 2] |
| ) |
| tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std( |
| tmpImg[:, :, 3] |
| ) |
| tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std( |
| tmpImg[:, :, 4] |
| ) |
| tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std( |
| tmpImg[:, :, 5] |
| ) |
|
|
| elif self.flag == 1: |
| tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) |
|
|
| if image.shape[2] == 1: |
| tmpImg[:, :, 0] = image[:, :, 0] |
| tmpImg[:, :, 1] = image[:, :, 0] |
| tmpImg[:, :, 2] = image[:, :, 0] |
| else: |
| tmpImg = image |
|
|
| |
| print("tmpImg:", tmpImg.min(), tmpImg.max()) |
| exit() |
| tmpImg = cv2.cvtColor(tmpImg, cv2.COLOR_RGB2LAB) |
|
|
| |
|
|
| tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / ( |
| np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0]) |
| ) |
| tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / ( |
| np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1]) |
| ) |
| tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / ( |
| np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2]) |
| ) |
|
|
| tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std( |
| tmpImg[:, :, 0] |
| ) |
| tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std( |
| tmpImg[:, :, 1] |
| ) |
| tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std( |
| tmpImg[:, :, 2] |
| ) |
|
|
| else: |
| tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) |
| image = image / np.max(image) |
| if image.shape[2] == 1: |
| tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 |
| tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229 |
| tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229 |
| else: |
| tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 |
| tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224 |
| tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225 |
|
|
| tmpLbl[:, :, 0] = label[:, :, 0] |
|
|
| |
| |
| tmpImg = tmpImg.transpose((2, 0, 1)) |
| tmpLbl = label.transpose((2, 0, 1)) |
|
|
| return {"image": torch.from_numpy(tmpImg), "label": torch.from_numpy(tmpLbl)} |
|
|
|
|
| def apply_transform( |
| data: Dict[str, np.ndarray], rescale_size: int, to_tensor_lab_flag: int |
| ) -> Dict[str, torch.Tensor]: |
| transform = transforms.Compose( |
| [RescaleT(output_size=rescale_size), ToTensorLab(flag=to_tensor_lab_flag)] |
| ) |
| return transform(data) |
|
|
|
|
| class BASNetImageProcessor(BaseImageProcessor): |
| model_input_names = ["pixel_values"] |
|
|
| def __init__( |
| self, rescale_size: int = 256, to_tensor_lab_flag: int = 0, **kwargs |
| ) -> None: |
| super().__init__(**kwargs) |
| self.rescale_size = rescale_size |
| self.to_tensor_lab_flag = to_tensor_lab_flag |
|
|
| def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature: |
| if not isinstance(images, PilImage): |
| raise ValueError(f"Expected PIL.Image, got {type(images)}") |
|
|
| image_pil = images |
| image_npy = np.array(image_pil, dtype=np.uint8) |
| width, height = image_pil.size |
| label_npy = np.zeros((height, width), dtype=np.uint8) |
|
|
| assert image_npy.shape[-1] == 3 |
| output = apply_transform( |
| {"image": image_npy, "label": label_npy}, |
| rescale_size=self.rescale_size, |
| to_tensor_lab_flag=self.to_tensor_lab_flag, |
| ) |
| image = output["image"] |
|
|
| assert isinstance(image, torch.Tensor) |
|
|
| return BatchFeature( |
| data={"pixel_values": image.float().unsqueeze(dim=0)}, tensor_type="pt" |
| ) |
|
|
| def postprocess( |
| self, prediction: torch.Tensor, width: int, height: int |
| ) -> PilImage: |
| def _norm_prediction(d: torch.Tensor) -> torch.Tensor: |
| ma, mi = torch.max(d), torch.min(d) |
|
|
| |
| dn = (d - mi) / ((ma - mi) + torch.finfo(torch.float32).eps) |
| return dn |
|
|
| prediction = _norm_prediction(prediction) |
| prediction = prediction.squeeze() |
| prediction = prediction * 255 + 0.5 |
| prediction = prediction.clamp(0, 255) |
|
|
| prediction_np = prediction.cpu().numpy() |
| image = Image.fromarray(prediction_np).convert("RGB") |
| image = image.resize((width, height), resample=Image.Resampling.BILINEAR) |
| return image |
|
|