Spaces:
Runtime error
Runtime error
XXXXRT666
commited on
Commit
Β·
7bdf3c3
1
Parent(s):
d2e713a
Fix
Browse files- AR/models/structs.py +1 -1
- AR/models/t2s_model_flash_attn.py +21 -15
- inference_webui.py +1 -0
- pre-requirements.txt +2 -1
- requirements.txt +1 -2
AR/models/structs.py
CHANGED
@@ -68,7 +68,7 @@ class T2SSession:
|
|
68 |
self.xy_dec_ = torch.rand((bsz, 1, decoder.embedding_dim)).to(dtype)
|
69 |
|
70 |
# EOS
|
71 |
-
self.completed = [False] * len(self.x)
|
72 |
self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
|
73 |
|
74 |
self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature)
|
|
|
68 |
self.xy_dec_ = torch.rand((bsz, 1, decoder.embedding_dim)).to(dtype)
|
69 |
|
70 |
# EOS
|
71 |
+
self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
|
72 |
self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
|
73 |
|
74 |
self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature)
|
AR/models/t2s_model_flash_attn.py
CHANGED
@@ -245,7 +245,6 @@ class CUDAGraphRunner:
|
|
245 |
**kwds,
|
246 |
)
|
247 |
|
248 |
-
torch_profiler.start()
|
249 |
with torch_profiler.record("AR"):
|
250 |
if session.graph:
|
251 |
session.xy_pos_.copy_(session.xy_pos)
|
@@ -275,22 +274,28 @@ class CUDAGraphRunner:
|
|
275 |
top_p=request.top_p,
|
276 |
repetition_penalty=request.repetition_penalty,
|
277 |
temperature=request.temperature,
|
278 |
-
use_cuda_graph=
|
279 |
idx=idx,
|
280 |
)
|
281 |
|
282 |
session.y = torch.cat([session.y, samples], dim=1)
|
283 |
|
284 |
with torch_profiler.record("EOS"):
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
if session.y.size(1) == 0:
|
295 |
session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
|
296 |
tqdm.write("Bad Zero Prediction")
|
@@ -306,7 +311,7 @@ class CUDAGraphRunner:
|
|
306 |
and (session.y.size(1) - session.y_len) > request.early_stop_num
|
307 |
):
|
308 |
for i in range(bsz):
|
309 |
-
if not session.completed[i]:
|
310 |
session.y_results[i] = session.y[i, session.y_len :]
|
311 |
session.completed[i] = True
|
312 |
break
|
@@ -316,10 +321,11 @@ class CUDAGraphRunner:
|
|
316 |
session.xy_pos = decoder.ar_audio_position.forward(session.input_pos - session.x_lens, y_emb)
|
317 |
|
318 |
if idx == 2:
|
|
|
319 |
t1 = time.perf_counter()
|
320 |
|
321 |
-
if idx == 51:
|
322 |
-
|
323 |
|
324 |
match session.device.type:
|
325 |
case "cuda":
|
@@ -331,7 +337,7 @@ class CUDAGraphRunner:
|
|
331 |
case "mtia":
|
332 |
torch.mtia.empty_cache()
|
333 |
gc.collect()
|
334 |
-
|
335 |
return session.y_results[: request.valid_length]
|
336 |
|
337 |
def generate(self, request: T2SRequest):
|
|
|
245 |
**kwds,
|
246 |
)
|
247 |
|
|
|
248 |
with torch_profiler.record("AR"):
|
249 |
if session.graph:
|
250 |
session.xy_pos_.copy_(session.xy_pos)
|
|
|
274 |
top_p=request.top_p,
|
275 |
repetition_penalty=request.repetition_penalty,
|
276 |
temperature=request.temperature,
|
277 |
+
use_cuda_graph=request.use_cuda_graph,
|
278 |
idx=idx,
|
279 |
)
|
280 |
|
281 |
session.y = torch.cat([session.y, samples], dim=1)
|
282 |
|
283 |
with torch_profiler.record("EOS"):
|
284 |
+
argmax_token = torch.argmax(logits, dim=-1)
|
285 |
+
sample_token = samples.squeeze(1)
|
286 |
+
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
|
287 |
+
with torch_profiler.record("EOS1"):
|
288 |
+
newly_done_mask = EOS_mask & (~session.completed)
|
289 |
+
with torch_profiler.record("EOS2"):
|
290 |
+
newly_done_indices = newly_done_mask.nonzero()
|
291 |
+
with torch_profiler.record("EOS3"):
|
292 |
+
if newly_done_indices.numel() > 0:
|
293 |
+
session.y_results[newly_done_indices[0]] = session.y[
|
294 |
+
newly_done_indices[0], session.y_len : -1
|
295 |
+
].squeeze(0)
|
296 |
+
session.completed[newly_done_indices] = True
|
297 |
+
with torch_profiler.record("EOS4"):
|
298 |
+
if torch.all(session.completed).item():
|
299 |
if session.y.size(1) == 0:
|
300 |
session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
|
301 |
tqdm.write("Bad Zero Prediction")
|
|
|
311 |
and (session.y.size(1) - session.y_len) > request.early_stop_num
|
312 |
):
|
313 |
for i in range(bsz):
|
314 |
+
if not session.completed[i].item():
|
315 |
session.y_results[i] = session.y[i, session.y_len :]
|
316 |
session.completed[i] = True
|
317 |
break
|
|
|
321 |
session.xy_pos = decoder.ar_audio_position.forward(session.input_pos - session.x_lens, y_emb)
|
322 |
|
323 |
if idx == 2:
|
324 |
+
torch_profiler.start()
|
325 |
t1 = time.perf_counter()
|
326 |
|
327 |
+
# if idx == 51:
|
328 |
+
# torch_profiler.end()
|
329 |
|
330 |
match session.device.type:
|
331 |
case "cuda":
|
|
|
337 |
case "mtia":
|
338 |
torch.mtia.empty_cache()
|
339 |
gc.collect()
|
340 |
+
torch_profiler.end()
|
341 |
return session.y_results[: request.valid_length]
|
342 |
|
343 |
def generate(self, request: T2SRequest):
|
inference_webui.py
CHANGED
@@ -836,4 +836,5 @@ if __name__ == "__main__":
|
|
836 |
server_name="0.0.0.0",
|
837 |
inbrowser=True,
|
838 |
show_api=False,
|
|
|
839 |
)
|
|
|
836 |
server_name="0.0.0.0",
|
837 |
inbrowser=True,
|
838 |
show_api=False,
|
839 |
+
server_port=1111,
|
840 |
)
|
pre-requirements.txt
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
torch==2.5.1
|
|
|
|
1 |
+
torch==2.5.1
|
2 |
+
torchaudio
|
requirements.txt
CHANGED
@@ -3,7 +3,6 @@ scipy>=1.11.3
|
|
3 |
tensorboard==2.15.1
|
4 |
librosa==0.9.2
|
5 |
numba==0.56.4
|
6 |
-
torchaudio
|
7 |
pytorch-lightning>=2.4
|
8 |
gradio==4.44.1
|
9 |
gradio_client==1.3.0
|
@@ -36,4 +35,4 @@ nltk==3.8.1
|
|
36 |
fast_langdetect==0.3.1
|
37 |
split_lang==2.1.0
|
38 |
ToJyutping==3.2.0
|
39 |
-
https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.
|
|
|
3 |
tensorboard==2.15.1
|
4 |
librosa==0.9.2
|
5 |
numba==0.56.4
|
|
|
6 |
pytorch-lightning>=2.4
|
7 |
gradio==4.44.1
|
8 |
gradio_client==1.3.0
|
|
|
35 |
fast_langdetect==0.3.1
|
36 |
split_lang==2.1.0
|
37 |
ToJyutping==3.2.0
|
38 |
+
https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|