| | import os |
| | import sys |
| | import gradio as gr |
| | os.system('git clone https://github.com/openai/CLIP') |
| | os.system('git clone https://github.com/crowsonkb/guided-diffusion') |
| | os.system('pip install -e ./CLIP') |
| | os.system('pip install -e ./guided-diffusion') |
| | os.system('pip install lpips') |
| | os.system("curl -OL 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'") |
| | import io |
| | import math |
| | import sys |
| | import lpips |
| | from PIL import Image |
| | import requests |
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| | from torchvision import transforms |
| | from torchvision.transforms import functional as TF |
| | from tqdm.notebook import tqdm |
| | sys.path.append('./CLIP') |
| | sys.path.append('./guided-diffusion') |
| | import clip |
| | from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults |
| | import numpy as np |
| | import imageio |
| | |
| | def fetch(url_or_path): |
| | if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): |
| | r = requests.get(url_or_path) |
| | r.raise_for_status() |
| | fd = io.BytesIO() |
| | fd.write(r.content) |
| | fd.seek(0) |
| | return fd |
| | return open(url_or_path, 'rb') |
| | def parse_prompt(prompt): |
| | if prompt.startswith('http://') or prompt.startswith('https://'): |
| | vals = prompt.rsplit(':', 2) |
| | vals = [vals[0] + ':' + vals[1], *vals[2:]] |
| | else: |
| | vals = prompt.rsplit(':', 1) |
| | vals = vals + ['', '1'][len(vals):] |
| | return vals[0], float(vals[1]) |
| | class MakeCutouts(nn.Module): |
| | def __init__(self, cut_size, cutn, cut_pow=1.): |
| | super().__init__() |
| | self.cut_size = cut_size |
| | self.cutn = cutn |
| | self.cut_pow = cut_pow |
| | def forward(self, input): |
| | sideY, sideX = input.shape[2:4] |
| | max_size = min(sideX, sideY) |
| | min_size = min(sideX, sideY, self.cut_size) |
| | cutouts = [] |
| | for _ in range(self.cutn): |
| | size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) |
| | offsetx = torch.randint(0, sideX - size + 1, ()) |
| | offsety = torch.randint(0, sideY - size + 1, ()) |
| | cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] |
| | cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) |
| | return torch.cat(cutouts) |
| | def spherical_dist_loss(x, y): |
| | x = F.normalize(x, dim=-1) |
| | y = F.normalize(y, dim=-1) |
| | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) |
| | def tv_loss(input): |
| | """L2 total variation loss, as in Mahendran et al.""" |
| | input = F.pad(input, (0, 1, 0, 1), 'replicate') |
| | x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] |
| | y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] |
| | return (x_diff**2 + y_diff**2).mean([1, 2, 3]) |
| | def range_loss(input): |
| | return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3]) |
| | |
| | def inference(text, init_image, skip_timesteps, clip_guidance_scale, tv_scale, range_scale, init_scale, seed, image_prompts,timestep_respacing, cutn): |
| | |
| | model_config = model_and_diffusion_defaults() |
| | model_config.update({ |
| | 'attention_resolutions': '32, 16, 8', |
| | 'class_cond': False, |
| | 'diffusion_steps': 1000, |
| | 'rescale_timesteps': True, |
| | 'timestep_respacing': str(timestep_respacing), |
| | |
| | 'image_size': 256, |
| | 'learn_sigma': True, |
| | 'noise_schedule': 'linear', |
| | 'num_channels': 256, |
| | 'num_head_channels': 64, |
| | 'num_res_blocks': 2, |
| | 'resblock_updown': True, |
| | 'use_fp16': True, |
| | 'use_scale_shift_norm': True, |
| | }) |
| | |
| | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| | print('Using device:', device) |
| | model, diffusion = create_model_and_diffusion(**model_config) |
| | model.load_state_dict(torch.load('256x256_diffusion_uncond.pt', map_location='cpu')) |
| | model.requires_grad_(False).eval().to(device) |
| | for name, param in model.named_parameters(): |
| | if 'qkv' in name or 'norm' in name or 'proj' in name: |
| | param.requires_grad_() |
| | if model_config['use_fp16']: |
| | model.convert_to_fp16() |
| | clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device) |
| | clip_size = clip_model.visual.input_resolution |
| | normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], |
| | std=[0.26862954, 0.26130258, 0.27577711]) |
| | lpips_model = lpips.LPIPS(net='vgg').to(device) |
| | |
| | |
| | all_frames = [] |
| | prompts = [text] |
| | if image_prompts: |
| | image_prompts = [image_prompts.name] |
| | else: |
| | image_prompts = [] |
| | batch_size = 1 |
| | clip_guidance_scale = clip_guidance_scale |
| | tv_scale = tv_scale |
| | range_scale = range_scale |
| | cutn = cutn |
| | n_batches = 1 |
| | if init_image: |
| | init_image = init_image.name |
| | else: |
| | init_image = None |
| | skip_timesteps = skip_timesteps |
| | |
| | init_scale = init_scale |
| | seed = seed |
| | |
| | if seed is not None: |
| | torch.manual_seed(seed) |
| | make_cutouts = MakeCutouts(clip_size, cutn) |
| | side_x = side_y = model_config['image_size'] |
| | target_embeds, weights = [], [] |
| | for prompt in prompts: |
| | txt, weight = parse_prompt(prompt) |
| | target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float()) |
| | weights.append(weight) |
| | for prompt in image_prompts: |
| | path, weight = parse_prompt(prompt) |
| | img = Image.open(fetch(path)).convert('RGB') |
| | img = TF.resize(img, min(side_x, side_y, *img.size), transforms.InterpolationMode.LANCZOS) |
| | batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device)) |
| | embed = clip_model.encode_image(normalize(batch)).float() |
| | target_embeds.append(embed) |
| | weights.extend([weight / cutn] * cutn) |
| | target_embeds = torch.cat(target_embeds) |
| | weights = torch.tensor(weights, device=device) |
| | if weights.sum().abs() < 1e-3: |
| | raise RuntimeError('The weights must not sum to 0.') |
| | weights /= weights.sum().abs() |
| | init = None |
| | if init_image is not None: |
| | init = Image.open(fetch(init_image)).convert('RGB') |
| | init = init.resize((side_x, side_y), Image.LANCZOS) |
| | init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1) |
| | cur_t = None |
| | def cond_fn(x, t, y=None): |
| | with torch.enable_grad(): |
| | x = x.detach().requires_grad_() |
| | n = x.shape[0] |
| | my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t |
| | out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y}) |
| | fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t] |
| | x_in = out['pred_xstart'] * fac + x * (1 - fac) |
| | clip_in = normalize(make_cutouts(x_in.add(1).div(2))) |
| | image_embeds = clip_model.encode_image(clip_in).float() |
| | dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0)) |
| | dists = dists.view([cutn, n, -1]) |
| | losses = dists.mul(weights).sum(2).mean(0) |
| | tv_losses = tv_loss(x_in) |
| | range_losses = range_loss(out['pred_xstart']) |
| | loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale |
| | if init is not None and init_scale: |
| | init_losses = lpips_model(x_in, init) |
| | loss = loss + init_losses.sum() * init_scale |
| | return -torch.autograd.grad(loss, x)[0] |
| | if model_config['timestep_respacing'].startswith('ddim'): |
| | sample_fn = diffusion.ddim_sample_loop_progressive |
| | else: |
| | sample_fn = diffusion.p_sample_loop_progressive |
| | for i in range(n_batches): |
| | cur_t = diffusion.num_timesteps - skip_timesteps - 1 |
| | samples = sample_fn( |
| | model, |
| | (batch_size, 3, side_y, side_x), |
| | clip_denoised=False, |
| | model_kwargs={}, |
| | cond_fn=cond_fn, |
| | progress=True, |
| | skip_timesteps=skip_timesteps, |
| | init_image=init, |
| | randomize_class=True, |
| | ) |
| | for j, sample in enumerate(samples): |
| | cur_t -= 1 |
| | if j % 1 == 0 or cur_t == -1: |
| | print() |
| | for k, image in enumerate(sample['pred_xstart']): |
| | |
| | img = TF.to_pil_image(image.add(1).div(2).clamp(0, 1)) |
| | all_frames.append(img) |
| | tqdm.write(f'Batch {i}, step {j}, output {k}:') |
| | |
| | writer = imageio.get_writer('video.mp4', fps=5) |
| | for im in all_frames: |
| | writer.append_data(np.array(im)) |
| | writer.close() |
| | return img, 'video.mp4' |
| | |
| | title = "CLIP Guided Diffusion" |
| | iface = gr.Interface(inference, inputs=["text",gr.inputs.Image(type="file", label='initial image (optional)', optional=True),gr.inputs.Slider(minimum=0, maximum=45, step=1, default=10, label="skip_timesteps"), gr.inputs.Slider(minimum=0, maximum=3000, step=1, default=600, label="clip guidance scale (Controls how much the image should look like the prompt)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="tv_scale (Controls the smoothness of the final output)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="range_scale (Controls how far out of range RGB values are allowed to be)"), gr.inputs.Slider(minimum=0, maximum=1000, step=1, default=0, label="init_scale (This enhances the effect of the init image)"), gr.inputs.Number(default=0, label="Seed"), gr.inputs.Image(type="file", label='image prompt (optional)', optional=True), gr.inputs.Slider(minimum=50, maximum=500, step=1, default=50, label="timestep respacing"),gr.inputs.Slider(minimum=1, maximum=64, step=1, default=32, label="cutn")], outputs=["image","video"], title=title, examples=[["coral reef city by artistation artists", None, 0, 1000, 150, 50, 0, 0, None, 90, 32]], |
| | enable_queue=True) |
| | iface.launch() |
| |
|