Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from huggingface_hub import login | |
import PyPDF2 | |
import pandas as pd | |
import torch | |
import os | |
import time | |
import re | |
# Set page configuration | |
st.set_page_config( | |
page_title="WizNerd Insp", | |
page_icon="π", | |
layout="centered" | |
) | |
# Load Hugging Face token from environment variable | |
HF_TOKEN = os.getenv("HF_TOKEN") # Set this in your Space's secrets | |
# Model name | |
MODEL_NAME = "amiguel/instruct_BERT-base-uncased_model" | |
# Label mapping (same as in Colab) | |
LABEL_TO_CLASS = { | |
0: "Campaign", 1: "Corrosion Monitoring", 2: "Flare Tip", 3: "Flare TIP", | |
4: "FU Items", 5: "Intelligent Pigging", 6: "Lifting", 7: "Non Structural Tank", | |
8: "Piping", 9: "Pressure Safety Device", 10: "Pressure Vessel (VIE)", | |
11: "Pressure Vessel (VII)", 12: "Structure", 13: "Flame Arrestor" | |
} | |
# Title with rocket emojis | |
st.title("π WizNerd Insp π") | |
# Configure Avatars | |
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png" | |
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg" | |
# Sidebar configuration | |
with st.sidebar: | |
st.header("Upload Documents π") | |
uploaded_files = st.file_uploader( | |
"Choose PDF, XLSX, or CSV files", | |
type=["pdf", "xlsx", "csv"], | |
accept_multiple_files=True, # Allow multiple file uploads | |
label_visibility="collapsed" | |
) | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# File processing function | |
def process_files(uploaded_files): | |
if not uploaded_files: | |
return [] | |
scopes = [] | |
try: | |
for uploaded_file in uploaded_files: | |
if uploaded_file.type == "application/pdf": | |
pdf_reader = PyPDF2.PdfReader(uploaded_file) | |
text = "\n".join([page.extract_text() for page in pdf_reader.pages]) | |
# Split text into potential scope lines (e.g., by newlines or sentences) | |
lines = [line.strip() for line in text.split("\n") if line.strip()] | |
# Filter lines that look like scope instructions (e.g., contain keywords like "at location", "DAL/") | |
scope_lines = [line for line in lines if re.search(r"(at location|DAL/|PSV-|CD-|DA-)", line, re.IGNORECASE)] | |
scopes.extend(scope_lines) | |
elif uploaded_file.type in ["application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", "text/csv"]: | |
if uploaded_file.type == "text/csv": | |
df = pd.read_csv(uploaded_file) | |
else: | |
df = pd.read_excel(uploaded_file) | |
# Assume the first column contains scope instructions | |
if not df.empty: | |
scope_column = df.columns[0] # First column | |
scope_lines = df[scope_column].dropna().astype(str).tolist() | |
scopes.extend([line.strip() for line in scope_lines if line.strip()]) | |
except Exception as e: | |
st.error(f"π Error processing file: {str(e)}") | |
return [] | |
return scopes | |
# Model loading function | |
def load_model(hf_token): | |
try: | |
if not hf_token: | |
st.error("π Authentication required! Please set the HF_TOKEN environment variable.") | |
return None | |
login(token=hf_token) | |
# Load tokenizer and model for classification | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token) | |
model = AutoModelForSequenceClassification.from_pretrained( | |
MODEL_NAME, | |
num_labels=len(LABEL_TO_CLASS), | |
token=hf_token | |
) | |
# Determine device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"π€ Model loading failed: {str(e)}") | |
return None | |
# Classification function with streaming simulation | |
def classify_instruction(prompt, file_context, model, tokenizer): | |
full_prompt = f"Context:\n{file_context}\n\nInstruction: {prompt}" | |
model.eval() | |
device = model.device | |
inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
prediction = outputs.logits.argmax().item() | |
class_name = LABEL_TO_CLASS[prediction] | |
return class_name | |
def stream_classification_output(class_name, delay=0.05): | |
"""Simulate streaming by displaying the class name character by character.""" | |
response_container = st.empty() | |
full_response = "" | |
for char in class_name: | |
full_response += char | |
response_container.markdown(f"Predicted class: {full_response} β") | |
time.sleep(delay) | |
response_container.markdown(f"Predicted class: {full_response}") | |
return full_response | |
# Display chat messages | |
for message in st.session_state.messages: | |
try: | |
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR | |
with st.chat_message(message["role"], avatar=avatar): | |
st.markdown(message["content"]) | |
except: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Chat input handling | |
if prompt := st.chat_input("Ask your inspection question..."): | |
# Load model if not already loaded | |
if "model" not in st.session_state: | |
model_data = load_model(HF_TOKEN) | |
if model_data is None: | |
st.error("Failed to load model. Please ensure HF_TOKEN is set correctly.") | |
st.stop() | |
st.session_state.model, st.session_state.tokenizer = model_data | |
model = st.session_state.model | |
tokenizer = st.session_state.tokenizer | |
# Add user message | |
with st.chat_message("user", avatar=USER_AVATAR): | |
st.markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Process file context (if any) | |
file_scopes = process_files(uploaded_files) | |
file_context = "\n".join(file_scopes) if file_scopes else "" | |
# Classify the user prompt | |
if model and tokenizer: | |
try: | |
with st.chat_message("assistant", avatar=BOT_AVATAR): | |
# Classify the user-entered prompt | |
predicted_class = classify_instruction(prompt, file_context, model, tokenizer) | |
# Stream the classification output | |
streamed_response = stream_classification_output(predicted_class) | |
response = f"Predicted class: {predicted_class}" | |
# If there are scopes from files, classify them too | |
if file_scopes: | |
st.markdown("### Classifications from Uploaded Files") | |
results = [] | |
for scope in file_scopes: | |
predicted_class = classify_instruction(scope, file_context, model, tokenizer) | |
results.append({"Scope": scope, "Predicted Class": predicted_class}) | |
# Display results in a table | |
df_results = pd.DataFrame(results) | |
st.table(df_results) | |
# Add table to chat history | |
response += "\n\n### Classifications from Uploaded Files\n" + df_results.to_markdown(index=False) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
except Exception as e: | |
st.error(f"β‘ Classification error: {str(e)}") | |
else: | |
st.error("π€ Model not loaded!") |