kevinwang676 commited on
Commit
10d4c29
·
verified ·
1 Parent(s): a9f4026

Update api_v2.py

Browse files
Files changed (1) hide show
  1. api_v2.py +85 -50
api_v2.py CHANGED
@@ -129,14 +129,23 @@ cut_method_names = get_cut_method_names()
129
  import os
130
  import sys
131
  import traceback
132
- from typing import Generator
133
  import requests
134
  import tempfile
135
  import urllib.parse
136
  from pathlib import Path
137
 
138
  # Function to check if a path is a URL and download it if needed
139
- def process_audio_path(audio_path):
 
 
 
 
 
 
 
 
 
140
  if audio_path and (audio_path.startswith('http://') or audio_path.startswith('https://') or
141
  audio_path.startswith('s3://')):
142
  try:
@@ -175,13 +184,13 @@ def process_audio_path(audio_path):
175
  f.write(chunk)
176
 
177
  print(f"Downloaded to: {local_path}")
178
- return local_path
179
  except Exception as e:
180
  print(f"Error downloading audio file: {e}")
181
  raise Exception(f"Failed to download audio from URL: {e}")
182
 
183
  # If not a URL or download failed, return the original path
184
- return audio_path
185
 
186
  parser = argparse.ArgumentParser(description="GPT-SoVITS api")
187
  parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
@@ -332,40 +341,14 @@ def check_params(req:dict):
332
  async def tts_handle(req:dict):
333
  """
334
  Text to speech handler.
335
-
336
- Args:
337
- req (dict):
338
- {
339
- "text": "", # str.(required) text to be synthesized
340
- "text_lang: "", # str.(required) language of the text to be synthesized
341
- "ref_audio_path": "", # str.(required) reference audio path or URL
342
- "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths or URLs
343
- "prompt_text": "", # str.(optional) prompt text for the reference audio
344
- "prompt_lang": "", # str.(required) language of the prompt text for the reference audio
345
- "top_k": 5, # int. top k sampling
346
- "top_p": 1, # float. top p sampling
347
- "temperature": 1, # float. temperature for sampling
348
- "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
349
- "batch_size": 1, # int. batch size for inference
350
- "batch_threshold": 0.75, # float. threshold for batch splitting.
351
- "split_bucket: True, # bool. whether to split the batch into multiple buckets.
352
- "speed_factor":1.0, # float. control the speed of the synthesized audio.
353
- "fragment_interval":0.3, # float. to control the interval of the audio fragment.
354
- "seed": -1, # int. random seed for reproducibility.
355
- "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
356
- "streaming_mode": False, # bool. whether to return a streaming response.
357
- "parallel_infer": True, # bool.(optional) whether to use parallel inference.
358
- "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
359
- "sample_steps": 32, # int. number of sampling steps for VITS model V3.
360
- "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
361
- }
362
- returns:
363
- StreamingResponse: audio stream response.
364
  """
365
 
366
  streaming_mode = req.get("streaming_mode", False)
367
  return_fragment = req.get("return_fragment", False)
368
  media_type = req.get("media_type", "wav")
 
 
 
369
 
370
  check_res = check_params(req)
371
  if check_res is not None:
@@ -376,36 +359,69 @@ async def tts_handle(req:dict):
376
 
377
  try:
378
  # Process ref_audio_path (download if it's a URL)
379
- req["ref_audio_path"] = process_audio_path(req["ref_audio_path"])
 
 
 
380
 
381
  # Process aux_ref_audio_paths (download if they're URLs)
382
  if req.get("aux_ref_audio_paths"):
383
  aux_paths = []
384
  for aux_path in req["aux_ref_audio_paths"]:
385
- aux_paths.append(process_audio_path(aux_path))
 
 
 
386
  req["aux_ref_audio_paths"] = aux_paths
387
 
388
- tts_generator=tts_pipeline.run(req)
389
 
390
  if streaming_mode:
391
- def streaming_generator(tts_generator:Generator, media_type:str):
392
  if_frist_chunk = True
393
- for sr, chunk in tts_generator:
394
- if if_frist_chunk and media_type == "wav":
395
- yield wave_header_chunk(sample_rate=sr)
396
- media_type = "raw"
397
- if_frist_chunk = False
398
- yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
399
- # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
400
- return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
401
-
 
 
 
 
 
 
 
 
 
402
  else:
403
  sr, audio_data = next(tts_generator)
404
  audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
 
 
 
 
 
 
 
 
 
 
405
  return Response(audio_data, media_type=f"audio/{media_type}")
406
  except Exception as e:
407
- return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
408
-
 
 
 
 
 
 
 
 
409
 
410
 
411
 
@@ -479,13 +495,32 @@ async def tts_post_endpoint(request: TTS_Request):
479
 
480
  @APP.get("/set_refer_audio")
481
  async def set_refer_aduio(refer_audio_path: str = None):
 
482
  try:
483
  # Process the path (download if it's a URL)
484
- local_path = process_audio_path(refer_audio_path)
 
 
 
 
485
  tts_pipeline.set_ref_audio(local_path)
 
 
 
 
 
 
 
486
  except Exception as e:
 
 
 
 
 
 
 
 
487
  return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
488
- return JSONResponse(status_code=200, content={"message": "success"})
489
 
490
 
491
  # @APP.post("/set_refer_audio")
 
129
  import os
130
  import sys
131
  import traceback
132
+ from typing import Generator, Tuple
133
  import requests
134
  import tempfile
135
  import urllib.parse
136
  from pathlib import Path
137
 
138
  # Function to check if a path is a URL and download it if needed
139
+ def process_audio_path(audio_path) -> Tuple[str, bool]:
140
+ """
141
+ Process an audio path, downloading it if it's a URL.
142
+
143
+ Args:
144
+ audio_path (str): Path or URL to audio file
145
+
146
+ Returns:
147
+ Tuple[str, bool]: (local_path, is_temporary)
148
+ """
149
  if audio_path and (audio_path.startswith('http://') or audio_path.startswith('https://') or
150
  audio_path.startswith('s3://')):
151
  try:
 
184
  f.write(chunk)
185
 
186
  print(f"Downloaded to: {local_path}")
187
+ return local_path, True # Return path and flag indicating it's temporary
188
  except Exception as e:
189
  print(f"Error downloading audio file: {e}")
190
  raise Exception(f"Failed to download audio from URL: {e}")
191
 
192
  # If not a URL or download failed, return the original path
193
+ return audio_path, False # Not a temporary file
194
 
195
  parser = argparse.ArgumentParser(description="GPT-SoVITS api")
196
  parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
 
341
  async def tts_handle(req:dict):
342
  """
343
  Text to speech handler.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  """
345
 
346
  streaming_mode = req.get("streaming_mode", False)
347
  return_fragment = req.get("return_fragment", False)
348
  media_type = req.get("media_type", "wav")
349
+ temp_files = [] # Track temporary files for cleanup
350
+
351
+ print(f"----------现在使用的模型版本是:{tts_config.version}----------")
352
 
353
  check_res = check_params(req)
354
  if check_res is not None:
 
359
 
360
  try:
361
  # Process ref_audio_path (download if it's a URL)
362
+ ref_path, is_temp = process_audio_path(req["ref_audio_path"])
363
+ req["ref_audio_path"] = ref_path
364
+ if is_temp:
365
+ temp_files.append(ref_path)
366
 
367
  # Process aux_ref_audio_paths (download if they're URLs)
368
  if req.get("aux_ref_audio_paths"):
369
  aux_paths = []
370
  for aux_path in req["aux_ref_audio_paths"]:
371
+ local_path, is_temp = process_audio_path(aux_path)
372
+ aux_paths.append(local_path)
373
+ if is_temp:
374
+ temp_files.append(local_path)
375
  req["aux_ref_audio_paths"] = aux_paths
376
 
377
+ tts_generator = tts_pipeline.run(req)
378
 
379
  if streaming_mode:
380
+ async def streaming_generator(tts_generator:Generator, media_type:str):
381
  if_frist_chunk = True
382
+ try:
383
+ for sr, chunk in tts_generator:
384
+ if if_frist_chunk and media_type == "wav":
385
+ yield wave_header_chunk(sample_rate=sr)
386
+ media_type = "raw"
387
+ if_frist_chunk = False
388
+ yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
389
+ finally:
390
+ # Clean up temporary files after streaming completes
391
+ for temp_file in temp_files:
392
+ try:
393
+ if os.path.exists(temp_file):
394
+ os.remove(temp_file)
395
+ print(f"Removed temporary file: {temp_file}")
396
+ except Exception as e:
397
+ print(f"Error removing temporary file {temp_file}: {e}")
398
+
399
+ return StreamingResponse(streaming_generator(tts_generator, media_type), media_type=f"audio/{media_type}")
400
  else:
401
  sr, audio_data = next(tts_generator)
402
  audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
403
+
404
+ # Clean up temporary files after generation completes
405
+ for temp_file in temp_files:
406
+ try:
407
+ if os.path.exists(temp_file):
408
+ os.remove(temp_file)
409
+ print(f"Removed temporary file: {temp_file}")
410
+ except Exception as e:
411
+ print(f"Error removing temporary file {temp_file}: {e}")
412
+
413
  return Response(audio_data, media_type=f"audio/{media_type}")
414
  except Exception as e:
415
+ # Clean up temporary files in case of error
416
+ for temp_file in temp_files:
417
+ try:
418
+ if os.path.exists(temp_file):
419
+ os.remove(temp_file)
420
+ print(f"Removed temporary file: {temp_file}")
421
+ except Exception as cleanup_error:
422
+ print(f"Error removing temporary file {temp_file}: {cleanup_error}")
423
+
424
+ return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
425
 
426
 
427
 
 
495
 
496
  @APP.get("/set_refer_audio")
497
  async def set_refer_aduio(refer_audio_path: str = None):
498
+ temp_file = None
499
  try:
500
  # Process the path (download if it's a URL)
501
+ local_path, is_temp = process_audio_path(refer_audio_path)
502
+ if is_temp:
503
+ temp_file = local_path
504
+
505
+ # Store reference to the audio
506
  tts_pipeline.set_ref_audio(local_path)
507
+
508
+ # If temporary, remove after setting (since TTS pipeline should load the audio into memory)
509
+ if temp_file and os.path.exists(temp_file):
510
+ os.remove(temp_file)
511
+ print(f"Removed temporary file: {temp_file}")
512
+
513
+ return JSONResponse(status_code=200, content={"message": "success"})
514
  except Exception as e:
515
+ # Clean up temp file in case of error
516
+ if temp_file and os.path.exists(temp_file):
517
+ try:
518
+ os.remove(temp_file)
519
+ print(f"Removed temporary file: {temp_file}")
520
+ except Exception as cleanup_error:
521
+ print(f"Error removing temporary file {temp_file}: {cleanup_error}")
522
+
523
  return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
 
524
 
525
 
526
  # @APP.post("/set_refer_audio")