ceymox commited on
Commit
a2fb11d
·
verified ·
1 Parent(s): 4ab43d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -14
app.py CHANGED
@@ -73,7 +73,6 @@ class TTSModelWrapper:
73
  def load_tts_model_with_retry(max_retries=3, retry_delay=5):
74
  global tts_model, tts_model_wrapper
75
 
76
- # First, check if model is already in cache
77
  print("Checking if TTS model is in cache...")
78
  try:
79
  cache_info = scan_cache_dir()
@@ -83,15 +82,15 @@ def load_tts_model_with_retry(max_retries=3, retry_delay=5):
83
  tts_model = AutoModel.from_pretrained(
84
  tts_repo_id,
85
  trust_remote_code=True,
86
- local_files_only=True
87
- ).to(device)
 
88
  tts_model_wrapper = TTSModelWrapper(tts_model)
89
  print("TTS model loaded from cache successfully!")
90
  return
91
  except Exception as e:
92
  print(f"Cache check failed: {e}")
93
 
94
- # If not in cache or cache check failed, try loading with retries
95
  for attempt in range(max_retries):
96
  try:
97
  print(f"Loading {tts_repo_id} model (attempt {attempt+1}/{max_retries})...")
@@ -100,21 +99,19 @@ def load_tts_model_with_retry(max_retries=3, retry_delay=5):
100
  trust_remote_code=True,
101
  revision="main",
102
  use_auth_token=HF_TOKEN,
103
- low_cpu_mem_usage=True
104
- ).to(device)
105
-
106
  tts_model_wrapper = TTSModelWrapper(tts_model)
107
  print(f"TTS model loaded successfully! Type: {type(tts_model)}")
108
- return # Success, exit function
109
-
110
  except Exception as e:
111
  print(f"⚠️ Attempt {attempt+1}/{max_retries} failed: {e}")
112
  if attempt < max_retries - 1:
113
  print(f"Waiting {retry_delay} seconds before retrying...")
114
  time.sleep(retry_delay)
115
- retry_delay *= 1.5 # Exponential backoff
116
 
117
- # If all attempts failed, try one last time with fallback options
118
  try:
119
  print("Trying with fallback options...")
120
  tts_model = AutoModel.from_pretrained(
@@ -124,14 +121,53 @@ def load_tts_model_with_retry(max_retries=3, retry_delay=5):
124
  local_files_only=False,
125
  use_auth_token=HF_TOKEN,
126
  force_download=False,
127
- resume_download=True
128
- ).to(device)
 
129
  tts_model_wrapper = TTSModelWrapper(tts_model)
130
  print("TTS model loaded with fallback options!")
131
  except Exception as e2:
132
  print(f"❌ All attempts to load TTS model failed: {e2}")
133
  print("Will continue without TTS model loaded.")
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def load_asr_model():
136
  global asr_model
137
  try:
@@ -362,7 +398,7 @@ def enhance_audio(audio_data):
362
 
363
  return audio_data
364
 
365
- def split_into_chunks(text, max_length=30):
366
  """Split text into smaller chunks based on punctuation and length"""
367
  # First split by sentences
368
  sentence_markers = ['.', '?', '!', ';', ':', '।', '॥']
 
73
  def load_tts_model_with_retry(max_retries=3, retry_delay=5):
74
  global tts_model, tts_model_wrapper
75
 
 
76
  print("Checking if TTS model is in cache...")
77
  try:
78
  cache_info = scan_cache_dir()
 
82
  tts_model = AutoModel.from_pretrained(
83
  tts_repo_id,
84
  trust_remote_code=True,
85
+ local_files_only=True,
86
+ device_map="auto" # <-- Use device_map instead of .to(device)
87
+ )
88
  tts_model_wrapper = TTSModelWrapper(tts_model)
89
  print("TTS model loaded from cache successfully!")
90
  return
91
  except Exception as e:
92
  print(f"Cache check failed: {e}")
93
 
 
94
  for attempt in range(max_retries):
95
  try:
96
  print(f"Loading {tts_repo_id} model (attempt {attempt+1}/{max_retries})...")
 
99
  trust_remote_code=True,
100
  revision="main",
101
  use_auth_token=HF_TOKEN,
102
+ low_cpu_mem_usage=True,
103
+ device_map="auto" # <-- Use device_map here as well
104
+ )
105
  tts_model_wrapper = TTSModelWrapper(tts_model)
106
  print(f"TTS model loaded successfully! Type: {type(tts_model)}")
107
+ return
 
108
  except Exception as e:
109
  print(f"⚠️ Attempt {attempt+1}/{max_retries} failed: {e}")
110
  if attempt < max_retries - 1:
111
  print(f"Waiting {retry_delay} seconds before retrying...")
112
  time.sleep(retry_delay)
113
+ retry_delay *= 1.5
114
 
 
115
  try:
116
  print("Trying with fallback options...")
117
  tts_model = AutoModel.from_pretrained(
 
121
  local_files_only=False,
122
  use_auth_token=HF_TOKEN,
123
  force_download=False,
124
+ resume_download=True,
125
+ device_map="auto" # <-- And here too
126
+ )
127
  tts_model_wrapper = TTSModelWrapper(tts_model)
128
  print("TTS model loaded with fallback options!")
129
  except Exception as e2:
130
  print(f"❌ All attempts to load TTS model failed: {e2}")
131
  print("Will continue without TTS model loaded.")
132
 
133
+ # Reduce chunk size for faster streaming and lower latency
134
+ def split_into_chunks(text, max_length=15): # Reduced from 30 to 15
135
+ sentence_markers = ['.', '?', '!', ';', ':', '।', '॥']
136
+ chunks = []
137
+ current = ""
138
+
139
+ for char in text:
140
+ current += char
141
+ if char in sentence_markers and current.strip():
142
+ chunks.append(current.strip())
143
+ current = ""
144
+
145
+ if current.strip():
146
+ chunks.append(current.strip())
147
+
148
+ final_chunks = []
149
+ for chunk in chunks:
150
+ if len(chunk) <= max_length:
151
+ final_chunks.append(chunk)
152
+ else:
153
+ comma_splits = chunk.split(',')
154
+ current_part = ""
155
+ for part in comma_splits:
156
+ if len(current_part) + len(part) <= max_length:
157
+ if current_part:
158
+ current_part += ","
159
+ current_part += part
160
+ else:
161
+ if current_part:
162
+ final_chunks.append(current_part.strip())
163
+ current_part = part
164
+ if current_part:
165
+ final_chunks.append(current_part.strip())
166
+
167
+ print(f"Split text into {len(final_chunks)} chunks")
168
+ return final_chunks
169
+ )
170
+
171
  def load_asr_model():
172
  global asr_model
173
  try:
 
398
 
399
  return audio_data
400
 
401
+ def split_into_chunks(text, max_length=20):
402
  """Split text into smaller chunks based on punctuation and length"""
403
  # First split by sentences
404
  sentence_markers = ['.', '?', '!', ';', ':', '।', '॥']