| from typing import Sequence |
| import random |
| from typing import Any |
|
|
| from tqdm import tqdm |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import diffusers.schedulers as noise_schedulers |
| from diffusers.schedulers.scheduling_utils import SchedulerMixin |
| from diffusers.utils.torch_utils import randn_tensor |
|
|
| from models.autoencoder.autoencoder_base import AutoEncoderBase |
| from models.content_encoder.content_encoder import ContentEncoder |
| from models.content_adapter import ContentAdapterBase, ContentEncoderAdapterMixin |
| import soundfile as sf |
| from models.common import ( |
| LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase, |
| ) |
| from utils.torch_utilities import ( |
| create_alignment_path, create_mask_from_length, loss_with_mask, |
| trim_or_pad_length |
| ) |
|
|
|
|
| class DiffusionMixin: |
| def __init__( |
| self, |
| noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", |
| snr_gamma: float = None, |
| cfg_drop_ratio: float = 0.2 |
| ) -> None: |
| self.noise_scheduler_name = noise_scheduler_name |
| self.snr_gamma = snr_gamma |
| self.classifier_free_guidance = cfg_drop_ratio > 0.0 |
| self.cfg_drop_ratio = cfg_drop_ratio |
| self.noise_scheduler = noise_schedulers.DDPMScheduler.from_pretrained( |
| self.noise_scheduler_name, subfolder="scheduler" |
| ) |
|
|
| def compute_snr(self, timesteps) -> torch.Tensor: |
| """ |
| Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 |
| """ |
| alphas_cumprod = self.noise_scheduler.alphas_cumprod |
| sqrt_alphas_cumprod = alphas_cumprod**0.5 |
| sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod)**0.5 |
|
|
| |
| |
| sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device |
| )[timesteps].float() |
| while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): |
| sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
| alpha = sqrt_alphas_cumprod.expand(timesteps.shape) |
|
|
| sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( |
| device=timesteps.device |
| )[timesteps].float() |
| while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): |
| sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., |
| None] |
| sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) |
|
|
| |
| snr = (alpha / sigma)**2 |
| return snr |
|
|
| def get_timesteps( |
| self, |
| batch_size: int, |
| device: torch.device, |
| training: bool = True |
| ) -> torch.Tensor: |
| if training: |
| timesteps = torch.randint( |
| 0, |
| self.noise_scheduler.config.num_train_timesteps, |
| (batch_size, ), |
| device=device |
| ) |
| else: |
| |
| timesteps = (self.noise_scheduler.config.num_train_timesteps // |
| 2) * torch.ones((batch_size, ), |
| dtype=torch.int64, |
| device=device) |
|
|
| timesteps = timesteps.long() |
| return timesteps |
|
|
| def get_input_target_and_timesteps( |
| self, |
| latent: torch.Tensor, |
| training: bool, |
| ): |
| batch_size = latent.shape[0] |
| device = latent.device |
| num_train_timesteps = self.noise_scheduler.config.num_train_timesteps |
| self.noise_scheduler.set_timesteps(num_train_timesteps, device=device) |
| timesteps = self.get_timesteps(batch_size, device, training=training) |
| noise = torch.randn_like(latent) |
| noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps) |
| target = self.get_target(latent, noise, timesteps) |
| return noisy_latent, target, timesteps |
|
|
| def get_target( |
| self, latent: torch.Tensor, noise: torch.Tensor, |
| timesteps: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Get the target for loss depending on the prediction type |
| """ |
| if self.noise_scheduler.config.prediction_type == "epsilon": |
| target = noise |
| elif self.noise_scheduler.config.prediction_type == "v_prediction": |
| target = self.noise_scheduler.get_velocity( |
| latent, noise, timesteps |
| ) |
| else: |
| raise ValueError( |
| f"Unknown prediction type {self.noise_scheduler.config.prediction_type}" |
| ) |
| return target |
|
|
| def loss_with_snr( |
| self, |
| pred: torch.Tensor, |
| target: torch.Tensor, |
| timesteps: torch.Tensor, |
| mask: torch.Tensor, |
| reduce: bool = True |
| ) -> torch.Tensor: |
| if self.snr_gamma is None: |
| loss = F.mse_loss(pred.float(), target.float(), reduction="none") |
| loss = loss_with_mask(loss, mask, reduce=reduce) |
| else: |
| |
| |
| snr = self.compute_snr(timesteps) |
| mse_loss_weights = torch.stack( |
| [ |
| snr, |
| self.snr_gamma * torch.ones_like(timesteps), |
| ], |
| dim=1, |
| ).min(dim=1)[0] |
| |
| mse_loss_weights = mse_loss_weights / snr |
| loss = F.mse_loss(pred.float(), target.float(), reduction="none") |
| loss = loss_with_mask(loss, mask, reduce=False) * mse_loss_weights |
| if reduce: |
| loss = loss.mean() |
| return loss |
|
|
| def rescale_cfg( |
| self, pred_cond: torch.Tensor, pred_cfg: torch.Tensor, |
| guidance_rescale: float |
| ): |
| """ |
| Rescale `pred_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
| Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
| """ |
| std_cond = pred_cond.std( |
| dim=list(range(1, pred_cond.ndim)), keepdim=True |
| ) |
| std_cfg = pred_cfg.std(dim=list(range(1, pred_cfg.ndim)), keepdim=True) |
|
|
| pred_rescaled = pred_cfg * (std_cond / std_cfg) |
| pred_cfg = guidance_rescale * pred_rescaled + ( |
| 1 - guidance_rescale |
| ) * pred_cfg |
| return pred_cfg |
|
|
|
|
| class SingleTaskCrossAttentionAudioDiffusion( |
| LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase, |
| DiffusionMixin, ContentEncoderAdapterMixin |
| ): |
| def __init__( |
| self, |
| autoencoder: AutoEncoderBase, |
| content_encoder: ContentEncoder, |
| backbone: nn.Module, |
| content_dim: int, |
| noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", |
| snr_gamma: float = None, |
| cfg_drop_ratio: float = 0.2, |
| ): |
| nn.Module.__init__(self) |
| DiffusionMixin.__init__( |
| self, noise_scheduler_name, snr_gamma, cfg_drop_ratio |
| ) |
| ContentEncoderAdapterMixin.__init__( |
| self, content_encoder=content_encoder |
| ) |
| self.autoencoder = autoencoder |
| for param in self.autoencoder.parameters(): |
| param.requires_grad = False |
|
|
| if hasattr(self.content_encoder, "audio_encoder"): |
| self.content_encoder.audio_encoder.model = self.autoencoder |
|
|
| self.backbone = backbone |
| self.dummy_param = nn.Parameter(torch.empty(0)) |
|
|
| def forward( |
| self, content: list[Any], task: list[str], |
| waveform: torch.Tensor, waveform_lengths: torch.Tensor, **kwargs |
| ): |
| device = self.dummy_param.device |
|
|
| self.autoencoder.eval() |
| self.content_encoder.eval() |
| with torch.no_grad(): |
| |
| latent, latent_mask = self.autoencoder.encode( |
| waveform.unsqueeze(1), waveform_lengths,pad_latent_len=500 |
| ) |
|
|
| with torch.no_grad(): |
| content_dict = self.content_encoder.encode_content(content, task, device) |
| context, context_mask = content_dict["content"], content_dict[ |
| "content_mask"] |
| time_aligned_content = content_dict["length_aligned_content"] |
| time_aligned_content_mask = content_dict[ |
| "time_aligned_content_mask" |
| ] |
| latent_mask = time_aligned_content_mask.to(device) |
|
|
| if self.training and self.classifier_free_guidance: |
| mask_indices = [ |
| k for k in range(len(waveform)) |
| if random.random() < self.cfg_drop_ratio |
| ] |
| if len(mask_indices) > 0: |
| context[mask_indices] = 0 |
| |
| |
|
|
| noisy_latent, target, timesteps = self.get_input_target_and_timesteps( |
| latent, self.training |
| ) |
|
|
| pred: torch.Tensor = self.backbone( |
| x=noisy_latent, |
| timesteps=timesteps, |
| time_aligned_context=time_aligned_content, |
| context=context, |
| x_mask=latent_mask, |
| context_mask=context_mask |
| ) |
| |
| pred = pred.transpose(1, self.autoencoder.time_dim) |
| target = target.transpose(1, self.autoencoder.time_dim) |
| loss = self.loss_with_snr(pred, target, timesteps, latent_mask) |
|
|
| return loss |
|
|
| def prepare_latent( |
| self, batch_size: int, scheduler: SchedulerMixin, |
| latent_shape: Sequence[int], dtype: torch.dtype, device: str |
| ): |
| shape = (batch_size, *latent_shape) |
| latent = randn_tensor( |
| shape, generator=None, device=device, dtype=dtype |
| ) |
| |
| latent = latent * scheduler.init_noise_sigma |
| return latent |
|
|
| def iterative_denoise( |
| self, |
| latent: torch.Tensor, |
| scheduler: SchedulerMixin, |
| verbose: bool, |
| cfg: bool, |
| cfg_scale: float, |
| cfg_rescale: float, |
| backbone_input: dict, |
| ): |
| timesteps = scheduler.timesteps |
| num_steps = len(timesteps) |
| num_warmup_steps = len(timesteps) - num_steps * scheduler.order |
| progress_bar = tqdm(range(num_steps), disable=not verbose) |
|
|
| for i, timestep in enumerate(timesteps): |
| |
| if cfg: |
| latent_input = torch.cat([latent, latent]) |
| else: |
| latent_input = latent |
| latent_input = scheduler.scale_model_input(latent_input, timestep) |
| |
| noise_pred = self.backbone( |
| x=latent_input, timesteps=timestep, **backbone_input |
| ) |
|
|
| |
| if cfg: |
| noise_pred_uncond, noise_pred_content = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + cfg_scale * ( |
| noise_pred_content - noise_pred_uncond |
| ) |
| if cfg_rescale != 0.0: |
| noise_pred = self.rescale_cfg( |
| noise_pred_content, noise_pred, cfg_rescale |
| ) |
|
|
| |
| latent = scheduler.step(noise_pred, timestep, latent).prev_sample |
|
|
| |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and |
| (i + 1) % scheduler.order == 0): |
| progress_bar.update(1) |
|
|
| progress_bar.close() |
|
|
| return latent |
|
|
| @torch.no_grad() |
| def inference( |
| self, |
| content: list[Any], |
| task: list[str], |
| scheduler: SchedulerMixin, |
| num_steps: int = 50, |
| guidance_scale: float = 3.0, |
| guidance_rescale: float = 0.0, |
| disable_progress: bool = True, |
| mask_time_aligned_content: bool = True, |
| **kwargs |
| ): |
| device = self.dummy_param.device |
| classifier_free_guidance = guidance_scale > 1.0 |
| batch_size = len(content) |
|
|
|
|
| content_dict = self.content_encoder.encode_content(content, task, device) |
| |
|
|
| context, context_mask = content_dict["content"], content_dict[ |
| "content_mask"] |
| time_aligned_content = content_dict["length_aligned_content"] |
| time_aligned_content_mask = content_dict[ |
| "time_aligned_content_mask" |
| ] |
|
|
|
|
|
|
| B, T, C = time_aligned_content.shape |
| latent_shape = (C, T) |
| latent_mask=time_aligned_content_mask.to(device) |
|
|
| |
|
|
| if classifier_free_guidance: |
|
|
|
|
| if mask_time_aligned_content: |
| uncond_time_aligned_content = torch.zeros_like(time_aligned_content) |
| else: |
| uncond_time_aligned_content = time_aligned_content.detach().clone() |
|
|
| uncond_context = torch.zeros_like(context) |
| uncond_context_mask = context_mask.detach().clone() |
| time_aligned_content = torch.cat([ |
| uncond_time_aligned_content, time_aligned_content |
| ]) |
| context = torch.cat([uncond_context, context]) |
| context_mask = torch.cat([uncond_context_mask, context_mask]) |
| latent_mask = torch.cat([ |
| latent_mask, latent_mask.detach().clone() |
| ]) |
|
|
| scheduler.set_timesteps(num_steps, device=device) |
| |
| latent = self.prepare_latent( |
| batch_size, scheduler, latent_shape, context.dtype, device |
| ) |
|
|
| latent = self.iterative_denoise( |
| latent=latent, |
| scheduler=scheduler, |
| verbose=not disable_progress, |
| cfg=classifier_free_guidance, |
| cfg_scale=guidance_scale, |
| cfg_rescale=guidance_rescale, |
| backbone_input={ |
| "x_mask": latent_mask, |
| "context": context, |
| "context_mask": context_mask, |
| "time_aligned_context": time_aligned_content, |
| } |
| ) |
| waveform = self.autoencoder.decode(latent,latent_mask) |
|
|
| return waveform |
|
|
|
|
|
|