XXXXRT666 commited on
Commit
7bdf3c3
Β·
1 Parent(s): d2e713a
AR/models/structs.py CHANGED
@@ -68,7 +68,7 @@ class T2SSession:
68
  self.xy_dec_ = torch.rand((bsz, 1, decoder.embedding_dim)).to(dtype)
69
 
70
  # EOS
71
- self.completed = [False] * len(self.x)
72
  self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
73
 
74
  self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature)
 
68
  self.xy_dec_ = torch.rand((bsz, 1, decoder.embedding_dim)).to(dtype)
69
 
70
  # EOS
71
+ self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
72
  self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
73
 
74
  self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature)
AR/models/t2s_model_flash_attn.py CHANGED
@@ -245,7 +245,6 @@ class CUDAGraphRunner:
245
  **kwds,
246
  )
247
 
248
- torch_profiler.start()
249
  with torch_profiler.record("AR"):
250
  if session.graph:
251
  session.xy_pos_.copy_(session.xy_pos)
@@ -275,22 +274,28 @@ class CUDAGraphRunner:
275
  top_p=request.top_p,
276
  repetition_penalty=request.repetition_penalty,
277
  temperature=request.temperature,
278
- use_cuda_graph=False,
279
  idx=idx,
280
  )
281
 
282
  session.y = torch.cat([session.y, samples], dim=1)
283
 
284
  with torch_profiler.record("EOS"):
285
- EOS_mask = (samples[:, 0] == decoder.EOS) | (torch.argmax(logits, dim=-1) == decoder.EOS)
286
- EOS_indices: List[int] = torch.where(EOS_mask)[0].tolist()
287
-
288
- for i in EOS_indices:
289
- if not session.completed[i]:
290
- session.y_results[i] = session.y[i, session.y_len : -1]
291
- session.completed[i] = True
292
-
293
- if all(session.completed):
 
 
 
 
 
 
294
  if session.y.size(1) == 0:
295
  session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
296
  tqdm.write("Bad Zero Prediction")
@@ -306,7 +311,7 @@ class CUDAGraphRunner:
306
  and (session.y.size(1) - session.y_len) > request.early_stop_num
307
  ):
308
  for i in range(bsz):
309
- if not session.completed[i]:
310
  session.y_results[i] = session.y[i, session.y_len :]
311
  session.completed[i] = True
312
  break
@@ -316,10 +321,11 @@ class CUDAGraphRunner:
316
  session.xy_pos = decoder.ar_audio_position.forward(session.input_pos - session.x_lens, y_emb)
317
 
318
  if idx == 2:
 
319
  t1 = time.perf_counter()
320
 
321
- if idx == 51:
322
- torch_profiler.end()
323
 
324
  match session.device.type:
325
  case "cuda":
@@ -331,7 +337,7 @@ class CUDAGraphRunner:
331
  case "mtia":
332
  torch.mtia.empty_cache()
333
  gc.collect()
334
-
335
  return session.y_results[: request.valid_length]
336
 
337
  def generate(self, request: T2SRequest):
 
245
  **kwds,
246
  )
247
 
 
248
  with torch_profiler.record("AR"):
249
  if session.graph:
250
  session.xy_pos_.copy_(session.xy_pos)
 
274
  top_p=request.top_p,
275
  repetition_penalty=request.repetition_penalty,
276
  temperature=request.temperature,
277
+ use_cuda_graph=request.use_cuda_graph,
278
  idx=idx,
279
  )
280
 
281
  session.y = torch.cat([session.y, samples], dim=1)
282
 
283
  with torch_profiler.record("EOS"):
284
+ argmax_token = torch.argmax(logits, dim=-1)
285
+ sample_token = samples.squeeze(1)
286
+ EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
287
+ with torch_profiler.record("EOS1"):
288
+ newly_done_mask = EOS_mask & (~session.completed)
289
+ with torch_profiler.record("EOS2"):
290
+ newly_done_indices = newly_done_mask.nonzero()
291
+ with torch_profiler.record("EOS3"):
292
+ if newly_done_indices.numel() > 0:
293
+ session.y_results[newly_done_indices[0]] = session.y[
294
+ newly_done_indices[0], session.y_len : -1
295
+ ].squeeze(0)
296
+ session.completed[newly_done_indices] = True
297
+ with torch_profiler.record("EOS4"):
298
+ if torch.all(session.completed).item():
299
  if session.y.size(1) == 0:
300
  session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
301
  tqdm.write("Bad Zero Prediction")
 
311
  and (session.y.size(1) - session.y_len) > request.early_stop_num
312
  ):
313
  for i in range(bsz):
314
+ if not session.completed[i].item():
315
  session.y_results[i] = session.y[i, session.y_len :]
316
  session.completed[i] = True
317
  break
 
321
  session.xy_pos = decoder.ar_audio_position.forward(session.input_pos - session.x_lens, y_emb)
322
 
323
  if idx == 2:
324
+ torch_profiler.start()
325
  t1 = time.perf_counter()
326
 
327
+ # if idx == 51:
328
+ # torch_profiler.end()
329
 
330
  match session.device.type:
331
  case "cuda":
 
337
  case "mtia":
338
  torch.mtia.empty_cache()
339
  gc.collect()
340
+ torch_profiler.end()
341
  return session.y_results[: request.valid_length]
342
 
343
  def generate(self, request: T2SRequest):
inference_webui.py CHANGED
@@ -836,4 +836,5 @@ if __name__ == "__main__":
836
  server_name="0.0.0.0",
837
  inbrowser=True,
838
  show_api=False,
 
839
  )
 
836
  server_name="0.0.0.0",
837
  inbrowser=True,
838
  show_api=False,
839
+ server_port=1111,
840
  )
pre-requirements.txt CHANGED
@@ -1 +1,2 @@
1
- torch==2.5.1
 
 
1
+ torch==2.5.1
2
+ torchaudio
requirements.txt CHANGED
@@ -3,7 +3,6 @@ scipy>=1.11.3
3
  tensorboard==2.15.1
4
  librosa==0.9.2
5
  numba==0.56.4
6
- torchaudio
7
  pytorch-lightning>=2.4
8
  gradio==4.44.1
9
  gradio_client==1.3.0
@@ -36,4 +35,4 @@ nltk==3.8.1
36
  fast_langdetect==0.3.1
37
  split_lang==2.1.0
38
  ToJyutping==3.2.0
39
- https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
 
3
  tensorboard==2.15.1
4
  librosa==0.9.2
5
  numba==0.56.4
 
6
  pytorch-lightning>=2.4
7
  gradio==4.44.1
8
  gradio_client==1.3.0
 
35
  fast_langdetect==0.3.1
36
  split_lang==2.1.0
37
  ToJyutping==3.2.0
38
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl