# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import random from collections import defaultdict from copy import deepcopy from dataclasses import dataclass from enum import IntEnum, auto from typing import List, Tuple, Dict from typing import Optional, Union, Any import numpy as np import torch import torch.nn.functional as F from diffusers.utils import BaseOutput from transformers.tokenization_utils_fast import PreTrainedTokenizerFast def default(value, default_value): return value if value is not None else default_value def ensure_list(value): if value is None: return [] if isinstance(value, (list, tuple)): return list(value) return [value] class Resolution(object): def __init__(self, size, *args): if isinstance(size, str): if 'x' in size: size = size.split('x') size = (int(size[0]), int(size[1])) else: size = int(size) if len(args) > 0: size = (size, args[0]) if isinstance(size, int): size = (size, size) self.h = self.height = size[0] self.w = self.width = size[1] self.r = self.ratio = self.height / self.width def __getitem__(self, idx): if idx == 0: return self.h elif idx == 1: return self.w else: raise IndexError(f'Index {idx} out of range') def __str__(self): return f'{self.h}x{self.w}' class ResolutionGroup(object): def __init__(self, base_size=None, step=None, align=1, extra_resolutions=None): self.align = align self.base_size = base_size assert base_size % align == 0, f'base_size {base_size} is not divisible by align {align}' if base_size is not None and not isinstance(base_size, int): raise ValueError(f'base_size must be None or int, but got {type(base_size)}') if step is None: step = base_size // 16 if step is not None and step > base_size // 2: raise ValueError(f'step must be smaller than base_size // 2, but got {step} > {base_size // 2}') self.step = step self.data = self._calc_by_step() if extra_resolutions is not None: for extra_resolution in extra_resolutions: height, width = extra_resolution.height, extra_resolution.width ratio = height / width flag = True for resolution in self.data: if resolution.ratio == ratio: flag = False break if flag: self.data.append(extra_resolution) self.ratio = np.array([x.ratio for x in self.data]) self.attr = ['' for _ in range(len(self.data))] self.prefix_space = 0 def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def __repr__(self): prefix = self.prefix_space * ' ' prefix_close = (self.prefix_space - 4) * ' ' res_str = f'ResolutionGroup(base_size={self.base_size}, step={self.step}, data=' attr_maxlen = max([len(x) for x in self.attr] + [5]) res_str += \ f'\n{prefix}ID: height width ratio {" " * max(0, attr_maxlen - 4)}count h/16 w/16 tokens\n{prefix}' res_str += \ ('\n' + prefix).join([f'{i:2d}: ({x.h:4d}, {x.w:4d}) {self.ratio[i]:.4f} {self.attr[i]:>{attr_maxlen}s} ' f'({x.h // 16:3d}, {x.w // 16:3d}) {x.h // 16 * x.w // 16:6d}' for i, x in enumerate(self.data)]) res_str += f'\n{prefix_close})' return res_str def _calc_by_step(self): assert self.align <= self.step, f'align {self.align} must be smaller than step {self.step}' min_height = self.base_size // 2 min_width = self.base_size // 2 max_height = self.base_size * 2 max_width = self.base_size * 2 resolutions = [Resolution(self.base_size, self.base_size)] cur_height, cur_width = self.base_size, self.base_size while True: if cur_height >= max_height and cur_width <= min_width: break cur_height = min(cur_height + self.step, max_height) cur_width = max(cur_width - self.step, min_width) resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align)) cur_height, cur_width = self.base_size, self.base_size while True: if cur_height <= min_height and cur_width >= max_width: break cur_height = max(cur_height - self.step, min_height) cur_width = min(cur_width + self.step, max_width) resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align)) resolutions = sorted(resolutions, key=lambda x: x.ratio) return resolutions def get_target_size(self, width, height): ratio = height / width idx = np.argmin(np.abs(self.ratio - ratio)) reso = self.data[idx] return reso.w, reso.h def get_base_size_and_ratio_index(self, width, height): ratio = height / width idx = np.argmin(np.abs(self.ratio - ratio)) return self.base_size, idx class ImageInfo: """ Class to store image information for processing and generation. """ def __init__( self, image_type: str = None, image_tensor: torch.Tensor = None, image_width: int = None, image_height: int = None, token_width: int = None, token_height: int = None, image_token_length: int = None, base_size: int = None, ratio_index: int = None, ori_image_width: int = None, ori_image_height: int = None, **kwargs, ): self.image_type = image_type self.image_tensor = image_tensor self.ori_image_width = ori_image_width self.image_width = image_width self.w = image_width self.ori_image_height = ori_image_height self.image_height = image_height self.h = image_height self.token_width = token_width self.tk_w = token_width self.token_height = token_height self.tk_h = token_height self.image_token_length = default( image_token_length, token_width * token_height if token_width is not None and token_height is not None else None ) self.base_size = base_size self.ratio_index = ratio_index self.add_timestep_token = kwargs.get("add_timestep_token", True) self.add_guidance_token = kwargs.get("add_guidance_token", False) self.use_front_boi_token = kwargs.get("use_front_boi_token", True) self.add_image_shape_token = kwargs.get("add_image_shape_token", True) self.add_timestep_r_token = kwargs.get("add_timestep_r_token", False) def __getitem__(self, key: str) -> Any: """Allow dictionary-like access to attributes.""" if hasattr(self, key): return getattr(self, key) raise KeyError(f"Key '{key}' not found in ImageInfo") def __setitem__(self, key: str, value: Any) -> None: """Allow dictionary-like assignment to attributes.""" if hasattr(self, key): setattr(self, key, value) else: raise KeyError(f"Key '{key}' not found in ImageInfo") def __contains__(self, key: str) -> bool: """Check if the key exists in the ImageInfo object.""" return hasattr(self, key) def __repr__(self): return (f"ImageInfo(image_type={self.image_type}, image_tensor={self.image_tensor}, " f"ori_image_width={self.ori_image_width}, ori_image_height={self.ori_image_height}, " f"image_width={self.image_width}, image_height={self.image_height}, " f"token_width={self.token_width}, token_height={self.token_height}, " f"image_token_length={self.image_token_length}, " f"base_size={self.base_size}, ratio_index={self.ratio_index}") @property def meta_info(self): # Used for image sections of tkwrapper.encode_general() if self.image_type in ["vae", "gen_image"]: return dict( token_length=self.image_token_length, add_timestep_token=self.add_timestep_token, add_guidance_token=self.add_guidance_token, add_timestep_r_token=self.add_timestep_r_token, use_front_boi_token=self.use_front_boi_token, add_image_shape_token=self.add_image_shape_token, base_size=self.base_size, ratio_idx=self.ratio_index, # for rope 2d token_height=self.token_height, token_width=self.token_width, # for bc image_height=self.image_height, image_width=self.image_width, ori_image_width=self.ori_image_width, ori_image_height=self.ori_image_height, ) elif self.image_type in ["vit", "siglip2"]: return dict( token_length=self.image_token_length, use_front_boi_token=self.use_front_boi_token, add_image_shape_token=self.add_image_shape_token, # for rope 2d token_height=self.token_height, token_width=self.token_width, # for bc image_height=self.image_height, image_width=self.image_width, ori_image_width=self.ori_image_width, ori_image_height=self.ori_image_height, ) else: raise ValueError(f"Unknown image type '{self.image_type}'") @property def num_special_tokens(self): if self.args is None: raise ValueError("meta_info requires `args` attribute to be set.") if self.image_type in ["vae", "src_image", "gen_image"]: count = ( 2 + # + or + (1 if self.add_timestep_token else 0) + (1 if self.add_guidance_token else 0) + (1 if self.add_timestep_r_token else 0) + (2 if self.add_image_shape_token else 0) ) else: raise ValueError(f"Unknown image_type: {self.image_type}") return count def copy(self, copy_image_tensor=True): if copy_image_tensor and self.image_tensor is None: raise ValueError("image_tensor is None, cannot copy") return ImageInfo( image_type=self.image_type, image_tensor=self.image_tensor.clone() if copy_image_tensor else None, image_width=self.image_width, image_height=self.image_height, ori_image_width=self.ori_image_width, ori_image_height=self.ori_image_height, token_width=self.token_width, token_height=self.token_height, image_token_length=self.image_token_length, base_size=self.base_size, ratio_index=self.ratio_index, ) def zeros_(self): self.image_tensor = torch.zeros_like(self.image_tensor) class ImageTensor(torch.Tensor): # This class is just for type hinting purposes. Attribute `i` should be defined # as an instance attribute of the torch.Tensor instance, like: tensor.i = ImageInfo(...) i: ImageInfo vision_encoder_kwargs: dict class JointImageInfo(object): def __init__(self, vae_image_info: ImageInfo, vision_image_info: ImageInfo, vision_encoder_kwargs: dict = None): self.vae_image_info = vae_image_info self.vision_image_info = vision_image_info self.vision_encoder_kwargs = vision_encoder_kwargs # Define key attributes to align with ImageInfo for uniformity self.image_type = "joint_image" self.image_token_length = vae_image_info.image_token_length + vision_image_info.image_token_length self.add_timestep_token = vae_image_info.add_timestep_token self.use_front_boi_token = vae_image_info.use_front_boi_token self.add_image_shape_token = vae_image_info.add_image_shape_token def __repr__(self): return f"JointImageInfo(vae_image={self.vae_image_info}, vision_image={self.vision_image_info})" @property def meta_info(self): # Used for image sections of tkwrapper.encode_general() return dict( token_length=[self.vae_image_info.image_token_length, self.vision_image_info.image_token_length], add_timestep_token=self.add_timestep_token, use_front_boi_token=self.use_front_boi_token, add_image_shape_token=self.add_image_shape_token, base_size=self.vae_image_info.base_size, ratio_idx=self.vae_image_info.ratio_index, # for rope 2d token_height=[self.vae_image_info.token_height, self.vision_image_info.token_height], token_width=[self.vae_image_info.token_width, self.vision_image_info.token_width], # for bc image_height=[self.vae_image_info.image_height, self.vision_image_info.image_height], image_width=[self.vae_image_info.image_width, self.vision_image_info.image_width], ) @property def num_special_tokens(self): return ( 2 + # + (1 if self.add_timestep_token else 0) + (2 if self.add_image_shape_token else 0) + 1 # ) def copy(self, copy_image_tensor=True): if copy_image_tensor and ( self.vae_image_info.image_tensor is None or self.vision_image_info.image_tensor is None): raise ValueError("image_tensor is None, cannot copy") return JointImageInfo( self.vae_image_info.copy(copy_image_tensor), self.vision_image_info.copy(copy_image_tensor), self.vision_encoder_kwargs, ) def zeros_(self): self.vae_image_info.zeros_() self.vision_image_info.zeros_() class CondImage(object): def __init__(self, image_type: str, vae_image: ImageTensor, vit_image: ImageTensor): self.image_type = image_type self.vae_image = vae_image self.vit_image = vit_image if image_type == "vae": self.i = vae_image.i self.section_type = "cond_vae_image" elif image_type == "vit": self.i = vit_image.i self.section_type = "cond_vit_image" elif image_type == "vae_vit": self.i = JointImageInfo(vae_image.i, vit_image.i) self.section_type = "cond_joint_image" else: raise ValueError(f"Unknown image_type: {image_type}") class TokenizerEncodeOutput(BaseOutput): tokens: torch.Tensor = None text_slices: Optional[list[slice]] = None vae_image_slices: Optional[list[slice]] = None gen_image_slices: Optional[list[slice]] = None vit_image_slices: Optional[list[slice]] = None joint_image_slices: Optional[list[slice]] = None all_image_slices: Optional[list[slice]] = None text_mask: Optional[torch.Tensor] = None vae_image_mask: Optional[torch.Tensor] = None gen_image_mask: Optional[torch.Tensor] = None vit_image_mask: Optional[torch.Tensor] = None real_pos: Optional[torch.Tensor] = None guidance_scatter_index: Optional[torch.Tensor] = None cond_timestep_scatter_index: Optional[torch.Tensor] = None gen_timestep_scatter_index: Optional[torch.Tensor] = None gen_timestep_r_scatter_index: Optional[torch.Tensor] = None class SeparatorStyle(IntEnum): ADD_COLON_SPACE_SINGLE = auto() NONE = auto() @dataclass class Conversation(object): name: str system_template: str = "{system_message}" system_message: str = "" roles: Tuple[str, str] = ("User", "Assistant") messages: List[List[str]] = () sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SPACE_SINGLE sep: str = "\n" sep2: str = None sep_sp: str = None stop_token_ids: list[int] = None def get_prompt(self, return_type="str", add_system=True): system_prompt = self.system_template.format(system_message=self.system_message) prompt_list = [] if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: seps = [self.sep, self.sep2] if add_system: prompt_list.append(("System", system_prompt + self.sep_sp if system_prompt else "")) for i, (role, message) in enumerate(self.messages): if message: prompt_list.append((role, f"{role}: {message}{seps[i % 2]}")) else: prompt_list.append((role, f"{role}: ")) elif self.sep_style == SeparatorStyle.NONE: seps = [self.sep, self.sep2] if add_system: prompt_list.append(("System", system_prompt + self.sep_sp if system_prompt else "")) for i, (role, message) in enumerate(self.messages): if message: prompt_list.append((role, f"{role}{message}{seps[i % 2]}")) else: prompt_list.append((role, f"{role}")) else: raise NotImplementedError(f"Unsupported sep_style: {self.sep_style}") if return_type == "str": prompt = "".join([msg for _, msg in prompt_list]) else: prompt = prompt_list return prompt def get_role_prefix(self, role): if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: return f"{role}: " elif self.sep_style == SeparatorStyle.NONE: return f"{role}" else: raise NotImplementedError(f"Unsupported sep_style: {self.sep_style}") def set_system_message(self, system_message: str): """Set the system message.""" self.system_message = system_message def add_message(self, role: str, message: str): """Append a new message.""" self.messages.append([role, message]) def copy(self): return deepcopy(self) def empty(self, name=None): """Return an empty conversation with the same template.""" return Conversation( name=name or self.name, system_template=self.system_template, system_message="", roles=self.roles, messages=[], sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, sep_sp=self.sep_sp, stop_token_ids=self.stop_token_ids, ) # A global registry for all conversation templates conv_templates: Dict[str, Conversation] = {} def register_conv_template(template: Conversation, override: bool = False): """Register a new conversation template.""" if not override: assert ( template.name not in conv_templates ), f"{template.name} has been registered." conv_templates[template.name] = template register_conv_template( Conversation( name="hunyuan-image-3", system_template="{system_message}", system_message="", roles=("User", "Assistant"), messages=[], sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, sep="\n\n", sep2="<|endoftext|>", sep_sp="\n\n", stop_token_ids=[127957], ) ) def get_conversation_template(name: str) -> Conversation: """Get a conversation template.""" return conv_templates[name].copy() class HunyuanImage3TokenizerFast(PreTrainedTokenizerFast): """ Tokenizer for Hunyuan Multimodal models, utilizing a fast tokenizer backend. This tokenizer extends the PreTrainedTokenizerFast from Hugging Face Transformers for multimodal tasks. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # A convenience mapping for special tokens special_tokens = self.special_tokens_map.get('additional_special_tokens', []) if len(special_tokens) > 0: special_token_ids = self.convert_tokens_to_ids(special_tokens) self._sp_dict = dict(zip(special_tokens, special_token_ids)) else: self._sp_dict = dict() # Assign commonly used special tokens to attributes for easy access. self.setup_special_tokens() # Define decorator section self.conversation_template = kwargs.get("conversation_template", "hunyuan-image-3") self.conversation = get_conversation_template(self.conversation_template) self.sequence_template = kwargs.get("sequence_template", "instruct") self.decorator_section = DecoratorSections( self, conv=self.conversation, sequence_template=self.sequence_template, ) def setup_special_tokens(self): # Define names for commonly used special tokens predefined_name_mapping = dict( boi="", eoi="", boa="", eoa="", bov="", eov="", img="", audio="