| |
| import argparse |
| import math |
| import os |
| from copy import deepcopy |
|
|
| import torch |
| from audio_diffusion.models import DiffusionAttnUnet1D |
| from diffusion import sampling |
| from torch import nn |
|
|
| from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel |
|
|
|
|
| MODELS_MAP = { |
| "gwf-440k": { |
| "url": "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt", |
| "sample_rate": 48000, |
| "sample_size": 65536, |
| }, |
| "jmann-small-190k": { |
| "url": "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt", |
| "sample_rate": 48000, |
| "sample_size": 65536, |
| }, |
| "jmann-large-580k": { |
| "url": "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt", |
| "sample_rate": 48000, |
| "sample_size": 131072, |
| }, |
| "maestro-uncond-150k": { |
| "url": "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt", |
| "sample_rate": 16000, |
| "sample_size": 65536, |
| }, |
| "unlocked-uncond-250k": { |
| "url": "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt", |
| "sample_rate": 16000, |
| "sample_size": 65536, |
| }, |
| "honk-140k": { |
| "url": "https://model-server.zqevans2.workers.dev/honk-140k.ckpt", |
| "sample_rate": 16000, |
| "sample_size": 65536, |
| }, |
| } |
|
|
|
|
| def alpha_sigma_to_t(alpha, sigma): |
| """Returns a timestep, given the scaling factors for the clean image and for |
| the noise.""" |
| return torch.atan2(sigma, alpha) / math.pi * 2 |
|
|
|
|
| def get_crash_schedule(t): |
| sigma = torch.sin(t * math.pi / 2) ** 2 |
| alpha = (1 - sigma**2) ** 0.5 |
| return alpha_sigma_to_t(alpha, sigma) |
|
|
|
|
| class Object(object): |
| pass |
|
|
|
|
| class DiffusionUncond(nn.Module): |
| def __init__(self, global_args): |
| super().__init__() |
|
|
| self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4) |
| self.diffusion_ema = deepcopy(self.diffusion) |
| self.rng = torch.quasirandom.SobolEngine(1, scramble=True) |
|
|
|
|
| def download(model_name): |
| url = MODELS_MAP[model_name]["url"] |
| os.system(f"wget {url} ./") |
|
|
| return f"./{model_name}.ckpt" |
|
|
|
|
| DOWN_NUM_TO_LAYER = { |
| "1": "resnets.0", |
| "2": "attentions.0", |
| "3": "resnets.1", |
| "4": "attentions.1", |
| "5": "resnets.2", |
| "6": "attentions.2", |
| } |
| UP_NUM_TO_LAYER = { |
| "8": "resnets.0", |
| "9": "attentions.0", |
| "10": "resnets.1", |
| "11": "attentions.1", |
| "12": "resnets.2", |
| "13": "attentions.2", |
| } |
| MID_NUM_TO_LAYER = { |
| "1": "resnets.0", |
| "2": "attentions.0", |
| "3": "resnets.1", |
| "4": "attentions.1", |
| "5": "resnets.2", |
| "6": "attentions.2", |
| "8": "resnets.3", |
| "9": "attentions.3", |
| "10": "resnets.4", |
| "11": "attentions.4", |
| "12": "resnets.5", |
| "13": "attentions.5", |
| } |
| DEPTH_0_TO_LAYER = { |
| "0": "resnets.0", |
| "1": "resnets.1", |
| "2": "resnets.2", |
| "4": "resnets.0", |
| "5": "resnets.1", |
| "6": "resnets.2", |
| } |
|
|
| RES_CONV_MAP = { |
| "skip": "conv_skip", |
| "main.0": "conv_1", |
| "main.1": "group_norm_1", |
| "main.3": "conv_2", |
| "main.4": "group_norm_2", |
| } |
|
|
| ATTN_MAP = { |
| "norm": "group_norm", |
| "qkv_proj": ["query", "key", "value"], |
| "out_proj": ["proj_attn"], |
| } |
|
|
|
|
| def convert_resconv_naming(name): |
| if name.startswith("skip"): |
| return name.replace("skip", RES_CONV_MAP["skip"]) |
|
|
| |
| if not name.startswith("main."): |
| raise ValueError(f"ResConvBlock error with {name}") |
|
|
| return name.replace(name[:6], RES_CONV_MAP[name[:6]]) |
|
|
|
|
| def convert_attn_naming(name): |
| for key, value in ATTN_MAP.items(): |
| if name.startswith(key) and not isinstance(value, list): |
| return name.replace(key, value) |
| elif name.startswith(key): |
| return [name.replace(key, v) for v in value] |
| raise ValueError(f"Attn error with {name}") |
|
|
|
|
| def rename(input_string, max_depth=13): |
| string = input_string |
|
|
| if string.split(".")[0] == "timestep_embed": |
| return string.replace("timestep_embed", "time_proj") |
|
|
| depth = 0 |
| if string.startswith("net.3."): |
| depth += 1 |
| string = string[6:] |
| elif string.startswith("net."): |
| string = string[4:] |
|
|
| while string.startswith("main.7."): |
| depth += 1 |
| string = string[7:] |
|
|
| if string.startswith("main."): |
| string = string[5:] |
|
|
| |
| if string[:2].isdigit(): |
| layer_num = string[:2] |
| string_left = string[2:] |
| else: |
| layer_num = string[0] |
| string_left = string[1:] |
|
|
| if depth == max_depth: |
| new_layer = MID_NUM_TO_LAYER[layer_num] |
| prefix = "mid_block" |
| elif depth > 0 and int(layer_num) < 7: |
| new_layer = DOWN_NUM_TO_LAYER[layer_num] |
| prefix = f"down_blocks.{depth}" |
| elif depth > 0 and int(layer_num) > 7: |
| new_layer = UP_NUM_TO_LAYER[layer_num] |
| prefix = f"up_blocks.{max_depth - depth - 1}" |
| elif depth == 0: |
| new_layer = DEPTH_0_TO_LAYER[layer_num] |
| prefix = f"up_blocks.{max_depth - 1}" if int(layer_num) > 3 else "down_blocks.0" |
|
|
| if not string_left.startswith("."): |
| raise ValueError(f"Naming error with {input_string} and string_left: {string_left}.") |
|
|
| string_left = string_left[1:] |
|
|
| if "resnets" in new_layer: |
| string_left = convert_resconv_naming(string_left) |
| elif "attentions" in new_layer: |
| new_string_left = convert_attn_naming(string_left) |
| string_left = new_string_left |
|
|
| if not isinstance(string_left, list): |
| new_string = prefix + "." + new_layer + "." + string_left |
| else: |
| new_string = [prefix + "." + new_layer + "." + s for s in string_left] |
| return new_string |
|
|
|
|
| def rename_orig_weights(state_dict): |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| if k.endswith("kernel"): |
| |
| continue |
|
|
| new_k = rename(k) |
|
|
| |
| if isinstance(new_k, list): |
| new_state_dict = transform_conv_attns(new_state_dict, new_k, v) |
| else: |
| new_state_dict[new_k] = v |
|
|
| return new_state_dict |
|
|
|
|
| def transform_conv_attns(new_state_dict, new_k, v): |
| if len(new_k) == 1: |
| if len(v.shape) == 3: |
| |
| new_state_dict[new_k[0]] = v[:, :, 0] |
| else: |
| |
| new_state_dict[new_k[0]] = v |
| else: |
| |
| trippled_shape = v.shape[0] |
| single_shape = trippled_shape // 3 |
| for i in range(3): |
| if len(v.shape) == 3: |
| new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0] |
| else: |
| new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape] |
| return new_state_dict |
|
|
|
|
| def main(args): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| model_name = args.model_path.split("/")[-1].split(".")[0] |
| if not os.path.isfile(args.model_path): |
| assert ( |
| model_name == args.model_path |
| ), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}" |
| args.model_path = download(model_name) |
|
|
| sample_rate = MODELS_MAP[model_name]["sample_rate"] |
| sample_size = MODELS_MAP[model_name]["sample_size"] |
|
|
| config = Object() |
| config.sample_size = sample_size |
| config.sample_rate = sample_rate |
| config.latent_dim = 0 |
|
|
| diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate) |
| diffusers_state_dict = diffusers_model.state_dict() |
|
|
| orig_model = DiffusionUncond(config) |
| orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"]) |
| orig_model = orig_model.diffusion_ema.eval() |
| orig_model_state_dict = orig_model.state_dict() |
| renamed_state_dict = rename_orig_weights(orig_model_state_dict) |
|
|
| renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys()) |
| diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys()) |
|
|
| assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}" |
| assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}" |
|
|
| for key, value in renamed_state_dict.items(): |
| assert ( |
| diffusers_state_dict[key].squeeze().shape == value.squeeze().shape |
| ), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}" |
| if key == "time_proj.weight": |
| value = value.squeeze() |
|
|
| diffusers_state_dict[key] = value |
|
|
| diffusers_model.load_state_dict(diffusers_state_dict) |
|
|
| steps = 100 |
| seed = 33 |
|
|
| diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps) |
|
|
| generator = torch.manual_seed(seed) |
| noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device) |
|
|
| t = torch.linspace(1, 0, steps + 1, device=device)[:-1] |
| step_list = get_crash_schedule(t) |
|
|
| pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler) |
|
|
| generator = torch.manual_seed(33) |
| audio = pipe(num_inference_steps=steps, generator=generator).audios |
|
|
| generated = sampling.iplms_sample(orig_model, noise, step_list, {}) |
| generated = generated.clamp(-1, 1) |
|
|
| diff_sum = (generated - audio).abs().sum() |
| diff_max = (generated - audio).abs().max() |
|
|
| if args.save: |
| pipe.save_pretrained(args.checkpoint_path) |
|
|
| print("Diff sum", diff_sum) |
| print("Diff max", diff_max) |
|
|
| assert diff_max < 1e-3, f"Diff max: {diff_max} is too much :-/" |
|
|
| print(f"Conversion for {model_name} successful!") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") |
| parser.add_argument( |
| "--save", default=True, type=bool, required=False, help="Whether to save the converted model or not." |
| ) |
| parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") |
| args = parser.parse_args() |
|
|
| main(args) |
|
|