Update app.py
Browse files
app.py
CHANGED
@@ -53,8 +53,18 @@ e2tts = E2TTS(
|
|
53 |
)
|
54 |
|
55 |
# Load the pre-trained model
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
e2tts.eval()
|
59 |
|
60 |
def generate_podcast_script(content, duration):
|
|
|
53 |
)
|
54 |
|
55 |
# Load the pre-trained model
|
56 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
|
57 |
+
if 'model_state_dict' in checkpoint:
|
58 |
+
state_dict = checkpoint['model_state_dict']
|
59 |
+
elif 'ema_model_state_dict' in checkpoint:
|
60 |
+
state_dict = checkpoint['ema_model_state_dict']
|
61 |
+
else:
|
62 |
+
state_dict = checkpoint # Assume the checkpoint is the state dict itself
|
63 |
+
|
64 |
+
# Filter out unexpected keys
|
65 |
+
model_dict = e2tts.state_dict()
|
66 |
+
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
|
67 |
+
e2tts.load_state_dict(filtered_state_dict, strict=False)
|
68 |
e2tts.eval()
|
69 |
|
70 |
def generate_podcast_script(content, duration):
|