Update app.py
Browse files
app.py
CHANGED
@@ -16,38 +16,35 @@ except ImportError:
|
|
16 |
"Please install it using: `pip install peft`"
|
17 |
)
|
18 |
|
|
|
|
|
|
|
19 |
# Set page configuration
|
20 |
st.set_page_config(
|
21 |
-
page_title="
|
22 |
page_icon="π",
|
23 |
layout="centered"
|
24 |
)
|
25 |
|
26 |
-
# Model
|
27 |
-
BASE_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
|
28 |
MODEL_OPTIONS = {
|
29 |
-
"Full Fine-Tuned": "amiguel/mistral-angolan-laborlaw",
|
30 |
"LoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-lora",
|
31 |
-
"QLoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-qlora"
|
32 |
}
|
33 |
|
34 |
-
# Title with rocket emojis
|
35 |
st.title("π WizNerd Insp π")
|
36 |
|
37 |
-
# Configure Avatars
|
38 |
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
|
39 |
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
|
40 |
|
41 |
-
# Sidebar
|
42 |
with st.sidebar:
|
43 |
-
st.header("Authentication π")
|
44 |
-
hf_token = st.text_input("Hugging Face Token", type="password",
|
45 |
-
help="Get your token from https://huggingface.co/settings/tokens")
|
46 |
-
|
47 |
st.header("Model Selection π€")
|
48 |
model_type = st.selectbox("Choose Model Type", list(MODEL_OPTIONS.keys()), index=0)
|
49 |
selected_model = MODEL_OPTIONS[model_type]
|
50 |
-
|
51 |
st.header("Upload Documents π")
|
52 |
uploaded_file = st.file_uploader(
|
53 |
"Choose a PDF or XLSX file",
|
@@ -55,11 +52,11 @@ with st.sidebar:
|
|
55 |
label_visibility="collapsed"
|
56 |
)
|
57 |
|
58 |
-
#
|
59 |
if "messages" not in st.session_state:
|
60 |
st.session_state.messages = []
|
61 |
|
62 |
-
# File
|
63 |
@st.cache_data
|
64 |
def process_file(uploaded_file):
|
65 |
if uploaded_file is None:
|
@@ -76,51 +73,42 @@ def process_file(uploaded_file):
|
|
76 |
st.error(f"π Error processing file: {str(e)}")
|
77 |
return ""
|
78 |
|
79 |
-
# Model
|
80 |
@st.cache_resource
|
81 |
-
def load_model(
|
82 |
try:
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
login(token=hf_token)
|
88 |
-
|
89 |
-
# Load tokenizer
|
90 |
-
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, token=hf_token)
|
91 |
-
|
92 |
-
# Load model based on type
|
93 |
if model_type == "Full Fine-Tuned":
|
94 |
-
# Load full fine-tuned model directly
|
95 |
model = AutoModelForCausalLM.from_pretrained(
|
96 |
selected_model,
|
97 |
torch_dtype=torch.bfloat16,
|
98 |
device_map="auto",
|
99 |
-
token=
|
100 |
)
|
101 |
else:
|
102 |
-
# Load base model and apply PEFT adapter
|
103 |
base_model = AutoModelForCausalLM.from_pretrained(
|
104 |
BASE_MODEL_NAME,
|
105 |
torch_dtype=torch.bfloat16,
|
106 |
device_map="auto",
|
107 |
-
token=
|
108 |
)
|
109 |
model = PeftModel.from_pretrained(
|
110 |
base_model,
|
111 |
selected_model,
|
112 |
torch_dtype=torch.bfloat16,
|
113 |
-
is_trainable=False,
|
114 |
-
token=
|
115 |
)
|
116 |
-
|
117 |
return model, tokenizer
|
118 |
-
|
119 |
except Exception as e:
|
120 |
st.error(f"π€ Model loading failed: {str(e)}")
|
121 |
return None
|
122 |
|
123 |
-
# Generation function
|
124 |
def generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True):
|
125 |
full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:"
|
126 |
|
@@ -147,81 +135,68 @@ def generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=Tru
|
|
147 |
Thread(target=model.generate, kwargs=generation_kwargs).start()
|
148 |
return streamer
|
149 |
|
150 |
-
# Display chat
|
151 |
for message in st.session_state.messages:
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
with st.chat_message(message["role"]):
|
158 |
-
st.markdown(message["content"])
|
159 |
-
|
160 |
-
# Chat input handling
|
161 |
if prompt := st.chat_input("Ask your inspection question..."):
|
162 |
-
if not hf_token:
|
163 |
-
st.error("π Authentication required!")
|
164 |
-
st.stop()
|
165 |
|
166 |
-
# Load model if
|
167 |
if "model" not in st.session_state or st.session_state.get("model_type") != model_type:
|
168 |
-
model_data = load_model(
|
169 |
if model_data is None:
|
170 |
-
st.error("Failed to load model.
|
171 |
st.stop()
|
172 |
-
|
173 |
st.session_state.model, st.session_state.tokenizer = model_data
|
174 |
st.session_state.model_type = model_type
|
175 |
-
|
176 |
model = st.session_state.model
|
177 |
tokenizer = st.session_state.tokenizer
|
178 |
-
|
179 |
-
# Add user message
|
180 |
with st.chat_message("user", avatar=USER_AVATAR):
|
181 |
st.markdown(prompt)
|
182 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
183 |
|
184 |
-
# Process file
|
185 |
file_context = process_file(uploaded_file)
|
186 |
-
|
187 |
-
# Generate response with KV caching
|
188 |
if model and tokenizer:
|
189 |
try:
|
190 |
with st.chat_message("assistant", avatar=BOT_AVATAR):
|
191 |
start_time = time.time()
|
192 |
streamer = generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True)
|
193 |
-
|
194 |
response_container = st.empty()
|
195 |
full_response = ""
|
196 |
-
|
197 |
for chunk in streamer:
|
198 |
cleaned_chunk = chunk.replace("<think>", "").replace("</think>", "").strip()
|
199 |
full_response += cleaned_chunk + " "
|
200 |
response_container.markdown(full_response + "β", unsafe_allow_html=True)
|
201 |
-
|
202 |
-
# Calculate performance metrics
|
203 |
end_time = time.time()
|
204 |
input_tokens = len(tokenizer(prompt)["input_ids"])
|
205 |
output_tokens = len(tokenizer(full_response)["input_ids"])
|
206 |
speed = output_tokens / (end_time - start_time)
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
output_cost = (output_tokens / 1000000) * 15 # $15 per million output tokens
|
211 |
total_cost_usd = input_cost + output_cost
|
212 |
-
total_cost_aoa = total_cost_usd * 1160
|
213 |
-
|
214 |
-
# Display metrics
|
215 |
st.caption(
|
216 |
f"π Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
|
217 |
f"π Speed: {speed:.1f}t/s | π° Cost (USD): ${total_cost_usd:.4f} | "
|
218 |
f"π΅ Cost (AOA): {total_cost_aoa:.4f}"
|
219 |
)
|
220 |
-
|
221 |
response_container.markdown(full_response)
|
222 |
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
223 |
-
|
224 |
except Exception as e:
|
225 |
st.error(f"β‘ Generation error: {str(e)}")
|
226 |
else:
|
227 |
-
st.error("π€ Model not loaded!")
|
|
|
16 |
"Please install it using: `pip install peft`"
|
17 |
)
|
18 |
|
19 |
+
# π Hardcoded Hugging Face Token
|
20 |
+
HF_TOKEN = HF_TOKEN # Replace with your actual token
|
21 |
+
|
22 |
# Set page configuration
|
23 |
st.set_page_config(
|
24 |
+
page_title="Assistente LGT | Angola",
|
25 |
page_icon="π",
|
26 |
layout="centered"
|
27 |
)
|
28 |
|
29 |
+
# Model base and options
|
30 |
+
BASE_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
|
31 |
MODEL_OPTIONS = {
|
32 |
+
"Full Fine-Tuned": "amiguel/mistral-angolan-laborlaw",
|
33 |
"LoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-lora",
|
34 |
+
"QLoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-qlora"
|
35 |
}
|
36 |
|
|
|
37 |
st.title("π WizNerd Insp π")
|
38 |
|
|
|
39 |
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
|
40 |
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
|
41 |
|
42 |
+
# Sidebar
|
43 |
with st.sidebar:
|
|
|
|
|
|
|
|
|
44 |
st.header("Model Selection π€")
|
45 |
model_type = st.selectbox("Choose Model Type", list(MODEL_OPTIONS.keys()), index=0)
|
46 |
selected_model = MODEL_OPTIONS[model_type]
|
47 |
+
|
48 |
st.header("Upload Documents π")
|
49 |
uploaded_file = st.file_uploader(
|
50 |
"Choose a PDF or XLSX file",
|
|
|
52 |
label_visibility="collapsed"
|
53 |
)
|
54 |
|
55 |
+
# Session state
|
56 |
if "messages" not in st.session_state:
|
57 |
st.session_state.messages = []
|
58 |
|
59 |
+
# File processor
|
60 |
@st.cache_data
|
61 |
def process_file(uploaded_file):
|
62 |
if uploaded_file is None:
|
|
|
73 |
st.error(f"π Error processing file: {str(e)}")
|
74 |
return ""
|
75 |
|
76 |
+
# Model loader
|
77 |
@st.cache_resource
|
78 |
+
def load_model(model_type, selected_model):
|
79 |
try:
|
80 |
+
login(token=HF_TOKEN)
|
81 |
+
|
82 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, token=HF_TOKEN)
|
83 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
if model_type == "Full Fine-Tuned":
|
|
|
85 |
model = AutoModelForCausalLM.from_pretrained(
|
86 |
selected_model,
|
87 |
torch_dtype=torch.bfloat16,
|
88 |
device_map="auto",
|
89 |
+
token=HF_TOKEN
|
90 |
)
|
91 |
else:
|
|
|
92 |
base_model = AutoModelForCausalLM.from_pretrained(
|
93 |
BASE_MODEL_NAME,
|
94 |
torch_dtype=torch.bfloat16,
|
95 |
device_map="auto",
|
96 |
+
token=HF_TOKEN
|
97 |
)
|
98 |
model = PeftModel.from_pretrained(
|
99 |
base_model,
|
100 |
selected_model,
|
101 |
torch_dtype=torch.bfloat16,
|
102 |
+
is_trainable=False,
|
103 |
+
token=HF_TOKEN
|
104 |
)
|
|
|
105 |
return model, tokenizer
|
106 |
+
|
107 |
except Exception as e:
|
108 |
st.error(f"π€ Model loading failed: {str(e)}")
|
109 |
return None
|
110 |
|
111 |
+
# Generation function
|
112 |
def generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True):
|
113 |
full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:"
|
114 |
|
|
|
135 |
Thread(target=model.generate, kwargs=generation_kwargs).start()
|
136 |
return streamer
|
137 |
|
138 |
+
# Display chat history
|
139 |
for message in st.session_state.messages:
|
140 |
+
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
|
141 |
+
with st.chat_message(message["role"], avatar=avatar):
|
142 |
+
st.markdown(message["content"])
|
143 |
+
|
144 |
+
# Prompt interaction
|
|
|
|
|
|
|
|
|
145 |
if prompt := st.chat_input("Ask your inspection question..."):
|
|
|
|
|
|
|
146 |
|
147 |
+
# Load model if necessary
|
148 |
if "model" not in st.session_state or st.session_state.get("model_type") != model_type:
|
149 |
+
model_data = load_model(model_type, selected_model)
|
150 |
if model_data is None:
|
151 |
+
st.error("Failed to load model.")
|
152 |
st.stop()
|
153 |
+
|
154 |
st.session_state.model, st.session_state.tokenizer = model_data
|
155 |
st.session_state.model_type = model_type
|
156 |
+
|
157 |
model = st.session_state.model
|
158 |
tokenizer = st.session_state.tokenizer
|
159 |
+
|
|
|
160 |
with st.chat_message("user", avatar=USER_AVATAR):
|
161 |
st.markdown(prompt)
|
162 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
163 |
|
|
|
164 |
file_context = process_file(uploaded_file)
|
165 |
+
|
|
|
166 |
if model and tokenizer:
|
167 |
try:
|
168 |
with st.chat_message("assistant", avatar=BOT_AVATAR):
|
169 |
start_time = time.time()
|
170 |
streamer = generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True)
|
171 |
+
|
172 |
response_container = st.empty()
|
173 |
full_response = ""
|
174 |
+
|
175 |
for chunk in streamer:
|
176 |
cleaned_chunk = chunk.replace("<think>", "").replace("</think>", "").strip()
|
177 |
full_response += cleaned_chunk + " "
|
178 |
response_container.markdown(full_response + "β", unsafe_allow_html=True)
|
179 |
+
|
|
|
180 |
end_time = time.time()
|
181 |
input_tokens = len(tokenizer(prompt)["input_ids"])
|
182 |
output_tokens = len(tokenizer(full_response)["input_ids"])
|
183 |
speed = output_tokens / (end_time - start_time)
|
184 |
+
|
185 |
+
input_cost = (input_tokens / 1_000_000) * 5
|
186 |
+
output_cost = (output_tokens / 1_000_000) * 15
|
|
|
187 |
total_cost_usd = input_cost + output_cost
|
188 |
+
total_cost_aoa = total_cost_usd * 1160
|
189 |
+
|
|
|
190 |
st.caption(
|
191 |
f"π Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
|
192 |
f"π Speed: {speed:.1f}t/s | π° Cost (USD): ${total_cost_usd:.4f} | "
|
193 |
f"π΅ Cost (AOA): {total_cost_aoa:.4f}"
|
194 |
)
|
195 |
+
|
196 |
response_container.markdown(full_response)
|
197 |
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
198 |
+
|
199 |
except Exception as e:
|
200 |
st.error(f"β‘ Generation error: {str(e)}")
|
201 |
else:
|
202 |
+
st.error("π€ Model not loaded!")
|