amiguel commited on
Commit
36ef005
·
verified ·
1 Parent(s): d607a2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -0
app.py CHANGED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
+ from huggingface_hub import login
4
+ from threading import Thread
5
+ import PyPDF2
6
+ import pandas as pd
7
+ import torch
8
+ import time
9
+
10
+ # Check if 'peft' is installed
11
+ try:
12
+ from peft import PeftModel, PeftConfig
13
+ except ImportError:
14
+ raise ImportError(
15
+ "The 'peft' library is required but not installed. "
16
+ "Please install it using: `pip install peft`"
17
+ )
18
+
19
+ # Set page configuration
20
+ st.set_page_config(
21
+ page_title="Translator Agent",
22
+ page_icon="🚀",
23
+ layout="centered"
24
+ )
25
+
26
+ # Model names
27
+ BASE_MODEL_NAME = "amiguel/en2fr-transformer"
28
+ MODEL_OPTIONS = {
29
+ "Full Fine-Tuned": "amiguel/instruct_BERT-base-uncased_model", #"amiguel/playbook_FT",#"amiguel/SmolLM2-360M-concise-reasoning",
30
+ "LoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-lora",
31
+ "QLoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-qlora" # Hypothetical, adjust if needed
32
+ }
33
+
34
+ # Title with rocket emojis
35
+ st.title("🚀 WizNerd Insp 🚀")
36
+
37
+ # Configure Avatars
38
+ USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
39
+ BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
40
+
41
+ # Sidebar configuration
42
+ with st.sidebar:
43
+ st.header("Authentication 🔒")
44
+ hf_token = st.text_input("Hugging Face Token", type="password",
45
+ help="Get your token from https://huggingface.co/settings/tokens")
46
+
47
+ st.header("Model Selection 🤖")
48
+ model_type = st.selectbox("Choose Model Type", list(MODEL_OPTIONS.keys()), index=0)
49
+ selected_model = MODEL_OPTIONS[model_type]
50
+
51
+ st.header("Upload Documents 📂")
52
+ uploaded_file = st.file_uploader(
53
+ "Choose a PDF or XLSX file",
54
+ type=["pdf", "xlsx"],
55
+ label_visibility="collapsed"
56
+ )
57
+
58
+ # Initialize chat history
59
+ if "messages" not in st.session_state:
60
+ st.session_state.messages = []
61
+
62
+ # File processing function
63
+ @st.cache_data
64
+ def process_file(uploaded_file):
65
+ if uploaded_file is None:
66
+ return ""
67
+
68
+ try:
69
+ if uploaded_file.type == "application/pdf":
70
+ pdf_reader = PyPDF2.PdfReader(uploaded_file)
71
+ return "\n".join([page.extract_text() for page in pdf_reader.pages])
72
+ elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
73
+ df = pd.read_excel(uploaded_file)
74
+ return df.to_markdown()
75
+ except Exception as e:
76
+ st.error(f"📄 Error processing file: {str(e)}")
77
+ return ""
78
+
79
+ # Model loading function
80
+ @st.cache_resource
81
+ def load_model(hf_token, model_type, selected_model):
82
+ try:
83
+ if not hf_token:
84
+ st.error("🔐 Authentication required! Please provide a Hugging Face token.")
85
+ return None
86
+
87
+ login(token=hf_token)
88
+
89
+ # Load tokenizer
90
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, token=hf_token)
91
+
92
+ # Load model based on type
93
+ if model_type == "Full Fine-Tuned":
94
+ # Load full fine-tuned model directly
95
+ model = AutoModelForCausalLM.from_pretrained(
96
+ selected_model,
97
+ torch_dtype=torch.bfloat16,
98
+ device_map="auto",
99
+ token=hf_token
100
+ )
101
+ else:
102
+ # Load base model and apply PEFT adapter
103
+ base_model = AutoModelForCausalLM.from_pretrained(
104
+ BASE_MODEL_NAME,
105
+ torch_dtype=torch.bfloat16,
106
+ device_map="auto",
107
+ token=hf_token
108
+ )
109
+ model = PeftModel.from_pretrained(
110
+ base_model,
111
+ selected_model,
112
+ torch_dtype=torch.bfloat16,
113
+ is_trainable=False, # Inference mode
114
+ token=hf_token
115
+ )
116
+
117
+ return model, tokenizer
118
+
119
+ except Exception as e:
120
+ st.error(f"🤖 Model loading failed: {str(e)}")
121
+ return None
122
+
123
+ # Generation function with KV caching
124
+ def generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True):
125
+ full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:"
126
+
127
+ streamer = TextIteratorStreamer(
128
+ tokenizer,
129
+ skip_prompt=True,
130
+ skip_special_tokens=True
131
+ )
132
+
133
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
134
+
135
+ generation_kwargs = {
136
+ "input_ids": inputs["input_ids"],
137
+ "attention_mask": inputs["attention_mask"],
138
+ "max_new_tokens": 1024,
139
+ "temperature": 0.7,
140
+ "top_p": 0.9,
141
+ "repetition_penalty": 1.1,
142
+ "do_sample": True,
143
+ "use_cache": use_cache,
144
+ "streamer": streamer
145
+ }
146
+
147
+ Thread(target=model.generate, kwargs=generation_kwargs).start()
148
+ return streamer
149
+
150
+ # Display chat messages
151
+ for message in st.session_state.messages:
152
+ try:
153
+ avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
154
+ with st.chat_message(message["role"], avatar=avatar):
155
+ st.markdown(message["content"])
156
+ except:
157
+ with st.chat_message(message["role"]):
158
+ st.markdown(message["content"])
159
+
160
+ # Chat input handling
161
+ if prompt := st.chat_input("Ask your inspection question..."):
162
+ if not hf_token:
163
+ st.error("🔑 Authentication required!")
164
+ st.stop()
165
+
166
+ # Load model if not already loaded or if model type changed
167
+ if "model" not in st.session_state or st.session_state.get("model_type") != model_type:
168
+ model_data = load_model(hf_token, model_type, selected_model)
169
+ if model_data is None:
170
+ st.error("Failed to load model. Please check your token and try again.")
171
+ st.stop()
172
+
173
+ st.session_state.model, st.session_state.tokenizer = model_data
174
+ st.session_state.model_type = model_type
175
+
176
+ model = st.session_state.model
177
+ tokenizer = st.session_state.tokenizer
178
+
179
+ # Add user message
180
+ with st.chat_message("user", avatar=USER_AVATAR):
181
+ st.markdown(prompt)
182
+ st.session_state.messages.append({"role": "user", "content": prompt})
183
+
184
+ # Process file
185
+ file_context = process_file(uploaded_file)
186
+
187
+ # Generate response with KV caching
188
+ if model and tokenizer:
189
+ try:
190
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
191
+ start_time = time.time()
192
+ streamer = generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True)
193
+
194
+ response_container = st.empty()
195
+ full_response = ""
196
+
197
+ for chunk in streamer:
198
+ cleaned_chunk = chunk.replace("<think>", "").replace("</think>", "").strip()
199
+ full_response += cleaned_chunk + " "
200
+ response_container.markdown(full_response + "▌", unsafe_allow_html=True)
201
+
202
+ # Calculate performance metrics
203
+ end_time = time.time()
204
+ input_tokens = len(tokenizer(prompt)["input_ids"])
205
+ output_tokens = len(tokenizer(full_response)["input_ids"])
206
+ speed = output_tokens / (end_time - start_time)
207
+
208
+ # Calculate costs (hypothetical pricing model)
209
+ input_cost = (input_tokens / 1000000) * 5 # $5 per million input tokens
210
+ output_cost = (output_tokens / 1000000) * 15 # $15 per million output tokens
211
+ total_cost_usd = input_cost + output_cost
212
+ total_cost_aoa = total_cost_usd * 1160 # Convert to AOA (Angolan Kwanza)
213
+
214
+ # Display metrics
215
+ st.caption(
216
+ f"🔑 Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
217
+ f"🕒 Speed: {speed:.1f}t/s | 💰 Cost (USD): ${total_cost_usd:.4f} | "
218
+ f"💵 Cost (AOA): {total_cost_aoa:.4f}"
219
+ )
220
+
221
+ response_container.markdown(full_response)
222
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
223
+
224
+ except Exception as e:
225
+ st.error(f"⚡ Generation error: {str(e)}")
226
+ else:
227
+ st.error("🤖 Model not loaded!")