SWivid commited on
Commit
cb28313
·
1 Parent(s): 7a1ca18

add ckpt load opt. for .safetensor

Browse files
model/utils.py CHANGED
@@ -545,3 +545,28 @@ def repetition_found(text, length = 2, tolerance = 10):
545
  if count > tolerance:
546
  return True
547
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  if count > tolerance:
546
  return True
547
  return False
548
+
549
+
550
+ # load model checkpoint for inference
551
+
552
+ def load_checkpoint(model, ckpt_path, device, use_ema = True):
553
+ from ema_pytorch import EMA
554
+
555
+ ckpt_type = ckpt_path.split(".")[-1]
556
+ if ckpt_type == "safetensors":
557
+ from safetensors.torch import load_file
558
+ checkpoint = load_file(ckpt_path, device=device)
559
+ else:
560
+ checkpoint = torch.load(ckpt_path, map_location=device)
561
+
562
+ if use_ema == True:
563
+ ema_model = EMA(model, include_online_model = False).to(device)
564
+ if ckpt_type == "safetensors":
565
+ ema_model.load_state_dict(checkpoint)
566
+ else:
567
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
568
+ ema_model.copy_params_from_ema_to_model()
569
+ else:
570
+ model.load_state_dict(checkpoint['model_state_dict'])
571
+
572
+ return model
requirements.txt CHANGED
@@ -10,6 +10,7 @@ jiwer
10
  librosa
11
  matplotlib
12
  pypinyin
 
13
  # torch>=2.0
14
  # torchaudio>=2.3.0
15
  torchdiffeq
 
10
  librosa
11
  matplotlib
12
  pypinyin
13
+ safetensors
14
  # torch>=2.0
15
  # torchaudio>=2.3.0
16
  torchdiffeq
test_infer_batch.py CHANGED
@@ -8,11 +8,11 @@ import torch
8
  import torchaudio
9
  from accelerate import Accelerator
10
  from einops import rearrange
11
- from ema_pytorch import EMA
12
  from vocos import Vocos
13
 
14
  from model import CFM, UNetT, DiT
15
  from model.utils import (
 
16
  get_tokenizer,
17
  get_seedtts_testset_metainfo,
18
  get_librispeech_test_clean_metainfo,
@@ -55,7 +55,7 @@ seed = args.seed
55
  dataset_name = args.dataset
56
  exp_name = args.expname
57
  ckpt_step = args.ckptstep
58
- checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
59
 
60
  nfe_step = args.nfestep
61
  ode_method = args.odemethod
@@ -152,12 +152,7 @@ model = CFM(
152
  vocab_char_map = vocab_char_map,
153
  ).to(device)
154
 
155
- if use_ema == True:
156
- ema_model = EMA(model, include_online_model = False).to(device)
157
- ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
158
- ema_model.copy_params_from_ema_to_model()
159
- else:
160
- model.load_state_dict(checkpoint['model_state_dict'])
161
 
162
  if not os.path.exists(output_dir) and accelerator.is_main_process:
163
  os.makedirs(output_dir)
 
8
  import torchaudio
9
  from accelerate import Accelerator
10
  from einops import rearrange
 
11
  from vocos import Vocos
12
 
13
  from model import CFM, UNetT, DiT
14
  from model.utils import (
15
+ load_checkpoint,
16
  get_tokenizer,
17
  get_seedtts_testset_metainfo,
18
  get_librispeech_test_clean_metainfo,
 
55
  dataset_name = args.dataset
56
  exp_name = args.expname
57
  ckpt_step = args.ckptstep
58
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
59
 
60
  nfe_step = args.nfestep
61
  ode_method = args.odemethod
 
152
  vocab_char_map = vocab_char_map,
153
  ).to(device)
154
 
155
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
 
 
 
 
 
156
 
157
  if not os.path.exists(output_dir) and accelerator.is_main_process:
158
  os.makedirs(output_dir)
test_infer_single.py CHANGED
@@ -4,11 +4,11 @@ import re
4
  import torch
5
  import torchaudio
6
  from einops import rearrange
7
- from ema_pytorch import EMA
8
  from vocos import Vocos
9
 
10
  from model import CFM, UNetT, DiT, MMDiT
11
  from model.utils import (
 
12
  get_tokenizer,
13
  convert_char_to_pinyin,
14
  save_spectrogram,
@@ -50,7 +50,7 @@ elif exp_name == "E2TTS_Base":
50
  model_cls = UNetT
51
  model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
52
 
53
- checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
54
  output_dir = "tests"
55
 
56
  ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
@@ -101,12 +101,7 @@ model = CFM(
101
  vocab_char_map = vocab_char_map,
102
  ).to(device)
103
 
104
- if use_ema == True:
105
- ema_model = EMA(model, include_online_model = False).to(device)
106
- ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
107
- ema_model.copy_params_from_ema_to_model()
108
- else:
109
- model.load_state_dict(checkpoint['model_state_dict'])
110
 
111
  # Audio
112
  audio, sr = torchaudio.load(ref_audio)
 
4
  import torch
5
  import torchaudio
6
  from einops import rearrange
 
7
  from vocos import Vocos
8
 
9
  from model import CFM, UNetT, DiT, MMDiT
10
  from model.utils import (
11
+ load_checkpoint,
12
  get_tokenizer,
13
  convert_char_to_pinyin,
14
  save_spectrogram,
 
50
  model_cls = UNetT
51
  model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
52
 
53
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
54
  output_dir = "tests"
55
 
56
  ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
 
101
  vocab_char_map = vocab_char_map,
102
  ).to(device)
103
 
104
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
 
 
 
 
 
105
 
106
  # Audio
107
  audio, sr = torchaudio.load(ref_audio)
test_infer_single_edit.py CHANGED
@@ -4,11 +4,11 @@ import torch
4
  import torch.nn.functional as F
5
  import torchaudio
6
  from einops import rearrange
7
- from ema_pytorch import EMA
8
  from vocos import Vocos
9
 
10
  from model import CFM, UNetT, DiT, MMDiT
11
  from model.utils import (
 
12
  get_tokenizer,
13
  convert_char_to_pinyin,
14
  save_spectrogram,
@@ -49,7 +49,7 @@ elif exp_name == "E2TTS_Base":
49
  model_cls = UNetT
50
  model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
51
 
52
- checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
53
  output_dir = "tests"
54
 
55
  # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
@@ -112,12 +112,7 @@ model = CFM(
112
  vocab_char_map = vocab_char_map,
113
  ).to(device)
114
 
115
- if use_ema == True:
116
- ema_model = EMA(model, include_online_model = False).to(device)
117
- ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
118
- ema_model.copy_params_from_ema_to_model()
119
- else:
120
- model.load_state_dict(checkpoint['model_state_dict'])
121
 
122
  # Audio
123
  audio, sr = torchaudio.load(audio_to_edit)
 
4
  import torch.nn.functional as F
5
  import torchaudio
6
  from einops import rearrange
 
7
  from vocos import Vocos
8
 
9
  from model import CFM, UNetT, DiT, MMDiT
10
  from model.utils import (
11
+ load_checkpoint,
12
  get_tokenizer,
13
  convert_char_to_pinyin,
14
  save_spectrogram,
 
49
  model_cls = UNetT
50
  model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
51
 
52
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
53
  output_dir = "tests"
54
 
55
  # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
 
112
  vocab_char_map = vocab_char_map,
113
  ).to(device)
114
 
115
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
 
 
 
 
 
116
 
117
  # Audio
118
  audio, sr = torchaudio.load(audio_to_edit)