|
|
|
|
|
|
| import json |
| import torchvision.transforms as transforms |
| from torch.utils.data.dataset import Dataset |
| |
| from PIL import Image |
| import os |
| import torch |
| import torchvision.transforms.functional as F |
| def tokenize_captions( caption, tokenizer): |
| captions = [caption] |
| inputs = tokenizer( |
| captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" |
| ) |
| |
| |
| return inputs.input_ids |
|
|
|
|
|
|
|
|
| class SquarePad: |
| def __call__(self, image ): |
| w, h = image.size |
| max_wh = max(w, h) |
| hp = int((max_wh - w) / 2) |
| vp = int((max_wh - h) / 2) |
| padding = (hp, vp, hp, vp) |
| return F.pad(image, padding, (255,255,255), 'constant') |
|
|
| class NormalSegDataset(Dataset): |
| def __init__(self,args, path,tokenizer,cfg_prob ): |
| |
|
|
| self.image_transforms = transforms.Compose( |
| [ |
| |
| |
| |
| |
| transforms.RandomResizedCrop(args.resolution, scale=(0.9, 1.0), interpolation=2, ), |
| transforms.ToTensor(), |
| ] |
| ) |
|
|
| self.additional_image_transforms = transforms.Compose( |
| [transforms.Normalize([0.5], [0.5]),] |
| ) |
|
|
|
|
| meta_path = os.path.join(path, 'meta_train_seg.json') |
|
|
| with open(meta_path, 'r') as f: |
| self.meta = json.load(f) |
|
|
| |
|
|
| self.keys = self.meta['keys'] |
| self.meta = self.meta['data'] |
| |
|
|
| self.tokenizer = tokenizer |
|
|
| self.cfg_prob = cfg_prob |
| |
| def __len__(self): |
| return len(self.keys) |
| |
| def __getitem__(self, index): |
|
|
| meta_data = self.meta[self.keys[index]] |
|
|
| rgb_path = meta_data['rgb'] |
| normal_path = meta_data['normal'] |
| seg_path = meta_data['seg'] |
| text_prompt = meta_data['caption'][0] |
|
|
| rand = torch.rand(1).item() |
| if rand < self.cfg_prob: |
| text_prompt = "" |
| |
| image = Image.open(rgb_path).convert("RGB") |
| state = torch.get_rng_state() |
| image = self.image_transforms(image) |
|
|
| rand = torch.rand(1).item() |
| if rand < self.cfg_prob: |
| |
| |
| normal_image = Image.new('RGB', (image.shape[1], image.shape[2]), (255, 255, 255)) |
| |
| seg_image = Image.new('L', (image.shape[1], image.shape[2]), (0)) |
| else: |
| normal_image = Image.open(normal_path).convert("RGB") |
| seg_image = Image.open(seg_path).convert("L") |
| torch.set_rng_state(state) |
| normal_image = self.image_transforms(normal_image) |
|
|
| torch.set_rng_state(state) |
| seg_image = self.image_transforms(seg_image) |
|
|
|
|
| conditioning_image = torch.cat([normal_image, seg_image], dim=0) |
|
|
| image = self.additional_image_transforms(image) |
| |
| prompt = text_prompt |
| |
|
|
|
|
| |
| prompt = tokenize_captions(prompt, self.tokenizer) |
|
|
| return image, conditioning_image, prompt, text_prompt |
|
|
|
|
|
|