Spaces:
Configuration error
Configuration error
98440
commited on
Commit
·
9640640
1
Parent(s):
9e51878
Added intel XPU support
Browse files- README.md +4 -0
- src/f5_tts/api.py +9 -1
- src/f5_tts/eval/eval_utmos.py +1 -1
- src/f5_tts/infer/speech_edit.py +9 -1
- src/f5_tts/infer/utils_infer.py +9 -1
- src/f5_tts/socket_server.py +7 -1
- src/f5_tts/train/finetune_gradio.py +30 -1
README.md
CHANGED
@@ -32,6 +32,10 @@ pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https:/
|
|
32 |
|
33 |
# AMD GPU: install pytorch with your ROCm version, e.g.
|
34 |
pip install torch==2.5.1+rocm6.2 torchaudio==2.5.1+rocm6.2 --extra-index-url https://download.pytorch.org/whl/rocm6.2
|
|
|
|
|
|
|
|
|
35 |
```
|
36 |
|
37 |
Then you can choose from a few options below:
|
|
|
32 |
|
33 |
# AMD GPU: install pytorch with your ROCm version, e.g.
|
34 |
pip install torch==2.5.1+rocm6.2 torchaudio==2.5.1+rocm6.2 --extra-index-url https://download.pytorch.org/whl/rocm6.2
|
35 |
+
|
36 |
+
# intel GPU: install pytorch with your XPU version, e.g.
|
37 |
+
# Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit must be installed
|
38 |
+
pip install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu
|
39 |
```
|
40 |
|
41 |
Then you can choose from a few options below:
|
src/f5_tts/api.py
CHANGED
@@ -47,7 +47,15 @@ class F5TTS:
|
|
47 |
else:
|
48 |
import torch
|
49 |
|
50 |
-
self.device =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
# Load models
|
53 |
self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
|
|
|
47 |
else:
|
48 |
import torch
|
49 |
|
50 |
+
self.device = (
|
51 |
+
"cuda"
|
52 |
+
if torch.cuda.is_available()
|
53 |
+
else "xpu"
|
54 |
+
if torch.xpu.is_available()
|
55 |
+
else "mps"
|
56 |
+
if torch.backends.mps.is_available()
|
57 |
+
else "cpu"
|
58 |
+
)
|
59 |
|
60 |
# Load models
|
61 |
self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
|
src/f5_tts/eval/eval_utmos.py
CHANGED
@@ -13,7 +13,7 @@ def main():
|
|
13 |
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
|
14 |
args = parser.parse_args()
|
15 |
|
16 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
18 |
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
|
19 |
predictor = predictor.to(device)
|
|
|
13 |
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
|
14 |
args = parser.parse_args()
|
15 |
|
16 |
+
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
|
17 |
|
18 |
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
|
19 |
predictor = predictor.to(device)
|
src/f5_tts/infer/speech_edit.py
CHANGED
@@ -10,7 +10,15 @@ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectro
|
|
10 |
from f5_tts.model import CFM, DiT, UNetT
|
11 |
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
12 |
|
13 |
-
device =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
# --------------------- Dataset Settings -------------------- #
|
|
|
10 |
from f5_tts.model import CFM, DiT, UNetT
|
11 |
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
12 |
|
13 |
+
device = (
|
14 |
+
"cuda"
|
15 |
+
if torch.cuda.is_available()
|
16 |
+
else "xpu"
|
17 |
+
if torch.xpu.is_available()
|
18 |
+
else "mps"
|
19 |
+
if torch.backends.mps.is_available()
|
20 |
+
else "cpu"
|
21 |
+
)
|
22 |
|
23 |
|
24 |
# --------------------- Dataset Settings -------------------- #
|
src/f5_tts/infer/utils_infer.py
CHANGED
@@ -33,7 +33,15 @@ from f5_tts.model.utils import (
|
|
33 |
|
34 |
_ref_audio_cache = {}
|
35 |
|
36 |
-
device =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
# -----------------------------------------
|
39 |
|
|
|
33 |
|
34 |
_ref_audio_cache = {}
|
35 |
|
36 |
+
device = (
|
37 |
+
"cuda"
|
38 |
+
if torch.cuda.is_available()
|
39 |
+
else "xpu"
|
40 |
+
if torch.xpu.is_available()
|
41 |
+
else "mps"
|
42 |
+
if torch.backends.mps.is_available()
|
43 |
+
else "cpu"
|
44 |
+
)
|
45 |
|
46 |
# -----------------------------------------
|
47 |
|
src/f5_tts/socket_server.py
CHANGED
@@ -17,7 +17,13 @@ from model.backbones.dit import DiT
|
|
17 |
class TTSStreamingProcessor:
|
18 |
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
|
19 |
self.device = device or (
|
20 |
-
"cuda"
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
)
|
22 |
|
23 |
# Load the model using the provided checkpoint and vocab files
|
|
|
17 |
class TTSStreamingProcessor:
|
18 |
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
|
19 |
self.device = device or (
|
20 |
+
"cuda"
|
21 |
+
if torch.cuda.is_available()
|
22 |
+
else "xpu"
|
23 |
+
if torch.xpu.is_available()
|
24 |
+
else "mps"
|
25 |
+
if torch.backends.mps.is_available()
|
26 |
+
else "cpu"
|
27 |
)
|
28 |
|
29 |
# Load the model using the provided checkpoint and vocab files
|
src/f5_tts/train/finetune_gradio.py
CHANGED
@@ -46,7 +46,15 @@ path_data = str(files("f5_tts").joinpath("../../data"))
|
|
46 |
path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
|
47 |
file_train = str(files("f5_tts").joinpath("train/finetune_cli.py"))
|
48 |
|
49 |
-
device =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
|
52 |
# Save settings from a JSON file
|
@@ -889,6 +897,13 @@ def calculate_train(
|
|
889 |
gpu_properties = torch.cuda.get_device_properties(i)
|
890 |
total_memory += gpu_properties.total_memory / (1024**3) # in GB
|
891 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
892 |
elif torch.backends.mps.is_available():
|
893 |
gpu_count = 1
|
894 |
total_memory = psutil.virtual_memory().available / (1024**3)
|
@@ -1284,7 +1299,21 @@ def get_gpu_stats():
|
|
1284 |
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
|
1285 |
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
|
1286 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1288 |
elif torch.backends.mps.is_available():
|
1289 |
gpu_count = 1
|
1290 |
gpu_stats += "MPS GPU\n"
|
|
|
46 |
path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
|
47 |
file_train = str(files("f5_tts").joinpath("train/finetune_cli.py"))
|
48 |
|
49 |
+
device = (
|
50 |
+
"cuda"
|
51 |
+
if torch.cuda.is_available()
|
52 |
+
else "xpu"
|
53 |
+
if torch.xpu.is_available()
|
54 |
+
else "mps"
|
55 |
+
if torch.backends.mps.is_available()
|
56 |
+
else "cpu"
|
57 |
+
)
|
58 |
|
59 |
|
60 |
# Save settings from a JSON file
|
|
|
897 |
gpu_properties = torch.cuda.get_device_properties(i)
|
898 |
total_memory += gpu_properties.total_memory / (1024**3) # in GB
|
899 |
|
900 |
+
elif torch.xpu.is_available():
|
901 |
+
gpu_count = torch.xpu.device_count()
|
902 |
+
total_memory = 0
|
903 |
+
for i in range(gpu_count):
|
904 |
+
gpu_properties = torch.xpu.get_device_properties(i)
|
905 |
+
total_memory += gpu_properties.total_memory / (1024**3)
|
906 |
+
|
907 |
elif torch.backends.mps.is_available():
|
908 |
gpu_count = 1
|
909 |
total_memory = psutil.virtual_memory().available / (1024**3)
|
|
|
1299 |
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
|
1300 |
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
|
1301 |
)
|
1302 |
+
elif torch.xpu.is_available():
|
1303 |
+
gpu_count = torch.xpu.device_count()
|
1304 |
+
for i in range(gpu_count):
|
1305 |
+
gpu_name = torch.xpu.get_device_name(i)
|
1306 |
+
gpu_properties = torch.xpu.get_device_properties(i)
|
1307 |
+
total_memory = gpu_properties.total_memory / (1024**3) # in GB
|
1308 |
+
allocated_memory = torch.xpu.memory_allocated(i) / (1024**2) # in MB
|
1309 |
+
reserved_memory = torch.xpu.memory_reserved(i) / (1024**2) # in MB
|
1310 |
|
1311 |
+
gpu_stats += (
|
1312 |
+
f"GPU {i} Name: {gpu_name}\n"
|
1313 |
+
f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
|
1314 |
+
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
|
1315 |
+
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
|
1316 |
+
)
|
1317 |
elif torch.backends.mps.is_available():
|
1318 |
gpu_count = 1
|
1319 |
gpu_stats += "MPS GPU\n"
|