Files changed (2) hide show
  1. handler.py +172 -0
  2. requirements.txt +31 -0
handler.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Inference Endpoint Handler for SongFormer
3
+ Supports binary audio input (WAV, MP3, etc.) via base64 encoding or direct bytes
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import io
9
+ import base64
10
+ import json
11
+ import tempfile
12
+ from typing import Dict, Any, Union
13
+ import librosa
14
+ import numpy as np
15
+ import torch
16
+ from transformers import AutoModel
17
+
18
+ class EndpointHandler:
19
+ """
20
+ HuggingFace Inference Endpoint Handler for SongFormer model.
21
+
22
+ Accepts base64-encoded audio (WAV, MP3, FLAC, etc.)
23
+ """
24
+
25
+ def __init__(self, path: str = ""):
26
+ """
27
+ Initialize the handler and load the SongFormer model.
28
+
29
+ Args:
30
+ path: Path to the model directory (provided by HuggingFace)
31
+ """
32
+ # Set up environment
33
+ self.model_path = path or os.getcwd()
34
+ os.environ["SONGFORMER_LOCAL_DIR"] = self.model_path
35
+ sys.path.insert(0, self.model_path)
36
+
37
+ # Import after setting up path
38
+
39
+ # Load the model
40
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ print(f"Loading SongFormer model on {self.device}...")
42
+
43
+ # Load model without device_map to avoid meta device initialization
44
+ # The SongFormerModel.__init__ now handles meta device detection
45
+ self.model = AutoModel.from_pretrained(
46
+ self.model_path,
47
+ trust_remote_code=True,
48
+ device_map=None,
49
+ )
50
+ self.model.to(self.device)
51
+ self.model.eval()
52
+
53
+ # Expected sampling rate for the model
54
+ self.target_sr = 24000
55
+
56
+ print("SongFormer model loaded successfully!")
57
+
58
+ def _decode_base64_audio(self, audio_b64: str) -> np.ndarray:
59
+ """
60
+ Decode base64-encoded audio to numpy array.
61
+
62
+ Args:
63
+ audio_b64: Base64-encoded audio string
64
+
65
+ Returns:
66
+ numpy array of audio samples at 24kHz
67
+ """
68
+ # Decode base64 string to bytes
69
+ try:
70
+ audio_bytes = base64.b64decode(audio_b64)
71
+ except Exception as e:
72
+ raise ValueError(f"Failed to decode base64 audio data: {e}")
73
+
74
+ # Load audio from bytes using librosa
75
+
76
+ # Create a file-like object from bytes
77
+ audio_io = io.BytesIO(audio_bytes)
78
+
79
+ # Load with librosa (automatically handles WAV, MP3, etc.)
80
+ audio_array, _ = librosa.load(audio_io, sr=self.target_sr, mono=True)
81
+
82
+ return audio_array
83
+
84
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
85
+ """
86
+ Process inference request with base64-encoded audio.
87
+
88
+ Expected input:
89
+ {
90
+ "inputs": "<base64-encoded-audio-data>"
91
+ }
92
+
93
+ Returns:
94
+ {
95
+ "segments": [
96
+ {
97
+ "label": "intro",
98
+ "start": 0.0,
99
+ "end": 15.2
100
+ },
101
+ ...
102
+ ],
103
+ "duration": 180.5,
104
+ "num_segments": 8
105
+ }
106
+ """
107
+ try:
108
+ # Extract base64-encoded audio
109
+ audio_b64 = data.get("inputs")
110
+ if not audio_b64:
111
+ raise ValueError("Missing 'inputs' key with base64-encoded audio")
112
+
113
+ if not isinstance(audio_b64, str):
114
+ raise ValueError("Input must be a base64-encoded string")
115
+
116
+ # Decode audio
117
+ audio_array = self._decode_base64_audio(audio_b64)
118
+
119
+ # Run inference
120
+ with torch.no_grad():
121
+ result = self.model(audio_array)
122
+
123
+ # Calculate duration
124
+ duration = len(audio_array) / self.target_sr
125
+
126
+ # Format output
127
+ output = {
128
+ "segments": result,
129
+ "duration": float(duration),
130
+ "num_segments": len(result)
131
+ }
132
+
133
+ return output
134
+
135
+ except Exception as e:
136
+ # Return error in a structured format
137
+ return {
138
+ "error": str(e),
139
+ "error_type": type(e).__name__,
140
+ "segments": [],
141
+ "duration": 0.0,
142
+ "num_segments": 0
143
+ }
144
+
145
+
146
+ # For local testing
147
+ if __name__ == "__main__":
148
+ import argparse
149
+
150
+ parser = argparse.ArgumentParser(description="Test SongFormer handler locally")
151
+ parser.add_argument("audio_file", help="Path to audio file to test")
152
+ parser.add_argument("--model-path", default=".", help="Path to model directory")
153
+ args = parser.parse_args()
154
+
155
+ # Initialize handler
156
+ handler = EndpointHandler(args.model_path)
157
+
158
+ # Read and encode audio file
159
+ with open(args.audio_file, "rb") as f:
160
+ audio_bytes = f.read()
161
+
162
+ audio_b64 = base64.b64encode(audio_bytes).decode('utf-8')
163
+
164
+ # Test with base64 input
165
+ print("\n=== Testing with base64-encoded audio ===")
166
+ result = handler({"inputs": audio_b64})
167
+ print(json.dumps(result, indent=2))
168
+
169
+ # Test with file path directly (for comparison)
170
+ print("\n=== Testing with direct file path (not typical for endpoint) ===")
171
+ result_direct = handler.model(args.audio_file)
172
+ print(json.dumps(result_direct, indent=2))
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requirements for HuggingFace Inference Endpoint
2
+ # This file contains dependencies needed for the handler.py to work
3
+
4
+ # Core ML frameworks
5
+ transformers>=4.30.0
6
+ torch>=2.0.0
7
+
8
+ # Audio processing
9
+ librosa>=0.10.0
10
+ soundfile>=0.12.0
11
+ audioread>=3.0.0
12
+
13
+ # Numerical computing
14
+ numpy>=1.24.0
15
+ scipy>=1.10.0
16
+
17
+ # Additional dependencies for SongFormer model
18
+ # (these may already be installed by the model itself)
19
+ einops>=0.7.0
20
+ x-transformers>=1.0.0
21
+ ema-pytorch>=0.2.0
22
+ loguru>=0.7.0
23
+ omegaconf>=2.3.0
24
+ muq
25
+ msaf
26
+
27
+
28
+
29
+ # Note: For MP3 support, ffmpeg must be installed in the system
30
+ # Add to Dockerfile:
31
+ # RUN apt-get update && apt-get install -y ffmpeg