| |
| |
|
|
| |
| import spaces |
| import traceback |
| import os |
| import time |
| import logging |
| from pathlib import Path |
| from typing import Tuple, Optional, Dict, Any |
| import gc |
| import gradio as gr |
| import numpy as np |
| import soundfile as sf |
| from huggingface_hub import snapshot_download |
|
|
|
|
| |
| |
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| logger = logging.getLogger("mmedit_space") |
|
|
|
|
|
|
| MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit") |
| MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None) |
|
|
| QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct") |
| QWEN_REVISION = os.environ.get("QWEN_REVISION", None) |
|
|
|
|
| HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
|
| OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs")) |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| USE_AMP = os.environ.get("USE_AMP", "0") == "1" |
| AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16") |
|
|
| _PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {} |
| |
| _MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {} |
|
|
|
|
| |
| |
| |
| def resolve_model_dirs() -> Tuple[Path, Path]: |
| cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}" |
| if cache_key in _MODEL_DIR_CACHE: |
| return _MODEL_DIR_CACHE[cache_key] |
|
|
| logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})") |
| repo_root = snapshot_download( |
| repo_id=MMEDIT_REPO_ID, |
| revision=MMEDIT_REVISION, |
| local_dir=None, |
| local_dir_use_symlinks=False, |
| token=HF_TOKEN, |
| ) |
| repo_root = Path(repo_root).resolve() |
|
|
| logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID} (revision={QWEN_REVISION})") |
| qwen_root = snapshot_download( |
| repo_id=QWEN_REPO_ID, |
| revision=QWEN_REVISION, |
| local_dir=None, |
| local_dir_use_symlinks=False, |
| token=HF_TOKEN, |
| ) |
| qwen_root = Path(qwen_root).resolve() |
|
|
| _MODEL_DIR_CACHE[cache_key] = (repo_root, qwen_root) |
| return repo_root, qwen_root |
|
|
|
|
| |
| |
| |
| def load_and_process_audio(audio_path: str, target_sr: int): |
| |
| import torch |
| import torchaudio |
| import librosa |
|
|
| |
|
|
| path = Path(audio_path) |
| if not path.exists(): |
| raise FileNotFoundError(f"Audio file not found: {audio_path}") |
|
|
| waveform, orig_sr = torchaudio.load(str(path)) |
|
|
| |
| if waveform.ndim == 2: |
| waveform = waveform.mean(dim=0) |
| elif waveform.ndim > 2: |
| waveform = waveform.reshape(-1) |
|
|
| if target_sr and int(target_sr) != int(orig_sr): |
| waveform_np = waveform.cpu().numpy() |
|
|
| |
| sr_mid = 16000 |
| if int(orig_sr) != sr_mid: |
| waveform_np = librosa.resample(waveform_np, orig_sr=int(orig_sr), target_sr=sr_mid) |
| orig_sr_mid = sr_mid |
| else: |
| orig_sr_mid = int(orig_sr) |
|
|
| |
| if int(target_sr) != orig_sr_mid: |
| waveform_np = librosa.resample(waveform_np, orig_sr=orig_sr_mid, target_sr=int(target_sr)) |
|
|
| waveform = torch.from_numpy(waveform_np) |
|
|
| return waveform |
|
|
|
|
| |
| |
| |
| def assert_repo_layout(repo_root: Path) -> None: |
| must = [repo_root / "config.yaml", repo_root / "model.safetensors", repo_root / "vae"] |
| for p in must: |
| if not p.exists(): |
| raise FileNotFoundError(f"Missing required path: {p}") |
|
|
| vae_files = list((repo_root / "vae").glob("*.ckpt")) |
| if len(vae_files) == 0: |
| raise FileNotFoundError(f"No .ckpt found under: {repo_root/'vae'}") |
|
|
|
|
| |
| |
| |
| def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_root: Path) -> None: |
| |
| vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None) |
| if vae_ckpt: |
| vae_ckpt = str(vae_ckpt).replace("\\", "/") |
| idx = vae_ckpt.find("vae/") |
| if idx != -1: |
| vae_rel = vae_ckpt[idx:] |
| else: |
| if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt: |
| vae_rel = f"vae/{vae_ckpt}" |
| else: |
| vae_rel = vae_ckpt |
|
|
| vae_path = (repo_root / vae_rel).resolve() |
| exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(vae_path) |
|
|
| if not vae_path.exists(): |
| raise FileNotFoundError( |
| f"VAE ckpt not found after patch:\n" |
| f" original: {vae_ckpt}\n" |
| f" patched : {vae_path}\n" |
| f"Repo root: {repo_root}\n" |
| f"Expected: {repo_root/'vae'/'*.ckpt'}" |
| ) |
|
|
| |
| exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root) |
|
|
|
|
|
|
| @spaces.GPU |
| def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, seed): |
| |
|
|
|
|
|
|
| import torch |
| import hydra |
| from omegaconf import OmegaConf |
| from safetensors.torch import load_file |
| import diffusers.schedulers as noise_schedulers |
| logger.info("🚀 Starting ..") |
| torch.backends.cuda.matmul.allow_tf32 = False |
| torch.backends.cudnn.allow_tf32 = False |
|
|
|
|
|
|
| try: |
| from utils.config import register_omegaconf_resolvers |
| register_omegaconf_resolvers() |
| except: pass |
|
|
| if not audio_file: return None, "Please upload audio." |
| |
|
|
| model = None |
| |
| try: |
| |
| logger.info("🚀 Starting ZeroGPU Task...") |
| |
| |
| repo_root, qwen_root = resolve_model_dirs() |
| exp_cfg = OmegaConf.to_container(OmegaConf.load(repo_root / "config.yaml"), resolve=True) |
| |
| |
| vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", "") |
| if vae_ckpt: |
| p1 = repo_root / "vae" / Path(vae_ckpt).name |
| p2 = repo_root / Path(vae_ckpt).name |
| if p1.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p1) |
| elif p2.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p2) |
| exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root) |
|
|
| |
| logger.info("Instantiating model (Hydra)...") |
| model = hydra.utils.instantiate(exp_cfg["model"], _convert_="all") |
| |
| |
| ckpt_path = str(repo_root / "model.safetensors") |
| logger.info(f"Loading weights from {ckpt_path}...") |
| sd = load_file(ckpt_path) |
| model.load_pretrained(sd) |
| del sd |
| gc.collect() |
|
|
| |
| |
| device = torch.device("cuda") |
| logger.info("Moving model to CUDA (FP16)...") |
| |
| |
| def safe_move_model(m, dev): |
| logger.info("🛡️ Moving model to GPU in FP32...") |
| for name, child in m.named_children(): |
| child.to(dev, dtype=torch.float32) |
| logger.info(f"Moving {name} to GPU (fp32)...") |
| m.to(dev, dtype=torch.float32) |
| return m |
|
|
| |
| model = safe_move_model(model, device) |
| model.eval() |
| logger.info("Model is moved to CUDA.") |
| |
| try: |
| scheduler = noise_schedulers.DDIMScheduler.from_pretrained( |
| exp_cfg["model"].get("noise_scheduler_name", ""), |
| subfolder="scheduler", token=HF_TOKEN |
| ) |
| except: |
| scheduler = noise_schedulers.DDIMScheduler(num_train_timesteps=1000) |
|
|
| |
| |
| |
| target_sr = int(exp_cfg.get("sample_rate", 24000)) |
| torch.manual_seed(int(seed)) |
| np.random.seed(int(seed)) |
| |
| wav = load_and_process_audio(audio_file, target_sr).to(device, dtype=torch.float32) |
| |
| batch = { |
| "audio_id": [Path(audio_file).stem], |
| "content": [{"audio": wav, "caption": caption}], |
| "task": ["audio_editing"], |
| "num_steps": int(num_steps), |
| "guidance_scale": float(guidance_scale), |
| "guidance_rescale": float(guidance_rescale), |
| "use_gt_duration": False, |
| "mask_time_aligned_content": False |
| } |
| |
| logger.info("Inference running...") |
| t0 = time.time() |
| with torch.no_grad(): |
| out = model.inference(scheduler=scheduler, **batch) |
| |
|
|
| out_audio = out[0, 0].detach().float().cpu().numpy() |
| out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav" |
| sf.write(str(out_path), out_audio, samplerate=target_sr) |
| |
| return str(out_path), f"Success | {time.time()-t0:.2f}s" |
|
|
| except Exception as e: |
| err = traceback.format_exc() |
| logger.error(f"❌ ERROR:\n{err}") |
| return None, f"Runtime Error: {e}" |
| |
| finally: |
| |
| logger.info("Cleaning up...") |
| if model is not None: del model |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| |
| |
| |
| def build_demo(): |
| with gr.Blocks(title="MMEdit") as demo: |
| gr.Markdown("# MMEdit ZeroGPU (Direct Load)") |
| with gr.Row(): |
| with gr.Column(): |
| audio_in = gr.Audio(label="Input", type="filepath") |
| caption = gr.Textbox(label="Instruction", lines=3) |
| gr.Examples( |
| label="Examples (Click to load)", |
| |
| examples=[ |
| |
| ["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."], |
| ["./YDKM2KjNkX18.wav", "Incorporate Telephone bell ringing into the background."], |
| ["./drop_audiocaps_1.wav", "Erase the rain falling sound from the background."], |
| ["./reorder_audiocaps_1.wav", "Switch the positions of the woman's voice and whistling."] |
| ], |
| inputs=[audio_in, caption], |
| cache_examples=False, |
| ) |
| with gr.Row(): |
| num_steps = gr.Slider(10, 100, 50, step=1, label="Steps") |
| guidance_scale = gr.Slider(1.0, 12.0, 5.0, step=0.5, label="Guidance") |
| guidance_rescale = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Rescale") |
| seed = gr.Number(42, label="Seed") |
| run_btn = gr.Button("Run", variant="primary") |
| |
| with gr.Column(): |
| out = gr.Audio(label="Output") |
| status = gr.Textbox(label="Status") |
| |
| run_btn.click(run_edit, [audio_in, caption, num_steps, guidance_scale, guidance_rescale, seed], [out, status]) |
| return demo |
|
|
| if __name__ == "__main__": |
| print("[BOOT] entering main()", flush=True) |
| demo = build_demo() |
| port = int(os.environ.get("PORT", "7860")) |
| print(f"[BOOT] launching gradio on 0.0.0.0:{port}", flush=True) |
| demo.queue().launch( |
| server_name="0.0.0.0", |
| server_port=port, |
| share=False, |
| ssr_mode=False, |
| ) |