| import argparse |
| import os |
| import shutil |
|
|
| from tqdm import tqdm |
|
|
| from ort_common import WenetONNXRunner, pack_calibration_dataset |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser( |
| description="Generate calibration_dataset for exported ONNX models") |
| parser.add_argument("--input", |
| "-i", |
| nargs="+", |
| required=True, |
| help="Input wav file(s) or directory/directories") |
| parser.add_argument("--config", |
| required=True, |
| help="yaml file in checkpoint path") |
| parser.add_argument( |
| "--vocab", |
| required=True, |
| help="pretrained units.txt, for example pretrained/<model>/units.txt", |
| ) |
| parser.add_argument("--onnx_dir", |
| default="onnx_model", |
| help="directory containing exported ONNX models") |
| parser.add_argument("--calib_data_path", |
| default="calibration_dataset", |
| help="output calibration dataset directory") |
| parser.add_argument("--parts", |
| nargs="+", |
| choices=["all", "offline", "online", "decoder"], |
| default=["all"], |
| help="which model inputs to generate") |
| parser.add_argument("--offline_seq_len", type=int, default=1024) |
| parser.add_argument("--decoder_len", type=int, default=32) |
| parser.add_argument("--decoding_chunk_size", type=int, default=16) |
| parser.add_argument("--num_decoding_left_chunks", type=int, default=5) |
| parser.add_argument("--max_num", |
| type=int, |
| default=100, |
| help="maximum number of audio files used for calibration; set <= 0 to use all") |
| parser.add_argument("--keep_existing", |
| action="store_true", |
| help="append to an existing calibration directory") |
| return parser.parse_args() |
|
|
|
|
| def expand_audio_inputs(inputs): |
| audio_exts = {".wav", ".flac", ".mp3", ".m4a", ".ogg"} |
| audio_files = [] |
| for path in inputs: |
| if os.path.isdir(path): |
| for root, _, files in os.walk(path): |
| for filename in files: |
| if os.path.splitext(filename)[1].lower() in audio_exts: |
| audio_files.append(os.path.join(root, filename)) |
| else: |
| audio_files.append(path) |
| audio_files = sorted(audio_files) |
| if not audio_files: |
| raise FileNotFoundError("No audio files found") |
| return audio_files |
|
|
|
|
| def normalize_parts(parts): |
| if "all" in parts: |
| return {"offline", "online", "decoder"} |
| return set(parts) |
|
|
|
|
| def limit_audio_files(audio_files, max_num): |
| if max_num is None or max_num <= 0: |
| return audio_files |
| return audio_files[:max_num] |
|
|
|
|
| def main(): |
| args = get_args() |
| parts = normalize_parts(args.parts) |
| audio_files = limit_audio_files(expand_audio_inputs(args.input), |
| args.max_num) |
|
|
| if os.path.exists(args.calib_data_path) and not args.keep_existing: |
| shutil.rmtree(args.calib_data_path) |
| os.makedirs(args.calib_data_path, exist_ok=True) |
|
|
| runner = WenetONNXRunner( |
| args.config, |
| args.vocab, |
| onnx_dir=args.onnx_dir, |
| offline_seq_len=args.offline_seq_len, |
| decoder_len=args.decoder_len, |
| decoding_chunk_size=args.decoding_chunk_size, |
| num_decoding_left_chunks=args.num_decoding_left_chunks, |
| ) |
|
|
| counts = {"offline": 0, "online": 0, "decoder": 0} |
| progress = tqdm(audio_files, |
| desc="Generating calibration data", |
| unit="wav") |
| for audio_idx, audio_file in enumerate(progress): |
| sample_counts = runner.save_calibration_for_audio( |
| audio_file, parts, args.calib_data_path, audio_idx) |
| for key, value in sample_counts.items(): |
| counts[key] += value |
| progress.set_postfix(offline=counts["offline"], |
| online=counts["online"], |
| decoder=counts["decoder"]) |
|
|
| print("Packing calibration dataset...") |
| pack_calibration_dataset(args.calib_data_path) |
| print(f"Generated calibration data in {args.calib_data_path}") |
| print(f"offline samples: {counts['offline']}") |
| print(f"online samples: {counts['online']}") |
| print(f"decoder samples: {counts['decoder']}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|