| from contextlib import nullcontext |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| from einops import rearrange |
| from PIL import Image |
| from tqdm.auto import tqdm |
|
|
| from diffusers import DiffusionPipeline |
| from diffusers.pipelines.pipeline_utils import ImagePipelineOutput |
|
|
| from .constants import SUPPORTED_IMAGE_SIZES |
|
|
|
|
| PromptType = Union[str, List[str]] |
|
|
|
|
| def _get_pkv_seq_len(past_key_values) -> int: |
| """Get cached sequence length from past_key_values (supports tuple and DynamicCache).""" |
| if hasattr(past_key_values, "get_seq_length"): |
| return past_key_values.get_seq_length() |
| return past_key_values[0][0].shape[2] |
|
|
|
|
| class BitDanceDiffusionPipeline(DiffusionPipeline): |
| model_cpu_offload_seq = "text_encoder->projector->diffusion_head->autoencoder" |
|
|
| def __init__( |
| self, |
| tokenizer, |
| text_encoder, |
| autoencoder, |
| diffusion_head, |
| projector, |
| supported_image_sizes: Optional[List[List[int]]] = None, |
| dtype: Optional[torch.dtype] = None, |
| ) -> None: |
| super().__init__() |
| self.register_modules( |
| tokenizer=tokenizer, |
| text_encoder=text_encoder, |
| autoencoder=autoencoder, |
| diffusion_head=diffusion_head, |
| projector=projector, |
| ) |
|
|
| image_sizes = supported_image_sizes or SUPPORTED_IMAGE_SIZES |
| self.register_to_config(supported_image_sizes=[list(size) for size in image_sizes]) |
|
|
| self.hidden_size = self.text_encoder.config.hidden_size |
| self.vae_patch_size = self.autoencoder.patch_size |
| self.parallel_num = int(self.diffusion_head.config.parallel_num) |
| self.ps = int(self.parallel_num**0.5) |
| if self.ps * self.ps != self.parallel_num: |
| raise ValueError( |
| f"parallel_num must be a perfect square (got {self.parallel_num})." |
| ) |
|
|
| self._build_pos_embed() |
|
|
| @property |
| def supported_image_sizes(self) -> List[List[int]]: |
| return [list(size) for size in self.config.supported_image_sizes] |
|
|
| def _execution_device_fallback(self) -> torch.device: |
| if getattr(self, "_execution_device", None) is not None: |
| return self._execution_device |
| return next(self.text_encoder.parameters()).device |
|
|
| def _build_pos_embed(self) -> None: |
| max_resolution = max(max(size) for size in self.supported_image_sizes) |
| max_len = max_resolution // self.vae_patch_size |
| pos_embed_1d = self._get_1d_sincos_pos_embed(self.hidden_size // 2, max_len) |
| self.pos_embed_1d = pos_embed_1d |
|
|
| @staticmethod |
| def _get_1d_sincos_pos_embed(dim: int, max_len: int, pe_interpolation: float = 1.0) -> torch.Tensor: |
| if dim % 2 != 0: |
| raise ValueError(f"dim must be even, got {dim}") |
| omega = torch.arange(dim // 2, dtype=torch.float32) |
| omega /= dim / 2.0 |
| omega = 1.0 / 10000**omega |
| pos = torch.arange(max_len, dtype=torch.float32) / pe_interpolation |
| out = torch.einsum("m,d->md", pos, omega) |
| emb_sin = torch.sin(out) |
| emb_cos = torch.cos(out) |
| return torch.cat([emb_sin, emb_cos], dim=1) |
|
|
| def _get_2d_embed(self, h: int, w: int, ps: int = 1) -> torch.Tensor: |
| emb_v = self.pos_embed_1d[:h] |
| emb_h = self.pos_embed_1d[:w] |
| grid_v = emb_v.view(h, 1, self.hidden_size // 2).repeat(1, w, 1) |
| grid_h = emb_h.view(1, w, self.hidden_size // 2).repeat(h, 1, 1) |
| pos_embed = torch.cat([grid_h, grid_v], dim=-1) |
| return rearrange(pos_embed, "(h p1) (w p2) c -> (h w p1 p2) c", p1=ps, p2=ps) |
|
|
| def _encode_prompt_to_embeds( |
| self, |
| prompt: str, |
| image_size: Tuple[int, int], |
| num_images_per_prompt: int, |
| guidance_scale: float, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: |
| device = self._execution_device_fallback() |
| model = self.text_encoder.model |
| tokenizer = self.tokenizer |
|
|
| cond_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" |
| uncond_prompt = "<|im_start|>assistant\n" |
|
|
| cond_ids = torch.tensor(tokenizer.encode(cond_prompt), device=device, dtype=torch.long) |
| cond_emb = model.embed_tokens(cond_ids) |
| uncond_emb = None |
| if guidance_scale > 1.0: |
| uncond_ids = torch.tensor(tokenizer.encode(uncond_prompt), device=device, dtype=torch.long) |
| uncond_emb = model.embed_tokens(uncond_ids) |
|
|
| image_h, image_w = image_size |
| img_start_id = tokenizer.convert_tokens_to_ids("<|vision_start|>") |
| res_h_token_id = tokenizer.convert_tokens_to_ids(f"<|res_{image_h // self.vae_patch_size}|>") |
| res_w_token_id = tokenizer.convert_tokens_to_ids(f"<|res_{image_w // self.vae_patch_size}|>") |
| img_start_emb = model.embed_tokens(torch.tensor([img_start_id, res_h_token_id, res_w_token_id], device=device)) |
|
|
| for i in range(1, self.parallel_num): |
| query_token_id = tokenizer.convert_tokens_to_ids(f"<|query_{i}|>") |
| query_token = torch.tensor([query_token_id], device=device, dtype=torch.long) |
| query_embed = model.embed_tokens(query_token) |
| img_start_emb = torch.cat([img_start_emb, query_embed], dim=0) |
|
|
| input_embeds_cond = torch.cat([cond_emb, img_start_emb], dim=0).unsqueeze(0).repeat(num_images_per_prompt, 1, 1) |
| input_embeds_uncond = None |
| if guidance_scale > 1.0 and uncond_emb is not None: |
| input_embeds_uncond = torch.cat([uncond_emb, img_start_emb], dim=0).unsqueeze(0).repeat(num_images_per_prompt, 1, 1) |
| return input_embeds_cond, input_embeds_uncond, img_start_emb |
|
|
| def _decode_tokens_to_image(self, image_latents: torch.Tensor, image_size: Tuple[int, int], ps: int = 1) -> torch.Tensor: |
| h, w = image_size |
| image_latents = rearrange(image_latents, "b (h w p1 p2) c -> b c (h p1) (w p2)", h=h // ps, w=w // ps, p1=ps, p2=ps) |
| return self.autoencoder.decode(image_latents) |
|
|
| @torch.no_grad() |
| def _generate_single_prompt( |
| self, |
| prompt: str, |
| height: int, |
| width: int, |
| num_inference_steps: int, |
| guidance_scale: float, |
| num_images_per_prompt: int, |
| generator: Optional[torch.Generator], |
| show_progress_bar: bool, |
| ) -> torch.Tensor: |
| image_size = (height, width) |
| if list(image_size) not in self.supported_image_sizes: |
| raise ValueError( |
| f"image_size {list(image_size)} is not supported. " |
| f"Please choose from {self.supported_image_sizes}" |
| ) |
|
|
| h, w = height // self.vae_patch_size, width // self.vae_patch_size |
| max_length = h * w |
| step_width = self.parallel_num |
| if max_length % step_width != 0: |
| raise ValueError( |
| f"max_length ({max_length}) must be divisible by parallel_num ({step_width})." |
| ) |
| num_steps = max_length // step_width |
|
|
| device = self._execution_device_fallback() |
| model = self.text_encoder.model |
| dtype = next(self.text_encoder.parameters()).dtype |
|
|
| input_embeds_cond, input_embeds_uncond, _ = self._encode_prompt_to_embeds( |
| prompt=prompt, |
| image_size=image_size, |
| num_images_per_prompt=num_images_per_prompt, |
| guidance_scale=guidance_scale, |
| ) |
| pos_embed_for_diff = self._get_2d_embed(h, w, ps=self.ps).unsqueeze(0).to(device=device, dtype=dtype) |
|
|
| autocast_ctx = ( |
| torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16) |
| if device.type == "cuda" |
| else nullcontext() |
| ) |
|
|
| with autocast_ctx: |
| outputs_c = model(inputs_embeds=input_embeds_cond[:, :-step_width, :], use_cache=True) |
| pkv_c = outputs_c.past_key_values |
|
|
| bi_attn_mask = torch.ones( |
| (input_embeds_cond.shape[0], 1, step_width, step_width + _get_pkv_seq_len(pkv_c)), |
| dtype=torch.bool, |
| device=device, |
| ) |
| outputs_c = model( |
| inputs_embeds=input_embeds_cond[:, -step_width:, :], |
| past_key_values=pkv_c, |
| use_cache=True, |
| attention_mask=bi_attn_mask, |
| ) |
| pkv_c = outputs_c.past_key_values |
| hidden_c = outputs_c.last_hidden_state[:, -step_width:] |
|
|
| hidden_u = None |
| pkv_u = None |
| if guidance_scale > 1.0 and input_embeds_uncond is not None: |
| outputs_u = model(inputs_embeds=input_embeds_uncond[:, :-step_width, :], use_cache=True) |
| pkv_u = outputs_u.past_key_values |
| bi_attn_mask_u = torch.ones( |
| (input_embeds_uncond.shape[0], 1, step_width, step_width + _get_pkv_seq_len(pkv_u)), |
| dtype=torch.bool, |
| device=device, |
| ) |
| outputs_u = model( |
| inputs_embeds=input_embeds_uncond[:, -step_width:, :], |
| past_key_values=pkv_u, |
| use_cache=True, |
| attention_mask=bi_attn_mask_u, |
| ) |
| pkv_u = outputs_u.past_key_values |
| hidden_u = outputs_u.last_hidden_state[:, -step_width:] |
|
|
| out_tokens = [] |
| step_iter = range(num_steps) |
| if show_progress_bar: |
| step_iter = tqdm(step_iter, total=num_steps, desc="Decoding steps") |
|
|
| for step in step_iter: |
| if guidance_scale > 1.0 and hidden_u is not None: |
| h_fused = torch.cat([hidden_c, hidden_u], dim=0) |
| else: |
| h_fused = hidden_c |
|
|
| pos_slice = pos_embed_for_diff[:, step * step_width : (step + 1) * step_width, :] |
| h_fused = h_fused + pos_slice |
| pred_latents = self.diffusion_head.sample( |
| h_fused, |
| num_sampling_steps=num_inference_steps, |
| cfg=guidance_scale, |
| generator=generator, |
| ) |
| curr_tokens = torch.sign(pred_latents) |
| curr_embeds = self.projector(curr_tokens) |
| out_tokens.append(curr_tokens[:num_images_per_prompt]) |
|
|
| model_input = curr_embeds + pos_slice |
| bi_attn_mask = torch.ones( |
| (model_input.shape[0], 1, model_input.shape[1], model_input.shape[1] + _get_pkv_seq_len(pkv_c)), |
| dtype=torch.bool, |
| device=device, |
| ) |
| outputs_c = model( |
| inputs_embeds=model_input[:num_images_per_prompt], |
| past_key_values=pkv_c, |
| use_cache=True, |
| attention_mask=bi_attn_mask[:num_images_per_prompt], |
| ) |
| pkv_c = outputs_c.past_key_values |
| hidden_c = outputs_c.last_hidden_state[:, -step_width:] |
|
|
| if guidance_scale > 1.0 and hidden_u is not None and pkv_u is not None: |
| bi_attn_mask_u = torch.ones( |
| (model_input.shape[0], 1, model_input.shape[1], model_input.shape[1] + _get_pkv_seq_len(pkv_u)), |
| dtype=torch.bool, |
| device=device, |
| ) |
| outputs_u = model( |
| inputs_embeds=model_input[num_images_per_prompt:], |
| past_key_values=pkv_u, |
| use_cache=True, |
| attention_mask=bi_attn_mask_u[num_images_per_prompt:], |
| ) |
| pkv_u = outputs_u.past_key_values |
| hidden_u = outputs_u.last_hidden_state[:, -step_width:] |
|
|
| full_output = torch.cat(out_tokens, dim=1) |
| return self._decode_tokens_to_image(full_output, image_size=(h, w), ps=self.ps) |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| prompt: PromptType, |
| height: int = 1024, |
| width: int = 1024, |
| num_inference_steps: int = 50, |
| guidance_scale: float = 7.5, |
| num_images_per_prompt: int = 1, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| output_type: str = "pil", |
| return_dict: bool = True, |
| show_progress_bar: bool = False, |
| ) -> Union[ImagePipelineOutput, Tuple]: |
| prompts = [prompt] if isinstance(prompt, str) else list(prompt) |
| if len(prompts) == 0: |
| raise ValueError("prompt must be a non-empty string or list of strings.") |
|
|
| if isinstance(generator, list) and len(generator) != len(prompts): |
| raise ValueError("When passing a list of generators, its length must equal len(prompt).") |
|
|
| image_tensors = [] |
| for i, prompt_text in enumerate(prompts): |
| prompt_generator = generator[i] if isinstance(generator, list) else generator |
| images = self._generate_single_prompt( |
| prompt=prompt_text, |
| height=height, |
| width=width, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| num_images_per_prompt=num_images_per_prompt, |
| generator=prompt_generator, |
| show_progress_bar=show_progress_bar, |
| ) |
| image_tensors.append(images) |
|
|
| images_pt = torch.cat(image_tensors, dim=0) |
| images_pt_01 = torch.clamp((images_pt + 1.0) / 2.0, 0.0, 1.0) |
|
|
| if output_type == "pt": |
| output_images = images_pt_01 |
| elif output_type == "np": |
| output_images = images_pt_01.permute(0, 2, 3, 1).float().cpu().numpy() |
| elif output_type == "pil": |
| images_uint8 = ( |
| torch.clamp(127.5 * images_pt + 128.0, 0, 255) |
| .permute(0, 2, 3, 1) |
| .to("cpu", dtype=torch.uint8) |
| .numpy() |
| ) |
| output_images = [Image.fromarray(image) for image in images_uint8] |
| else: |
| raise ValueError(f"Unsupported output_type={output_type}. Expected 'pil', 'np', or 'pt'.") |
|
|
| if not return_dict: |
| return (output_images,) |
| return ImagePipelineOutput(images=output_images) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| prompt: str, |
| height: int = 1024, |
| width: int = 1024, |
| num_sampling_steps: int = 50, |
| guidance_scale: float = 7.5, |
| num_images: int = 1, |
| seed: Optional[int] = None, |
| ) -> List[Image.Image]: |
| generator = None |
| if seed is not None: |
| device = self._execution_device_fallback() |
| generator_device = "cuda" if device.type == "cuda" else "cpu" |
| generator = torch.Generator(device=generator_device).manual_seed(seed) |
| output = self( |
| prompt=prompt, |
| height=height, |
| width=width, |
| num_inference_steps=num_sampling_steps, |
| guidance_scale=guidance_scale, |
| num_images_per_prompt=num_images, |
| generator=generator, |
| output_type="pil", |
| return_dict=True, |
| show_progress_bar=True, |
| ) |
| return output.images |
|
|