add ckpt load opt. for .safetensor
Browse files- model/utils.py +25 -0
- requirements.txt +1 -0
- test_infer_batch.py +3 -8
- test_infer_single.py +3 -8
- test_infer_single_edit.py +3 -8
model/utils.py
CHANGED
@@ -545,3 +545,28 @@ def repetition_found(text, length = 2, tolerance = 10):
|
|
545 |
if count > tolerance:
|
546 |
return True
|
547 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
if count > tolerance:
|
546 |
return True
|
547 |
return False
|
548 |
+
|
549 |
+
|
550 |
+
# load model checkpoint for inference
|
551 |
+
|
552 |
+
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
553 |
+
from ema_pytorch import EMA
|
554 |
+
|
555 |
+
ckpt_type = ckpt_path.split(".")[-1]
|
556 |
+
if ckpt_type == "safetensors":
|
557 |
+
from safetensors.torch import load_file
|
558 |
+
checkpoint = load_file(ckpt_path, device=device)
|
559 |
+
else:
|
560 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
561 |
+
|
562 |
+
if use_ema == True:
|
563 |
+
ema_model = EMA(model, include_online_model = False).to(device)
|
564 |
+
if ckpt_type == "safetensors":
|
565 |
+
ema_model.load_state_dict(checkpoint)
|
566 |
+
else:
|
567 |
+
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
568 |
+
ema_model.copy_params_from_ema_to_model()
|
569 |
+
else:
|
570 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
571 |
+
|
572 |
+
return model
|
requirements.txt
CHANGED
@@ -10,6 +10,7 @@ jiwer
|
|
10 |
librosa
|
11 |
matplotlib
|
12 |
pypinyin
|
|
|
13 |
# torch>=2.0
|
14 |
# torchaudio>=2.3.0
|
15 |
torchdiffeq
|
|
|
10 |
librosa
|
11 |
matplotlib
|
12 |
pypinyin
|
13 |
+
safetensors
|
14 |
# torch>=2.0
|
15 |
# torchaudio>=2.3.0
|
16 |
torchdiffeq
|
test_infer_batch.py
CHANGED
@@ -8,11 +8,11 @@ import torch
|
|
8 |
import torchaudio
|
9 |
from accelerate import Accelerator
|
10 |
from einops import rearrange
|
11 |
-
from ema_pytorch import EMA
|
12 |
from vocos import Vocos
|
13 |
|
14 |
from model import CFM, UNetT, DiT
|
15 |
from model.utils import (
|
|
|
16 |
get_tokenizer,
|
17 |
get_seedtts_testset_metainfo,
|
18 |
get_librispeech_test_clean_metainfo,
|
@@ -55,7 +55,7 @@ seed = args.seed
|
|
55 |
dataset_name = args.dataset
|
56 |
exp_name = args.expname
|
57 |
ckpt_step = args.ckptstep
|
58 |
-
|
59 |
|
60 |
nfe_step = args.nfestep
|
61 |
ode_method = args.odemethod
|
@@ -152,12 +152,7 @@ model = CFM(
|
|
152 |
vocab_char_map = vocab_char_map,
|
153 |
).to(device)
|
154 |
|
155 |
-
|
156 |
-
ema_model = EMA(model, include_online_model = False).to(device)
|
157 |
-
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
158 |
-
ema_model.copy_params_from_ema_to_model()
|
159 |
-
else:
|
160 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
161 |
|
162 |
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
163 |
os.makedirs(output_dir)
|
|
|
8 |
import torchaudio
|
9 |
from accelerate import Accelerator
|
10 |
from einops import rearrange
|
|
|
11 |
from vocos import Vocos
|
12 |
|
13 |
from model import CFM, UNetT, DiT
|
14 |
from model.utils import (
|
15 |
+
load_checkpoint,
|
16 |
get_tokenizer,
|
17 |
get_seedtts_testset_metainfo,
|
18 |
get_librispeech_test_clean_metainfo,
|
|
|
55 |
dataset_name = args.dataset
|
56 |
exp_name = args.expname
|
57 |
ckpt_step = args.ckptstep
|
58 |
+
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
|
59 |
|
60 |
nfe_step = args.nfestep
|
61 |
ode_method = args.odemethod
|
|
|
152 |
vocab_char_map = vocab_char_map,
|
153 |
).to(device)
|
154 |
|
155 |
+
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
158 |
os.makedirs(output_dir)
|
test_infer_single.py
CHANGED
@@ -4,11 +4,11 @@ import re
|
|
4 |
import torch
|
5 |
import torchaudio
|
6 |
from einops import rearrange
|
7 |
-
from ema_pytorch import EMA
|
8 |
from vocos import Vocos
|
9 |
|
10 |
from model import CFM, UNetT, DiT, MMDiT
|
11 |
from model.utils import (
|
|
|
12 |
get_tokenizer,
|
13 |
convert_char_to_pinyin,
|
14 |
save_spectrogram,
|
@@ -50,7 +50,7 @@ elif exp_name == "E2TTS_Base":
|
|
50 |
model_cls = UNetT
|
51 |
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
52 |
|
53 |
-
|
54 |
output_dir = "tests"
|
55 |
|
56 |
ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
|
@@ -101,12 +101,7 @@ model = CFM(
|
|
101 |
vocab_char_map = vocab_char_map,
|
102 |
).to(device)
|
103 |
|
104 |
-
|
105 |
-
ema_model = EMA(model, include_online_model = False).to(device)
|
106 |
-
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
107 |
-
ema_model.copy_params_from_ema_to_model()
|
108 |
-
else:
|
109 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
110 |
|
111 |
# Audio
|
112 |
audio, sr = torchaudio.load(ref_audio)
|
|
|
4 |
import torch
|
5 |
import torchaudio
|
6 |
from einops import rearrange
|
|
|
7 |
from vocos import Vocos
|
8 |
|
9 |
from model import CFM, UNetT, DiT, MMDiT
|
10 |
from model.utils import (
|
11 |
+
load_checkpoint,
|
12 |
get_tokenizer,
|
13 |
convert_char_to_pinyin,
|
14 |
save_spectrogram,
|
|
|
50 |
model_cls = UNetT
|
51 |
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
52 |
|
53 |
+
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
|
54 |
output_dir = "tests"
|
55 |
|
56 |
ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
|
|
|
101 |
vocab_char_map = vocab_char_map,
|
102 |
).to(device)
|
103 |
|
104 |
+
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
# Audio
|
107 |
audio, sr = torchaudio.load(ref_audio)
|
test_infer_single_edit.py
CHANGED
@@ -4,11 +4,11 @@ import torch
|
|
4 |
import torch.nn.functional as F
|
5 |
import torchaudio
|
6 |
from einops import rearrange
|
7 |
-
from ema_pytorch import EMA
|
8 |
from vocos import Vocos
|
9 |
|
10 |
from model import CFM, UNetT, DiT, MMDiT
|
11 |
from model.utils import (
|
|
|
12 |
get_tokenizer,
|
13 |
convert_char_to_pinyin,
|
14 |
save_spectrogram,
|
@@ -49,7 +49,7 @@ elif exp_name == "E2TTS_Base":
|
|
49 |
model_cls = UNetT
|
50 |
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
51 |
|
52 |
-
|
53 |
output_dir = "tests"
|
54 |
|
55 |
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
@@ -112,12 +112,7 @@ model = CFM(
|
|
112 |
vocab_char_map = vocab_char_map,
|
113 |
).to(device)
|
114 |
|
115 |
-
|
116 |
-
ema_model = EMA(model, include_online_model = False).to(device)
|
117 |
-
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
118 |
-
ema_model.copy_params_from_ema_to_model()
|
119 |
-
else:
|
120 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
121 |
|
122 |
# Audio
|
123 |
audio, sr = torchaudio.load(audio_to_edit)
|
|
|
4 |
import torch.nn.functional as F
|
5 |
import torchaudio
|
6 |
from einops import rearrange
|
|
|
7 |
from vocos import Vocos
|
8 |
|
9 |
from model import CFM, UNetT, DiT, MMDiT
|
10 |
from model.utils import (
|
11 |
+
load_checkpoint,
|
12 |
get_tokenizer,
|
13 |
convert_char_to_pinyin,
|
14 |
save_spectrogram,
|
|
|
49 |
model_cls = UNetT
|
50 |
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
|
51 |
|
52 |
+
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
|
53 |
output_dir = "tests"
|
54 |
|
55 |
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
|
|
112 |
vocab_char_map = vocab_char_map,
|
113 |
).to(device)
|
114 |
|
115 |
+
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
# Audio
|
118 |
audio, sr = torchaudio.load(audio_to_edit)
|