import os import sys import json import traceback import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import soundfile as sf import librosa from fairseq.tasks.audio_pretraining import AudioPretrainingTask from fairseq import checkpoint_utils from transformers import HubertModel, Wav2Vec2Model # Environment settings for MPS fallback os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0" # Configuration class for device and precision management class Config: def __init__(self, device): self.device = device if torch.cuda.is_available() else "cpu" self.is_half = self.device != "cpu" self.version_config_paths = [ os.path.join("", f"{k}.json") for k in ["32k", "40k", "48k", "48k_v2", "32k_v2"] ] self.json_config = self.load_config_json() self.device_config() def load_config_json(self): configs = {} for config_file in self.version_config_paths: config_path = os.path.join("configs", config_file) with open(config_path, "r") as f: configs[config_file] = json.load(f) return configs def device_config(self): if self.device.startswith("cuda"): i_device = int(self.device.split(":")[-1]) gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (1024**3) self.is_half = gpu_mem > 4 and "V100" in torch.cuda.get_device_name(i_device) elif torch.backends.mps.is_available(): self.device = "mps" self.is_half = False else: self.device = "cpu" self.is_half = False # Model-specific definitions class HubertModelWithFinalProj(HubertModel): def __init__(self, config): super().__init__(config) self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) def load_hubert_fairseq(model_path, device, is_half): # Load checkpoint and force audio_pretraining task saved_state = checkpoint_utils.load_checkpoint_to_cpu(model_path) saved_cfg = saved_state["cfg"] task = AudioPretrainingTask.setup_task(saved_cfg.task) models, saved_cfg, _ = checkpoint_utils.load_model_ensemble_and_task([model_path], task=task) model = models[0].to(device) if is_half and device not in ["mps", "cpu"]: model = model.half() model.eval() return {"model": model, "saved_cfg": saved_cfg} def load_huggingface_model(model_path, device, is_half, model_class=HubertModelWithFinalProj): dtype = torch.float16 if is_half and "cuda" in device else torch.float32 model = model_class.from_pretrained(model_path).to(device).to(dtype) model.eval() return {"model": model} def hubert_preprocess(feats, saved_cfg): if saved_cfg.task.normalize: with torch.no_grad(): feats = F.layer_norm(feats, feats.shape) return feats def hubert_prepare_input(feats, device, version): padding_mask = torch.BoolTensor(feats.shape).fill_(False).to(device) output_layer = 9 if version == "v1" else 12 return { "source": feats.half().to(device) if device not in ["mps", "cpu"] else feats.to(device), "padding_mask": padding_mask, "output_layer": output_layer, } def hubert_extract_features(model, inputs): with torch.no_grad(): logits = model.extract_features(**inputs) feats = model.final_proj(logits[0]) if inputs["output_layer"] == 9 else logits[0] return feats def general_preprocess(feats, *args): return feats def general_prepare_input(feats, device): return feats.to(device) def general_extract_features(model, inputs): with torch.no_grad(): feats = model(inputs)["last_hidden_state"] return feats # Model configurations model_configs = { "hubert": { "target_sr": 16000, "load_model": load_hubert_fairseq, "preprocess": hubert_preprocess, "prepare_input": hubert_prepare_input, "extract_features": hubert_extract_features, }, "contentvec": { "target_sr": 16000, "load_model": lambda path, dev, half: load_huggingface_model(path, dev, half, ContentVecModel), "preprocess": general_preprocess, "prepare_input": general_prepare_input, "extract_features": general_extract_features, }, "wav2vec": { "target_sr": 16000, "load_model": lambda path, dev, half: load_huggingface_model(path, dev, half, Wav2Vec2Model), "preprocess": general_preprocess, "prepare_input": general_prepare_input, "extract_features": general_extract_features, }, } # Utility functions def load_audio(file, target_sr): audio, sr = sf.read(file.strip()) if audio.ndim > 1: audio = librosa.to_mono(audio.T) if sr != target_sr: audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr) return audio def printt(f, strr): print(strr) f.write(f"{strr}\n") f.flush() # Main script def main(): # Parse arguments device = sys.argv[1] n_part = int(sys.argv[2]) i_part = int(sys.argv[3]) exp_dir = sys.argv[4] if len(sys.argv) == 6 else sys.argv[5] version = sys.argv[5] if len(sys.argv) == 6 else sys.argv[6] model_path = sys.argv[7] model_name = sys.argv[8] if len(sys.argv) > 6: os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[4] config = Config(device) log_file = open(f"{exp_dir}/extract_f0_feature.log", "a+") printt(log_file, f"Args: {sys.argv}") # Resolve model path and name custom_mappings = { "wav2vec" : ("wav2vec_small_960h.pt", "wav2vec"), "hubert_base": ("hubert_base.pt", "hubert"), "contentvec_base": ("contentvec_base.pt", "contentvec"), "hubert_base_japanese" : ("hubert_base_japanese.pt","hubert_base_japanese") } if os.path.split(model_path)[-1] == "Custom" and model_name in custom_mappings: model_path, resolved_model_name = custom_mappings[model_name] model_name = resolved_model_name if not os.path.exists(model_path): printt(log_file, f"Error: {model_path} does not exist.") sys.exit(1) # Load model model_config = model_configs.get(model_name, model_configs[model_name]) model_dict = model_config["load_model"](model_path, config.device, config.is_half) model = model_dict["model"] additional_configs = model_dict.get("saved_cfg") printt(log_file, f"Loaded model from {model_path} on {config.device}") # Setup directories feature_dim = 256 if version == "v1" else 768 if model_name != "hubert_large_ll60k" else 1024 wav_path = f"{exp_dir}/1_16k_wavs" out_path = f"{exp_dir}/3_feature{feature_dim}" os.makedirs(out_path, exist_ok=True) # Process audio files todo = sorted(os.listdir(wav_path))[i_part::n_part] printt(log_file, f"Total files to process: {len(todo)}") if not todo: printt(log_file, "No files to process.") return target_sr = model_config["target_sr"] for idx, file in enumerate(todo): if not file.endswith(".wav"): continue try: wav_file = f"{wav_path}/{file}" out_file = f"{out_path}/{file.replace('.wav', '.npy')}" if os.path.exists(out_file): continue # Load and preprocess audio wav = load_audio(wav_file, target_sr) feats = torch.from_numpy(wav).float().view(1, -1) if feats.dim() > 2: feats = feats.mean(-1) preprocessed_feats = model_config["preprocess"](feats, additional_configs) inputs = model_config["prepare_input"](preprocessed_feats, config.device, version) feats = model_config["extract_features"](model, inputs) # Save features feats = feats.squeeze(0).float().cpu().numpy() if not np.isnan(feats).any(): np.save(out_file, feats, allow_pickle=False) printt(log_file, f"Processed {file}: {feats.shape}") else: printt(log_file, f"{file} contains NaN values") except Exception: printt(log_file, traceback.format_exc()) printt(log_file, "Feature extraction completed.") log_file.close() if __name__ == "__main__": main()