amine_dubs
commited on
Commit
·
ec41997
1
Parent(s):
050f2a9
changed model
Browse files- backend/main.py +25 -57
- static/script.js +25 -8
backend/main.py
CHANGED
@@ -87,8 +87,8 @@ def initialize_model():
|
|
87 |
try:
|
88 |
print(f"Initializing model and tokenizer (attempt {model_initialization_attempts})...")
|
89 |
|
90 |
-
# Use a
|
91 |
-
model_name = "
|
92 |
|
93 |
# Check for available device - properly detect CPU/GPU
|
94 |
device = "cpu" # Default to CPU which is more reliable
|
@@ -101,7 +101,8 @@ def initialize_model():
|
|
101 |
print(f"Loading tokenizer from {model_name}...")
|
102 |
tokenizer = AutoTokenizer.from_pretrained(
|
103 |
model_name,
|
104 |
-
cache_dir="/tmp/transformers_cache"
|
|
|
105 |
)
|
106 |
if tokenizer is None:
|
107 |
print("Failed to load tokenizer")
|
@@ -130,7 +131,7 @@ def initialize_model():
|
|
130 |
try:
|
131 |
# Create the pipeline with explicit model and tokenizer
|
132 |
translator = pipeline(
|
133 |
-
"
|
134 |
model=model,
|
135 |
tokenizer=tokenizer,
|
136 |
device=0 if device == "cuda" else -1, # Proper device mapping
|
@@ -142,7 +143,8 @@ def initialize_model():
|
|
142 |
return False
|
143 |
|
144 |
# Test the model with a simple translation to verify it works
|
145 |
-
|
|
|
146 |
print(f"Model test result: {test_result}")
|
147 |
if not test_result or not isinstance(test_result, list) or len(test_result) == 0:
|
148 |
print("Model test failed: Invalid output format")
|
@@ -176,32 +178,25 @@ def translate_text(text, source_lang, target_lang):
|
|
176 |
return use_fallback_translation(text, source_lang, target_lang)
|
177 |
|
178 |
try:
|
179 |
-
# Prepare input with explicit instruction format for better results with
|
180 |
-
|
181 |
-
|
182 |
-
input_text = f"You are a bilingual in {source_lang} and Arabic, a professional translator, translate this script from {source_lang} to Arabic MSA with cultural sensitivity and accuracy, with a focus on meaning and eloquence (Balagha), avoiding overly literal translations.: {text}"
|
183 |
-
else:
|
184 |
-
input_text = f"Translate from {source_lang} to {target_lang}: {text}"
|
185 |
|
186 |
# Use a more reliable timeout approach with concurrent.futures
|
187 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
188 |
future = executor.submit(
|
189 |
lambda: translator(
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
)[0]["
|
195 |
)
|
196 |
|
197 |
try:
|
198 |
# Set a reasonable timeout (15 seconds instead of 10)
|
199 |
result = future.result(timeout=15)
|
200 |
|
201 |
-
# Clean up result (remove any instruction preamble if present)
|
202 |
-
if ':' in result and len(result.split(':', 1)) > 1:
|
203 |
-
result = result.split(':', 1)[1].strip()
|
204 |
-
|
205 |
return result
|
206 |
except concurrent.futures.TimeoutError:
|
207 |
print(f"Model inference timed out after 15 seconds, falling back to online translation")
|
@@ -230,8 +225,8 @@ def check_and_reinitialize_model():
|
|
230 |
return initialize_model()
|
231 |
|
232 |
# Test the existing model with a simple translation
|
233 |
-
test_text = "
|
234 |
-
result = translator(test_text, max_length=128)
|
235 |
|
236 |
# If we got a valid result, model is working fine
|
237 |
if result and isinstance(result, list) and len(result) > 0:
|
@@ -388,51 +383,24 @@ async def translate_text_endpoint(request: TranslationRequest):
|
|
388 |
raise Exception("Failed to initialize translation model")
|
389 |
|
390 |
# Format the prompt for the model
|
391 |
-
|
392 |
-
|
393 |
-
"zh": "Chinese", "ja": "Japanese", "ko": "Korean", "ar": "Arabic",
|
394 |
-
"ru": "Russian", "pt": "Portuguese", "it": "Italian", "nl": "Dutch"
|
395 |
-
}
|
396 |
-
|
397 |
-
source_lang_name = lang_code_map.get(source_lang.lower(), source_lang)
|
398 |
-
target_lang_name = lang_code_map.get(target_lang.lower(), target_lang)
|
399 |
|
400 |
-
# Create a proper prompt for instruction-based models
|
401 |
-
prompt = f"Translate from {source_lang_name} to {target_lang_name}: {text}"
|
402 |
-
print(f"Using prompt: {prompt}")
|
403 |
-
|
404 |
-
# Check that translator is callable before proceeding
|
405 |
-
if not callable(translator):
|
406 |
-
print("[DEBUG] Translator is not callable, attempting to reinitialize")
|
407 |
-
success = initialize_model()
|
408 |
-
if not success or not callable(translator):
|
409 |
-
raise Exception("Translator is not callable after reinitialization")
|
410 |
print("[DEBUG] Calling translator model...")
|
411 |
# Use a thread pool to execute the translation with a timeout
|
412 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
413 |
future = executor.submit(
|
414 |
lambda: translator(
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
)
|
420 |
)
|
421 |
|
422 |
try:
|
423 |
result = future.result(timeout=15)
|
424 |
-
|
425 |
-
if not result or not isinstance(result, list) or len(result) == 0:
|
426 |
-
raise Exception(f"Invalid model output format: {result}")
|
427 |
-
|
428 |
-
translation_result = result[0]["generated_text"]
|
429 |
-
|
430 |
-
# Clean up the output - remove any prefix like "Translation:"
|
431 |
-
prefixes = ["Translation:", "Translation: ", f"{target_lang_name}:", f"{target_lang_name}: "]
|
432 |
-
for prefix in prefixes:
|
433 |
-
if translation_result.startswith(prefix):
|
434 |
-
translation_result = translation_result[len(prefix):].strip()
|
435 |
-
|
436 |
print(f"Local model translation result: {translation_result}")
|
437 |
except concurrent.futures.TimeoutError:
|
438 |
print("Translation timed out after 15 seconds")
|
|
|
87 |
try:
|
88 |
print(f"Initializing model and tokenizer (attempt {model_initialization_attempts})...")
|
89 |
|
90 |
+
# Use a better translation model that handles multilingual tasks well
|
91 |
+
model_name = "facebook/nllb-200-distilled-600M" # Better multilingual translation model
|
92 |
|
93 |
# Check for available device - properly detect CPU/GPU
|
94 |
device = "cpu" # Default to CPU which is more reliable
|
|
|
101 |
print(f"Loading tokenizer from {model_name}...")
|
102 |
tokenizer = AutoTokenizer.from_pretrained(
|
103 |
model_name,
|
104 |
+
cache_dir="/tmp/transformers_cache",
|
105 |
+
use_fast=True # Use faster tokenizer when possible
|
106 |
)
|
107 |
if tokenizer is None:
|
108 |
print("Failed to load tokenizer")
|
|
|
131 |
try:
|
132 |
# Create the pipeline with explicit model and tokenizer
|
133 |
translator = pipeline(
|
134 |
+
"translation",
|
135 |
model=model,
|
136 |
tokenizer=tokenizer,
|
137 |
device=0 if device == "cuda" else -1, # Proper device mapping
|
|
|
143 |
return False
|
144 |
|
145 |
# Test the model with a simple translation to verify it works
|
146 |
+
# NLLB needs language codes in format like "eng_Latn" and "ara_Arab"
|
147 |
+
test_result = translator("hello", src_lang="eng_Latn", tgt_lang="ara_Arab", max_length=128)
|
148 |
print(f"Model test result: {test_result}")
|
149 |
if not test_result or not isinstance(test_result, list) or len(test_result) == 0:
|
150 |
print("Model test failed: Invalid output format")
|
|
|
178 |
return use_fallback_translation(text, source_lang, target_lang)
|
179 |
|
180 |
try:
|
181 |
+
# Prepare input with explicit instruction format for better results with NLLB
|
182 |
+
src_lang_code = f"{source_lang}_Latn" if source_lang != "ar" else f"{source_lang}_Arab"
|
183 |
+
tgt_lang_code = f"{target_lang}_Latn" if target_lang != "ar" else f"{target_lang}_Arab"
|
|
|
|
|
|
|
184 |
|
185 |
# Use a more reliable timeout approach with concurrent.futures
|
186 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
187 |
future = executor.submit(
|
188 |
lambda: translator(
|
189 |
+
text,
|
190 |
+
src_lang=src_lang_code,
|
191 |
+
tgt_lang=tgt_lang_code,
|
192 |
+
max_length=512
|
193 |
+
)[0]["translation_text"]
|
194 |
)
|
195 |
|
196 |
try:
|
197 |
# Set a reasonable timeout (15 seconds instead of 10)
|
198 |
result = future.result(timeout=15)
|
199 |
|
|
|
|
|
|
|
|
|
200 |
return result
|
201 |
except concurrent.futures.TimeoutError:
|
202 |
print(f"Model inference timed out after 15 seconds, falling back to online translation")
|
|
|
225 |
return initialize_model()
|
226 |
|
227 |
# Test the existing model with a simple translation
|
228 |
+
test_text = "hello"
|
229 |
+
result = translator(test_text, src_lang="eng_Latn", tgt_lang="fra_Latn", max_length=128)
|
230 |
|
231 |
# If we got a valid result, model is working fine
|
232 |
if result and isinstance(result, list) and len(result) > 0:
|
|
|
383 |
raise Exception("Failed to initialize translation model")
|
384 |
|
385 |
# Format the prompt for the model
|
386 |
+
src_lang_code = f"{source_lang}_Latn" if source_lang != "ar" else f"{source_lang}_Arab"
|
387 |
+
tgt_lang_code = f"{target_lang}_Latn" if target_lang != "ar" else f"{target_lang}_Arab"
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
print("[DEBUG] Calling translator model...")
|
390 |
# Use a thread pool to execute the translation with a timeout
|
391 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
392 |
future = executor.submit(
|
393 |
lambda: translator(
|
394 |
+
text,
|
395 |
+
src_lang=src_lang_code,
|
396 |
+
tgt_lang=tgt_lang_code,
|
397 |
+
max_length=512
|
398 |
+
)[0]["translation_text"]
|
399 |
)
|
400 |
|
401 |
try:
|
402 |
result = future.result(timeout=15)
|
403 |
+
translation_result = result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
print(f"Local model translation result: {translation_result}")
|
405 |
except concurrent.futures.TimeoutError:
|
406 |
print("Translation timed out after 15 seconds")
|
static/script.js
CHANGED
@@ -56,15 +56,16 @@ document.addEventListener('DOMContentLoaded', () => {
|
|
56 |
docLoadingIndicator.style.display = 'none';
|
57 |
}
|
58 |
|
59 |
-
//
|
60 |
if (textForm) {
|
61 |
textForm.addEventListener('submit', async (e) => {
|
62 |
e.preventDefault();
|
63 |
clearFeedback();
|
64 |
|
65 |
-
|
66 |
-
const
|
67 |
-
const
|
|
|
68 |
|
69 |
if (!sourceText) {
|
70 |
displayError('Please enter text to translate');
|
@@ -72,8 +73,16 @@ document.addEventListener('DOMContentLoaded', () => {
|
|
72 |
}
|
73 |
|
74 |
try {
|
75 |
-
// Show loading state
|
76 |
-
document.getElementById('text-loading')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
// Log payload for debugging
|
79 |
console.log('Sending payload:', { text: sourceText, source_lang: sourceLang, target_lang: targetLang });
|
@@ -91,7 +100,7 @@ document.addEventListener('DOMContentLoaded', () => {
|
|
91 |
});
|
92 |
|
93 |
// Hide loading state
|
94 |
-
|
95 |
|
96 |
// Log response status
|
97 |
console.log('Response status:', response.status);
|
@@ -112,13 +121,21 @@ document.addEventListener('DOMContentLoaded', () => {
|
|
112 |
return;
|
113 |
}
|
114 |
|
|
|
|
|
|
|
|
|
|
|
115 |
textOutput.textContent = data.translated_text;
|
116 |
textResultBox.style.display = 'block';
|
117 |
|
118 |
} catch (error) {
|
119 |
console.error('Error:', error);
|
120 |
displayError('Network error or invalid response format');
|
121 |
-
|
|
|
|
|
|
|
122 |
}
|
123 |
});
|
124 |
}
|
|
|
56 |
docLoadingIndicator.style.display = 'none';
|
57 |
}
|
58 |
|
59 |
+
// Fix the text form submission handler to use correct field IDs
|
60 |
if (textForm) {
|
61 |
textForm.addEventListener('submit', async (e) => {
|
62 |
e.preventDefault();
|
63 |
clearFeedback();
|
64 |
|
65 |
+
// Use correct field IDs matching the HTML
|
66 |
+
const sourceText = document.getElementById('text-input').value.trim();
|
67 |
+
const sourceLang = document.getElementById('source-lang-text').value;
|
68 |
+
const targetLang = document.getElementById('target-lang-text').value;
|
69 |
|
70 |
if (!sourceText) {
|
71 |
displayError('Please enter text to translate');
|
|
|
73 |
}
|
74 |
|
75 |
try {
|
76 |
+
// Show loading state (create it if missing)
|
77 |
+
let textLoading = document.getElementById('text-loading');
|
78 |
+
if (!textLoading) {
|
79 |
+
textLoading = document.createElement('div');
|
80 |
+
textLoading.id = 'text-loading';
|
81 |
+
textLoading.className = 'loading-spinner';
|
82 |
+
textLoading.innerHTML = 'Translating...';
|
83 |
+
textForm.appendChild(textLoading);
|
84 |
+
}
|
85 |
+
textLoading.style.display = 'block';
|
86 |
|
87 |
// Log payload for debugging
|
88 |
console.log('Sending payload:', { text: sourceText, source_lang: sourceLang, target_lang: targetLang });
|
|
|
100 |
});
|
101 |
|
102 |
// Hide loading state
|
103 |
+
textLoading.style.display = 'none';
|
104 |
|
105 |
// Log response status
|
106 |
console.log('Response status:', response.status);
|
|
|
121 |
return;
|
122 |
}
|
123 |
|
124 |
+
if (!data.translated_text) {
|
125 |
+
displayError('Translation returned empty text');
|
126 |
+
return;
|
127 |
+
}
|
128 |
+
|
129 |
textOutput.textContent = data.translated_text;
|
130 |
textResultBox.style.display = 'block';
|
131 |
|
132 |
} catch (error) {
|
133 |
console.error('Error:', error);
|
134 |
displayError('Network error or invalid response format');
|
135 |
+
|
136 |
+
// Hide loading if it exists
|
137 |
+
const textLoading = document.getElementById('text-loading');
|
138 |
+
if (textLoading) textLoading.style.display = 'none';
|
139 |
}
|
140 |
});
|
141 |
}
|