test with global model
Browse files- sparktts/modules/speaker/perceiver_encoder.py +9 -18
- webui.py +11 -6
sparktts/modules/speaker/perceiver_encoder.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15 |
|
16 |
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
17 |
|
|
|
18 |
from functools import wraps
|
19 |
|
20 |
import torch
|
@@ -45,21 +46,6 @@ def once(fn):
|
|
45 |
|
46 |
print_once = once(print)
|
47 |
|
48 |
-
# Define config class at module level
|
49 |
-
class EfficientAttentionConfig:
|
50 |
-
def __init__(self, enable_flash, enable_math, enable_mem_efficient):
|
51 |
-
self.enable_flash = enable_flash
|
52 |
-
self.enable_math = enable_math
|
53 |
-
self.enable_mem_efficient = enable_mem_efficient
|
54 |
-
|
55 |
-
def _asdict(self):
|
56 |
-
return {
|
57 |
-
'enable_flash': self.enable_flash,
|
58 |
-
'enable_math': self.enable_math,
|
59 |
-
'enable_mem_efficient': self.enable_mem_efficient
|
60 |
-
}
|
61 |
-
|
62 |
-
|
63 |
# main class
|
64 |
|
65 |
|
@@ -77,7 +63,12 @@ class Attend(nn.Module):
|
|
77 |
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
78 |
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
79 |
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
81 |
self.cuda_config = None
|
82 |
|
83 |
if not torch.cuda.is_available() or not use_flash:
|
@@ -89,12 +80,12 @@ class Attend(nn.Module):
|
|
89 |
print_once(
|
90 |
"A100 GPU detected, using flash attention if input tensor is on cuda"
|
91 |
)
|
92 |
-
self.cuda_config =
|
93 |
else:
|
94 |
print_once(
|
95 |
"Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
|
96 |
)
|
97 |
-
self.cuda_config =
|
98 |
|
99 |
def get_mask(self, n, device):
|
100 |
if exists(self.mask) and self.mask.shape[-1] >= n:
|
|
|
15 |
|
16 |
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
17 |
|
18 |
+
from collections import namedtuple
|
19 |
from functools import wraps
|
20 |
|
21 |
import torch
|
|
|
46 |
|
47 |
print_once = once(print)
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
# main class
|
50 |
|
51 |
|
|
|
63 |
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
64 |
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
65 |
|
66 |
+
# determine efficient attention configs for cuda and cpu
|
67 |
+
self.config = namedtuple(
|
68 |
+
"EfficientAttentionConfig",
|
69 |
+
["enable_flash", "enable_math", "enable_mem_efficient"],
|
70 |
+
)
|
71 |
+
self.cpu_config = self.config(True, True, True)
|
72 |
self.cuda_config = None
|
73 |
|
74 |
if not torch.cuda.is_available() or not use_flash:
|
|
|
80 |
print_once(
|
81 |
"A100 GPU detected, using flash attention if input tensor is on cuda"
|
82 |
)
|
83 |
+
self.cuda_config = self.config(True, False, False)
|
84 |
else:
|
85 |
print_once(
|
86 |
"Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
|
87 |
)
|
88 |
+
self.cuda_config = self.config(False, True, True)
|
89 |
|
90 |
def get_mask(self, n, device):
|
91 |
if exists(self.mask) and self.mask.shape[-1] >= n:
|
webui.py
CHANGED
@@ -25,6 +25,8 @@ from sparktts.utils.token_parser import LEVELS_MAP_UI
|
|
25 |
from huggingface_hub import snapshot_download
|
26 |
import spaces
|
27 |
|
|
|
|
|
28 |
def initialize_model(model_dir=None, device="cpu"):
|
29 |
"""Load the model once at the beginning."""
|
30 |
|
@@ -38,8 +40,7 @@ def initialize_model(model_dir=None, device="cpu"):
|
|
38 |
return model
|
39 |
|
40 |
@spaces.GPU
|
41 |
-
def generate(
|
42 |
-
text,
|
43 |
prompt_speech,
|
44 |
prompt_text,
|
45 |
gender,
|
@@ -47,6 +48,10 @@ def generate(model,
|
|
47 |
speed,
|
48 |
):
|
49 |
"""Generate audio from text."""
|
|
|
|
|
|
|
|
|
50 |
# if gpu available, move model to gpu
|
51 |
if torch.cuda.is_available():
|
52 |
model = model.to("cuda")
|
@@ -66,7 +71,6 @@ def generate(model,
|
|
66 |
|
67 |
def run_tts(
|
68 |
text,
|
69 |
-
model,
|
70 |
prompt_text=None,
|
71 |
prompt_speech=None,
|
72 |
gender=None,
|
@@ -90,7 +94,7 @@ def run_tts(
|
|
90 |
logging.info("Starting inference...")
|
91 |
|
92 |
# Perform inference and save the output audio
|
93 |
-
wav = generate(
|
94 |
prompt_speech,
|
95 |
prompt_text,
|
96 |
gender,
|
@@ -109,6 +113,9 @@ def build_ui(model_dir, device=0):
|
|
109 |
|
110 |
# Initialize model
|
111 |
model = initialize_model(model_dir, device=device)
|
|
|
|
|
|
|
112 |
|
113 |
# Define callback function for voice cloning
|
114 |
def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
|
@@ -123,7 +130,6 @@ def build_ui(model_dir, device=0):
|
|
123 |
|
124 |
audio_output_path = run_tts(
|
125 |
text,
|
126 |
-
model,
|
127 |
prompt_text=prompt_text_clean,
|
128 |
prompt_speech=prompt_speech
|
129 |
)
|
@@ -141,7 +147,6 @@ def build_ui(model_dir, device=0):
|
|
141 |
speed_val = LEVELS_MAP_UI[int(speed)]
|
142 |
audio_output_path = run_tts(
|
143 |
text,
|
144 |
-
model,
|
145 |
gender=gender,
|
146 |
pitch=pitch_val,
|
147 |
speed=speed_val
|
|
|
25 |
from huggingface_hub import snapshot_download
|
26 |
import spaces
|
27 |
|
28 |
+
MODEL = None
|
29 |
+
|
30 |
def initialize_model(model_dir=None, device="cpu"):
|
31 |
"""Load the model once at the beginning."""
|
32 |
|
|
|
40 |
return model
|
41 |
|
42 |
@spaces.GPU
|
43 |
+
def generate(text,
|
|
|
44 |
prompt_speech,
|
45 |
prompt_text,
|
46 |
gender,
|
|
|
48 |
speed,
|
49 |
):
|
50 |
"""Generate audio from text."""
|
51 |
+
|
52 |
+
global MODEL
|
53 |
+
model = MODEL
|
54 |
+
|
55 |
# if gpu available, move model to gpu
|
56 |
if torch.cuda.is_available():
|
57 |
model = model.to("cuda")
|
|
|
71 |
|
72 |
def run_tts(
|
73 |
text,
|
|
|
74 |
prompt_text=None,
|
75 |
prompt_speech=None,
|
76 |
gender=None,
|
|
|
94 |
logging.info("Starting inference...")
|
95 |
|
96 |
# Perform inference and save the output audio
|
97 |
+
wav = generate(text,
|
98 |
prompt_speech,
|
99 |
prompt_text,
|
100 |
gender,
|
|
|
113 |
|
114 |
# Initialize model
|
115 |
model = initialize_model(model_dir, device=device)
|
116 |
+
|
117 |
+
global MODEL
|
118 |
+
MODEL = model
|
119 |
|
120 |
# Define callback function for voice cloning
|
121 |
def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
|
|
|
130 |
|
131 |
audio_output_path = run_tts(
|
132 |
text,
|
|
|
133 |
prompt_text=prompt_text_clean,
|
134 |
prompt_speech=prompt_speech
|
135 |
)
|
|
|
147 |
speed_val = LEVELS_MAP_UI[int(speed)]
|
148 |
audio_output_path = run_tts(
|
149 |
text,
|
|
|
150 |
gender=gender,
|
151 |
pitch=pitch_val,
|
152 |
speed=speed_val
|