XXXXRT666 commited on
Commit
4ae2215
·
1 Parent(s): 301f27c
AR/models/embedding.py CHANGED
@@ -33,51 +33,6 @@ class TokenEmbedding(nn.Module):
33
  return x
34
 
35
 
36
- class SinePositionalEmbedding(nn.Module):
37
- def __init__(
38
- self,
39
- embedding_dim: int,
40
- dropout: float = 0.0,
41
- scale: bool = False,
42
- alpha: bool = False,
43
- ):
44
- super().__init__()
45
- self.embedding_dim = embedding_dim
46
- self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
47
- self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
48
- self.dropout = torch.nn.Dropout(p=dropout)
49
-
50
- self.reverse = False
51
- self.pe = None
52
- self.extend_pe(torch.tensor(0.0).expand(1, 4000))
53
-
54
- def extend_pe(self, x):
55
- """Reset the positional encodings."""
56
- if self.pe is not None:
57
- if self.pe.size(1) >= x.size(1):
58
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
59
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
60
- return
61
- pe = torch.zeros(x.size(1), self.embedding_dim)
62
- if self.reverse:
63
- position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
64
- else:
65
- position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
66
- div_term = torch.exp(
67
- torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
68
- )
69
- pe[:, 0::2] = torch.sin(position * div_term)
70
- pe[:, 1::2] = torch.cos(position * div_term)
71
- pe = pe.unsqueeze(0)
72
- self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
73
-
74
- def forward(self, x: torch.Tensor) -> torch.Tensor:
75
- self.extend_pe(x)
76
- output = x.unsqueeze(-1) if x.ndim == 2 else x
77
- output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
78
- return self.dropout(output)
79
-
80
-
81
  class SinePositionalEmbeddingNested(nn.Module):
82
  def __init__(
83
  self,
 
33
  return x
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  class SinePositionalEmbeddingNested(nn.Module):
37
  def __init__(
38
  self,
AR/models/structs.py CHANGED
@@ -5,11 +5,11 @@ Modified From https://github.com/XXXXRT666/GPT-SoVITS
5
  from __future__ import annotations
6
 
7
  from dataclasses import dataclass
8
- from typing import List, Literal, Optional
9
 
10
  import torch
11
 
12
- from AR.models.t2s_model_abc import Sampler, T2SDecoderABC
13
 
14
  Tensor = torch.Tensor
15
 
@@ -53,6 +53,7 @@ class T2SSession:
53
  self.y_len = y_len
54
 
55
  # Cache
 
56
  self.sampler = Sampler(bsz, decoder.vocab_size)
57
 
58
  # Forward args
@@ -66,6 +67,11 @@ class T2SSession:
66
  self.input_pos = torch.zeros_like(self.prefill_len)
67
  self.input_pos.add_(self.prefill_len)
68
 
 
 
 
 
 
69
  # EOS
70
  self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
71
  self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
@@ -81,3 +87,5 @@ class T2SSession:
81
  mask[-y_len:, -y_len:] = ~torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1)
82
  attn_mask.append(mask)
83
  self.attn_mask_nested = torch.nested.nested_tensor(attn_mask)
 
 
 
5
  from __future__ import annotations
6
 
7
  from dataclasses import dataclass
8
+ from typing import List, Literal, MutableSequence, Optional
9
 
10
  import torch
11
 
12
+ from AR.models.t2s_model_abc import KVCacheABC, Sampler, T2SDecoderABC
13
 
14
  Tensor = torch.Tensor
15
 
 
53
  self.y_len = y_len
54
 
55
  # Cache
56
+ self.kv_cache: MutableSequence[KVCacheABC]
57
  self.sampler = Sampler(bsz, decoder.vocab_size)
58
 
59
  # Forward args
 
67
  self.input_pos = torch.zeros_like(self.prefill_len)
68
  self.input_pos.add_(self.prefill_len)
69
 
70
+ # CUDA Graph
71
+ self.graph: Optional[torch.cuda.CUDAGraph] = None
72
+ self.xy_pos_: Tensor
73
+ self.xy_dec_: Tensor
74
+
75
  # EOS
76
  self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
77
  self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
 
87
  mask[-y_len:, -y_len:] = ~torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1)
88
  attn_mask.append(mask)
89
  self.attn_mask_nested = torch.nested.nested_tensor(attn_mask)
90
+
91
+ self.id: int = -1
AR/models/t2s_model_abc.py CHANGED
@@ -5,10 +5,10 @@ Modified From https://github.com/XXXXRT666/GPT-SoVITS
5
  from __future__ import annotations
6
 
7
  import os
8
- import time
9
  from abc import ABC, abstractmethod
10
  from contextlib import nullcontext
11
- from typing import Any, Dict, List, MutableSequence, Optional, Tuple, Type
12
 
13
  import torch
14
  import torch._inductor.config
@@ -85,6 +85,10 @@ class KVCacheABC(ABC, nn.Module):
85
  @abstractmethod
86
  def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int) -> None: ...
87
 
 
 
 
 
88
  def forward(self):
89
  raise NotImplementedError()
90
 
@@ -363,6 +367,8 @@ class T2SDecoderABC(ABC, nn.Module):
363
 
364
  self.kv_class: Type[KVCacheNHD] | Type[KVCacheHND]
365
 
 
 
366
  self._register_load_state_dict_pre_hook(self.load_hook)
367
 
368
  def load_hook(self, state_dict, prefix, *args):
@@ -396,6 +402,7 @@ class T2SDecoderABC(ABC, nn.Module):
396
  self.h.compile(fullgraph=True, mode="reduce-overhead")
397
 
398
  def capture(self, input_pos: Tensor, x: Tensor, x_dec: Tensor, *args, **kwds) -> CUDAGraph:
 
399
  s = torch.cuda.Stream()
400
  s.wait_stream(torch.cuda.current_stream())
401
 
@@ -419,6 +426,51 @@ class T2SDecoderABC(ABC, nn.Module):
419
  def post_forward(self, idx: int, session: Any) -> None: ...
420
 
421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  class TorchProfiler:
423
  def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
424
  self.debug = debug
 
5
  from __future__ import annotations
6
 
7
  import os
8
+ import random
9
  from abc import ABC, abstractmethod
10
  from contextlib import nullcontext
11
+ from typing import Any, Dict, List, MutableSequence, Tuple, Type
12
 
13
  import torch
14
  import torch._inductor.config
 
85
  @abstractmethod
86
  def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int) -> None: ...
87
 
88
+ def sync_cache(self, kv_cache: KVCacheABC):
89
+ self.k_cache.copy_(kv_cache.k_cache)
90
+ self.v_cache.copy_(kv_cache.v_cache)
91
+
92
  def forward(self):
93
  raise NotImplementedError()
94
 
 
367
 
368
  self.kv_class: Type[KVCacheNHD] | Type[KVCacheHND]
369
 
370
+ self.GraphCache: CUDAGraphCacheABC | None
371
+
372
  self._register_load_state_dict_pre_hook(self.load_hook)
373
 
374
  def load_hook(self, state_dict, prefix, *args):
 
402
  self.h.compile(fullgraph=True, mode="reduce-overhead")
403
 
404
  def capture(self, input_pos: Tensor, x: Tensor, x_dec: Tensor, *args, **kwds) -> CUDAGraph:
405
+ assert torch.cuda.is_available()
406
  s = torch.cuda.Stream()
407
  s.wait_stream(torch.cuda.current_stream())
408
 
 
426
  def post_forward(self, idx: int, session: Any) -> None: ...
427
 
428
 
429
+ class CUDAGraphCacheABC(ABC):
430
+ def __init__(
431
+ self,
432
+ decoder: T2SDecoderABC,
433
+ device: torch.device = torch.device("cpu"),
434
+ dtype: torch.dtype = torch.float32,
435
+ ) -> None:
436
+ assert torch.cuda.is_available()
437
+
438
+ self.assigned: bool = False
439
+
440
+ self.decoder: T2SDecoderABC = decoder
441
+ self.kv_cache: MutableSequence[KVCacheABC] = decoder.init_cache(1)
442
+ self.xy_pos = torch.rand((1, 1, decoder.embedding_dim), device=device).to(dtype)
443
+ self.xy_dec = torch.rand((1, 1, decoder.embedding_dim), device=device).to(dtype)
444
+ self.input_pos = torch.tensor([10]).int().cuda()
445
+ self.graph: torch.cuda.CUDAGraph | None = None
446
+
447
+ self.id: int = random.randint(1, 2**32 - 1)
448
+
449
+ def assign_graph(self, session: Any):
450
+ if self.graph is None:
451
+ args, kwds = self.decoder.pre_forward(session)
452
+ graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, *args, **kwds)
453
+ self.graph = graph
454
+
455
+ if self.assigned is False:
456
+ self.get_cache_graph(session)
457
+ session.id = self.id
458
+ self.assigned = True
459
+ else:
460
+ self.capture_new_graph(session)
461
+
462
+ @abstractmethod
463
+ def release_graph(self, session: Any): ...
464
+
465
+ @abstractmethod
466
+ def get_cache_graph(self, session: Any):
467
+ pass
468
+
469
+ @abstractmethod
470
+ def capture_new_graph(self, session: Any):
471
+ pass
472
+
473
+
474
  class TorchProfiler:
475
  def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
476
  self.debug = debug
AR/models/t2s_model_flash_attn.py CHANGED
@@ -2,13 +2,13 @@
2
  Modified From https://github.com/XXXXRT666/GPT-SoVITS
3
  """
4
 
 
5
  import os
6
  import time
7
  import traceback
8
- from typing import Dict, List, Optional, Tuple
9
 
10
  import flash_attn # type: ignore
11
- import gradio as gr
12
  import torch
13
  import torch.nn as nn
14
  from tqdm import tqdm
@@ -20,6 +20,7 @@ from AR.models.embedding import TokenEmbedding
20
  from AR.models.structs import T2SRequest, T2SResult, T2SSession
21
  from AR.models.t2s_model_abc import (
22
  AttentionABC,
 
23
  FeedForward,
24
  KVCacheABC,
25
  KVCacheNHD,
@@ -121,6 +122,7 @@ class T2SDecoder(T2SDecoderABC):
121
  max_batch_size=10,
122
  **kwds,
123
  ) -> None:
 
124
  super().__init__()
125
 
126
  hidden_dim = config["model"]["hidden_dim"]
@@ -205,6 +207,42 @@ class T2SDecoder(T2SDecoderABC):
205
  return list(), dict()
206
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  class CUDAGraphRunner:
209
  def __init__(
210
  self,
@@ -212,70 +250,51 @@ class CUDAGraphRunner:
212
  device: torch.device = torch.device("cpu"),
213
  dtype: torch.dtype = torch.float32,
214
  ) -> None:
215
- assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"}
216
- assert dtype in {torch.float16, torch.bfloat16, torch.float32}
217
  self.device = device
218
  self.dtype = dtype
219
 
220
- self.decoder_path: os.PathLike
221
  self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
222
 
223
- self.graph: Optional[torch.cuda.CUDAGraph] = None
224
- self.xy_pos_ = torch.rand((1, 1, decoder_model.embedding_dim), device=device).to(dtype)
225
- self.xy_dec_ = torch.rand((1, 1, decoder_model.embedding_dim), device=device).to(dtype)
226
- self.kv_cache = decoder_model.init_cache(1)
227
- self.input_pos = torch.tensor([10]).int().cuda()
228
 
229
  def _handle_request(self, request: T2SRequest):
230
  with self.device:
231
- for i in self.kv_cache:
232
- i.empty()
233
-
234
  decoder = self.decoder_model
235
  session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
236
- self.input_pos.copy_(session.input_pos)
237
 
238
  t1 = 0.0
239
  infer_speed = 0.0
240
- y = session.y
241
- bsz = y.size(0)
242
  torch_profiler = TorchProfiler(request.debug)
243
  with torch_profiler.profiler():
244
  for idx in tqdm(range(1500)):
245
  if idx == 0:
246
- xy_dec = decoder.h.prefill(session.xy_pos, session.attn_mask_nested, self.kv_cache)
 
247
  xy_dec = torch.stack([t[[-1]] for t in xy_dec.unbind()])
248
  else:
249
- if request.use_cuda_graph and self.graph is None and torch.cuda.is_available():
250
- self.xy_pos_.copy_(session.xy_pos)
251
- args, kwds = decoder.pre_forward(session)
252
- self.graph = decoder.capture(
253
- self.input_pos,
254
- self.xy_pos_,
255
- self.xy_dec_,
256
- kv_caches=self.kv_cache,
257
- *args,
258
- **kwds,
259
- )
260
 
261
  with torch_profiler.record("AR"):
262
- if self.graph:
263
- self.xy_pos_.copy_(session.xy_pos)
264
- self.graph.replay()
265
- xy_dec = self.xy_dec_.clone()
266
  else:
267
  args, kwds = decoder.pre_forward(session)
268
  xy_dec = decoder.h.forward(
269
- self.input_pos,
270
  session.xy_pos,
271
- self.kv_cache,
272
  *args,
273
  **kwds,
274
  )
275
 
276
  decoder.post_forward(idx, session)
277
  logits = decoder.ar_predict_layer(xy_dec[:, -1])
278
- self.input_pos.add_(1)
279
 
280
  if idx == 0:
281
  logits[:, -1] = float("-inf")
@@ -322,7 +341,7 @@ class CUDAGraphRunner:
322
  request.early_stop_num != -1
323
  and (session.y.size(1) - session.y_len) > request.early_stop_num
324
  ) or idx == 1499:
325
- for i in range(bsz):
326
  if not session.completed[i].item():
327
  session.y_results[i] = session.y[i, session.y_len :]
328
  session.completed[i] = True
@@ -330,7 +349,7 @@ class CUDAGraphRunner:
330
 
331
  with torch_profiler.record("NextPos"):
332
  y_emb = decoder.ar_audio_embedding(session.y[:, -1:])
333
- session.xy_pos = decoder.ar_audio_position.forward(self.input_pos - session.x_lens, y_emb)
334
 
335
  if idx == 2:
336
  torch_profiler.start()
@@ -359,8 +378,11 @@ class CUDAGraphRunner:
359
  torch.xpu.empty_cache()
360
  case "mtia":
361
  torch.mtia.empty_cache()
 
 
362
 
363
  torch_profiler.end()
 
364
  return session.y_results[: request.valid_length], infer_speed
365
 
366
  def generate(self, request: T2SRequest):
 
2
  Modified From https://github.com/XXXXRT666/GPT-SoVITS
3
  """
4
 
5
+ import gc
6
  import os
7
  import time
8
  import traceback
9
+ from typing import Dict, List, Tuple
10
 
11
  import flash_attn # type: ignore
 
12
  import torch
13
  import torch.nn as nn
14
  from tqdm import tqdm
 
20
  from AR.models.structs import T2SRequest, T2SResult, T2SSession
21
  from AR.models.t2s_model_abc import (
22
  AttentionABC,
23
+ CUDAGraphCacheABC,
24
  FeedForward,
25
  KVCacheABC,
26
  KVCacheNHD,
 
122
  max_batch_size=10,
123
  **kwds,
124
  ) -> None:
125
+ assert torch.cuda.is_available()
126
  super().__init__()
127
 
128
  hidden_dim = config["model"]["hidden_dim"]
 
207
  return list(), dict()
208
 
209
 
210
+ class CUDAGraphCache(CUDAGraphCacheABC):
211
+ def __init__(
212
+ self,
213
+ decoder: T2SDecoderABC,
214
+ device: torch.device = torch.device("cpu"),
215
+ dtype: torch.dtype = torch.float32,
216
+ ) -> None:
217
+ super().__init__(decoder, device, dtype)
218
+
219
+ def release_graph(self, session: T2SSession):
220
+ if session.id != self.id:
221
+ self.assigned = False
222
+ else:
223
+ del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
224
+
225
+ def get_cache_graph(self, session: T2SSession):
226
+ assert self.graph
227
+ session.graph = self.graph
228
+
229
+ session.xy_pos_ = self.xy_pos
230
+ session.xy_dec_ = self.xy_dec
231
+ session.input_pos = self.input_pos.copy_(session.input_pos)
232
+
233
+ for cache, cache_ in zip(self.kv_cache, session.kv_cache):
234
+ cache.sync_cache(cache_)
235
+
236
+ def capture_new_graph(self, session: T2SSession):
237
+ session.xy_pos_ = self.xy_pos.clone()
238
+ session.xy_dec_ = self.xy_dec.clone()
239
+ session.input_pos = self.input_pos.clone().copy_(session.input_pos)
240
+
241
+ args, kwds = self.decoder.pre_forward(session)
242
+ graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, *args, **kwds)
243
+ session.graph = graph
244
+
245
+
246
  class CUDAGraphRunner:
247
  def __init__(
248
  self,
 
250
  device: torch.device = torch.device("cpu"),
251
  dtype: torch.dtype = torch.float32,
252
  ) -> None:
253
+ assert device.type == "cuda"
 
254
  self.device = device
255
  self.dtype = dtype
256
 
 
257
  self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
258
 
259
+ self.graphcache = CUDAGraphCache(decoder_model, device, dtype)
 
 
 
 
260
 
261
  def _handle_request(self, request: T2SRequest):
262
  with self.device:
 
 
 
263
  decoder = self.decoder_model
264
  session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
 
265
 
266
  t1 = 0.0
267
  infer_speed = 0.0
268
+
 
269
  torch_profiler = TorchProfiler(request.debug)
270
  with torch_profiler.profiler():
271
  for idx in tqdm(range(1500)):
272
  if idx == 0:
273
+ session.kv_cache = decoder.init_cache(session.bsz)
274
+ xy_dec = decoder.h.prefill(session.xy_pos, session.attn_mask_nested, session.kv_cache)
275
  xy_dec = torch.stack([t[[-1]] for t in xy_dec.unbind()])
276
  else:
277
+ if request.use_cuda_graph and session.graph is None and torch.cuda.is_available():
278
+ self.graphcache.assign_graph(session)
 
 
 
 
 
 
 
 
 
279
 
280
  with torch_profiler.record("AR"):
281
+ if session.graph:
282
+ session.xy_pos_.copy_(session.xy_pos)
283
+ session.graph.replay()
284
+ xy_dec = session.xy_dec_.clone()
285
  else:
286
  args, kwds = decoder.pre_forward(session)
287
  xy_dec = decoder.h.forward(
288
+ session.input_pos,
289
  session.xy_pos,
290
+ session.kv_cache,
291
  *args,
292
  **kwds,
293
  )
294
 
295
  decoder.post_forward(idx, session)
296
  logits = decoder.ar_predict_layer(xy_dec[:, -1])
297
+ session.input_pos.add_(1)
298
 
299
  if idx == 0:
300
  logits[:, -1] = float("-inf")
 
341
  request.early_stop_num != -1
342
  and (session.y.size(1) - session.y_len) > request.early_stop_num
343
  ) or idx == 1499:
344
+ for i in range(session.bsz):
345
  if not session.completed[i].item():
346
  session.y_results[i] = session.y[i, session.y_len :]
347
  session.completed[i] = True
 
349
 
350
  with torch_profiler.record("NextPos"):
351
  y_emb = decoder.ar_audio_embedding(session.y[:, -1:])
352
+ session.xy_pos = decoder.ar_audio_position.forward(session.input_pos - session.x_lens, y_emb)
353
 
354
  if idx == 2:
355
  torch_profiler.start()
 
378
  torch.xpu.empty_cache()
379
  case "mtia":
380
  torch.mtia.empty_cache()
381
+ case "cpu":
382
+ gc.collect()
383
 
384
  torch_profiler.end()
385
+ self.graphcache.release_graph(session)
386
  return session.y_results[: request.valid_length], infer_speed
387
 
388
  def generate(self, request: T2SRequest):
inference_webui.py CHANGED
@@ -1,7 +1,47 @@
 
1
  import os
 
 
 
2
 
3
- os.makedirs("pretrained_models", exist_ok=True)
 
 
 
 
 
 
 
 
4
  from huggingface_hub import snapshot_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  snapshot_download(
7
  repo_id="lj1995/GPT-SoVITS",
@@ -27,75 +67,20 @@ snapshot_download(
27
  allow_patterns="v2Pro/s2Gv2ProPlus.pth",
28
  local_dir="pretrained_models",
29
  )
30
- import logging
31
- import traceback
32
-
33
- logging.getLogger("markdown_it").setLevel(logging.ERROR)
34
- logging.getLogger("urllib3").setLevel(logging.ERROR)
35
- logging.getLogger("httpcore").setLevel(logging.ERROR)
36
- logging.getLogger("httpx").setLevel(logging.ERROR)
37
- logging.getLogger("asyncio").setLevel(logging.ERROR)
38
- logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
39
- logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
40
- logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
41
- logging.getLogger("python_multipart.multipart").setLevel(logging.ERROR)
42
- logging.getLogger("split_lang.split.splitter").setLevel(logging.ERROR)
43
-
44
- import nltk
45
- import torchaudio
46
-
47
- from text.LangSegmenter import LangSegmenter
48
-
49
- nltk.download("averaged_perceptron_tagger_eng")
50
- import json
51
- import os
52
- import pdb
53
- import re
54
- import sys
55
- import threading
56
-
57
- import LangSegment
58
- import spaces
59
- import torch
60
-
61
- lock = threading.Lock()
62
 
63
  version = "v2" # os.environ.get("version","v2")
64
  cnhubert_base_path = os.environ.get("cnhubert_base_path", "pretrained_models/chinese-hubert-base")
65
  bert_path = os.environ.get("bert_path", "pretrained_models/chinese-roberta-wwm-ext-large")
66
-
67
- punctuation = set(["!", "?", "…", ",", ".", "-", " "])
68
- import gradio as gr
69
- import gradio.themes as themes
70
- import librosa
71
- import numpy as np
72
- from gradio.themes.utils import fonts
73
- from transformers import AutoModelForMaskedLM, AutoTokenizer
74
-
75
- from feature_extractor import cnhubert
76
-
77
  cnhubert.cnhubert_base_path = cnhubert_base_path
78
 
79
- from time import time as ttime
80
 
81
- from AR.models.structs import T2SRequest
82
- from AR.models.t2s_model_flash_attn import CUDAGraphRunner
83
- from module.mel_processing import spectrogram_torch
84
- from module.models import SynthesizerTrn
85
- from text import cleaned_text_to_sequence
86
- from text.cleaner import clean_text
87
- from tools.i18n.i18n import I18nAuto, scan_language_list
88
- from tools.my_utils import load_audio
89
 
90
- # language=os.environ.get("language","Auto")
91
- # language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
92
  i18n = I18nAuto(language="Auto")
93
 
94
- # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
95
-
96
  if torch.cuda.is_available():
97
  device = "cuda"
98
- is_half = True # eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
99
  else:
100
  device = "cpu"
101
  is_half = False
@@ -125,7 +110,7 @@ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
125
 
126
  tokenizer = AutoTokenizer.from_pretrained(bert_path)
127
  bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
128
- if is_half == True:
129
  bert_model = bert_model.half().to(device)
130
  else:
131
  bert_model = bert_model.to(device)
@@ -176,7 +161,7 @@ class DictToAttrRecursive(dict):
176
 
177
 
178
  ssl_model = cnhubert.get_model()
179
- if is_half == True:
180
  ssl_model = ssl_model.half().to(device)
181
  else:
182
  ssl_model = ssl_model.to(device)
@@ -248,7 +233,7 @@ def change_gpt_weights(gpt_path):
248
 
249
 
250
  change_gpt_weights("pretrained_models/s1v3.ckpt")
251
- from sv import SV
252
 
253
  sv_cn_model = SV(device, is_half)
254
 
@@ -288,7 +273,7 @@ def get_spepc(hps, filename, dtype, device, is_v2pro=False):
288
  center=False,
289
  )
290
  spec = spec.to(dtype)
291
- if is_v2pro == True:
292
  audio = resample(audio, sr1, 16000, device).to(dtype)
293
  return spec, audio
294
 
@@ -300,7 +285,7 @@ def clean_text_inf(text, language, version):
300
  return phones, word2ph, norm_text
301
 
302
 
303
- dtype = torch.float16 if is_half == True else torch.float32
304
 
305
 
306
  def get_bert_inf(phones, word2ph, norm_text, language):
@@ -310,27 +295,13 @@ def get_bert_inf(phones, word2ph, norm_text, language):
310
  else:
311
  bert = torch.zeros(
312
  (1024, len(phones)),
313
- dtype=torch.float16 if is_half == True else torch.float32,
314
  ).to(device)
315
 
316
  return bert
317
 
318
 
319
- splits = {
320
- ",",
321
- "。",
322
- "?",
323
- "!",
324
- ",",
325
- ".",
326
- "?",
327
- "!",
328
- "~",
329
- ":",
330
- ":",
331
- "—",
332
- "…",
333
- }
334
 
335
 
336
  def get_first(text):
@@ -339,9 +310,6 @@ def get_first(text):
339
  return text
340
 
341
 
342
- from text import chinese
343
-
344
-
345
  def get_phones_and_bert(text, language, version, final=False):
346
  if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
347
  formattext = text
@@ -363,7 +331,7 @@ def get_phones_and_bert(text, language, version, final=False):
363
  phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
364
  bert = torch.zeros(
365
  (1024, len(phones)),
366
- dtype=torch.float16 if is_half == True else torch.float32,
367
  ).to(device)
368
  elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
369
  textlist = []
@@ -475,7 +443,7 @@ def get_tts_wav(
475
  print(i18n("实际输入的目标文本:"), text)
476
  zero_wav = np.zeros(
477
  int(hps.data.sampling_rate * 0.3),
478
- dtype=np.float16 if is_half == True else np.float32,
479
  )
480
  if not ref_free:
481
  with torch.no_grad():
@@ -485,7 +453,7 @@ def get_tts_wav(
485
  raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
486
  wav16k = torch.from_numpy(wav16k)
487
  zero_wav_torch = torch.from_numpy(zero_wav)
488
- if is_half == True:
489
  wav16k = wav16k.half().to(device)
490
  zero_wav_torch = zero_wav_torch.half().to(device)
491
  else:
@@ -544,10 +512,10 @@ def get_tts_wav(
544
  t2 = ttime()
545
  # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
546
  # print(cache.keys(),if_freeze)
547
- if i_text in cache and if_freeze == True:
548
  pred_semantic = cache[i_text]
549
  else:
550
- with torch.no_grad(), lock:
551
  t2s_request = T2SRequest(
552
  [all_phoneme_ids.squeeze(0)],
553
  all_phoneme_len,
@@ -564,9 +532,8 @@ def get_tts_wav(
564
  t2s_result = t2s_model.generate(t2s_request)
565
 
566
  if t2s_result.exception is not None:
567
- print(t2s_result.exception)
568
  print(t2s_result.traceback)
569
- raise RuntimeError("")
570
 
571
  infer_speed.append(t2s_result.infer_speed)
572
  pred_semantic = t2s_result.result
@@ -608,8 +575,8 @@ def get_tts_wav(
608
  t.extend([t2 - t1, t3 - t2, t4 - t3])
609
  t1 = ttime()
610
  print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
611
- gr.Info(f"Infer Speed: {sum(infer_speed) / len(infer_speed):.2f} Token/s")
612
- gr.Info("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])), duration=4)
613
  yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
614
 
615
 
@@ -713,7 +680,7 @@ def cut5(inp):
713
 
714
  def custom_sort_key(s):
715
  # 使用正则表达式提取字符串中的数字部分和非数字部分
716
- parts = re.split("(\d+)", s)
717
  # 将数字部分转换为整数,非数字部分保持不变
718
  parts = [int(part) if part.isdigit() else part for part in parts]
719
  return parts
 
1
+ import logging
2
  import os
3
+ import re
4
+ import traceback
5
+ from time import time as ttime
6
 
7
+ import gradio as gr
8
+ import gradio.themes as themes
9
+ import librosa
10
+ import nltk
11
+ import numpy as np
12
+ import spaces
13
+ import torch
14
+ import torchaudio
15
+ from gradio.themes.utils import fonts
16
  from huggingface_hub import snapshot_download
17
+ from transformers.models.auto.modeling_auto import AutoModelForMaskedLM
18
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
19
+
20
+ from AR.models.structs import T2SRequest
21
+ from AR.models.t2s_model_flash_attn import CUDAGraphRunner
22
+ from feature_extractor import cnhubert
23
+ from module.mel_processing import spectrogram_torch
24
+ from module.models import SynthesizerTrn
25
+ from sv import SV
26
+ from text import chinese, cleaned_text_to_sequence
27
+ from text.cleaner import clean_text
28
+ from text.LangSegmenter import LangSegmenter
29
+ from tools.i18n.i18n import I18nAuto
30
+
31
+ logging.getLogger("markdown_it").setLevel(logging.ERROR)
32
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
33
+ logging.getLogger("httpcore").setLevel(logging.ERROR)
34
+ logging.getLogger("httpx").setLevel(logging.ERROR)
35
+ logging.getLogger("asyncio").setLevel(logging.ERROR)
36
+ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
37
+ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
38
+ logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
39
+ logging.getLogger("python_multipart.multipart").setLevel(logging.ERROR)
40
+ logging.getLogger("split_lang.split.splitter").setLevel(logging.ERROR)
41
+
42
+ os.makedirs("pretrained_models", exist_ok=True)
43
+
44
+ nltk.download("averaged_perceptron_tagger_eng")
45
 
46
  snapshot_download(
47
  repo_id="lj1995/GPT-SoVITS",
 
67
  allow_patterns="v2Pro/s2Gv2ProPlus.pth",
68
  local_dir="pretrained_models",
69
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  version = "v2" # os.environ.get("version","v2")
72
  cnhubert_base_path = os.environ.get("cnhubert_base_path", "pretrained_models/chinese-hubert-base")
73
  bert_path = os.environ.get("bert_path", "pretrained_models/chinese-roberta-wwm-ext-large")
 
 
 
 
 
 
 
 
 
 
 
74
  cnhubert.cnhubert_base_path = cnhubert_base_path
75
 
76
+ punctuation = set(["!", "?", "…", ",", ".", "-", " "])
77
 
 
 
 
 
 
 
 
 
78
 
 
 
79
  i18n = I18nAuto(language="Auto")
80
 
 
 
81
  if torch.cuda.is_available():
82
  device = "cuda"
83
+ is_half = True
84
  else:
85
  device = "cpu"
86
  is_half = False
 
110
 
111
  tokenizer = AutoTokenizer.from_pretrained(bert_path)
112
  bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
113
+ if is_half is True:
114
  bert_model = bert_model.half().to(device)
115
  else:
116
  bert_model = bert_model.to(device)
 
161
 
162
 
163
  ssl_model = cnhubert.get_model()
164
+ if is_half is True:
165
  ssl_model = ssl_model.half().to(device)
166
  else:
167
  ssl_model = ssl_model.to(device)
 
233
 
234
 
235
  change_gpt_weights("pretrained_models/s1v3.ckpt")
236
+
237
 
238
  sv_cn_model = SV(device, is_half)
239
 
 
273
  center=False,
274
  )
275
  spec = spec.to(dtype)
276
+ if is_v2pro is True:
277
  audio = resample(audio, sr1, 16000, device).to(dtype)
278
  return spec, audio
279
 
 
285
  return phones, word2ph, norm_text
286
 
287
 
288
+ dtype = torch.float16 if is_half is True else torch.float32
289
 
290
 
291
  def get_bert_inf(phones, word2ph, norm_text, language):
 
295
  else:
296
  bert = torch.zeros(
297
  (1024, len(phones)),
298
+ dtype=torch.float16 if is_half is True else torch.float32,
299
  ).to(device)
300
 
301
  return bert
302
 
303
 
304
+ splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
 
307
  def get_first(text):
 
310
  return text
311
 
312
 
 
 
 
313
  def get_phones_and_bert(text, language, version, final=False):
314
  if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
315
  formattext = text
 
331
  phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
332
  bert = torch.zeros(
333
  (1024, len(phones)),
334
+ dtype=torch.float16 if is_half is True else torch.float32,
335
  ).to(device)
336
  elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
337
  textlist = []
 
443
  print(i18n("实际输入的目标文本:"), text)
444
  zero_wav = np.zeros(
445
  int(hps.data.sampling_rate * 0.3),
446
+ dtype=np.float16 if is_half is True else np.float32,
447
  )
448
  if not ref_free:
449
  with torch.no_grad():
 
453
  raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
454
  wav16k = torch.from_numpy(wav16k)
455
  zero_wav_torch = torch.from_numpy(zero_wav)
456
+ if is_half is True:
457
  wav16k = wav16k.half().to(device)
458
  zero_wav_torch = zero_wav_torch.half().to(device)
459
  else:
 
512
  t2 = ttime()
513
  # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
514
  # print(cache.keys(),if_freeze)
515
+ if i_text in cache and if_freeze is True:
516
  pred_semantic = cache[i_text]
517
  else:
518
+ with torch.no_grad():
519
  t2s_request = T2SRequest(
520
  [all_phoneme_ids.squeeze(0)],
521
  all_phoneme_len,
 
532
  t2s_result = t2s_model.generate(t2s_request)
533
 
534
  if t2s_result.exception is not None:
 
535
  print(t2s_result.traceback)
536
+ raise t2s_result.exception
537
 
538
  infer_speed.append(t2s_result.infer_speed)
539
  pred_semantic = t2s_result.result
 
575
  t.extend([t2 - t1, t3 - t2, t4 - t3])
576
  t1 = ttime()
577
  print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
578
+ gr.Info(f"{sum(infer_speed) / len(infer_speed):.2f} Token/s", title="Infer Speed")
579
+ gr.Info("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])), title="Time Stamps")
580
  yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
581
 
582
 
 
680
 
681
  def custom_sort_key(s):
682
  # 使用正则表达式提取字符串中的数字部分和非数字部分
683
+ parts = re.split(r"(\d+)", s)
684
  # 将数字部分转换为整数,非数字部分保持不变
685
  parts = [int(part) if part.isdigit() else part for part in parts]
686
  return parts