| from typing import Dict, Optional, Tuple |
|
|
| import numpy as np |
| import torch.nn.functional as F |
| import torchvision.transforms as transforms |
| import torchvision.transforms.functional as TF |
| from PIL.Image import Image |
| from torch import Tensor |
| from transformers.image_processing_utils import BaseImageProcessor |
|
|
| INPUT_IMAGE_SIZE = (352, 352) |
|
|
| transform = transforms.Compose( |
| [ |
| transforms.Resize( |
| INPUT_IMAGE_SIZE, |
| interpolation=TF.InterpolationMode.BICUBIC, |
| ), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| (0.5, 0.5, 0.5), |
| (0.5, 0.5, 0.5), |
| ), |
| ] |
| ) |
|
|
|
|
| class DPTDepthImageProcessor(BaseImageProcessor): |
| model_input_names = ["dptdepth_preprocessor"] |
|
|
| def __init__(self, testsize: Optional[int] = 352, **kwargs) -> None: |
| super().__init__(**kwargs) |
| self.testsize = testsize |
|
|
| def preprocess( |
| self, inputs: Dict[str, Image], **kwargs |
| ) -> Dict[str, Tensor]: |
| rgb: Tensor = transform(inputs["rgb"]) |
| return dict(rgb=rgb.unsqueeze(0)) |
|
|
| def postprocess( |
| self, logits: Tensor, size: Tuple[int, int], **kwargs |
| ) -> np.ndarray: |
| logits: Tensor = F.upsample( |
| logits, size=size, mode="bilinear", align_corners=False |
| ) |
| res: np.ndarray = logits.squeeze().data.cpu().numpy() |
| |
| return res |
|
|