import argparse import os import shutil from tqdm import tqdm from ort_common import WenetONNXRunner, pack_calibration_dataset def get_args(): parser = argparse.ArgumentParser( description="Generate calibration_dataset for exported ONNX models") parser.add_argument("--input", "-i", nargs="+", required=True, help="Input wav file(s) or directory/directories") parser.add_argument("--config", required=True, help="yaml file in checkpoint path") parser.add_argument( "--vocab", required=True, help="pretrained units.txt, for example pretrained//units.txt", ) parser.add_argument("--onnx_dir", default="onnx_model", help="directory containing exported ONNX models") parser.add_argument("--calib_data_path", default="calibration_dataset", help="output calibration dataset directory") parser.add_argument("--parts", nargs="+", choices=["all", "offline", "online", "decoder"], default=["all"], help="which model inputs to generate") parser.add_argument("--offline_seq_len", type=int, default=1024) parser.add_argument("--decoder_len", type=int, default=32) parser.add_argument("--decoding_chunk_size", type=int, default=16) parser.add_argument("--num_decoding_left_chunks", type=int, default=5) parser.add_argument("--max_num", type=int, default=100, help="maximum number of audio files used for calibration; set <= 0 to use all") parser.add_argument("--keep_existing", action="store_true", help="append to an existing calibration directory") return parser.parse_args() def expand_audio_inputs(inputs): audio_exts = {".wav", ".flac", ".mp3", ".m4a", ".ogg"} audio_files = [] for path in inputs: if os.path.isdir(path): for root, _, files in os.walk(path): for filename in files: if os.path.splitext(filename)[1].lower() in audio_exts: audio_files.append(os.path.join(root, filename)) else: audio_files.append(path) audio_files = sorted(audio_files) if not audio_files: raise FileNotFoundError("No audio files found") return audio_files def normalize_parts(parts): if "all" in parts: return {"offline", "online", "decoder"} return set(parts) def limit_audio_files(audio_files, max_num): if max_num is None or max_num <= 0: return audio_files return audio_files[:max_num] def main(): args = get_args() parts = normalize_parts(args.parts) audio_files = limit_audio_files(expand_audio_inputs(args.input), args.max_num) if os.path.exists(args.calib_data_path) and not args.keep_existing: shutil.rmtree(args.calib_data_path) os.makedirs(args.calib_data_path, exist_ok=True) runner = WenetONNXRunner( args.config, args.vocab, onnx_dir=args.onnx_dir, offline_seq_len=args.offline_seq_len, decoder_len=args.decoder_len, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, ) counts = {"offline": 0, "online": 0, "decoder": 0} progress = tqdm(audio_files, desc="Generating calibration data", unit="wav") for audio_idx, audio_file in enumerate(progress): sample_counts = runner.save_calibration_for_audio( audio_file, parts, args.calib_data_path, audio_idx) for key, value in sample_counts.items(): counts[key] += value progress.set_postfix(offline=counts["offline"], online=counts["online"], decoder=counts["decoder"]) print("Packing calibration dataset...") pack_calibration_dataset(args.calib_data_path) print(f"Generated calibration data in {args.calib_data_path}") print(f"offline samples: {counts['offline']}") print(f"online samples: {counts['online']}") print(f"decoder samples: {counts['decoder']}") if __name__ == "__main__": main()