amine_dubs commited on
Commit
4e86ac5
·
1 Parent(s): aded6a5
Files changed (1) hide show
  1. backend/main.py +58 -7
backend/main.py CHANGED
@@ -11,6 +11,7 @@ import io
11
  import concurrent.futures
12
  import subprocess
13
  import sys
 
14
 
15
  # Import transformers for local model inference
16
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
@@ -55,14 +56,29 @@ os.environ['XDG_CACHE_HOME'] = '/tmp/cache'
55
  translator = None
56
  tokenizer = None
57
  model = None
 
 
 
 
58
 
59
  # --- Model initialization function ---
60
  def initialize_model():
61
  """Initialize the translation model and tokenizer."""
62
- global translator, tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  try:
65
- print("Initializing model and tokenizer...")
66
 
67
  # Use a smaller model that works well for instruction-based translation
68
  model_name = "google/flan-t5-small"
@@ -124,7 +140,9 @@ def initialize_model():
124
  if not test_result or not isinstance(test_result, list) or len(test_result) == 0:
125
  print("Model test failed: Invalid output format")
126
  return False
127
-
 
 
128
  print(f"Model {model_name} successfully initialized and tested")
129
  return True
130
  except Exception as inner_e:
@@ -143,9 +161,11 @@ def translate_text(text, source_lang, target_lang):
143
 
144
  print(f"Translation Request - Source Lang: {source_lang}, Target Lang: {target_lang}")
145
 
146
- if not model or not tokenizer:
 
147
  success = initialize_model()
148
  if not success:
 
149
  return use_fallback_translation(text, source_lang, target_lang)
150
 
151
  try:
@@ -175,11 +195,45 @@ def translate_text(text, source_lang, target_lang):
175
  except concurrent.futures.TimeoutError:
176
  print(f"Model inference timed out after 15 seconds, falling back to online translation")
177
  return use_fallback_translation(text, source_lang, target_lang)
 
 
 
 
 
 
 
178
  except Exception as e:
179
  print(f"Error using local model: {e}")
180
  traceback.print_exc()
181
  return use_fallback_translation(text, source_lang, target_lang)
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  def use_fallback_translation(text, source_lang, target_lang):
184
  """Use various fallback online translation services."""
185
  # List of LibreTranslate servers to try in order
@@ -254,7 +308,6 @@ async def extract_text_from_file(file: UploadFile) -> str:
254
  break
255
  except UnicodeDecodeError:
256
  continue
257
-
258
  elif file_extension == '.docx':
259
  try:
260
  import docx
@@ -266,7 +319,6 @@ async def extract_text_from_file(file: UploadFile) -> str:
266
  extracted_text = '\n'.join([para.text for para in doc.paragraphs])
267
  except ImportError:
268
  raise HTTPException(status_code=501, detail="DOCX processing requires 'python-docx' library")
269
-
270
  elif file_extension == '.pdf':
271
  try:
272
  import fitz # PyMuPDF
@@ -283,7 +335,6 @@ async def extract_text_from_file(file: UploadFile) -> str:
283
  doc.close()
284
  except ImportError:
285
  raise HTTPException(status_code=501, detail="PDF processing requires 'PyMuPDF' library")
286
-
287
  else:
288
  raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_extension}")
289
 
 
11
  import concurrent.futures
12
  import subprocess
13
  import sys
14
+ import time
15
 
16
  # Import transformers for local model inference
17
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
56
  translator = None
57
  tokenizer = None
58
  model = None
59
+ model_initialization_attempts = 0
60
+ max_model_initialization_attempts = 3
61
+ last_initialization_attempt = 0
62
+ initialization_cooldown = 300 # 5 minutes cooldown between retry attempts
63
 
64
  # --- Model initialization function ---
65
  def initialize_model():
66
  """Initialize the translation model and tokenizer."""
67
+ global translator, tokenizer, model, model_initialization_attempts, last_initialization_attempt
68
+
69
+ # Check if we've exceeded maximum attempts and if enough time has passed since last attempt
70
+ current_time = time.time()
71
+ if (model_initialization_attempts >= max_model_initialization_attempts and
72
+ current_time - last_initialization_attempt < initialization_cooldown):
73
+ print(f"Maximum initialization attempts reached. Waiting for cooldown period.")
74
+ return False
75
+
76
+ # Update attempt counter and timestamp
77
+ model_initialization_attempts += 1
78
+ last_initialization_attempt = current_time
79
 
80
  try:
81
+ print(f"Initializing model and tokenizer (attempt {model_initialization_attempts})...")
82
 
83
  # Use a smaller model that works well for instruction-based translation
84
  model_name = "google/flan-t5-small"
 
140
  if not test_result or not isinstance(test_result, list) or len(test_result) == 0:
141
  print("Model test failed: Invalid output format")
142
  return False
143
+
144
+ # Success - reset the attempt counter
145
+ model_initialization_attempts = 0
146
  print(f"Model {model_name} successfully initialized and tested")
147
  return True
148
  except Exception as inner_e:
 
161
 
162
  print(f"Translation Request - Source Lang: {source_lang}, Target Lang: {target_lang}")
163
 
164
+ # Check if model is initialized, if not try to initialize it
165
+ if not model or not tokenizer or not translator:
166
  success = initialize_model()
167
  if not success:
168
+ print("Local model initialization failed, using fallback translation")
169
  return use_fallback_translation(text, source_lang, target_lang)
170
 
171
  try:
 
195
  except concurrent.futures.TimeoutError:
196
  print(f"Model inference timed out after 15 seconds, falling back to online translation")
197
  return use_fallback_translation(text, source_lang, target_lang)
198
+ except Exception as e:
199
+ print(f"Error during model inference: {e}")
200
+
201
+ # If the model failed during inference, try to re-initialize it for next time
202
+ # but use fallback for this request
203
+ initialize_model()
204
+ return use_fallback_translation(text, source_lang, target_lang)
205
  except Exception as e:
206
  print(f"Error using local model: {e}")
207
  traceback.print_exc()
208
  return use_fallback_translation(text, source_lang, target_lang)
209
 
210
+ # --- Function to check model status and trigger re-initialization if needed ---
211
+ def check_and_reinitialize_model():
212
+ """Check if model needs to be reinitialized and do so if necessary"""
213
+ global translator, model, tokenizer
214
+
215
+ try:
216
+ # If model isn't initialized yet, try to initialize it
217
+ if not model or not tokenizer or not translator:
218
+ print("Model not initialized. Attempting initialization...")
219
+ return initialize_model()
220
+
221
+ # Test the existing model with a simple translation
222
+ test_text = "Translate from English to French: hello"
223
+ result = translator(test_text, max_length=128)
224
+
225
+ # If we got a valid result, model is working fine
226
+ if result and isinstance(result, list) and len(result) > 0:
227
+ print("Model check: Model is functioning correctly.")
228
+ return True
229
+ else:
230
+ print("Model check: Model returned invalid result. Reinitializing...")
231
+ return initialize_model()
232
+ except Exception as e:
233
+ print(f"Error checking model status: {e}")
234
+ print("Model may be in a bad state. Attempting reinitialization...")
235
+ return initialize_model()
236
+
237
  def use_fallback_translation(text, source_lang, target_lang):
238
  """Use various fallback online translation services."""
239
  # List of LibreTranslate servers to try in order
 
308
  break
309
  except UnicodeDecodeError:
310
  continue
 
311
  elif file_extension == '.docx':
312
  try:
313
  import docx
 
319
  extracted_text = '\n'.join([para.text for para in doc.paragraphs])
320
  except ImportError:
321
  raise HTTPException(status_code=501, detail="DOCX processing requires 'python-docx' library")
 
322
  elif file_extension == '.pdf':
323
  try:
324
  import fitz # PyMuPDF
 
335
  doc.close()
336
  except ImportError:
337
  raise HTTPException(status_code=501, detail="PDF processing requires 'PyMuPDF' library")
 
338
  else:
339
  raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_extension}")
340