amine_dubs commited on
Commit
be03516
·
1 Parent(s): dbe4e2f
Files changed (1) hide show
  1. backend/main.py +178 -42
backend/main.py CHANGED
@@ -4,11 +4,12 @@ from fastapi.staticfiles import StaticFiles
4
  from fastapi.templating import Jinja2Templates
5
  from typing import List, Optional
6
  import shutil
 
 
 
 
7
  import os
8
- # Use AutoModel for flexibility
9
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
- import torch # Ensure torch is imported if using generate directly
11
- import traceback # Ensure traceback is imported
12
 
13
  # --- Configuration ---
14
  # Determine the base directory of the main.py script
@@ -29,36 +30,150 @@ app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
29
  # Ensure the templates directory exists (FastAPI doesn't create it)
30
  templates = Jinja2Templates(directory=TEMPLATE_DIR)
31
 
32
- # --- Model Loading ---
 
 
 
 
 
 
33
 
34
- # Define model name - Switched to FLAN-T5
35
- MODEL_NAME = "google/flan-t5-small"
36
- CACHE_DIR = "/app/.cache" # Explicitly define cache directory
37
  model = None
38
  tokenizer = None
39
 
 
 
 
 
 
 
 
 
40
  try:
41
- print("--- Loading Model ---")
42
- print(f"Loading tokenizer for {MODEL_NAME} using AutoTokenizer...")
43
- # Use AutoTokenizer and specify cache_dir
44
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
45
- print(f"Loading model for {MODEL_NAME} using AutoModelForSeq2SeqLM...")
46
- # Use AutoModelForSeq2SeqLM and specify cache_dir
47
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
48
- print("--- Model Loaded Successfully ---")
49
  except Exception as e:
50
- print(f"--- ERROR Loading Model ---")
51
- print(f"Error loading model or tokenizer {MODEL_NAME}: {e}")
52
- traceback.print_exc() # Print full traceback for loading error
53
- # Keep model and tokenizer as None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # --- Helper Functions ---
56
  def translate_text_internal(text: str, source_lang: str, target_lang: str = "ar") -> str:
57
- """Internal function to handle text translation using the loaded model via prompting."""
 
 
58
  if model is None or tokenizer is None:
59
- # If the model/tokenizer failed to load, raise an error
60
- raise HTTPException(status_code=503, detail="Translation service is unavailable (model not loaded).")
61
-
62
  # --- Enhanced Prompt Engineering ---
63
  # Map source language codes to full language names for better model understanding
64
  language_map = {
@@ -93,24 +208,43 @@ Text to translate:
93
  print(f"Translation Request - Source Lang: {source_lang} ({source_lang_name}), Target Lang: {target_lang}")
94
  print(f"Using Enhanced Prompt for Balagha and Cultural Sensitivity")
95
 
96
- # --- Actual Translation Logic (using model.generate) ---
97
  try:
98
- # Tokenize the prompt
99
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
100
-
101
- # Generate the translation with parameters tuned for quality
102
- outputs = model.generate(
103
- **inputs,
104
- max_length=512, # Adjust based on expected output length
105
- num_beams=5, # Increased for better quality
106
- length_penalty=1.0, # Encourage slightly longer outputs for natural flow
107
- top_k=50, # More diverse word choices
108
- top_p=0.95, # Sample from higher probability tokens for fluency
109
- early_stopping=True
110
- )
111
-
112
- # Decode the generated tokens
113
- translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  print(f"Raw Translation Output: {translated_text}")
116
  return translated_text
@@ -118,7 +252,9 @@ Text to translate:
118
  except Exception as e:
119
  print(f"Error during model generation: {e}")
120
  traceback.print_exc()
121
- raise HTTPException(status_code=500, detail=f"Translation failed during generation: {e}")
 
 
122
 
123
  # --- Function to extract text ---
124
  async def extract_text_from_file(file: UploadFile) -> str:
 
4
  from fastapi.templating import Jinja2Templates
5
  from typing import List, Optional
6
  import shutil
7
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MarianMTModel, MarianTokenizer
8
+ import torch
9
+ import traceback
10
+ import time # For retries
11
  import os
12
+ import requests # For direct API access as final fallback
 
 
 
13
 
14
  # --- Configuration ---
15
  # Determine the base directory of the main.py script
 
30
  # Ensure the templates directory exists (FastAPI doesn't create it)
31
  templates = Jinja2Templates(directory=TEMPLATE_DIR)
32
 
33
+ # --- Model Loading Strategy ---
34
+ # Define model options in order of preference
35
+ MODEL_OPTIONS = [
36
+ {"name": "google/flan-t5-small", "type": "flan-t5"},
37
+ {"name": "Helsinki-NLP/opus-mt-en-ar", "type": "marian"},
38
+ {"name": "t5-small", "type": "t5-fallback"} # Smaller, more commonly available model
39
+ ]
40
 
41
+ CACHE_DIR = "/app/.cache"
 
 
42
  model = None
43
  tokenizer = None
44
 
45
+ # Set environment variables for cache locations
46
+ os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
47
+ os.environ["HF_HOME"] = CACHE_DIR
48
+ print(f"Cache directories set to: {CACHE_DIR}")
49
+ print(f"Environment TRANSFORMERS_CACHE: {os.environ.get('TRANSFORMERS_CACHE')}")
50
+ print(f"Environment HF_HOME: {os.environ.get('HF_HOME')}")
51
+
52
+ # Create cache directory with explicit permissions
53
  try:
54
+ os.makedirs(CACHE_DIR, exist_ok=True)
55
+ # Ensure the cache directory is writeable - set permissive permissions
56
+ os.chmod(CACHE_DIR, 0o777) # Read/write/execute for all
57
+ print(f"Cache directory {CACHE_DIR} created with full permissions")
 
 
 
 
58
  except Exception as e:
59
+ print(f"Warning: Could not set permissions on cache dir: {e}")
60
+
61
+ # Try each model in order until one loads successfully
62
+ for model_option in MODEL_OPTIONS:
63
+ MODEL_NAME = model_option["name"]
64
+ MODEL_TYPE = model_option["type"]
65
+
66
+ print(f"--- Attempting to load model: {MODEL_NAME} (Type: {MODEL_TYPE}) ---")
67
+
68
+ # Try to load with retries
69
+ max_retries = 3
70
+ for attempt in range(max_retries):
71
+ try:
72
+ if MODEL_TYPE == "flan-t5" or MODEL_TYPE == "t5-fallback":
73
+ print(f"Loading with AutoTokenizer/AutoModelForSeq2SeqLM (Attempt {attempt+1}/{max_retries})")
74
+ tokenizer = AutoTokenizer.from_pretrained(
75
+ MODEL_NAME,
76
+ cache_dir=CACHE_DIR,
77
+ local_files_only=False, # Force online download
78
+ resume_download=True # Resume if download was interrupted
79
+ )
80
+
81
+ model = AutoModelForSeq2SeqLM.from_pretrained(
82
+ MODEL_NAME,
83
+ cache_dir=CACHE_DIR,
84
+ local_files_only=False, # Force online download
85
+ resume_download=True # Resume if download was interrupted
86
+ )
87
+ elif MODEL_TYPE == "marian":
88
+ print(f"Loading with MarianTokenizer/MarianMTModel (Attempt {attempt+1}/{max_retries})")
89
+ tokenizer = MarianTokenizer.from_pretrained(
90
+ MODEL_NAME,
91
+ cache_dir=CACHE_DIR,
92
+ local_files_only=False,
93
+ resume_download=True
94
+ )
95
+
96
+ model = MarianMTModel.from_pretrained(
97
+ MODEL_NAME,
98
+ cache_dir=CACHE_DIR,
99
+ local_files_only=False,
100
+ resume_download=True
101
+ )
102
+
103
+ print(f"--- Successfully loaded model: {MODEL_NAME} ---")
104
+ break # Break out of retry loop if successful
105
+
106
+ except Exception as e:
107
+ print(f"Error loading model {MODEL_NAME} (Attempt {attempt+1}): {e}")
108
+ traceback.print_exc()
109
+
110
+ if attempt < max_retries - 1:
111
+ wait_time = 2 * (attempt + 1) # Exponential backoff
112
+ print(f"Waiting {wait_time} seconds before retry...")
113
+ time.sleep(wait_time)
114
+ else:
115
+ print(f"Failed to load model {MODEL_NAME} after {max_retries} attempts.")
116
+
117
+ if model is not None and tokenizer is not None:
118
+ # If we successfully loaded a model, break out of the model options loop
119
+ break
120
+
121
+ # --- Fallback Translation Logic ---
122
+ # If we couldn't load any model, we'll set up a simple fallback system
123
+
124
+ # Define a simple dictionary for common phrases (just as a last resort)
125
+ FALLBACK_PHRASES = {
126
+ "hello": "مرحبا",
127
+ "thank you": "شكرا لك",
128
+ "goodbye": "مع السلامة",
129
+ "welcome": "أهلا وسهلا",
130
+ }
131
+
132
+ def fallback_translate(text, source_lang):
133
+ """Last resort fallback translation if all models fail to load."""
134
+ print("Using emergency fallback translation (very limited capability)")
135
+
136
+ # For longer text, try direct API call to a free translation service
137
+ try:
138
+ # Try to use LibreTranslate API as fallback (no API key needed for some instances)
139
+ url = "https://translate.terraprint.co/translate"
140
+
141
+ payload = {
142
+ "q": text,
143
+ "source": source_lang if source_lang != "auto" else "auto",
144
+ "target": "ar",
145
+ "format": "text"
146
+ }
147
+
148
+ headers = {"Content-Type": "application/json"}
149
+
150
+ print("Attempting LibreTranslate API call...")
151
+ response = requests.post(url, json=payload, headers=headers)
152
+
153
+ if response.status_code == 200:
154
+ result = response.json()
155
+ print("LibreTranslate API call successful")
156
+ return result.get("translatedText", f"[Translation Error: {response.text}]")
157
+
158
+ except Exception as e:
159
+ print(f"LibreTranslate API call failed: {e}")
160
+
161
+ # If that fails too, use our minimal dictionary
162
+ if text.lower() in FALLBACK_PHRASES:
163
+ return FALLBACK_PHRASES[text.lower()]
164
+
165
+ # For unknown text, return a message in Arabic explaining the issue
166
+ return "عذراً، لم نتمكن من تحميل نموذج الترجمة. هذه ترجمة محدودة جداً." # "Sorry, we couldn't load the translation model. This is a very limited translation."
167
 
168
  # --- Helper Functions ---
169
  def translate_text_internal(text: str, source_lang: str, target_lang: str = "ar") -> str:
170
+ """Internal function to handle text translation using the loaded model or fallbacks."""
171
+
172
+ # Check if we successfully loaded a model
173
  if model is None or tokenizer is None:
174
+ # No model available, use fallback
175
+ return fallback_translate(text, source_lang)
176
+
177
  # --- Enhanced Prompt Engineering ---
178
  # Map source language codes to full language names for better model understanding
179
  language_map = {
 
208
  print(f"Translation Request - Source Lang: {source_lang} ({source_lang_name}), Target Lang: {target_lang}")
209
  print(f"Using Enhanced Prompt for Balagha and Cultural Sensitivity")
210
 
211
+ # --- Model-specific translation logic ---
212
  try:
213
+ if MODEL_TYPE in ["flan-t5", "t5-fallback"]:
214
+ # Use prompt-based approach for T5 models
215
+ # Tokenize the prompt
216
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
217
+
218
+ # Generate the translation with parameters tuned for quality
219
+ outputs = model.generate(
220
+ **inputs,
221
+ max_length=512, # Adjust based on expected output length
222
+ num_beams=5, # Increased for better quality
223
+ length_penalty=1.0, # Encourage slightly longer outputs for natural flow
224
+ top_k=50, # More diverse word choices
225
+ top_p=0.95, # Sample from higher probability tokens for fluency
226
+ early_stopping=True
227
+ )
228
+
229
+ # Decode the generated tokens
230
+ translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
231
+
232
+ elif MODEL_TYPE == "marian":
233
+ # Direct translation for Marian model (specialized for translation)
234
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
235
+
236
+ outputs = model.generate(
237
+ **inputs,
238
+ max_length=512,
239
+ num_beams=5,
240
+ length_penalty=1.0,
241
+ early_stopping=True
242
+ )
243
+
244
+ translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
245
+ else:
246
+ # Unknown model type, use fallback
247
+ return fallback_translate(text, source_lang)
248
 
249
  print(f"Raw Translation Output: {translated_text}")
250
  return translated_text
 
252
  except Exception as e:
253
  print(f"Error during model generation: {e}")
254
  traceback.print_exc()
255
+
256
+ # If translation fails, use fallback
257
+ return fallback_translate(text, source_lang)
258
 
259
  # --- Function to extract text ---
260
  async def extract_text_from_file(file: UploadFile) -> str: