amine_dubs commited on
Commit
aded6a5
·
1 Parent(s): decdde7
Files changed (1) hide show
  1. backend/main.py +181 -38
backend/main.py CHANGED
@@ -9,6 +9,8 @@ import json
9
  import traceback
10
  import io
11
  import concurrent.futures
 
 
12
 
13
  # Import transformers for local model inference
14
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
@@ -66,47 +68,71 @@ def initialize_model():
66
  model_name = "google/flan-t5-small"
67
 
68
  # Check for available device - properly detect CPU/GPU
69
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
- print(f"Device set to use {device}")
 
 
 
71
 
72
  # Load the tokenizer with explicit cache directory
 
73
  tokenizer = AutoTokenizer.from_pretrained(
74
  model_name,
75
  cache_dir="/tmp/transformers_cache"
76
  )
 
 
 
 
77
 
78
- # Load the model with PyTorch approach which is more reliable
 
79
  try:
80
- print("Loading model with PyTorch backend...")
81
  model = AutoModelForSeq2SeqLM.from_pretrained(
82
  model_name,
83
  cache_dir="/tmp/transformers_cache",
84
- low_cpu_mem_usage=True, # Add this for better memory usage
85
- device_map="auto" # Let the library decide optimal device mapping
86
  )
 
 
 
87
  except Exception as e:
88
- print(f"PyTorch loading failed: {e}")
89
- print("Attempting to load with TensorFlow...")
90
- model = AutoModelForSeq2SeqLM.from_pretrained(
91
- model_name,
92
- from_tf=True,
93
- cache_dir="/tmp/transformers_cache"
94
- )
95
 
96
  # Create a pipeline with the loaded model and tokenizer
97
- print("Creating pipeline with pre-loaded model...")
98
- translator = pipeline(
99
- "text2text-generation",
100
- model=model,
101
- tokenizer=tokenizer,
102
- device=device, # Use detected device instead of hardcoding -1
103
- max_length=512
104
- )
105
-
106
- print(f"Model {model_name} successfully initialized")
107
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
- print(f"Error initializing model: {e}")
110
  traceback.print_exc()
111
  return False
112
 
@@ -276,22 +302,139 @@ async def read_root(request: Request):
276
  return templates.TemplateResponse("index.html", {"request": request})
277
 
278
  @app.post("/translate/text")
279
- async def translate_text_endpoint(
280
- text: str = Form(...),
281
- source_lang: str = Form(...),
282
- target_lang: str = Form("ar")
283
- ):
284
- """Translates direct text input."""
285
- if not text:
286
- raise HTTPException(status_code=400, detail="No text provided for translation.")
 
 
 
287
 
288
  try:
289
- translated_text = translate_text(text, source_lang, target_lang)
290
- return JSONResponse(content={"translated_text": translated_text, "source_lang": source_lang})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  except Exception as e:
292
- print(f"Translation error: {e}")
293
- traceback.print_exc()
294
- raise HTTPException(status_code=500, detail=f"Translation error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  @app.post("/translate/document")
297
  async def translate_document_endpoint(
 
9
  import traceback
10
  import io
11
  import concurrent.futures
12
+ import subprocess
13
+ import sys
14
 
15
  # Import transformers for local model inference
16
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
68
  model_name = "google/flan-t5-small"
69
 
70
  # Check for available device - properly detect CPU/GPU
71
+ device = "cpu" # Default to CPU which is more reliable
72
+ if torch.cuda.is_available():
73
+ device = "cuda"
74
+ print(f"CUDA is available: {torch.cuda.get_device_name(0)}")
75
+ print(f"Device set to use: {device}")
76
 
77
  # Load the tokenizer with explicit cache directory
78
+ print(f"Loading tokenizer from {model_name}...")
79
  tokenizer = AutoTokenizer.from_pretrained(
80
  model_name,
81
  cache_dir="/tmp/transformers_cache"
82
  )
83
+ if tokenizer is None:
84
+ print("Failed to load tokenizer")
85
+ return False
86
+ print("Tokenizer loaded successfully")
87
 
88
+ # Load the model with explicit device placement
89
+ print(f"Loading model from {model_name}...")
90
  try:
 
91
  model = AutoModelForSeq2SeqLM.from_pretrained(
92
  model_name,
93
  cache_dir="/tmp/transformers_cache",
94
+ low_cpu_mem_usage=True, # Better memory usage
95
+ torch_dtype=torch.float32 # Explicit dtype for better compatibility
96
  )
97
+ # Move model to device after loading
98
+ model = model.to(device)
99
+ print(f"Model loaded with PyTorch and moved to {device}")
100
  except Exception as e:
101
+ print(f"Error loading model: {e}")
102
+ print("Model initialization failed")
103
+ return False
 
 
 
 
104
 
105
  # Create a pipeline with the loaded model and tokenizer
106
+ print("Creating translation pipeline...")
107
+ try:
108
+ # Create the pipeline with explicit model and tokenizer
109
+ translator = pipeline(
110
+ "text2text-generation",
111
+ model=model,
112
+ tokenizer=tokenizer,
113
+ device=0 if device == "cuda" else -1, # Proper device mapping
114
+ framework="pt" # Explicitly use PyTorch
115
+ )
116
+
117
+ if translator is None:
118
+ print("Failed to create translator pipeline")
119
+ return False
120
+
121
+ # Test the model with a simple translation to verify it works
122
+ test_result = translator("Translate from English to French: hello", max_length=128)
123
+ print(f"Model test result: {test_result}")
124
+ if not test_result or not isinstance(test_result, list) or len(test_result) == 0:
125
+ print("Model test failed: Invalid output format")
126
+ return False
127
+
128
+ print(f"Model {model_name} successfully initialized and tested")
129
+ return True
130
+ except Exception as inner_e:
131
+ print(f"Error creating translation pipeline: {inner_e}")
132
+ traceback.print_exc()
133
+ return False
134
  except Exception as e:
135
+ print(f"Critical error initializing model: {e}")
136
  traceback.print_exc()
137
  return False
138
 
 
302
  return templates.TemplateResponse("index.html", {"request": request})
303
 
304
  @app.post("/translate/text")
305
+ async def translate_text(request: TranslationRequest):
306
+ global translator, model, tokenizer
307
+
308
+ source_lang = request.source_lang
309
+ target_lang = request.target_lang
310
+ text = request.text
311
+
312
+ print(f"Translation Request - Source Lang: {source_lang}, Target Lang: {target_lang}")
313
+
314
+ translation_result = ""
315
+ error_message = None
316
 
317
  try:
318
+ # Check if translator is initialized, if not, initialize it
319
+ if translator is None:
320
+ print("Translator not initialized. Attempting to initialize model...")
321
+ success = initialize_model()
322
+ if not success:
323
+ raise Exception("Failed to initialize translation model")
324
+
325
+ # Format the prompt for the model
326
+ lang_code_map = {
327
+ "en": "English", "es": "Spanish", "fr": "French", "de": "German",
328
+ "zh": "Chinese", "ja": "Japanese", "ko": "Korean", "ar": "Arabic",
329
+ "ru": "Russian", "pt": "Portuguese", "it": "Italian", "nl": "Dutch"
330
+ }
331
+
332
+ source_lang_name = lang_code_map.get(source_lang.lower(), source_lang)
333
+ target_lang_name = lang_code_map.get(target_lang.lower(), target_lang)
334
+
335
+ # Create a proper prompt for instruction-based models
336
+ prompt = f"Translate from {source_lang_name} to {target_lang_name}: {text}"
337
+ print(f"Using prompt: {prompt}")
338
+
339
+ # Check that translator is callable before proceeding
340
+ if not callable(translator):
341
+ print("Translator is not callable, attempting to reinitialize")
342
+ success = initialize_model()
343
+ if not success or not callable(translator):
344
+ raise Exception("Translator is not callable after reinitialization")
345
+
346
+ # Use a thread pool to execute the translation with a timeout
347
+ with concurrent.futures.ThreadPoolExecutor() as executor:
348
+ future = executor.submit(
349
+ lambda: translator(
350
+ prompt,
351
+ max_length=512,
352
+ do_sample=False,
353
+ temperature=0.7
354
+ )
355
+ )
356
+
357
+ try:
358
+ result = future.result(timeout=15)
359
+ translation_result = result[0]["generated_text"]
360
+
361
+ # Clean up the output - remove any prefix like "Translation:"
362
+ prefixes = ["Translation:", "Translation: ", f"{target_lang_name}:", f"{target_lang_name}: "]
363
+ for prefix in prefixes:
364
+ if translation_result.startswith(prefix):
365
+ translation_result = translation_result[len(prefix):].strip()
366
+
367
+ print(f"Local model translation result: {translation_result}")
368
+ except concurrent.futures.TimeoutError:
369
+ print("Translation timed out after 15 seconds")
370
+ raise Exception("Translation timed out")
371
+ except Exception as e:
372
+ print(f"Error using local model: {str(e)}")
373
+ raise Exception(f"Error using local model: {str(e)}")
374
+
375
  except Exception as e:
376
+ error_message = str(e)
377
+ print(f"Error using local model: {error_message}")
378
+
379
+ # Try the fallback options
380
+ try:
381
+ # Install googletrans if not present
382
+ try:
383
+ import googletrans
384
+ except ImportError:
385
+ print("Installing googletrans package...")
386
+ subprocess.call([sys.executable, "-m", "pip", "install", "googletrans==4.0.0-rc1"])
387
+
388
+ # Try LibreTranslate providers
389
+ libre_apis = [
390
+ "https://translate.terraprint.co/translate",
391
+ "https://libretranslate.de/translate",
392
+ "https://translate.argosopentech.com/translate",
393
+ "https://translate.fedilab.app/translate"
394
+ ]
395
+
396
+ for api_url in libre_apis:
397
+ try:
398
+ print(f"Attempting fallback translation using LibreTranslate: {api_url}")
399
+ payload = {
400
+ "q": text,
401
+ "source": source_lang,
402
+ "target": target_lang,
403
+ "format": "text",
404
+ "api_key": ""
405
+ }
406
+ headers = {"Content-Type": "application/json"}
407
+ response = requests.post(api_url, json=payload, headers=headers, timeout=5)
408
+
409
+ if response.status_code == 200:
410
+ result = response.json()
411
+ if "translatedText" in result:
412
+ translation_result = result["translatedText"]
413
+ print(f"LibreTranslate successful: {translation_result}")
414
+ break
415
+ except Exception as libre_error:
416
+ print(f"Error with LibreTranslate {api_url}: {str(libre_error)}")
417
+
418
+ # If LibreTranslate failed, try Google Translate
419
+ if not translation_result:
420
+ try:
421
+ print("Attempting fallback with Google Translate (no API key)")
422
+ from googletrans import Translator
423
+ google_translator = Translator()
424
+ result = google_translator.translate(text, src=source_lang, dest=target_lang)
425
+ translation_result = result.text
426
+ print(f"Google Translate successful: {translation_result}")
427
+ except Exception as google_error:
428
+ print(f"Error with Google Translate fallback: {str(google_error)}")
429
+
430
+ except Exception as fallback_error:
431
+ print(f"All fallback translation methods failed: {str(fallback_error)}")
432
+
433
+ # If all translation attempts failed
434
+ if not translation_result:
435
+ return {"success": False, "error": error_message or "All translation methods failed"}
436
+
437
+ return {"success": True, "translation": translation_result}
438
 
439
  @app.post("/translate/document")
440
  async def translate_document_endpoint(