Commit
·
1896698
1
Parent(s):
3a7abf6
add decorator spaces.GPU
Browse files- core/bark/generate_audio.py +2 -1
- event_handlers.py +7 -7
core/bark/generate_audio.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import sys
|
2 |
import logging
|
3 |
from typing_extensions import Union, List
|
@@ -21,7 +22,7 @@ logging.basicConfig(
|
|
21 |
)
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
-
|
25 |
def generate_audio(
|
26 |
texts: List[str],
|
27 |
prompt: Union[BarkPrompt, None] = None,
|
|
|
1 |
+
import spaces
|
2 |
import sys
|
3 |
import logging
|
4 |
from typing_extensions import Union, List
|
|
|
22 |
)
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
25 |
+
@spaces.GPU
|
26 |
def generate_audio(
|
27 |
texts: List[str],
|
28 |
prompt: Union[BarkPrompt, None] = None,
|
event_handlers.py
CHANGED
@@ -31,17 +31,17 @@ logger = logging.getLogger(__name__)
|
|
31 |
|
32 |
# return list of available devices and the best device to be used as default for all inference
|
33 |
def get_available_torch_devices() -> Tuple[List[str], str]:
|
34 |
-
devices = ["cpu"]
|
35 |
-
best_device = "cpu"
|
36 |
# if torch.backend.mps.is_available():
|
37 |
# devices.append("mps")
|
38 |
# best_device = "mps"
|
39 |
-
if torch.cuda.is_available():
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
return devices, best_device
|
44 |
|
|
|
|
|
45 |
|
46 |
# --- Helper Functions ---
|
47 |
# (Keep get_wav_duration, load_existing_audio, get_safe_filename,
|
|
|
31 |
|
32 |
# return list of available devices and the best device to be used as default for all inference
|
33 |
def get_available_torch_devices() -> Tuple[List[str], str]:
|
34 |
+
# devices = ["cpu"]
|
35 |
+
# best_device = "cpu"
|
36 |
# if torch.backend.mps.is_available():
|
37 |
# devices.append("mps")
|
38 |
# best_device = "mps"
|
39 |
+
# if torch.cuda.is_available():
|
40 |
+
# devices.append("cuda")
|
41 |
+
# best_device = "cuda"
|
|
|
|
|
42 |
|
43 |
+
# return devices, best_device
|
44 |
+
return ["cuda"], "cuda"
|
45 |
|
46 |
# --- Helper Functions ---
|
47 |
# (Keep get_wav_duration, load_existing_audio, get_safe_filename,
|