amine_dubs commited on
Commit
f259de7
·
1 Parent(s): 0350bc5
Files changed (1) hide show
  1. backend/main.py +81 -133
backend/main.py CHANGED
@@ -8,6 +8,7 @@ import requests
8
  import json
9
  import traceback
10
  import io
 
11
 
12
  # Import transformers for local model inference
13
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
@@ -64,47 +65,33 @@ def initialize_model():
64
  # Use a smaller model that works well for instruction-based translation
65
  model_name = "google/flan-t5-small"
66
 
 
 
 
 
67
  # Load the tokenizer with explicit cache directory
68
  tokenizer = AutoTokenizer.from_pretrained(
69
  model_name,
70
  cache_dir="/tmp/transformers_cache"
71
  )
72
 
73
- # Check if TensorFlow and tf-keras are available
74
- tf_available = False
75
- try:
76
- import tensorflow
77
- # Try to import tf_keras which is the compatibility package
78
- try:
79
- import tf_keras
80
- print("tf-keras is installed, using TensorFlow with compatibility layer")
81
- tf_available = True
82
- except ImportError:
83
- print("tf-keras not found, will try to use PyTorch backend")
84
- print("TensorFlow is available, will use from_tf=True")
85
- except ImportError:
86
- print("TensorFlow is not installed, will use default PyTorch loading")
87
-
88
- # Load the model with appropriate settings based on TensorFlow availability
89
- print(f"Loading model {'with from_tf=True' if tf_available else 'with default PyTorch settings'}...")
90
  try:
91
- # First try with PyTorch approach which is more reliable
92
  model = AutoModelForSeq2SeqLM.from_pretrained(
93
  model_name,
94
- from_tf=False, # Use PyTorch first
95
- cache_dir="/tmp/transformers_cache"
 
96
  )
97
  except Exception as e:
98
  print(f"PyTorch loading failed: {e}")
99
- if tf_available:
100
- print("Attempting to load with TensorFlow...")
101
- model = AutoModelForSeq2SeqLM.from_pretrained(
102
- model_name,
103
- from_tf=True,
104
- cache_dir="/tmp/transformers_cache"
105
- )
106
- else:
107
- raise # Re-raise if we can't use TensorFlow either
108
 
109
  # Create a pipeline with the loaded model and tokenizer
110
  print("Creating pipeline with pre-loaded model...")
@@ -112,7 +99,7 @@ def initialize_model():
112
  "text2text-generation",
113
  model=model,
114
  tokenizer=tokenizer,
115
- device=-1, # Use CPU for compatibility (-1) or GPU if available (0)
116
  max_length=512
117
  )
118
 
@@ -124,135 +111,96 @@ def initialize_model():
124
  return False
125
 
126
  # --- Translation Function ---
127
- def translate_text_internal(text: str, source_lang: str, target_lang: str = "ar") -> str:
128
- """
129
- Translate text using local T5 model with prompt engineering
130
- """
131
- global translator
132
 
133
- if not text.strip():
134
- return ""
135
-
136
  print(f"Translation Request - Source Lang: {source_lang}, Target Lang: {target_lang}")
137
 
138
- # Get full language name for prompt
139
- source_lang_name = LANGUAGE_MAP.get(source_lang, source_lang)
140
-
141
- # Initialize the model if it hasn't been loaded yet
142
- if translator is None:
143
  success = initialize_model()
144
  if not success:
145
- print("Model initialization failed, falling back to online translation")
146
- return fallback_translate(text, source_lang, target_lang)
147
 
148
  try:
149
- # Construct our eloquent Arabic translation prompt
150
- prompt = f"""Translate the following {source_lang_name} text into Modern Standard Arabic (Fusha).
151
- Focus on conveying the meaning elegantly using proper Balagha (Arabic eloquence).
152
- Adapt any cultural references or idioms appropriately rather than translating literally.
153
- Ensure the translation reads naturally to a native Arabic speaker.
154
-
155
- Text to translate:
156
- {text}"""
157
-
158
- # Add timeout handling to prevent hanging
159
- import threading
160
- import queue
161
-
162
- def model_inference():
163
- try:
164
- outputs = translator(prompt, max_length=512, do_sample=False)
165
- result_queue.put(outputs)
166
- except Exception as e:
167
- result_queue.put(e)
168
-
169
- # Create a queue to get the result or exception
170
- result_queue = queue.Queue()
171
-
172
- # Start the translation in a separate thread
173
- thread = threading.Thread(target=model_inference)
174
- thread.daemon = True
175
- thread.start()
176
 
177
- # Wait for the result with a timeout (10 seconds)
178
- thread.join(timeout=10)
179
-
180
- # Check if the thread completed within the timeout
181
- if thread.is_alive():
182
- print("Model inference timed out after 10 seconds, falling back to online translation")
183
- return fallback_translate(text, source_lang, target_lang)
 
 
 
184
 
185
- # Get the result from the queue
186
- try:
187
- result = result_queue.get(block=False)
188
- if isinstance(result, Exception):
189
- raise result
190
 
191
- # Process the translation result
192
- if result and len(result) > 0:
193
- translated_text = result[0]['generated_text']
194
- print(f"Translation successful using transformers model")
195
- return culturally_adapt_arabic(translated_text)
196
- else:
197
- print("Model returned empty output")
198
- return fallback_translate(text, source_lang, target_lang)
199
- except queue.Empty:
200
- print("No result in queue despite thread completing")
201
- return fallback_translate(text, source_lang, target_lang)
202
-
203
  except Exception as e:
204
- print(f"Error in model translation: {e}")
205
  traceback.print_exc()
206
- return fallback_translate(text, source_lang, target_lang)
207
 
208
- def fallback_translate(text: str, source_lang: str, target_lang: str = "ar") -> str:
209
- """Fallback to online translation APIs if local model fails."""
210
- # Try LibreTranslate
211
- libre_translate_endpoints = [
212
  "https://translate.terraprint.co/translate",
213
  "https://libretranslate.de/translate",
214
- "https://translate.argosopentech.com/translate"
 
215
  ]
216
 
217
- for endpoint in libre_translate_endpoints:
 
218
  try:
219
- print(f"Attempting fallback translation using LibreTranslate: {endpoint}")
 
 
 
220
  payload = {
221
  "q": text,
222
- "source": source_lang if source_lang != "auto" else "auto",
223
- "target": target_lang,
224
- "format": "text"
225
  }
226
 
227
- response = requests.post(endpoint, json=payload, timeout=10)
 
228
 
229
  if response.status_code == 200:
230
  result = response.json()
231
- translated_text = result.get("translatedText")
232
-
233
- if translated_text:
234
- print(f"Translation successful using LibreTranslate {endpoint}")
235
- return culturally_adapt_arabic(translated_text)
236
  except Exception as e:
237
- print(f"Error with LibreTranslate {endpoint}: {e}")
 
238
 
239
- # If all else fails, use a simple English-Arabic dictionary for common phrases
240
- common_phrases = {
241
- "hello": "مرحبا",
242
- "thank you": "شكرا لك",
243
- "goodbye": "مع السلامة",
244
- "welcome": "أهلا وسهلا",
245
- "yes": "نعم",
246
- "no": "لا",
247
- "please": "من فضلك",
248
- "sorry": "آسف",
249
- }
250
-
251
- if text.lower().strip() in common_phrases:
252
- return common_phrases[text.lower().strip()]
253
 
254
- # Last resort message
255
- return "عذراً، لم نتمكن من ترجمة النص بسبب خطأ فني. الرجاء المحاولة لاحقاً."
256
 
257
  def culturally_adapt_arabic(text: str) -> str:
258
  """Apply post-processing rules to enhance Arabic translation with cultural sensitivity."""
@@ -338,7 +286,7 @@ async def translate_text_endpoint(
338
  raise HTTPException(status_code=400, detail="No text provided for translation.")
339
 
340
  try:
341
- translated_text = translate_text_internal(text, source_lang, target_lang)
342
  return JSONResponse(content={"translated_text": translated_text, "source_lang": source_lang})
343
  except Exception as e:
344
  print(f"Translation error: {e}")
@@ -360,7 +308,7 @@ async def translate_document_endpoint(
360
  raise HTTPException(status_code=400, detail="Could not extract any text from the document.")
361
 
362
  # Translate the extracted text
363
- translated_text = translate_text_internal(extracted_text, source_lang, target_lang)
364
 
365
  return JSONResponse(content={
366
  "original_filename": file.filename,
 
8
  import json
9
  import traceback
10
  import io
11
+ import concurrent.futures
12
 
13
  # Import transformers for local model inference
14
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
65
  # Use a smaller model that works well for instruction-based translation
66
  model_name = "google/flan-t5-small"
67
 
68
+ # Check for available device - properly detect CPU/GPU
69
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+ print(f"Device set to use {device}")
71
+
72
  # Load the tokenizer with explicit cache directory
73
  tokenizer = AutoTokenizer.from_pretrained(
74
  model_name,
75
  cache_dir="/tmp/transformers_cache"
76
  )
77
 
78
+ # Load the model with PyTorch approach which is more reliable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  try:
80
+ print("Loading model with PyTorch backend...")
81
  model = AutoModelForSeq2SeqLM.from_pretrained(
82
  model_name,
83
+ cache_dir="/tmp/transformers_cache",
84
+ low_cpu_mem_usage=True, # Add this for better memory usage
85
+ device_map="auto" # Let the library decide optimal device mapping
86
  )
87
  except Exception as e:
88
  print(f"PyTorch loading failed: {e}")
89
+ print("Attempting to load with TensorFlow...")
90
+ model = AutoModelForSeq2SeqLM.from_pretrained(
91
+ model_name,
92
+ from_tf=True,
93
+ cache_dir="/tmp/transformers_cache"
94
+ )
 
 
 
95
 
96
  # Create a pipeline with the loaded model and tokenizer
97
  print("Creating pipeline with pre-loaded model...")
 
99
  "text2text-generation",
100
  model=model,
101
  tokenizer=tokenizer,
102
+ device=device, # Use detected device instead of hardcoding -1
103
  max_length=512
104
  )
105
 
 
111
  return False
112
 
113
  # --- Translation Function ---
114
+ def translate_text(text, source_lang, target_lang):
115
+ """Translate text using local model or fallback to online services."""
116
+ global translator, tokenizer, model
 
 
117
 
 
 
 
118
  print(f"Translation Request - Source Lang: {source_lang}, Target Lang: {target_lang}")
119
 
120
+ if not model or not tokenizer:
 
 
 
 
121
  success = initialize_model()
122
  if not success:
123
+ return use_fallback_translation(text, source_lang, target_lang)
 
124
 
125
  try:
126
+ # Prepare input with explicit instruction format for better results with flan-t5
127
+ input_text = f"Translate from {source_lang} to {target_lang}: {text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ # Use a more reliable timeout approach with concurrent.futures
130
+ with concurrent.futures.ThreadPoolExecutor() as executor:
131
+ future = executor.submit(
132
+ lambda: translator(
133
+ input_text,
134
+ max_length=512,
135
+ num_beams=4, # Increase beam search for better quality
136
+ no_repeat_ngram_size=2
137
+ )[0]["generated_text"]
138
+ )
139
 
140
+ try:
141
+ # Set a reasonable timeout (15 seconds instead of 10)
142
+ result = future.result(timeout=15)
 
 
143
 
144
+ # Clean up result (remove any instruction preamble if present)
145
+ if ':' in result and len(result.split(':', 1)) > 1:
146
+ result = result.split(':', 1)[1].strip()
147
+
148
+ return result
149
+ except concurrent.futures.TimeoutError:
150
+ print(f"Model inference timed out after 15 seconds, falling back to online translation")
151
+ return use_fallback_translation(text, source_lang, target_lang)
 
 
 
 
152
  except Exception as e:
153
+ print(f"Error using local model: {e}")
154
  traceback.print_exc()
155
+ return use_fallback_translation(text, source_lang, target_lang)
156
 
157
+ def use_fallback_translation(text, source_lang, target_lang):
158
+ """Use various fallback online translation services."""
159
+ # List of LibreTranslate servers to try in order
160
+ libre_servers = [
161
  "https://translate.terraprint.co/translate",
162
  "https://libretranslate.de/translate",
163
+ "https://translate.argosopentech.com/translate",
164
+ "https://translate.fedilab.app/translate" # Added additional server
165
  ]
166
 
167
+ # Try each LibreTranslate server
168
+ for server in libre_servers:
169
  try:
170
+ print(f"Attempting fallback translation using LibreTranslate: {server}")
171
+ headers = {
172
+ "Content-Type": "application/json"
173
+ }
174
  payload = {
175
  "q": text,
176
+ "source": source_lang,
177
+ "target": target_lang
 
178
  }
179
 
180
+ # Use a shorter timeout for the request (5 seconds instead of 10)
181
+ response = requests.post(server, json=payload, headers=headers, timeout=5)
182
 
183
  if response.status_code == 200:
184
  result = response.json()
185
+ if "translatedText" in result:
186
+ return result["translatedText"]
 
 
 
187
  except Exception as e:
188
+ print(f"Error with LibreTranslate {server}: {str(e)}")
189
+ continue
190
 
191
+ # If all LibreTranslate servers fail, try Google Translate API with a wrapper
192
+ # that doesn't need an API key for limited usage
193
+ try:
194
+ print("Attempting fallback with Google Translate (no API key)")
195
+ from googletrans import Translator
196
+ google_translator = Translator()
197
+ result = google_translator.translate(text, src=source_lang, dest=target_lang)
198
+ return result.text
199
+ except Exception as e:
200
+ print(f"Error with Google Translate fallback: {str(e)}")
 
 
 
 
201
 
202
+ # Final fallback - return original text with error message
203
+ return f"[Translation failed] {text}"
204
 
205
  def culturally_adapt_arabic(text: str) -> str:
206
  """Apply post-processing rules to enhance Arabic translation with cultural sensitivity."""
 
286
  raise HTTPException(status_code=400, detail="No text provided for translation.")
287
 
288
  try:
289
+ translated_text = translate_text(text, source_lang, target_lang)
290
  return JSONResponse(content={"translated_text": translated_text, "source_lang": source_lang})
291
  except Exception as e:
292
  print(f"Translation error: {e}")
 
308
  raise HTTPException(status_code=400, detail="Could not extract any text from the document.")
309
 
310
  # Translate the extracted text
311
+ translated_text = translate_text(extracted_text, source_lang, target_lang)
312
 
313
  return JSONResponse(content={
314
  "original_filename": file.filename,