amine_dubs commited on
Commit
c38e2fa
·
1 Parent(s): 7dfe957

Implement transformers library with T5 model and custom Arabic prompt

Browse files
Files changed (2) hide show
  1. backend/main.py +85 -62
  2. backend/requirements.txt +3 -0
backend/main.py CHANGED
@@ -9,6 +9,10 @@ import json
9
  import traceback
10
  import io
11
 
 
 
 
 
12
  # --- Configuration ---
13
  # Determine the base directory of the main.py script
14
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -37,18 +41,58 @@ LANGUAGE_MAP = {
37
  "it": "Italian"
38
  }
39
 
40
- # --- Free translation APIs ---
41
- LIBRE_TRANSLATE_ENDPOINTS = [
42
- "https://translate.terraprint.co/translate",
43
- "https://libretranslate.de/translate",
44
- "https://translate.argosopentech.com/translate"
45
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # --- Translation Function ---
48
  def translate_text_internal(text: str, source_lang: str, target_lang: str = "ar") -> str:
49
  """
50
- Translate text using Hugging Face Inference API and LibreTranslate as backup
51
  """
 
 
52
  if not text.strip():
53
  return ""
54
 
@@ -57,8 +101,15 @@ def translate_text_internal(text: str, source_lang: str, target_lang: str = "ar"
57
  # Get full language name for prompt
58
  source_lang_name = LANGUAGE_MAP.get(source_lang, source_lang)
59
 
60
- # Construct our eloquent Arabic translation prompt
61
- prompt = f"""Translate the following {source_lang_name} text into Modern Standard Arabic (Fusha).
 
 
 
 
 
 
 
62
  Focus on conveying the meaning elegantly using proper Balagha (Arabic eloquence).
63
  Adapt any cultural references or idioms appropriately rather than translating literally.
64
  Ensure the translation reads naturally to a native Arabic speaker.
@@ -66,62 +117,34 @@ Ensure the translation reads naturally to a native Arabic speaker.
66
  Text to translate:
67
  {text}"""
68
 
69
- # Try Hugging Face Inference API with models that are reliably available on the free tier
70
- hf_models = [
71
- "facebook/m2m100_418M", # Very reliable multilingual model
72
- "Helsinki-NLP/opus-mt-tc-big-en-ar" # Good for English to Arabic
73
- ]
74
-
75
- for model in hf_models:
76
- try:
77
- print(f"Attempting translation via Hugging Face Inference API: {model}")
78
- api_url = f"https://api-inference.huggingface.co/models/{model}"
79
-
80
- # Different payloads based on model architecture
81
- if "m2m" in model:
82
- payload = {
83
- "inputs": text,
84
- "parameters": {
85
- "src_lang": source_lang.upper() if source_lang != "zh" else "ZH",
86
- "tgt_lang": "AR"
87
- }
88
- }
89
- elif "opus-mt" in model:
90
- payload = {"inputs": text}
91
- else:
92
- payload = {"inputs": prompt}
93
-
94
- # No auth header for public models on free tier
95
- response = requests.post(api_url, json=payload, timeout=30)
96
-
97
- if response.status_code == 200:
98
- result = response.json()
99
- translated_text = None
100
-
101
- # Extract text from various response formats
102
- if isinstance(result, list) and len(result) > 0:
103
- if isinstance(result[0], dict):
104
- translated_text = result[0].get("translation_text") or result[0].get("generated_text")
105
- else:
106
- translated_text = str(result[0])
107
- elif isinstance(result, dict):
108
- translated_text = result.get("translation_text") or result.get("generated_text")
109
-
110
- if translated_text:
111
- print(f"Translation successful using {model}")
112
- return culturally_adapt_arabic(translated_text)
113
-
114
- print(f"Unexpected response format: {response.text}")
115
- else:
116
- print(f"API error: {response.status_code}")
117
 
118
- except Exception as e:
119
- print(f"Error with Hugging Face model {model}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- # If Hugging Face fails, try LibreTranslate
122
- for endpoint in LIBRE_TRANSLATE_ENDPOINTS:
123
  try:
124
- print(f"Attempting translation using LibreTranslate: {endpoint}")
125
  payload = {
126
  "q": text,
127
  "source": source_lang if source_lang != "auto" else "auto",
 
9
  import traceback
10
  import io
11
 
12
+ # Import transformers for local model inference
13
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
14
+ import torch
15
+
16
  # --- Configuration ---
17
  # Determine the base directory of the main.py script
18
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
 
41
  "it": "Italian"
42
  }
43
 
44
+ # --- Set cache directory to a writeable location ---
45
+ # This is crucial for Hugging Face Spaces where /app/.cache is not writable
46
+ # Using /tmp which is typically writable in most environments
47
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
48
+ os.environ['HF_HOME'] = '/tmp/hf_home'
49
+ os.environ['XDG_CACHE_HOME'] = '/tmp/cache'
50
+
51
+ # --- Global model and tokenizer variables ---
52
+ translator = None
53
+ tokenizer = None
54
+
55
+ # --- Model initialization function ---
56
+ def initialize_model():
57
+ """Initialize the translation model and tokenizer."""
58
+ global translator, tokenizer
59
+
60
+ try:
61
+ print("Initializing model and tokenizer...")
62
+
63
+ # Use a smaller model that works well for instruction-based translation
64
+ model_name = "google/flan-t5-small"
65
+
66
+ # Load the model and tokenizer with explicit cache directory
67
+ tokenizer = AutoTokenizer.from_pretrained(
68
+ model_name,
69
+ cache_dir="/tmp/transformers_cache"
70
+ )
71
+
72
+ # Create a pipeline for text2text generation
73
+ translator = pipeline(
74
+ "text2text-generation",
75
+ model=model_name,
76
+ tokenizer=tokenizer,
77
+ device=-1, # Use CPU for compatibility (-1) or GPU if available (0)
78
+ cache_dir="/tmp/transformers_cache",
79
+ max_length=512
80
+ )
81
+
82
+ print(f"Model {model_name} successfully initialized")
83
+ return True
84
+ except Exception as e:
85
+ print(f"Error initializing model: {e}")
86
+ traceback.print_exc()
87
+ return False
88
 
89
  # --- Translation Function ---
90
  def translate_text_internal(text: str, source_lang: str, target_lang: str = "ar") -> str:
91
  """
92
+ Translate text using local T5 model with prompt engineering
93
  """
94
+ global translator
95
+
96
  if not text.strip():
97
  return ""
98
 
 
101
  # Get full language name for prompt
102
  source_lang_name = LANGUAGE_MAP.get(source_lang, source_lang)
103
 
104
+ # Initialize the model if it hasn't been loaded yet
105
+ if translator is None:
106
+ success = initialize_model()
107
+ if not success:
108
+ return fallback_translate(text, source_lang, target_lang)
109
+
110
+ try:
111
+ # Construct our eloquent Arabic translation prompt
112
+ prompt = f"""Translate the following {source_lang_name} text into Modern Standard Arabic (Fusha).
113
  Focus on conveying the meaning elegantly using proper Balagha (Arabic eloquence).
114
  Adapt any cultural references or idioms appropriately rather than translating literally.
115
  Ensure the translation reads naturally to a native Arabic speaker.
 
117
  Text to translate:
118
  {text}"""
119
 
120
+ # Generate translation using the model
121
+ outputs = translator(prompt, max_length=512, do_sample=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ if outputs and len(outputs) > 0:
124
+ translated_text = outputs[0]['generated_text']
125
+ print(f"Translation successful using transformers model")
126
+ return culturally_adapt_arabic(translated_text)
127
+ else:
128
+ print("Model returned empty output")
129
+ return fallback_translate(text, source_lang, target_lang)
130
+
131
+ except Exception as e:
132
+ print(f"Error in model translation: {e}")
133
+ traceback.print_exc()
134
+ return fallback_translate(text, source_lang, target_lang)
135
+
136
+ def fallback_translate(text: str, source_lang: str, target_lang: str = "ar") -> str:
137
+ """Fallback to online translation APIs if local model fails."""
138
+ # Try LibreTranslate
139
+ libre_translate_endpoints = [
140
+ "https://translate.terraprint.co/translate",
141
+ "https://libretranslate.de/translate",
142
+ "https://translate.argosopentech.com/translate"
143
+ ]
144
 
145
+ for endpoint in libre_translate_endpoints:
 
146
  try:
147
+ print(f"Attempting fallback translation using LibreTranslate: {endpoint}")
148
  payload = {
149
  "q": text,
150
  "source": source_lang if source_lang != "auto" else "auto",
backend/requirements.txt CHANGED
@@ -5,3 +5,6 @@ PyMuPDF
5
  requests
6
  python-multipart
7
  jinja2
 
 
 
 
5
  requests
6
  python-multipart
7
  jinja2
8
+ transformers
9
+ torch
10
+ sentencepiece