Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from transformers import (
|
3 |
+
AutoTokenizer,
|
4 |
+
AutoModelForCausalLM,
|
5 |
+
TextIteratorStreamer,
|
6 |
+
AutoConfig
|
7 |
+
)
|
8 |
+
from huggingface_hub import login
|
9 |
+
from threading import Thread
|
10 |
+
import PyPDF2
|
11 |
+
import pandas as pd
|
12 |
+
import torch
|
13 |
+
import time
|
14 |
+
import os
|
15 |
+
|
16 |
+
# Check if 'peft' is installed
|
17 |
+
try:
|
18 |
+
from peft import PeftModel, PeftConfig
|
19 |
+
except ImportError:
|
20 |
+
raise ImportError(
|
21 |
+
"The 'peft' library is required but not installed. "
|
22 |
+
"Please install it using: `pip install peft`"
|
23 |
+
)
|
24 |
+
|
25 |
+
# π Hugging Face Token via Environment Variable
|
26 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
27 |
+
if not HF_TOKEN:
|
28 |
+
raise ValueError("Missing Hugging Face Token. Please set the HF_TOKEN environment variable.")
|
29 |
+
|
30 |
+
# π Model base and adapters
|
31 |
+
BASE_MODEL_NAME = "neuralmind/bert-base-portuguese-cased"
|
32 |
+
MODEL_OPTIONS = {
|
33 |
+
"Full Fine-Tuned": "amiguel/mistral-angolan-laborlaw-bert-base-pt",
|
34 |
+
"LoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-lora",
|
35 |
+
"QLoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-qlora"
|
36 |
+
}
|
37 |
+
|
38 |
+
# πΌ UI Setup
|
39 |
+
st.set_page_config(page_title="Assistente LGT | Angola", page_icon="π", layout="centered")
|
40 |
+
st.title("π Assistente LGT | Angola π")
|
41 |
+
|
42 |
+
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
|
43 |
+
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
|
44 |
+
|
45 |
+
# Sidebar
|
46 |
+
with st.sidebar:
|
47 |
+
st.header("Model Selection π€")
|
48 |
+
model_type = st.selectbox("Choose Model Type", list(MODEL_OPTIONS.keys()), index=0)
|
49 |
+
selected_model = MODEL_OPTIONS[model_type]
|
50 |
+
|
51 |
+
st.header("Upload Documents π")
|
52 |
+
uploaded_file = st.file_uploader("Choose a PDF or XLSX file", type=["pdf", "xlsx"], label_visibility="collapsed")
|
53 |
+
|
54 |
+
# Chat memory
|
55 |
+
if "messages" not in st.session_state:
|
56 |
+
st.session_state.messages = []
|
57 |
+
|
58 |
+
# π File processing
|
59 |
+
@st.cache_data
|
60 |
+
def process_file(uploaded_file):
|
61 |
+
if uploaded_file is None:
|
62 |
+
return ""
|
63 |
+
try:
|
64 |
+
if uploaded_file.type == "application/pdf":
|
65 |
+
reader = PyPDF2.PdfReader(uploaded_file)
|
66 |
+
return "\n".join(page.extract_text() or "" for page in reader.pages)
|
67 |
+
elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
|
68 |
+
df = pd.read_excel(uploaded_file)
|
69 |
+
return df.to_markdown()
|
70 |
+
except Exception as e:
|
71 |
+
st.error(f"π Error processing file: {str(e)}")
|
72 |
+
return ""
|
73 |
+
|
74 |
+
# π§ Load model and tokenizer
|
75 |
+
@st.cache_resource
|
76 |
+
def load_model(model_type, selected_model):
|
77 |
+
try:
|
78 |
+
login(token=HF_TOKEN)
|
79 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
80 |
+
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
81 |
+
|
82 |
+
tokenizer = AutoTokenizer.from_pretrained(selected_model, token=HF_TOKEN)
|
83 |
+
|
84 |
+
if model_type == "Full Fine-Tuned":
|
85 |
+
model = AutoModelForCausalLM.from_pretrained(
|
86 |
+
selected_model,
|
87 |
+
device_map="auto",
|
88 |
+
torch_dtype=dtype,
|
89 |
+
token=HF_TOKEN
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
93 |
+
BASE_MODEL_NAME,
|
94 |
+
device_map="auto",
|
95 |
+
torch_dtype=dtype,
|
96 |
+
token=HF_TOKEN
|
97 |
+
)
|
98 |
+
model = PeftModel.from_pretrained(
|
99 |
+
base_model,
|
100 |
+
selected_model,
|
101 |
+
is_trainable=False,
|
102 |
+
torch_dtype=dtype,
|
103 |
+
token=HF_TOKEN
|
104 |
+
)
|
105 |
+
return model, tokenizer
|
106 |
+
except Exception as e:
|
107 |
+
st.error(f"π€ Model loading failed: {str(e)}")
|
108 |
+
return None, None
|
109 |
+
|
110 |
+
# π Generate response
|
111 |
+
def generate_with_streaming(prompt, file_context, model, tokenizer):
|
112 |
+
full_prompt = f"Analisa este contexto:\n{file_context}\n\nPergunta: {prompt}\nResposta:"
|
113 |
+
|
114 |
+
inputs = tokenizer(full_prompt, return_tensors="pt")
|
115 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
116 |
+
|
117 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
118 |
+
gen_kwargs = {
|
119 |
+
"input_ids": inputs["input_ids"],
|
120 |
+
"attention_mask": inputs["attention_mask"],
|
121 |
+
"max_new_tokens": 1024,
|
122 |
+
"temperature": 0.7,
|
123 |
+
"top_p": 0.9,
|
124 |
+
"repetition_penalty": 1.1,
|
125 |
+
"do_sample": True,
|
126 |
+
"use_cache": True,
|
127 |
+
"streamer": streamer
|
128 |
+
}
|
129 |
+
|
130 |
+
Thread(target=model.generate, kwargs=gen_kwargs).start()
|
131 |
+
return streamer
|
132 |
+
|
133 |
+
# π§Ύ Display chat history
|
134 |
+
for msg in st.session_state.messages:
|
135 |
+
avatar = USER_AVATAR if msg["role"] == "user" else BOT_AVATAR
|
136 |
+
with st.chat_message(msg["role"], avatar=avatar):
|
137 |
+
st.markdown(msg["content"])
|
138 |
+
|
139 |
+
# π Main interaction loop
|
140 |
+
if prompt := st.chat_input("Pergunta sobre a LGT?"):
|
141 |
+
# Display user message
|
142 |
+
with st.chat_message("user", avatar=USER_AVATAR):
|
143 |
+
st.markdown(prompt)
|
144 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
145 |
+
|
146 |
+
# Load model if needed
|
147 |
+
if "model" not in st.session_state or st.session_state.get("model_type") != model_type:
|
148 |
+
with st.spinner("π A carregar modelo..."):
|
149 |
+
model, tokenizer = load_model(model_type, selected_model)
|
150 |
+
if not model:
|
151 |
+
st.stop()
|
152 |
+
st.session_state.model = model
|
153 |
+
st.session_state.tokenizer = tokenizer
|
154 |
+
st.session_state.model_type = model_type
|
155 |
+
else:
|
156 |
+
model = st.session_state.model
|
157 |
+
tokenizer = st.session_state.tokenizer
|
158 |
+
|
159 |
+
# Prepare context
|
160 |
+
file_context = process_file(uploaded_file) or "Sem contexto adicional disponΓvel."
|
161 |
+
|
162 |
+
# Generate assistant response
|
163 |
+
with st.chat_message("assistant", avatar=BOT_AVATAR):
|
164 |
+
response_box = st.empty()
|
165 |
+
full_response = ""
|
166 |
+
try:
|
167 |
+
start_time = time.time()
|
168 |
+
streamer = generate_with_streaming(prompt, file_context, model, tokenizer)
|
169 |
+
|
170 |
+
for chunk in streamer:
|
171 |
+
full_response += chunk.strip() + " "
|
172 |
+
response_box.markdown(full_response + "β", unsafe_allow_html=True)
|
173 |
+
|
174 |
+
# Token and speed metrics
|
175 |
+
end_time = time.time()
|
176 |
+
input_tokens = len(tokenizer(prompt)["input_ids"])
|
177 |
+
output_tokens = len(tokenizer(full_response)["input_ids"])
|
178 |
+
speed = output_tokens / (end_time - start_time)
|
179 |
+
cost_usd = ((input_tokens / 1e6) * 5) + ((output_tokens / 1e6) * 15)
|
180 |
+
cost_aoa = cost_usd * 1160
|
181 |
+
|
182 |
+
st.caption(
|
183 |
+
f"π Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
|
184 |
+
f"π Speed: {speed:.1f}t/s | π° USD: ${cost_usd:.4f} | π¦π΄ AOA: {cost_aoa:.2f}"
|
185 |
+
)
|
186 |
+
|
187 |
+
response_box.markdown(full_response.strip())
|
188 |
+
st.session_state.messages.append({"role": "assistant", "content": full_response.strip()})
|
189 |
+
|
190 |
+
except Exception as e:
|
191 |
+
st.error(f"β‘ Erro ao gerar resposta: {str(e)}")
|