amine_dubs commited on
Commit
ec41997
·
1 Parent(s): 050f2a9

changed model

Browse files
Files changed (2) hide show
  1. backend/main.py +25 -57
  2. static/script.js +25 -8
backend/main.py CHANGED
@@ -87,8 +87,8 @@ def initialize_model():
87
  try:
88
  print(f"Initializing model and tokenizer (attempt {model_initialization_attempts})...")
89
 
90
- # Use a smaller model that works well for instruction-based translation
91
- model_name = "google/flan-t5-small"
92
 
93
  # Check for available device - properly detect CPU/GPU
94
  device = "cpu" # Default to CPU which is more reliable
@@ -101,7 +101,8 @@ def initialize_model():
101
  print(f"Loading tokenizer from {model_name}...")
102
  tokenizer = AutoTokenizer.from_pretrained(
103
  model_name,
104
- cache_dir="/tmp/transformers_cache"
 
105
  )
106
  if tokenizer is None:
107
  print("Failed to load tokenizer")
@@ -130,7 +131,7 @@ def initialize_model():
130
  try:
131
  # Create the pipeline with explicit model and tokenizer
132
  translator = pipeline(
133
- "text2text-generation",
134
  model=model,
135
  tokenizer=tokenizer,
136
  device=0 if device == "cuda" else -1, # Proper device mapping
@@ -142,7 +143,8 @@ def initialize_model():
142
  return False
143
 
144
  # Test the model with a simple translation to verify it works
145
- test_result = translator("Translate from English to French: hello", max_length=128)
 
146
  print(f"Model test result: {test_result}")
147
  if not test_result or not isinstance(test_result, list) or len(test_result) == 0:
148
  print("Model test failed: Invalid output format")
@@ -176,32 +178,25 @@ def translate_text(text, source_lang, target_lang):
176
  return use_fallback_translation(text, source_lang, target_lang)
177
 
178
  try:
179
- # Prepare input with explicit instruction format for better results with flan-t5
180
- if target_lang == "Arabic" or target_lang == "ar":
181
- # Special prompt for Arabic translations
182
- input_text = f"You are a bilingual in {source_lang} and Arabic, a professional translator, translate this script from {source_lang} to Arabic MSA with cultural sensitivity and accuracy, with a focus on meaning and eloquence (Balagha), avoiding overly literal translations.: {text}"
183
- else:
184
- input_text = f"Translate from {source_lang} to {target_lang}: {text}"
185
 
186
  # Use a more reliable timeout approach with concurrent.futures
187
  with concurrent.futures.ThreadPoolExecutor() as executor:
188
  future = executor.submit(
189
  lambda: translator(
190
- input_text,
191
- max_length=512,
192
- num_beams=4, # Increase beam search for better quality
193
- no_repeat_ngram_size=2
194
- )[0]["generated_text"]
195
  )
196
 
197
  try:
198
  # Set a reasonable timeout (15 seconds instead of 10)
199
  result = future.result(timeout=15)
200
 
201
- # Clean up result (remove any instruction preamble if present)
202
- if ':' in result and len(result.split(':', 1)) > 1:
203
- result = result.split(':', 1)[1].strip()
204
-
205
  return result
206
  except concurrent.futures.TimeoutError:
207
  print(f"Model inference timed out after 15 seconds, falling back to online translation")
@@ -230,8 +225,8 @@ def check_and_reinitialize_model():
230
  return initialize_model()
231
 
232
  # Test the existing model with a simple translation
233
- test_text = "Translate from English to French: hello"
234
- result = translator(test_text, max_length=128)
235
 
236
  # If we got a valid result, model is working fine
237
  if result and isinstance(result, list) and len(result) > 0:
@@ -388,51 +383,24 @@ async def translate_text_endpoint(request: TranslationRequest):
388
  raise Exception("Failed to initialize translation model")
389
 
390
  # Format the prompt for the model
391
- lang_code_map = {
392
- "en": "English", "es": "Spanish", "fr": "French", "de": "German",
393
- "zh": "Chinese", "ja": "Japanese", "ko": "Korean", "ar": "Arabic",
394
- "ru": "Russian", "pt": "Portuguese", "it": "Italian", "nl": "Dutch"
395
- }
396
-
397
- source_lang_name = lang_code_map.get(source_lang.lower(), source_lang)
398
- target_lang_name = lang_code_map.get(target_lang.lower(), target_lang)
399
 
400
- # Create a proper prompt for instruction-based models
401
- prompt = f"Translate from {source_lang_name} to {target_lang_name}: {text}"
402
- print(f"Using prompt: {prompt}")
403
-
404
- # Check that translator is callable before proceeding
405
- if not callable(translator):
406
- print("[DEBUG] Translator is not callable, attempting to reinitialize")
407
- success = initialize_model()
408
- if not success or not callable(translator):
409
- raise Exception("Translator is not callable after reinitialization")
410
  print("[DEBUG] Calling translator model...")
411
  # Use a thread pool to execute the translation with a timeout
412
  with concurrent.futures.ThreadPoolExecutor() as executor:
413
  future = executor.submit(
414
  lambda: translator(
415
- prompt,
416
- max_length=512,
417
- do_sample=False,
418
- temperature=0.7
419
- )
420
  )
421
 
422
  try:
423
  result = future.result(timeout=15)
424
- # Check result format before accessing elements
425
- if not result or not isinstance(result, list) or len(result) == 0:
426
- raise Exception(f"Invalid model output format: {result}")
427
-
428
- translation_result = result[0]["generated_text"]
429
-
430
- # Clean up the output - remove any prefix like "Translation:"
431
- prefixes = ["Translation:", "Translation: ", f"{target_lang_name}:", f"{target_lang_name}: "]
432
- for prefix in prefixes:
433
- if translation_result.startswith(prefix):
434
- translation_result = translation_result[len(prefix):].strip()
435
-
436
  print(f"Local model translation result: {translation_result}")
437
  except concurrent.futures.TimeoutError:
438
  print("Translation timed out after 15 seconds")
 
87
  try:
88
  print(f"Initializing model and tokenizer (attempt {model_initialization_attempts})...")
89
 
90
+ # Use a better translation model that handles multilingual tasks well
91
+ model_name = "facebook/nllb-200-distilled-600M" # Better multilingual translation model
92
 
93
  # Check for available device - properly detect CPU/GPU
94
  device = "cpu" # Default to CPU which is more reliable
 
101
  print(f"Loading tokenizer from {model_name}...")
102
  tokenizer = AutoTokenizer.from_pretrained(
103
  model_name,
104
+ cache_dir="/tmp/transformers_cache",
105
+ use_fast=True # Use faster tokenizer when possible
106
  )
107
  if tokenizer is None:
108
  print("Failed to load tokenizer")
 
131
  try:
132
  # Create the pipeline with explicit model and tokenizer
133
  translator = pipeline(
134
+ "translation",
135
  model=model,
136
  tokenizer=tokenizer,
137
  device=0 if device == "cuda" else -1, # Proper device mapping
 
143
  return False
144
 
145
  # Test the model with a simple translation to verify it works
146
+ # NLLB needs language codes in format like "eng_Latn" and "ara_Arab"
147
+ test_result = translator("hello", src_lang="eng_Latn", tgt_lang="ara_Arab", max_length=128)
148
  print(f"Model test result: {test_result}")
149
  if not test_result or not isinstance(test_result, list) or len(test_result) == 0:
150
  print("Model test failed: Invalid output format")
 
178
  return use_fallback_translation(text, source_lang, target_lang)
179
 
180
  try:
181
+ # Prepare input with explicit instruction format for better results with NLLB
182
+ src_lang_code = f"{source_lang}_Latn" if source_lang != "ar" else f"{source_lang}_Arab"
183
+ tgt_lang_code = f"{target_lang}_Latn" if target_lang != "ar" else f"{target_lang}_Arab"
 
 
 
184
 
185
  # Use a more reliable timeout approach with concurrent.futures
186
  with concurrent.futures.ThreadPoolExecutor() as executor:
187
  future = executor.submit(
188
  lambda: translator(
189
+ text,
190
+ src_lang=src_lang_code,
191
+ tgt_lang=tgt_lang_code,
192
+ max_length=512
193
+ )[0]["translation_text"]
194
  )
195
 
196
  try:
197
  # Set a reasonable timeout (15 seconds instead of 10)
198
  result = future.result(timeout=15)
199
 
 
 
 
 
200
  return result
201
  except concurrent.futures.TimeoutError:
202
  print(f"Model inference timed out after 15 seconds, falling back to online translation")
 
225
  return initialize_model()
226
 
227
  # Test the existing model with a simple translation
228
+ test_text = "hello"
229
+ result = translator(test_text, src_lang="eng_Latn", tgt_lang="fra_Latn", max_length=128)
230
 
231
  # If we got a valid result, model is working fine
232
  if result and isinstance(result, list) and len(result) > 0:
 
383
  raise Exception("Failed to initialize translation model")
384
 
385
  # Format the prompt for the model
386
+ src_lang_code = f"{source_lang}_Latn" if source_lang != "ar" else f"{source_lang}_Arab"
387
+ tgt_lang_code = f"{target_lang}_Latn" if target_lang != "ar" else f"{target_lang}_Arab"
 
 
 
 
 
 
388
 
 
 
 
 
 
 
 
 
 
 
389
  print("[DEBUG] Calling translator model...")
390
  # Use a thread pool to execute the translation with a timeout
391
  with concurrent.futures.ThreadPoolExecutor() as executor:
392
  future = executor.submit(
393
  lambda: translator(
394
+ text,
395
+ src_lang=src_lang_code,
396
+ tgt_lang=tgt_lang_code,
397
+ max_length=512
398
+ )[0]["translation_text"]
399
  )
400
 
401
  try:
402
  result = future.result(timeout=15)
403
+ translation_result = result
 
 
 
 
 
 
 
 
 
 
 
404
  print(f"Local model translation result: {translation_result}")
405
  except concurrent.futures.TimeoutError:
406
  print("Translation timed out after 15 seconds")
static/script.js CHANGED
@@ -56,15 +56,16 @@ document.addEventListener('DOMContentLoaded', () => {
56
  docLoadingIndicator.style.display = 'none';
57
  }
58
 
59
- // Improve the text form submission handler
60
  if (textForm) {
61
  textForm.addEventListener('submit', async (e) => {
62
  e.preventDefault();
63
  clearFeedback();
64
 
65
- const sourceText = document.getElementById('source-text').value.trim();
66
- const sourceLang = document.getElementById('text-source-lang').value;
67
- const targetLang = document.getElementById('text-target-lang').value;
 
68
 
69
  if (!sourceText) {
70
  displayError('Please enter text to translate');
@@ -72,8 +73,16 @@ document.addEventListener('DOMContentLoaded', () => {
72
  }
73
 
74
  try {
75
- // Show loading state
76
- document.getElementById('text-loading').style.display = 'block';
 
 
 
 
 
 
 
 
77
 
78
  // Log payload for debugging
79
  console.log('Sending payload:', { text: sourceText, source_lang: sourceLang, target_lang: targetLang });
@@ -91,7 +100,7 @@ document.addEventListener('DOMContentLoaded', () => {
91
  });
92
 
93
  // Hide loading state
94
- document.getElementById('text-loading').style.display = 'none';
95
 
96
  // Log response status
97
  console.log('Response status:', response.status);
@@ -112,13 +121,21 @@ document.addEventListener('DOMContentLoaded', () => {
112
  return;
113
  }
114
 
 
 
 
 
 
115
  textOutput.textContent = data.translated_text;
116
  textResultBox.style.display = 'block';
117
 
118
  } catch (error) {
119
  console.error('Error:', error);
120
  displayError('Network error or invalid response format');
121
- document.getElementById('text-loading').style.display = 'none';
 
 
 
122
  }
123
  });
124
  }
 
56
  docLoadingIndicator.style.display = 'none';
57
  }
58
 
59
+ // Fix the text form submission handler to use correct field IDs
60
  if (textForm) {
61
  textForm.addEventListener('submit', async (e) => {
62
  e.preventDefault();
63
  clearFeedback();
64
 
65
+ // Use correct field IDs matching the HTML
66
+ const sourceText = document.getElementById('text-input').value.trim();
67
+ const sourceLang = document.getElementById('source-lang-text').value;
68
+ const targetLang = document.getElementById('target-lang-text').value;
69
 
70
  if (!sourceText) {
71
  displayError('Please enter text to translate');
 
73
  }
74
 
75
  try {
76
+ // Show loading state (create it if missing)
77
+ let textLoading = document.getElementById('text-loading');
78
+ if (!textLoading) {
79
+ textLoading = document.createElement('div');
80
+ textLoading.id = 'text-loading';
81
+ textLoading.className = 'loading-spinner';
82
+ textLoading.innerHTML = 'Translating...';
83
+ textForm.appendChild(textLoading);
84
+ }
85
+ textLoading.style.display = 'block';
86
 
87
  // Log payload for debugging
88
  console.log('Sending payload:', { text: sourceText, source_lang: sourceLang, target_lang: targetLang });
 
100
  });
101
 
102
  // Hide loading state
103
+ textLoading.style.display = 'none';
104
 
105
  // Log response status
106
  console.log('Response status:', response.status);
 
121
  return;
122
  }
123
 
124
+ if (!data.translated_text) {
125
+ displayError('Translation returned empty text');
126
+ return;
127
+ }
128
+
129
  textOutput.textContent = data.translated_text;
130
  textResultBox.style.display = 'block';
131
 
132
  } catch (error) {
133
  console.error('Error:', error);
134
  displayError('Network error or invalid response format');
135
+
136
+ // Hide loading if it exists
137
+ const textLoading = document.getElementById('text-loading');
138
+ if (textLoading) textLoading.style.display = 'none';
139
  }
140
  });
141
  }