File size: 5,489 Bytes
3aafe68
81e998f
27d2634
81e998f
330fc4f
b5b8672
330fc4f
0eb710b
3358b89
330fc4f
 
 
 
81e998f
330fc4f
b5b8672
27d2634
81e998f
330fc4f
b5b8672
 
81e998f
330fc4f
27d2634
a941d96
27d2634
 
 
330fc4f
 
 
 
 
b5b8672
330fc4f
 
 
b5b8672
ac19c17
330fc4f
 
ac19c17
 
 
 
 
 
 
 
 
 
 
 
 
330fc4f
ac19c17
330fc4f
27d2634
026c97a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330fc4f
ac19c17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330fc4f
27d2634
0373f3c
27d2634
0373f3c
026c97a
27b07a6
026c97a
 
 
 
 
 
 
27b07a6
 
 
0373f3c
 
 
330fc4f
27b07a6
 
0373f3c
330fc4f
0eb710b
330fc4f
27b07a6
 
0eb710b
 
 
0373f3c
 
 
 
 
 
 
 
0eb710b
 
 
 
0373f3c
 
 
a941d96
 
330fc4f
0373f3c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from huggingface_hub import login
from threading import Thread
import PyPDF2
import pandas as pd
import torch
import time

# Set page configuration
st.set_page_config(
    page_title="WizNerd Insp",
    page_icon="πŸš€",
    layout="centered"
)

MODEL_NAME = "amiguel/optimizedModelListing6.1"

# Title with rocket emojis
st.title("πŸš€ WizNerd Insp πŸš€")

# Sidebar configuration
with st.sidebar:
    st.header("Authentication πŸ”’")
    hf_token = st.text_input("Hugging Face Token", type="password", 
                           help="Get your token from https://huggingface.co/settings/tokens")
    
    st.header("Upload Documents πŸ“‚")
    uploaded_file = st.file_uploader(
        "Choose a PDF or XLSX file",
        type=["pdf", "xlsx"],
        label_visibility="collapsed"
    )

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# File processing function
@st.cache_data
def process_file(uploaded_file):
    if uploaded_file is None:
        return ""
    
    try:
        if uploaded_file.type == "application/pdf":
            pdf_reader = PyPDF2.PdfReader(uploaded_file)
            return "\n".join([page.extract_text() for page in pdf_reader.pages])
        elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
            df = pd.read_excel(uploaded_file)
            return df.to_markdown()
    except Exception as e:
        st.error(f"πŸ“„ Error processing file: {str(e)}")
        return ""

# Model loading function
@st.cache_resource
def load_model(hf_token):
    try:
        if not hf_token:
            st.error("πŸ” Authentication required! Please provide a Hugging Face token.")
            return None
        
        login(token=hf_token)
        
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_NAME,
            token=hf_token
        )
        
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            device_map="auto",
            torch_dtype=torch.float16,
            token=hf_token
        )
        
        return model, tokenizer
        
    except Exception as e:
        st.error(f"πŸ€– Model loading failed: {str(e)}")
        return None

# Generation function with KV caching
def generate_with_kv_cache(prompt, file_context, use_cache=True):
    full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:"
    
    streamer = TextIteratorStreamer(
        tokenizer, 
        skip_prompt=True, 
        skip_special_tokens=True
    )
    
    inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
    
    generation_kwargs = {
        **inputs,
        "max_new_tokens": 1024,
        "temperature": 0.7,
        "top_p": 0.9,
        "repetition_penalty": 1.1,
        "do_sample": True,
        "use_cache": use_cache,
        "streamer": streamer
    }
    
    Thread(target=model.generate, kwargs=generation_kwargs).start()
    return streamer

# Display chat messages
for message in st.session_state.messages:
    try:
        avatar = "πŸ‘€" if message["role"] == "user" else "πŸ€–"
        with st.chat_message(message["role"], avatar=avatar):
            st.markdown(message["content"])
    except:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

# Chat input handling
if prompt := st.chat_input("Ask your inspection question..."):
    if not hf_token:
        st.error("πŸ”‘ Authentication required!")
        st.stop()

    # Load model if not already loaded
    if "model" not in st.session_state:
        model_data = load_model(hf_token)
        if model_data is None:
            st.error("Failed to load model. Please check your token and try again.")
            st.stop()
            
        st.session_state.model, st.session_state.tokenizer = model_data
    
    model = st.session_state.model
    tokenizer = st.session_state.tokenizer
    
    # Add user message
    with st.chat_message("user", avatar="πŸ‘€"):
        st.markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    # Process file
    file_context = process_file(uploaded_file)
    
    # Generate response with KV caching
    if model and tokenizer:
        try:
            with st.chat_message("assistant", avatar="πŸ€–"):
                start_time = time.time()
                streamer = generate_with_kv_cache(prompt, file_context, use_cache=True)
                
                response_container = st.empty()
                full_response = ""
                
                for chunk in streamer:
                    cleaned_chunk = chunk.replace("<think>", "").replace("</think>", "").strip()
                    full_response += cleaned_chunk + " "
                    response_container.markdown(full_response + "β–Œ", unsafe_allow_html=True)
                
                # Display metrics
                end_time = time.time()
                st.caption(f"Generated in {end_time - start_time:.2f}s using KV caching")
                
                response_container.markdown(full_response)
                st.session_state.messages.append({"role": "assistant", "content": full_response})
                
        except Exception as e:
            st.error(f"⚑ Generation error: {str(e)}")
    else:
        st.error("πŸ€– Model not loaded!")