bluenevus commited on
Commit
c06bbb8
·
verified ·
1 Parent(s): 425f9fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -53,8 +53,18 @@ e2tts = E2TTS(
53
  )
54
 
55
  # Load the pre-trained model
56
- state_dict = torch.load(model_path, map_location=torch.device('cpu'))
57
- e2tts.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
58
  e2tts.eval()
59
 
60
  def generate_podcast_script(content, duration):
 
53
  )
54
 
55
  # Load the pre-trained model
56
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
57
+ if 'model_state_dict' in checkpoint:
58
+ state_dict = checkpoint['model_state_dict']
59
+ elif 'ema_model_state_dict' in checkpoint:
60
+ state_dict = checkpoint['ema_model_state_dict']
61
+ else:
62
+ state_dict = checkpoint # Assume the checkpoint is the state dict itself
63
+
64
+ # Filter out unexpected keys
65
+ model_dict = e2tts.state_dict()
66
+ filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
67
+ e2tts.load_state_dict(filtered_state_dict, strict=False)
68
  e2tts.eval()
69
 
70
  def generate_podcast_script(content, duration):