amiguel commited on
Commit
f149660
·
verified ·
1 Parent(s): 33724eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -53
app.py CHANGED
@@ -7,7 +7,7 @@ import pandas as pd
7
  import torch
8
  import time
9
 
10
- # Check if 'peft' is installed
11
  try:
12
  from peft import PeftModel, PeftConfig
13
  except ImportError:
@@ -23,16 +23,21 @@ st.set_page_config(
23
  layout="centered"
24
  )
25
 
26
- # Model names
27
- BASE_MODEL_NAME = "amiguel/en2fr-transformer"
28
- #MODEL_OPTIONS = {
29
- # "Full Fine-Tuned": "amiguel/instruct_BERT-base-uncased_model", #"amiguel/playbook_FT",#"amiguel/SmolLM2-360M-concise-reasoning",
30
- # "LoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-lora",
31
- # "QLoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-qlora" # Hypothetical, adjust if needed
32
- #}
 
 
 
 
 
33
 
34
  # Title with rocket emojis
35
- st.title("🚀 Translator 🚀")
36
 
37
  # Configure Avatars
38
  USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
@@ -44,13 +49,9 @@ with st.sidebar:
44
  hf_token = st.text_input("Hugging Face Token", type="password",
45
  help="Get your token from https://huggingface.co/settings/tokens")
46
 
47
- st.header("Model Selection 🤖")
48
- model_type = st.selectbox("Choose Model Type", list(MODEL_OPTIONS.keys()), index=0)
49
- selected_model = MODEL_OPTIONS[model_type]
50
-
51
  st.header("Upload Documents 📂")
52
  uploaded_file = st.file_uploader(
53
- "Choose a PDF or XLSX file",
54
  type=["pdf", "xlsx"],
55
  label_visibility="collapsed"
56
  )
@@ -78,7 +79,7 @@ def process_file(uploaded_file):
78
 
79
  # Model loading function
80
  @st.cache_resource
81
- def load_model(hf_token, model_type, selected_model):
82
  try:
83
  if not hf_token:
84
  st.error("🔐 Authentication required! Please provide a Hugging Face token.")
@@ -87,32 +88,15 @@ def load_model(hf_token, model_type, selected_model):
87
  login(token=hf_token)
88
 
89
  # Load tokenizer
90
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, token=hf_token)
91
 
92
- # Load model based on type
93
- if model_type == "Full Fine-Tuned":
94
- # Load full fine-tuned model directly
95
- model = AutoModelForCausalLM.from_pretrained(
96
- selected_model,
97
- torch_dtype=torch.bfloat16,
98
- device_map="auto",
99
- token=hf_token
100
- )
101
- else:
102
- # Load base model and apply PEFT adapter
103
- base_model = AutoModelForCausalLM.from_pretrained(
104
- BASE_MODEL_NAME,
105
- torch_dtype=torch.bfloat16,
106
- device_map="auto",
107
- token=hf_token
108
- )
109
- model = PeftModel.from_pretrained(
110
- base_model,
111
- selected_model,
112
- torch_dtype=torch.bfloat16,
113
- is_trainable=False, # Inference mode
114
- token=hf_token
115
- )
116
 
117
  return model, tokenizer
118
 
@@ -121,8 +105,8 @@ def load_model(hf_token, model_type, selected_model):
121
  return None
122
 
123
  # Generation function with KV caching
124
- def generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True):
125
- full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:"
126
 
127
  streamer = TextIteratorStreamer(
128
  tokenizer,
@@ -158,20 +142,19 @@ for message in st.session_state.messages:
158
  st.markdown(message["content"])
159
 
160
  # Chat input handling
161
- if prompt := st.chat_input("Ask your inspection question..."):
162
  if not hf_token:
163
  st.error("🔑 Authentication required!")
164
  st.stop()
165
 
166
- # Load model if not already loaded or if model type changed
167
- if "model" not in st.session_state or st.session_state.get("model_type") != model_type:
168
- model_data = load_model(hf_token, model_type, selected_model)
169
  if model_data is None:
170
  st.error("Failed to load model. Please check your token and try again.")
171
  st.stop()
172
 
173
  st.session_state.model, st.session_state.tokenizer = model_data
174
- st.session_state.model_type = model_type
175
 
176
  model = st.session_state.model
177
  tokenizer = st.session_state.tokenizer
@@ -181,27 +164,28 @@ if prompt := st.chat_input("Ask your inspection question..."):
181
  st.markdown(prompt)
182
  st.session_state.messages.append({"role": "user", "content": prompt})
183
 
184
- # Process file
185
  file_context = process_file(uploaded_file)
 
186
 
187
- # Generate response with KV caching
188
  if model and tokenizer:
189
  try:
190
  with st.chat_message("assistant", avatar=BOT_AVATAR):
191
  start_time = time.time()
192
- streamer = generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True)
193
 
194
  response_container = st.empty()
195
  full_response = ""
196
 
197
  for chunk in streamer:
198
- cleaned_chunk = chunk.replace("<think>", "").replace("</think>", "").strip()
199
  full_response += cleaned_chunk + " "
200
  response_container.markdown(full_response + "▌", unsafe_allow_html=True)
201
 
202
  # Calculate performance metrics
203
  end_time = time.time()
204
- input_tokens = len(tokenizer(prompt)["input_ids"])
205
  output_tokens = len(tokenizer(full_response)["input_ids"])
206
  speed = output_tokens / (end_time - start_time)
207
 
@@ -222,6 +206,6 @@ if prompt := st.chat_input("Ask your inspection question..."):
222
  st.session_state.messages.append({"role": "assistant", "content": full_response})
223
 
224
  except Exception as e:
225
- st.error(f"⚡ Generation error: {str(e)}")
226
  else:
227
  st.error("🤖 Model not loaded!")
 
7
  import torch
8
  import time
9
 
10
+ # Check if 'peft' is installed (though not used here, kept for potential future use)
11
  try:
12
  from peft import PeftModel, PeftConfig
13
  except ImportError:
 
23
  layout="centered"
24
  )
25
 
26
+ # Model name
27
+ MODEL_NAME = "amiguel/en2fr-transformer"
28
+
29
+ # Translation prompt template
30
+ TRANSLATION_PROMPT = """
31
+ You are a professional translator specializing in English-to-French translation. Translate the following text accurately and naturally into French, preserving the original meaning and tone:
32
+
33
+ **Text to translate:**
34
+ {input_text}
35
+
36
+ **French translation:**
37
+ """
38
 
39
  # Title with rocket emojis
40
+ st.title("🚀 English to French Translator 🚀")
41
 
42
  # Configure Avatars
43
  USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
 
49
  hf_token = st.text_input("Hugging Face Token", type="password",
50
  help="Get your token from https://huggingface.co/settings/tokens")
51
 
 
 
 
 
52
  st.header("Upload Documents 📂")
53
  uploaded_file = st.file_uploader(
54
+ "Choose a PDF or XLSX file to translate",
55
  type=["pdf", "xlsx"],
56
  label_visibility="collapsed"
57
  )
 
79
 
80
  # Model loading function
81
  @st.cache_resource
82
+ def load_model(hf_token):
83
  try:
84
  if not hf_token:
85
  st.error("🔐 Authentication required! Please provide a Hugging Face token.")
 
88
  login(token=hf_token)
89
 
90
  # Load tokenizer
91
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
92
 
93
+ # Load the full model (no adapters since we're using the base transformer)
94
+ model = AutoModelForCausalLM.from_pretrained(
95
+ MODEL_NAME,
96
+ torch_dtype=torch.bfloat16,
97
+ device_map="auto",
98
+ token=hf_token
99
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  return model, tokenizer
102
 
 
105
  return None
106
 
107
  # Generation function with KV caching
108
+ def generate_translation(input_text, model, tokenizer, use_cache=True):
109
+ full_prompt = TRANSLATION_PROMPT.format(input_text=input_text)
110
 
111
  streamer = TextIteratorStreamer(
112
  tokenizer,
 
142
  st.markdown(message["content"])
143
 
144
  # Chat input handling
145
+ if prompt := st.chat_input("Enter text to translate into French..."):
146
  if not hf_token:
147
  st.error("🔑 Authentication required!")
148
  st.stop()
149
 
150
+ # Load model if not already loaded
151
+ if "model" not in st.session_state:
152
+ model_data = load_model(hf_token)
153
  if model_data is None:
154
  st.error("Failed to load model. Please check your token and try again.")
155
  st.stop()
156
 
157
  st.session_state.model, st.session_state.tokenizer = model_data
 
158
 
159
  model = st.session_state.model
160
  tokenizer = st.session_state.tokenizer
 
164
  st.markdown(prompt)
165
  st.session_state.messages.append({"role": "user", "content": prompt})
166
 
167
+ # Process file or use prompt directly
168
  file_context = process_file(uploaded_file)
169
+ input_text = file_context if file_context else prompt
170
 
171
+ # Generate translation
172
  if model and tokenizer:
173
  try:
174
  with st.chat_message("assistant", avatar=BOT_AVATAR):
175
  start_time = time.time()
176
+ streamer = generate_translation(input_text, model, tokenizer, use_cache=True)
177
 
178
  response_container = st.empty()
179
  full_response = ""
180
 
181
  for chunk in streamer:
182
+ cleaned_chunk = chunk.strip()
183
  full_response += cleaned_chunk + " "
184
  response_container.markdown(full_response + "▌", unsafe_allow_html=True)
185
 
186
  # Calculate performance metrics
187
  end_time = time.time()
188
+ input_tokens = len(tokenizer(input_text)["input_ids"])
189
  output_tokens = len(tokenizer(full_response)["input_ids"])
190
  speed = output_tokens / (end_time - start_time)
191
 
 
206
  st.session_state.messages.append({"role": "assistant", "content": full_response})
207
 
208
  except Exception as e:
209
+ st.error(f"⚡ Translation error: {str(e)}")
210
  else:
211
  st.error("🤖 Model not loaded!")