johnbridges commited on
Commit
df3ba3c
·
verified ·
1 Parent(s): db0a2ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -83
app.py CHANGED
@@ -11,15 +11,14 @@ import uuid
11
  import logging
12
  from flask_cors import CORS
13
  import threading
 
14
  import tempfile
15
  from huggingface_hub import snapshot_download
16
- from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError
17
- import time
18
  from tts_processor import preprocess_all
19
  import hashlib
20
 
21
  # Configure logging
22
- logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
25
  app = Flask(__name__)
@@ -38,13 +37,52 @@ SERVE_DIR = os.environ.get("SERVE_DIR", "./files") # Default to './files' if no
38
 
39
  os.makedirs(SERVE_DIR, exist_ok=True)
40
  def validate_audio_file(file):
41
- if file.content_type not in ["audio/wav", "audio/x-wav", "audio/mpeg", "audio/mp3"]:
42
- raise ValueError("Unsupported file type")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  file.seek(0, os.SEEK_END)
44
  file_size = file.tell()
45
  file.seek(0) # Reset file pointer
46
- if file_size > 10 * 1024 * 1024: # 10 MB limit
47
- raise ValueError("File is too large (max 10 MB)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def validate_text_input(text):
50
  if not isinstance(text, str):
@@ -66,72 +104,59 @@ def is_cached(cached_file_path):
66
  exists = os.path.exists(cached_file_path) # Perform disk check
67
  file_cache[cached_file_path] = exists # Update the cache
68
  return exists
69
- import time
70
- from huggingface_hub import snapshot_download
71
- from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError
72
 
 
73
  def initialize_models():
74
  global sess, voice_style, processor, whisper_model
75
 
76
- max_retries = 5 # Maximum number of retries
77
- retry_delay = 2 # Initial delay in seconds (will double after each retry)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- for attempt in range(max_retries):
80
- try:
81
- # Download the ONNX model if not already downloaded
82
- if not os.path.exists(model_path):
83
- logger.info(f"Attempt {attempt + 1} to download and load Kokoro model...")
84
- kokoro_dir = snapshot_download(kokoro_model_id, cache_dir=model_path)
85
- logger.info(f"Kokoro model directory: {kokoro_dir}")
86
- else:
87
- kokoro_dir = model_path
88
- logger.info(f"Using cached Kokoro model directory: {kokoro_dir}")
89
-
90
- # Validate ONNX file path
91
- onnx_path = None
92
- for root, _, files in os.walk(kokoro_dir):
93
- if 'model.onnx' in files:
94
- onnx_path = os.path.join(root, 'model.onnx')
95
- break
96
-
97
- if not onnx_path or not os.path.exists(onnx_path):
98
- raise FileNotFoundError(f"ONNX file not found after redownload at {kokoro_dir}")
99
-
100
- logger.info("Loading ONNX session...")
101
- sess = InferenceSession(onnx_path)
102
- logger.info(f"ONNX session loaded successfully from {onnx_path}")
103
-
104
- # Load the voice style vector
105
- voice_style_path = None
106
- for root, _, files in os.walk(kokoro_dir):
107
- if f'{voice_name}.bin' in files:
108
- voice_style_path = os.path.join(root, f'{voice_name}.bin')
109
- break
110
-
111
- if not voice_style_path or not os.path.exists(voice_style_path):
112
- raise FileNotFoundError(f"Voice style file not found at {voice_style_path}")
113
-
114
- logger.info("Loading voice style vector...")
115
- voice_style = np.fromfile(voice_style_path, dtype=np.float32).reshape(-1, 1, 256)
116
- logger.info(f"Voice style vector loaded successfully from {voice_style_path}")
117
-
118
- # Initialize Whisper model for S2T
119
- logger.info("Downloading and loading Whisper model...")
120
- processor = WhisperProcessor.from_pretrained("openai/whisper-base")
121
- whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
122
- whisper_model.config.forced_decoder_ids = None
123
- logger.info("Whisper model loaded successfully")
124
-
125
- # If everything succeeds, break out of the retry loop
126
- break
127
-
128
- except (RepositoryNotFoundError, HfHubHTTPError, FileNotFoundError) as e:
129
- logger.error(f"Attempt {attempt + 1} failed: {str(e)}")
130
- if attempt == max_retries - 1:
131
- logger.error("Max retries reached. Failed to initialize models.")
132
- raise # Re-raise the exception if max retries are reached
133
- time.sleep(retry_delay)
134
- retry_delay *= 2 # Exponential backoff
135
 
136
  # Initialize models
137
  initialize_models()
@@ -221,24 +246,60 @@ def generate_audio():
221
  return jsonify({"status": "error", "message": str(e)}), 500
222
 
223
  # Speech-to-Text (S2T) Endpoint
 
 
 
 
 
 
224
  @app.route('/transcribe_audio', methods=['POST'])
225
  def transcribe_audio():
226
- """Speech-to-Text (S2T) Endpoint"""
227
  with global_lock: # Acquire global lock to ensure only one instance runs
228
- audio_path = None
 
229
  try:
230
  logger.debug("Received request to /transcribe_audio")
231
  file = request.files['file']
232
- validate_audio_file(file)
233
- # Generate a unique filename using uuid
234
- unique_filename = f"{uuid.uuid4().hex}_{file.filename}"
235
- audio_path = os.path.join("/tmp", unique_filename)
236
- file.save(audio_path)
237
- logger.debug(f"Audio file saved to {audio_path}")
238
-
239
- # Load and preprocess audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  logger.debug("Processing audio for transcription...")
241
- audio_array, sampling_rate = librosa.load(audio_path, sr=16000)
242
 
243
  input_features = processor(
244
  audio_array,
@@ -257,10 +318,14 @@ def transcribe_audio():
257
  logger.error(f"Error transcribing audio: {str(e)}")
258
  return jsonify({"status": "error", "message": str(e)}), 500
259
  finally:
260
- # Ensure temporary file is removed
261
- if audio_path and os.path.exists(audio_path):
262
- os.remove(audio_path)
263
- logger.debug(f"Temporary file {audio_path} removed")
 
 
 
 
264
 
265
  @app.route('/files/<filename>', methods=['GET'])
266
  def serve_wav_file(filename):
 
11
  import logging
12
  from flask_cors import CORS
13
  import threading
14
+ import werkzeug
15
  import tempfile
16
  from huggingface_hub import snapshot_download
 
 
17
  from tts_processor import preprocess_all
18
  import hashlib
19
 
20
  # Configure logging
21
+ logging.basicConfig(level=logging.DEBUG)
22
  logger = logging.getLogger(__name__)
23
 
24
  app = Flask(__name__)
 
37
 
38
  os.makedirs(SERVE_DIR, exist_ok=True)
39
  def validate_audio_file(file):
40
+ """Validates audio files including WebM/Opus format"""
41
+ if not isinstance(file, werkzeug.datastructures.FileStorage):
42
+ raise ValueError("Invalid file type")
43
+
44
+ # Supported MIME types (add WebM/Opus)
45
+ supported_types = [
46
+ "audio/wav",
47
+ "audio/x-wav",
48
+ "audio/mpeg",
49
+ "audio/mp3",
50
+ "audio/webm",
51
+ "audio/ogg" # For Opus in Ogg container
52
+ ]
53
+
54
+ # Check MIME type
55
+ if file.content_type not in supported_types:
56
+ raise ValueError(f"Unsupported file type. Must be one of: {', '.join(supported_types)}")
57
+
58
+ # Check file size
59
  file.seek(0, os.SEEK_END)
60
  file_size = file.tell()
61
  file.seek(0) # Reset file pointer
62
+
63
+ max_size = 10 * 1024 * 1024 # 10 MB
64
+ if file_size > max_size:
65
+ raise ValueError(f"File is too large (max {max_size//(1024*1024)} MB)")
66
+
67
+ # Optional: Verify file header matches content_type
68
+ if not verify_audio_header(file):
69
+ raise ValueError("File header doesn't match declared content type")
70
+ def verify_audio_header(file):
71
+ """Quickly checks if file headers match the declared audio format"""
72
+ header = file.read(4)
73
+ file.seek(0) # Rewind after reading
74
+
75
+ if file.content_type in ["audio/webm", "audio/ogg"]:
76
+ # WebM starts with \x1aE\xdf\xa3, Ogg with OggS
77
+ return (
78
+ (file.content_type == "audio/webm" and header.startswith(b'\x1aE\xdf\xa3')) or
79
+ (file.content_type == "audio/ogg" and header.startswith(b'OggS'))
80
+ )
81
+ elif file.content_type in ["audio/wav", "audio/x-wav"]:
82
+ return header.startswith(b'RIFF')
83
+ elif file.content_type in ["audio/mpeg", "audio/mp3"]:
84
+ return header.startswith(b'\xff\xfb') # MP3 frame sync
85
+ return True # Skip verification for other types
86
 
87
  def validate_text_input(text):
88
  if not isinstance(text, str):
 
104
  exists = os.path.exists(cached_file_path) # Perform disk check
105
  file_cache[cached_file_path] = exists # Update the cache
106
  return exists
 
 
 
107
 
108
+ # Initialize models
109
  def initialize_models():
110
  global sess, voice_style, processor, whisper_model
111
 
112
+ try:
113
+ # Download the ONNX model if not already downloaded
114
+ if not os.path.exists(model_path):
115
+ logger.info("Downloading and loading Kokoro model...")
116
+ kokoro_dir = snapshot_download(kokoro_model_id, cache_dir=model_path)
117
+ logger.info(f"Kokoro model directory: {kokoro_dir}")
118
+ else:
119
+ kokoro_dir = model_path
120
+ logger.info(f"Using cached Kokoro model directory: {kokoro_dir}")
121
+
122
+ # Validate ONNX file path
123
+ onnx_path = None
124
+ for root, _, files in os.walk(kokoro_dir):
125
+ if 'model.onnx' in files:
126
+ onnx_path = os.path.join(root, 'model.onnx')
127
+ break
128
+
129
+ if not onnx_path or not os.path.exists(onnx_path):
130
+ raise FileNotFoundError(f"ONNX file not found after redownload at {kokoro_dir}")
131
+
132
+ logger.info("Loading ONNX session...")
133
+ sess = InferenceSession(onnx_path)
134
+ logger.info(f"ONNX session loaded successfully from {onnx_path}")
135
+
136
+ # Load the voice style vector
137
+ voice_style_path = None
138
+ for root, _, files in os.walk(kokoro_dir):
139
+ if f'{voice_name}.bin' in files:
140
+ voice_style_path = os.path.join(root, f'{voice_name}.bin')
141
+ break
142
+
143
+ if not voice_style_path or not os.path.exists(voice_style_path):
144
+ raise FileNotFoundError(f"Voice style file not found at {voice_style_path}")
145
+
146
+ logger.info("Loading voice style vector...")
147
+ voice_style = np.fromfile(voice_style_path, dtype=np.float32).reshape(-1, 1, 256)
148
+ logger.info(f"Voice style vector loaded successfully from {voice_style_path}")
149
+
150
+ # Initialize Whisper model for S2T
151
+ logger.info("Downloading and loading Whisper model...")
152
+ processor = WhisperProcessor.from_pretrained("openai/whisper-base")
153
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
154
+ whisper_model.config.forced_decoder_ids = None
155
+ logger.info("Whisper model loaded successfully")
156
 
157
+ except Exception as e:
158
+ logger.error(f"Error initializing models: {str(e)}")
159
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  # Initialize models
162
  initialize_models()
 
246
  return jsonify({"status": "error", "message": str(e)}), 500
247
 
248
  # Speech-to-Text (S2T) Endpoint
249
+ # Add these imports at the top with the other imports
250
+ import subprocess
251
+ import tempfile
252
+ from pathlib import Path
253
+
254
+ # Then update the transcribe_audio function:
255
  @app.route('/transcribe_audio', methods=['POST'])
256
  def transcribe_audio():
257
+ """Speech-to-Text (S2T) Endpoint with automatic format conversion"""
258
  with global_lock: # Acquire global lock to ensure only one instance runs
259
+ input_audio_path = None
260
+ converted_audio_path = None
261
  try:
262
  logger.debug("Received request to /transcribe_audio")
263
  file = request.files['file']
264
+
265
+ # Create temporary files for both input and output
266
+ with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as input_temp:
267
+ input_audio_path = input_temp.name
268
+ file.save(input_audio_path)
269
+ logger.debug(f"Original audio file saved to {input_audio_path}")
270
+
271
+ # Create a temporary file for the converted WAV
272
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as output_temp:
273
+ converted_audio_path = output_temp.name
274
+
275
+ # Convert to WAV with ffmpeg (16kHz, mono)
276
+ logger.debug(f"Converting audio to 16kHz mono WAV format...")
277
+ conversion_command = [
278
+ 'ffmpeg',
279
+ '-y', # Force overwrite without prompting
280
+ '-i', input_audio_path,
281
+ '-acodec', 'pcm_s16le', # 16-bit PCM
282
+ '-ac', '1', # mono
283
+ '-ar', '16000', # 16kHz sample rate
284
+ '-af', 'highpass=f=80,lowpass=f=7500,afftdn=nr=10:nf=-25,loudnorm=I=-16:TP=-1.5:LRA=11', # Audio cleanup filters
285
+ converted_audio_path
286
+ ]
287
+ result = subprocess.run(
288
+ conversion_command,
289
+ stdout=subprocess.PIPE,
290
+ stderr=subprocess.PIPE,
291
+ text=True
292
+ )
293
+
294
+ if result.returncode != 0:
295
+ logger.error(f"FFmpeg conversion error: {result.stderr}")
296
+ raise Exception(f"Audio conversion failed: {result.stderr}")
297
+
298
+ logger.debug(f"Audio successfully converted to {converted_audio_path}")
299
+
300
+ # Load and process the converted audio
301
  logger.debug("Processing audio for transcription...")
302
+ audio_array, sampling_rate = librosa.load(converted_audio_path, sr=16000)
303
 
304
  input_features = processor(
305
  audio_array,
 
318
  logger.error(f"Error transcribing audio: {str(e)}")
319
  return jsonify({"status": "error", "message": str(e)}), 500
320
  finally:
321
+ # Clean up temporary files
322
+ for path in [input_audio_path, converted_audio_path]:
323
+ if path and os.path.exists(path):
324
+ try:
325
+ os.remove(path)
326
+ logger.debug(f"Temporary file {path} removed")
327
+ except Exception as e:
328
+ logger.warning(f"Failed to remove temporary file {path}: {e}")
329
 
330
  @app.route('/files/<filename>', methods=['GET'])
331
  def serve_wav_file(filename):