kevinwang676 commited on
Commit
a9f4026
·
verified ·
1 Parent(s): c16b111

Update api_v2.py

Browse files
Files changed (1) hide show
  1. api_v2.py +72 -3
api_v2.py CHANGED
@@ -126,6 +126,63 @@ from pydantic import BaseModel
126
  i18n = I18nAuto()
127
  cut_method_names = get_cut_method_names()
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  parser = argparse.ArgumentParser(description="GPT-SoVITS api")
130
  parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
131
  parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
@@ -281,8 +338,8 @@ async def tts_handle(req:dict):
281
  {
282
  "text": "", # str.(required) text to be synthesized
283
  "text_lang: "", # str.(required) language of the text to be synthesized
284
- "ref_audio_path": "", # str.(required) reference audio path
285
- "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
286
  "prompt_text": "", # str.(optional) prompt text for the reference audio
287
  "prompt_lang": "", # str.(required) language of the prompt text for the reference audio
288
  "top_k": 5, # int. top k sampling
@@ -318,6 +375,16 @@ async def tts_handle(req:dict):
318
  req["return_fragment"] = True
319
 
320
  try:
 
 
 
 
 
 
 
 
 
 
321
  tts_generator=tts_pipeline.run(req)
322
 
323
  if streaming_mode:
@@ -413,7 +480,9 @@ async def tts_post_endpoint(request: TTS_Request):
413
  @APP.get("/set_refer_audio")
414
  async def set_refer_aduio(refer_audio_path: str = None):
415
  try:
416
- tts_pipeline.set_ref_audio(refer_audio_path)
 
 
417
  except Exception as e:
418
  return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
419
  return JSONResponse(status_code=200, content={"message": "success"})
 
126
  i18n = I18nAuto()
127
  cut_method_names = get_cut_method_names()
128
 
129
+ import os
130
+ import sys
131
+ import traceback
132
+ from typing import Generator
133
+ import requests
134
+ import tempfile
135
+ import urllib.parse
136
+ from pathlib import Path
137
+
138
+ # Function to check if a path is a URL and download it if needed
139
+ def process_audio_path(audio_path):
140
+ if audio_path and (audio_path.startswith('http://') or audio_path.startswith('https://') or
141
+ audio_path.startswith('s3://')):
142
+ try:
143
+ # Create temp directory if it doesn't exist
144
+ temp_dir = os.path.join(now_dir, "temp_audio")
145
+ os.makedirs(temp_dir, exist_ok=True)
146
+
147
+ # Generate a filename from the URL
148
+ parsed_url = urllib.parse.urlparse(audio_path)
149
+ filename = os.path.basename(parsed_url.path)
150
+ if not filename:
151
+ filename = f"temp_audio_{hash(audio_path)}.wav"
152
+
153
+ # Full path for downloaded file
154
+ local_path = os.path.join(temp_dir, filename)
155
+
156
+ # Download file
157
+ if audio_path.startswith('s3://'):
158
+ # For S3 URLs, you would use boto3 here
159
+ # This is a placeholder - you'll need to add boto3 import and proper S3 handling
160
+ print(f"Downloading from S3: {audio_path}")
161
+ # Example boto3 code (commented out as boto3 import not in original code)
162
+ # import boto3
163
+ # s3_client = boto3.client('s3')
164
+ # bucket = parsed_url.netloc
165
+ # key = parsed_url.path.lstrip('/')
166
+ # s3_client.download_file(bucket, key, local_path)
167
+ raise NotImplementedError("S3 download not implemented. Add boto3 library and implementation.")
168
+ else:
169
+ # HTTP/HTTPS download
170
+ print(f"Downloading from URL: {audio_path}")
171
+ response = requests.get(audio_path, stream=True)
172
+ response.raise_for_status()
173
+ with open(local_path, 'wb') as f:
174
+ for chunk in response.iter_content(chunk_size=8192):
175
+ f.write(chunk)
176
+
177
+ print(f"Downloaded to: {local_path}")
178
+ return local_path
179
+ except Exception as e:
180
+ print(f"Error downloading audio file: {e}")
181
+ raise Exception(f"Failed to download audio from URL: {e}")
182
+
183
+ # If not a URL or download failed, return the original path
184
+ return audio_path
185
+
186
  parser = argparse.ArgumentParser(description="GPT-SoVITS api")
187
  parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
188
  parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
 
338
  {
339
  "text": "", # str.(required) text to be synthesized
340
  "text_lang: "", # str.(required) language of the text to be synthesized
341
+ "ref_audio_path": "", # str.(required) reference audio path or URL
342
+ "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths or URLs
343
  "prompt_text": "", # str.(optional) prompt text for the reference audio
344
  "prompt_lang": "", # str.(required) language of the prompt text for the reference audio
345
  "top_k": 5, # int. top k sampling
 
375
  req["return_fragment"] = True
376
 
377
  try:
378
+ # Process ref_audio_path (download if it's a URL)
379
+ req["ref_audio_path"] = process_audio_path(req["ref_audio_path"])
380
+
381
+ # Process aux_ref_audio_paths (download if they're URLs)
382
+ if req.get("aux_ref_audio_paths"):
383
+ aux_paths = []
384
+ for aux_path in req["aux_ref_audio_paths"]:
385
+ aux_paths.append(process_audio_path(aux_path))
386
+ req["aux_ref_audio_paths"] = aux_paths
387
+
388
  tts_generator=tts_pipeline.run(req)
389
 
390
  if streaming_mode:
 
480
  @APP.get("/set_refer_audio")
481
  async def set_refer_aduio(refer_audio_path: str = None):
482
  try:
483
+ # Process the path (download if it's a URL)
484
+ local_path = process_audio_path(refer_audio_path)
485
+ tts_pipeline.set_ref_audio(local_path)
486
  except Exception as e:
487
  return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
488
  return JSONResponse(status_code=200, content={"message": "success"})