sudhanm commited on
Commit
3940c6b
·
verified ·
1 Parent(s): c2ad75f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  from transformers import WhisperForConditionalGeneration, WhisperProcessor
8
  from indic_transliteration import sanscript
9
  from indic_transliteration.sanscript import transliterate
 
10
 
11
  # ---------------- CONFIG ---------------- #
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -63,17 +64,18 @@ SENTENCE_BANK = {
63
  ]
64
  }
65
 
66
- # ---------------- LOAD MODELS ---------------- #
67
- print("Loading Whisper models...")
68
  whisper_models = {}
69
  whisper_processors = {}
70
 
71
- for lang, model_id in MODEL_CONFIGS.items():
72
- print(f"Loading {lang} model: {model_id}")
73
- whisper_models[lang] = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE)
74
- whisper_processors[lang] = WhisperProcessor.from_pretrained(model_id)
75
-
76
- print("All models loaded successfully!")
 
 
77
 
78
  # ---------------- HELPERS ---------------- #
79
  def get_random_sentence(language_choice):
@@ -91,7 +93,11 @@ def transliterate_to_hk(text, lang_choice):
91
  }
92
  return transliterate(text, mapping[lang_choice], sanscript.HK) if mapping[lang_choice] else text
93
 
 
94
  def transcribe_once(audio_path, language_choice, initial_prompt, beam_size, temperature, condition_on_previous_text):
 
 
 
95
  # Get the appropriate model and processor for the language
96
  model = whisper_models[language_choice]
97
  processor = whisper_processors[language_choice]
@@ -151,6 +157,7 @@ def char_level_highlight(ref, hyp):
151
  return "".join(out)
152
 
153
  # ---------------- MAIN ---------------- #
 
154
  def compare_pronunciation(audio, language_choice, intended_sentence,
155
  pass1_beam, pass1_temp, pass1_condition):
156
  if audio is None or not intended_sentence.strip():
 
7
  from transformers import WhisperForConditionalGeneration, WhisperProcessor
8
  from indic_transliteration import sanscript
9
  from indic_transliteration.sanscript import transliterate
10
+ import spaces
11
 
12
  # ---------------- CONFIG ---------------- #
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
64
  ]
65
  }
66
 
67
+ # Global variables for models (will be loaded lazily)
 
68
  whisper_models = {}
69
  whisper_processors = {}
70
 
71
+ def load_model(language_choice):
72
+ """Load model for specific language if not already loaded"""
73
+ if language_choice not in whisper_models:
74
+ model_id = MODEL_CONFIGS[language_choice]
75
+ print(f"Loading {language_choice} model: {model_id}")
76
+ whisper_models[language_choice] = WhisperForConditionalGeneration.from_pretrained(model_id).to(DEVICE)
77
+ whisper_processors[language_choice] = WhisperProcessor.from_pretrained(model_id)
78
+ print(f"{language_choice} model loaded successfully!")
79
 
80
  # ---------------- HELPERS ---------------- #
81
  def get_random_sentence(language_choice):
 
93
  }
94
  return transliterate(text, mapping[lang_choice], sanscript.HK) if mapping[lang_choice] else text
95
 
96
+ @spaces.GPU
97
  def transcribe_once(audio_path, language_choice, initial_prompt, beam_size, temperature, condition_on_previous_text):
98
+ # Load model if not already loaded
99
+ load_model(language_choice)
100
+
101
  # Get the appropriate model and processor for the language
102
  model = whisper_models[language_choice]
103
  processor = whisper_processors[language_choice]
 
157
  return "".join(out)
158
 
159
  # ---------------- MAIN ---------------- #
160
+ @spaces.GPU
161
  def compare_pronunciation(audio, language_choice, intended_sentence,
162
  pass1_beam, pass1_temp, pass1_condition):
163
  if audio is None or not intended_sentence.strip():