thunnai commited on
Commit
f4176b0
·
1 Parent(s): d123787

test with global model

Browse files
sparktts/modules/speaker/perceiver_encoder.py CHANGED
@@ -15,6 +15,7 @@
15
 
16
  # Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
17
 
 
18
  from functools import wraps
19
 
20
  import torch
@@ -45,21 +46,6 @@ def once(fn):
45
 
46
  print_once = once(print)
47
 
48
- # Define config class at module level
49
- class EfficientAttentionConfig:
50
- def __init__(self, enable_flash, enable_math, enable_mem_efficient):
51
- self.enable_flash = enable_flash
52
- self.enable_math = enable_math
53
- self.enable_mem_efficient = enable_mem_efficient
54
-
55
- def _asdict(self):
56
- return {
57
- 'enable_flash': self.enable_flash,
58
- 'enable_math': self.enable_math,
59
- 'enable_mem_efficient': self.enable_mem_efficient
60
- }
61
-
62
-
63
  # main class
64
 
65
 
@@ -77,7 +63,12 @@ class Attend(nn.Module):
77
  use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
78
  ), "in order to use flash attention, you must be using pytorch 2.0 or above"
79
 
80
- self.cpu_config = EfficientAttentionConfig(True, True, True)
 
 
 
 
 
81
  self.cuda_config = None
82
 
83
  if not torch.cuda.is_available() or not use_flash:
@@ -89,12 +80,12 @@ class Attend(nn.Module):
89
  print_once(
90
  "A100 GPU detected, using flash attention if input tensor is on cuda"
91
  )
92
- self.cuda_config = EfficientAttentionConfig(True, False, False)
93
  else:
94
  print_once(
95
  "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
96
  )
97
- self.cuda_config = EfficientAttentionConfig(False, True, True)
98
 
99
  def get_mask(self, n, device):
100
  if exists(self.mask) and self.mask.shape[-1] >= n:
 
15
 
16
  # Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
17
 
18
+ from collections import namedtuple
19
  from functools import wraps
20
 
21
  import torch
 
46
 
47
  print_once = once(print)
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # main class
50
 
51
 
 
63
  use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
64
  ), "in order to use flash attention, you must be using pytorch 2.0 or above"
65
 
66
+ # determine efficient attention configs for cuda and cpu
67
+ self.config = namedtuple(
68
+ "EfficientAttentionConfig",
69
+ ["enable_flash", "enable_math", "enable_mem_efficient"],
70
+ )
71
+ self.cpu_config = self.config(True, True, True)
72
  self.cuda_config = None
73
 
74
  if not torch.cuda.is_available() or not use_flash:
 
80
  print_once(
81
  "A100 GPU detected, using flash attention if input tensor is on cuda"
82
  )
83
+ self.cuda_config = self.config(True, False, False)
84
  else:
85
  print_once(
86
  "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
87
  )
88
+ self.cuda_config = self.config(False, True, True)
89
 
90
  def get_mask(self, n, device):
91
  if exists(self.mask) and self.mask.shape[-1] >= n:
webui.py CHANGED
@@ -25,6 +25,8 @@ from sparktts.utils.token_parser import LEVELS_MAP_UI
25
  from huggingface_hub import snapshot_download
26
  import spaces
27
 
 
 
28
  def initialize_model(model_dir=None, device="cpu"):
29
  """Load the model once at the beginning."""
30
 
@@ -38,8 +40,7 @@ def initialize_model(model_dir=None, device="cpu"):
38
  return model
39
 
40
  @spaces.GPU
41
- def generate(model,
42
- text,
43
  prompt_speech,
44
  prompt_text,
45
  gender,
@@ -47,6 +48,10 @@ def generate(model,
47
  speed,
48
  ):
49
  """Generate audio from text."""
 
 
 
 
50
  # if gpu available, move model to gpu
51
  if torch.cuda.is_available():
52
  model = model.to("cuda")
@@ -66,7 +71,6 @@ def generate(model,
66
 
67
  def run_tts(
68
  text,
69
- model,
70
  prompt_text=None,
71
  prompt_speech=None,
72
  gender=None,
@@ -90,7 +94,7 @@ def run_tts(
90
  logging.info("Starting inference...")
91
 
92
  # Perform inference and save the output audio
93
- wav = generate(model, text,
94
  prompt_speech,
95
  prompt_text,
96
  gender,
@@ -109,6 +113,9 @@ def build_ui(model_dir, device=0):
109
 
110
  # Initialize model
111
  model = initialize_model(model_dir, device=device)
 
 
 
112
 
113
  # Define callback function for voice cloning
114
  def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
@@ -123,7 +130,6 @@ def build_ui(model_dir, device=0):
123
 
124
  audio_output_path = run_tts(
125
  text,
126
- model,
127
  prompt_text=prompt_text_clean,
128
  prompt_speech=prompt_speech
129
  )
@@ -141,7 +147,6 @@ def build_ui(model_dir, device=0):
141
  speed_val = LEVELS_MAP_UI[int(speed)]
142
  audio_output_path = run_tts(
143
  text,
144
- model,
145
  gender=gender,
146
  pitch=pitch_val,
147
  speed=speed_val
 
25
  from huggingface_hub import snapshot_download
26
  import spaces
27
 
28
+ MODEL = None
29
+
30
  def initialize_model(model_dir=None, device="cpu"):
31
  """Load the model once at the beginning."""
32
 
 
40
  return model
41
 
42
  @spaces.GPU
43
+ def generate(text,
 
44
  prompt_speech,
45
  prompt_text,
46
  gender,
 
48
  speed,
49
  ):
50
  """Generate audio from text."""
51
+
52
+ global MODEL
53
+ model = MODEL
54
+
55
  # if gpu available, move model to gpu
56
  if torch.cuda.is_available():
57
  model = model.to("cuda")
 
71
 
72
  def run_tts(
73
  text,
 
74
  prompt_text=None,
75
  prompt_speech=None,
76
  gender=None,
 
94
  logging.info("Starting inference...")
95
 
96
  # Perform inference and save the output audio
97
+ wav = generate(text,
98
  prompt_speech,
99
  prompt_text,
100
  gender,
 
113
 
114
  # Initialize model
115
  model = initialize_model(model_dir, device=device)
116
+
117
+ global MODEL
118
+ MODEL = model
119
 
120
  # Define callback function for voice cloning
121
  def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
 
130
 
131
  audio_output_path = run_tts(
132
  text,
 
133
  prompt_text=prompt_text_clean,
134
  prompt_speech=prompt_speech
135
  )
 
147
  speed_val = LEVELS_MAP_UI[int(speed)]
148
  audio_output_path = run_tts(
149
  text,
 
150
  gender=gender,
151
  pitch=pitch_val,
152
  speed=speed_val