Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
TextIteratorStreamer, | |
AutoConfig | |
) | |
from huggingface_hub import login | |
from threading import Thread | |
import PyPDF2 | |
import pandas as pd | |
import torch | |
import time | |
import os | |
# Check if 'peft' is installed | |
try: | |
from peft import PeftModel, PeftConfig | |
except ImportError: | |
raise ImportError( | |
"The 'peft' library is required but not installed. " | |
"Please install it using: `pip install peft`" | |
) | |
# π Hugging Face Token via Environment Variable | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
if not HF_TOKEN: | |
raise ValueError("Missing Hugging Face Token. Please set the HF_TOKEN environment variable.") | |
# π Model base and adapters | |
BASE_MODEL_NAME = "neuralmind/bert-base-portuguese-cased" | |
MODEL_OPTIONS = { | |
"Full Fine-Tuned": "amiguel/mistral-angolan-laborlaw-bert-base-pt", | |
"LoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-lora", | |
"QLoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-qlora" | |
} | |
# πΌ UI Setup | |
st.set_page_config(page_title="Assistente LGT | Angola", page_icon="π", layout="centered") | |
st.title("π Assistente LGT | Angola π") | |
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png" | |
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg" | |
# Sidebar | |
with st.sidebar: | |
st.header("Model Selection π€") | |
model_type = st.selectbox("Choose Model Type", list(MODEL_OPTIONS.keys()), index=0) | |
selected_model = MODEL_OPTIONS[model_type] | |
st.header("Upload Documents π") | |
uploaded_file = st.file_uploader("Choose a PDF or XLSX file", type=["pdf", "xlsx"], label_visibility="collapsed") | |
# Chat memory | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# π File processing | |
def process_file(uploaded_file): | |
if uploaded_file is None: | |
return "" | |
try: | |
if uploaded_file.type == "application/pdf": | |
reader = PyPDF2.PdfReader(uploaded_file) | |
return "\n".join(page.extract_text() or "" for page in reader.pages) | |
elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": | |
df = pd.read_excel(uploaded_file) | |
return df.to_markdown() | |
except Exception as e: | |
st.error(f"π Error processing file: {str(e)}") | |
return "" | |
# π§ Load model and tokenizer | |
def load_model(model_type, selected_model): | |
try: | |
login(token=HF_TOKEN) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 | |
tokenizer = AutoTokenizer.from_pretrained(selected_model, token=HF_TOKEN) | |
if model_type == "Full Fine-Tuned": | |
model = AutoModelForCausalLM.from_pretrained( | |
selected_model, | |
device_map="auto", | |
torch_dtype=dtype, | |
token=HF_TOKEN | |
) | |
else: | |
base_model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL_NAME, | |
device_map="auto", | |
torch_dtype=dtype, | |
token=HF_TOKEN | |
) | |
model = PeftModel.from_pretrained( | |
base_model, | |
selected_model, | |
is_trainable=False, | |
torch_dtype=dtype, | |
token=HF_TOKEN | |
) | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"π€ Model loading failed: {str(e)}") | |
return None, None | |
# π Generate response | |
def generate_with_streaming(prompt, file_context, model, tokenizer): | |
full_prompt = f"Analisa este contexto:\n{file_context}\n\nPergunta: {prompt}\nResposta:" | |
inputs = tokenizer(full_prompt, return_tensors="pt") | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
gen_kwargs = { | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_new_tokens": 1024, | |
"temperature": 0.7, | |
"top_p": 0.9, | |
"repetition_penalty": 1.1, | |
"do_sample": True, | |
"use_cache": True, | |
"streamer": streamer | |
} | |
Thread(target=model.generate, kwargs=gen_kwargs).start() | |
return streamer | |
# π§Ύ Display chat history | |
for msg in st.session_state.messages: | |
avatar = USER_AVATAR if msg["role"] == "user" else BOT_AVATAR | |
with st.chat_message(msg["role"], avatar=avatar): | |
st.markdown(msg["content"]) | |
# π Main interaction loop | |
if prompt := st.chat_input("Pergunta sobre a LGT?"): | |
# Display user message | |
with st.chat_message("user", avatar=USER_AVATAR): | |
st.markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Load model if needed | |
if "model" not in st.session_state or st.session_state.get("model_type") != model_type: | |
with st.spinner("π A carregar modelo..."): | |
model, tokenizer = load_model(model_type, selected_model) | |
if not model: | |
st.stop() | |
st.session_state.model = model | |
st.session_state.tokenizer = tokenizer | |
st.session_state.model_type = model_type | |
else: | |
model = st.session_state.model | |
tokenizer = st.session_state.tokenizer | |
# Prepare context | |
file_context = process_file(uploaded_file) or "Sem contexto adicional disponΓvel." | |
# Generate assistant response | |
with st.chat_message("assistant", avatar=BOT_AVATAR): | |
response_box = st.empty() | |
full_response = "" | |
try: | |
start_time = time.time() | |
streamer = generate_with_streaming(prompt, file_context, model, tokenizer) | |
for chunk in streamer: | |
full_response += chunk.strip() + " " | |
response_box.markdown(full_response + "β", unsafe_allow_html=True) | |
# Token and speed metrics | |
end_time = time.time() | |
input_tokens = len(tokenizer(prompt)["input_ids"]) | |
output_tokens = len(tokenizer(full_response)["input_ids"]) | |
speed = output_tokens / (end_time - start_time) | |
cost_usd = ((input_tokens / 1e6) * 5) + ((output_tokens / 1e6) * 15) | |
cost_aoa = cost_usd * 1160 | |
st.caption( | |
f"π Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | " | |
f"π Speed: {speed:.1f}t/s | π° USD: ${cost_usd:.4f} | π¦π΄ AOA: {cost_aoa:.2f}" | |
) | |
response_box.markdown(full_response.strip()) | |
st.session_state.messages.append({"role": "assistant", "content": full_response.strip()}) | |
except Exception as e: | |
st.error(f"β‘ Erro ao gerar resposta: {str(e)}") | |