MoHamdyy commited on
Commit
772925a
·
1 Parent(s): 9e8a757

initial commit

Browse files
Files changed (1) hide show
  1. app.py +21 -12
app.py CHANGED
@@ -393,42 +393,51 @@ TTS_MODEL_HUB_ID = "MoHamdyy/transformer-tts-ljspeech"
393
  ASR_HUB_ID = "MoHamdyy/whisper-stt-model"
394
  MARIAN_HUB_ID = "MoHamdyy/marian-ar-en-translation"
395
 
396
- TTS_MODEL = None
397
- stt_processor = None
398
- stt_model = None
399
- mt_tokenizer = None
400
- mt_model = None
401
 
402
  # Wrap model loading in a function to clearly see when it happens or to potentially delay it.
403
  # For Spaces, global loading is fine and preferred as it happens once.
404
  print("--- Starting Model Loading ---")
405
  try:
406
- print(f"Loading TTS model ({TTS_MODEL_HUB_ID}) to {DEVICE}...")
 
407
  tts_model_path = hf_hub_download(repo_id=TTS_MODEL_HUB_ID, filename="train_SimpleTransfromerTTS.pt")
408
- state = torch.load(tts_model_path, map_location=DEVICE) # Load to target device directly
409
- TTS_MODEL = TransformerTTS(device=DEVICE).to(DEVICE)
410
- model_state_dict = state.get("model", state.get("state_dict", state))
411
- TTS_MODEL.load_state_dict(model_state_dict)
 
 
 
 
 
412
  TTS_MODEL.eval()
413
  print("TTS model loaded successfully.")
414
  except Exception as e:
415
  print(f"Error loading TTS model: {e}")
 
416
 
 
417
  try:
418
- print(f"Loading STT (Whisper) model ({ASR_HUB_ID}) to {DEVICE}...")
419
  stt_processor = WhisperProcessor.from_pretrained(ASR_HUB_ID)
420
  stt_model = WhisperForConditionalGeneration.from_pretrained(ASR_HUB_ID).to(DEVICE).eval()
421
  print("STT model loaded successfully.")
422
  except Exception as e:
423
  print(f"Error loading STT model: {e}")
 
 
424
 
 
425
  try:
426
- print(f"Loading TTT (MarianMT) model ({MARIAN_HUB_ID}) to {DEVICE}...")
427
  mt_tokenizer = MarianTokenizer.from_pretrained(MARIAN_HUB_ID)
428
  mt_model = MarianMTModel.from_pretrained(MARIAN_HUB_ID).to(DEVICE).eval()
429
  print("TTT model loaded successfully.")
430
  except Exception as e:
431
  print(f"Error loading TTT model: {e}")
 
 
432
  print("--- Model Loading Complete ---")
433
 
434
 
 
393
  ASR_HUB_ID = "MoHamdyy/whisper-stt-model"
394
  MARIAN_HUB_ID = "MoHamdyy/marian-ar-en-translation"
395
 
396
+
 
 
 
 
397
 
398
  # Wrap model loading in a function to clearly see when it happens or to potentially delay it.
399
  # For Spaces, global loading is fine and preferred as it happens once.
400
  print("--- Starting Model Loading ---")
401
  try:
402
+ print("Loading TTS model...")
403
+ # Download the .pt file from its repo
404
  tts_model_path = hf_hub_download(repo_id=TTS_MODEL_HUB_ID, filename="train_SimpleTransfromerTTS.pt")
405
+ state = torch.load(tts_model_path, map_location=DEVICE)
406
+ TTS_MODEL = TransformerTTS().to(DEVICE)
407
+ # Check for the correct key in the state dictionary
408
+ if "model" in state:
409
+ TTS_MODEL.load_state_dict(state["model"])
410
+ elif "state_dict" in state:
411
+ TTS_MODEL.load_state_dict(state["state_dict"])
412
+ else:
413
+ TTS_MODEL.load_state_dict(state) # Assume the whole file is the state_dict
414
  TTS_MODEL.eval()
415
  print("TTS model loaded successfully.")
416
  except Exception as e:
417
  print(f"Error loading TTS model: {e}")
418
+ TTS_MODEL = None
419
 
420
+ # Load STT (Whisper) Model from Hub
421
  try:
422
+ print("Loading STT (Whisper) model...")
423
  stt_processor = WhisperProcessor.from_pretrained(ASR_HUB_ID)
424
  stt_model = WhisperForConditionalGeneration.from_pretrained(ASR_HUB_ID).to(DEVICE).eval()
425
  print("STT model loaded successfully.")
426
  except Exception as e:
427
  print(f"Error loading STT model: {e}")
428
+ stt_processor = None
429
+ stt_model = None
430
 
431
+ # Load TTT (MarianMT) Model from Hub
432
  try:
433
+ print("Loading TTT (MarianMT) model...")
434
  mt_tokenizer = MarianTokenizer.from_pretrained(MARIAN_HUB_ID)
435
  mt_model = MarianMTModel.from_pretrained(MARIAN_HUB_ID).to(DEVICE).eval()
436
  print("TTT model loaded successfully.")
437
  except Exception as e:
438
  print(f"Error loading TTT model: {e}")
439
+ mt_tokenizer = None
440
+ mt_model = None
441
  print("--- Model Loading Complete ---")
442
 
443