tiezhen HF Staff commited on
Commit
9a39ad9
·
verified ·
1 Parent(s): 4db13e7

use simple request/reply mode, instead of having multiple workers

Browse files
Files changed (1) hide show
  1. tts/gradio_api.py +32 -34
tts/gradio_api.py CHANGED
@@ -33,38 +33,44 @@ def forward_gpu(file_content, wav_path, latent_file, inp_text, time_step, p_w, t
33
  return wav_bytes
34
 
35
  def model_worker(input_queue, output_queue, device_id):
36
- while True:
37
- task = input_queue.get()
38
- inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
39
 
40
- if inp_npy_path is None or inp_audio_path is None:
41
- output_queue.put(None)
42
- raise gr.Error("Please provide .wav and .npy file")
43
- if (inp_audio_path.split('/')[-1][:-4] != inp_npy_path.split('/')[-1][:-4]):
44
- output_queue.put(None)
45
- raise gr.Error(".npy and .wav mismatch")
46
- if len(inp_text) > 200:
47
- output_queue.put(None)
48
- raise gr.Error("input text is too long")
49
-
50
- try:
51
- convert_to_wav(inp_audio_path)
52
- wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
53
- cut_wav(wav_path, max_len=24)
54
- with open(wav_path, 'rb') as file:
55
- file_content = file.read()
56
- wav_bytes = forward_gpu(file_content, wav_path, inp_npy_path, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
57
- output_queue.put(wav_bytes)
58
- except Exception as e:
59
- traceback.print_exc()
60
- print(task, str(e))
61
- output_queue.put(None)
62
- raise gr.Error("Generation failed")
 
 
 
63
 
64
 
65
  def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
 
66
  print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
67
  input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
 
 
 
 
 
68
  res = output_queue.get()
69
  if res is not None:
70
  return res
@@ -78,16 +84,8 @@ if __name__ == '__main__':
78
 
79
  num_workers = 1
80
  devices = [0]
81
- input_queue = mp_manager.Queue()
82
- output_queue = mp_manager.Queue()
83
  processes = []
84
 
85
- print("Start open workers")
86
- for i in range(num_workers):
87
- p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
88
- p.start()
89
- processes.append(p)
90
-
91
  api_interface = gr.Interface(fn=
92
  partial(main, processes=processes, input_queue=input_queue,
93
  output_queue=output_queue),
 
33
  return wav_bytes
34
 
35
  def model_worker(input_queue, output_queue, device_id):
 
 
 
36
 
37
+ task = input_queue.get()
38
+ inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
39
+
40
+ if inp_npy_path is None or inp_audio_path is None:
41
+ output_queue.put(None)
42
+ raise gr.Error("Please provide .wav and .npy file")
43
+ if (inp_audio_path.split('/')[-1][:-4] != inp_npy_path.split('/')[-1][:-4]):
44
+ output_queue.put(None)
45
+ raise gr.Error(".npy and .wav mismatch")
46
+ if len(inp_text) > 200:
47
+ output_queue.put(None)
48
+ raise gr.Error("input text is too long")
49
+
50
+ try:
51
+ convert_to_wav(inp_audio_path)
52
+ wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
53
+ cut_wav(wav_path, max_len=24)
54
+ with open(wav_path, 'rb') as file:
55
+ file_content = file.read()
56
+ wav_bytes = forward_gpu(file_content, wav_path, inp_npy_path, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
57
+ output_queue.put(wav_bytes)
58
+ except Exception as e:
59
+ traceback.print_exc()
60
+ print(task, str(e))
61
+ output_queue.put(None)
62
+ raise gr.Error("Generation failed")
63
 
64
 
65
  def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
66
+ input_queue = mp_manager.Queue()
67
  print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
68
  input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
69
+
70
+ output_queue = mp_manager.Queue()
71
+
72
+ model_worker(input_queue, output_queue, 0)
73
+
74
  res = output_queue.get()
75
  if res is not None:
76
  return res
 
84
 
85
  num_workers = 1
86
  devices = [0]
 
 
87
  processes = []
88
 
 
 
 
 
 
 
89
  api_interface = gr.Interface(fn=
90
  partial(main, processes=processes, input_queue=input_queue,
91
  output_queue=output_queue),