Spaces:
Running
on
Zero
Running
on
Zero
initial commit
Browse files
app.py
CHANGED
@@ -393,42 +393,51 @@ TTS_MODEL_HUB_ID = "MoHamdyy/transformer-tts-ljspeech"
|
|
393 |
ASR_HUB_ID = "MoHamdyy/whisper-stt-model"
|
394 |
MARIAN_HUB_ID = "MoHamdyy/marian-ar-en-translation"
|
395 |
|
396 |
-
|
397 |
-
stt_processor = None
|
398 |
-
stt_model = None
|
399 |
-
mt_tokenizer = None
|
400 |
-
mt_model = None
|
401 |
|
402 |
# Wrap model loading in a function to clearly see when it happens or to potentially delay it.
|
403 |
# For Spaces, global loading is fine and preferred as it happens once.
|
404 |
print("--- Starting Model Loading ---")
|
405 |
try:
|
406 |
-
print(
|
|
|
407 |
tts_model_path = hf_hub_download(repo_id=TTS_MODEL_HUB_ID, filename="train_SimpleTransfromerTTS.pt")
|
408 |
-
state = torch.load(tts_model_path, map_location=DEVICE)
|
409 |
-
TTS_MODEL = TransformerTTS(
|
410 |
-
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
412 |
TTS_MODEL.eval()
|
413 |
print("TTS model loaded successfully.")
|
414 |
except Exception as e:
|
415 |
print(f"Error loading TTS model: {e}")
|
|
|
416 |
|
|
|
417 |
try:
|
418 |
-
print(
|
419 |
stt_processor = WhisperProcessor.from_pretrained(ASR_HUB_ID)
|
420 |
stt_model = WhisperForConditionalGeneration.from_pretrained(ASR_HUB_ID).to(DEVICE).eval()
|
421 |
print("STT model loaded successfully.")
|
422 |
except Exception as e:
|
423 |
print(f"Error loading STT model: {e}")
|
|
|
|
|
424 |
|
|
|
425 |
try:
|
426 |
-
print(
|
427 |
mt_tokenizer = MarianTokenizer.from_pretrained(MARIAN_HUB_ID)
|
428 |
mt_model = MarianMTModel.from_pretrained(MARIAN_HUB_ID).to(DEVICE).eval()
|
429 |
print("TTT model loaded successfully.")
|
430 |
except Exception as e:
|
431 |
print(f"Error loading TTT model: {e}")
|
|
|
|
|
432 |
print("--- Model Loading Complete ---")
|
433 |
|
434 |
|
|
|
393 |
ASR_HUB_ID = "MoHamdyy/whisper-stt-model"
|
394 |
MARIAN_HUB_ID = "MoHamdyy/marian-ar-en-translation"
|
395 |
|
396 |
+
|
|
|
|
|
|
|
|
|
397 |
|
398 |
# Wrap model loading in a function to clearly see when it happens or to potentially delay it.
|
399 |
# For Spaces, global loading is fine and preferred as it happens once.
|
400 |
print("--- Starting Model Loading ---")
|
401 |
try:
|
402 |
+
print("Loading TTS model...")
|
403 |
+
# Download the .pt file from its repo
|
404 |
tts_model_path = hf_hub_download(repo_id=TTS_MODEL_HUB_ID, filename="train_SimpleTransfromerTTS.pt")
|
405 |
+
state = torch.load(tts_model_path, map_location=DEVICE)
|
406 |
+
TTS_MODEL = TransformerTTS().to(DEVICE)
|
407 |
+
# Check for the correct key in the state dictionary
|
408 |
+
if "model" in state:
|
409 |
+
TTS_MODEL.load_state_dict(state["model"])
|
410 |
+
elif "state_dict" in state:
|
411 |
+
TTS_MODEL.load_state_dict(state["state_dict"])
|
412 |
+
else:
|
413 |
+
TTS_MODEL.load_state_dict(state) # Assume the whole file is the state_dict
|
414 |
TTS_MODEL.eval()
|
415 |
print("TTS model loaded successfully.")
|
416 |
except Exception as e:
|
417 |
print(f"Error loading TTS model: {e}")
|
418 |
+
TTS_MODEL = None
|
419 |
|
420 |
+
# Load STT (Whisper) Model from Hub
|
421 |
try:
|
422 |
+
print("Loading STT (Whisper) model...")
|
423 |
stt_processor = WhisperProcessor.from_pretrained(ASR_HUB_ID)
|
424 |
stt_model = WhisperForConditionalGeneration.from_pretrained(ASR_HUB_ID).to(DEVICE).eval()
|
425 |
print("STT model loaded successfully.")
|
426 |
except Exception as e:
|
427 |
print(f"Error loading STT model: {e}")
|
428 |
+
stt_processor = None
|
429 |
+
stt_model = None
|
430 |
|
431 |
+
# Load TTT (MarianMT) Model from Hub
|
432 |
try:
|
433 |
+
print("Loading TTT (MarianMT) model...")
|
434 |
mt_tokenizer = MarianTokenizer.from_pretrained(MARIAN_HUB_ID)
|
435 |
mt_model = MarianMTModel.from_pretrained(MARIAN_HUB_ID).to(DEVICE).eval()
|
436 |
print("TTT model loaded successfully.")
|
437 |
except Exception as e:
|
438 |
print(f"Error loading TTT model: {e}")
|
439 |
+
mt_tokenizer = None
|
440 |
+
mt_model = None
|
441 |
print("--- Model Loading Complete ---")
|
442 |
|
443 |
|