amiguel commited on
Commit
cd58cfd
Β·
verified Β·
1 Parent(s): a7ba67c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -110
app.py CHANGED
@@ -1,22 +1,11 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
  from huggingface_hub import login
4
- from threading import Thread
5
  import PyPDF2
6
  import pandas as pd
7
  import torch
8
- import time
9
  import os
10
 
11
- # Check if 'peft' is installed
12
- try:
13
- from peft import PeftModel, PeftConfig
14
- except ImportError:
15
- raise ImportError(
16
- "The 'peft' library is required but not installed. "
17
- "Please install it using: `pip install peft`"
18
- )
19
-
20
  # Set page configuration
21
  st.set_page_config(
22
  page_title="WizNerd Insp",
@@ -25,14 +14,17 @@ st.set_page_config(
25
  )
26
 
27
  # Load Hugging Face token from environment variable
28
- HF_TOKEN = os.getenv("HF_TOKEN") # Set this in your environment, e.g., via export HF_TOKEN="your_token"
 
 
 
29
 
30
- # Model names
31
- BASE_MODEL_NAME = "google-bert/bert-base-uncased"
32
- MODEL_OPTIONS = {
33
- "Full Fine-Tuned": "amiguel/instruct_BERT-base-uncased_model",
34
- "LoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-lora",
35
- "QLoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-qlora" # Hypothetical, adjust if needed
36
  }
37
 
38
  # Title with rocket emojis
@@ -44,10 +36,6 @@ BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/99
44
 
45
  # Sidebar configuration
46
  with st.sidebar:
47
- st.header("Model Selection πŸ€–")
48
- model_type = st.selectbox("Choose Model Type", list(MODEL_OPTIONS.keys()), index=0)
49
- selected_model = MODEL_OPTIONS[model_type]
50
-
51
  st.header("Upload Documents πŸ“‚")
52
  uploaded_file = st.file_uploader(
53
  "Choose a PDF or XLSX file",
@@ -78,7 +66,7 @@ def process_file(uploaded_file):
78
 
79
  # Model loading function
80
  @st.cache_resource
81
- def load_model(hf_token, model_type, selected_model):
82
  try:
83
  if not hf_token:
84
  st.error("πŸ” Authentication required! Please set the HF_TOKEN environment variable.")
@@ -86,34 +74,17 @@ def load_model(hf_token, model_type, selected_model):
86
 
87
  login(token=hf_token)
88
 
89
- # Load tokenizer
90
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, token=hf_token)
 
 
 
 
 
91
 
92
  # Determine device
93
  device = "cuda" if torch.cuda.is_available() else "cpu"
94
-
95
- # Load model based on type
96
- if model_type == "Full Fine-Tuned":
97
- # Load full fine-tuned model directly
98
- model = AutoModelForCausalLM.from_pretrained(
99
- selected_model,
100
- torch_dtype=torch.bfloat16,
101
- token=hf_token
102
- ).to(device)
103
- else:
104
- # Load base model and apply PEFT adapter
105
- base_model = AutoModelForCausalLM.from_pretrained(
106
- BASE_MODEL_NAME,
107
- torch_dtype=torch.bfloat16,
108
- token=hf_token
109
- ).to(device)
110
- model = PeftModel.from_pretrained(
111
- base_model,
112
- selected_model,
113
- torch_dtype=torch.bfloat16,
114
- is_trainable=False, # Inference mode
115
- token=hf_token
116
- ).to(device)
117
 
118
  return model, tokenizer
119
 
@@ -121,32 +92,22 @@ def load_model(hf_token, model_type, selected_model):
121
  st.error(f"πŸ€– Model loading failed: {str(e)}")
122
  return None
123
 
124
- # Generation function with KV caching
125
- def generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True):
126
- full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:"
127
 
128
- streamer = TextIteratorStreamer(
129
- tokenizer,
130
- skip_prompt=True,
131
- skip_special_tokens=True
132
- )
133
 
134
- inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
 
135
 
136
- generation_kwargs = {
137
- "input_ids": inputs["input_ids"],
138
- "attention_mask": inputs["attention_mask"],
139
- "max_new_tokens": 1024,
140
- "temperature": 0.7,
141
- "top_p": 0.9,
142
- "repetition_penalty": 1.1,
143
- "do_sample": True,
144
- "use_cache": use_cache,
145
- "streamer": streamer
146
- }
147
 
148
- Thread(target=model.generate, kwargs=generation_kwargs).start()
149
- return streamer
150
 
151
  # Display chat messages
152
  for message in st.session_state.messages:
@@ -160,15 +121,14 @@ for message in st.session_state.messages:
160
 
161
  # Chat input handling
162
  if prompt := st.chat_input("Ask your inspection question..."):
163
- # Load model if not already loaded or if model type changed
164
- if "model" not in st.session_state or st.session_state.get("model_type") != model_type:
165
- model_data = load_model(HF_TOKEN, model_type, selected_model)
166
  if model_data is None:
167
  st.error("Failed to load model. Please ensure HF_TOKEN is set correctly.")
168
  st.stop()
169
 
170
  st.session_state.model, st.session_state.tokenizer = model_data
171
- st.session_state.model_type = model_type
172
 
173
  model = st.session_state.model
174
  tokenizer = st.session_state.tokenizer
@@ -178,47 +138,19 @@ if prompt := st.chat_input("Ask your inspection question..."):
178
  st.markdown(prompt)
179
  st.session_state.messages.append({"role": "user", "content": prompt})
180
 
181
- # Process Rank
182
  file_context = process_file(uploaded_file)
183
 
184
- # Generate response with KV caching
185
  if model and tokenizer:
186
  try:
187
  with st.chat_message("assistant", avatar=BOT_AVATAR):
188
- start_time = time.time()
189
- streamer = generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True)
190
-
191
- response_container = st.empty()
192
- full_response = ""
193
-
194
- for chunk in streamer:
195
- cleaned_chunk = chunk.replace("<think>", "").replace("</think>", "").strip()
196
- full_response += cleaned_chunk + " "
197
- response_container.markdown(full_response + "β–Œ", unsafe_allow_html=True)
198
-
199
- # Calculate performance metrics
200
- end_time = time.time()
201
- input_tokens = len(tokenizer(prompt)["input_ids"])
202
- output_tokens = len(tokenizer(full_response)["input_ids"])
203
- speed = output_tokens / (end_time - start_time)
204
-
205
- # Calculate costs (hypothetical pricing model)
206
- input_cost = (input_tokens / 1000000) * 5 # $5 per million input tokens
207
- output_cost = (output_tokens / 1000000) * 15 # $15 per million output tokens
208
- total_cost_usd = input_cost + output_cost
209
- total_cost_aoa = total_cost_usd * 1160 # Convert to AOA (Angolan Kwanza)
210
-
211
- # Display metrics
212
- st.caption(
213
- f"πŸ”‘ Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
214
- f"πŸ•’ Speed: {speed:.1f}t/s | πŸ’° Cost (USD): ${total_cost_usd:.4f} | "
215
- f"πŸ’΅ Cost (AOA): {total_cost_aoa:.4f}"
216
- )
217
-
218
- response_container.markdown(full_response)
219
- st.session_state.messages.append({"role": "assistant", "content": full_response})
220
 
221
  except Exception as e:
222
- st.error(f"⚑ Generation error: {str(e)}")
223
  else:
224
  st.error("πŸ€– Model not loaded!")
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  from huggingface_hub import login
 
4
  import PyPDF2
5
  import pandas as pd
6
  import torch
 
7
  import os
8
 
 
 
 
 
 
 
 
 
 
9
  # Set page configuration
10
  st.set_page_config(
11
  page_title="WizNerd Insp",
 
14
  )
15
 
16
  # Load Hugging Face token from environment variable
17
+ HF_TOKEN = os.getenv("HF_TOKEN") # Set this in your Space's secrets
18
+
19
+ # Model name
20
+ MODEL_NAME = "amiguel/instruct_BERT-base-uncased_model"
21
 
22
+ # Label mapping (same as in Colab)
23
+ LABEL_TO_CLASS = {
24
+ 0: "Campaign", 1: "Corrosion Monitoring", 2: "Flare Tip", 3: "Flare TIP",
25
+ 4: "FU Items", 5: "Intelligent Pigging", 6: "Lifting", 7: "Non Structural Tank",
26
+ 8: "Piping", 9: "Pressure Safety Device", 10: "Pressure Vessel (VIE)",
27
+ 11: "Pressure Vessel (VII)", 12: "Structure", 13: "Flame Arrestor"
28
  }
29
 
30
  # Title with rocket emojis
 
36
 
37
  # Sidebar configuration
38
  with st.sidebar:
 
 
 
 
39
  st.header("Upload Documents πŸ“‚")
40
  uploaded_file = st.file_uploader(
41
  "Choose a PDF or XLSX file",
 
66
 
67
  # Model loading function
68
  @st.cache_resource
69
+ def load_model(hf_token):
70
  try:
71
  if not hf_token:
72
  st.error("πŸ” Authentication required! Please set the HF_TOKEN environment variable.")
 
74
 
75
  login(token=hf_token)
76
 
77
+ # Load tokenizer and model for classification
78
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
79
+ model = AutoModelForSequenceClassification.from_pretrained(
80
+ MODEL_NAME,
81
+ num_labels=len(LABEL_TO_CLASS), # Ensure correct number of labels
82
+ token=hf_token
83
+ )
84
 
85
  # Determine device
86
  device = "cuda" if torch.cuda.is_available() else "cpu"
87
+ model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  return model, tokenizer
90
 
 
92
  st.error(f"πŸ€– Model loading failed: {str(e)}")
93
  return None
94
 
95
+ # Classification function
96
+ def classify_instruction(prompt, file_context, model, tokenizer):
97
+ full_prompt = f"Context:\n{file_context}\n\nInstruction: {prompt}"
98
 
99
+ model.eval()
100
+ device = model.device
 
 
 
101
 
102
+ inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
103
+ inputs = {k: v.to(device) for k, v in inputs.items()}
104
 
105
+ with torch.no_grad():
106
+ outputs = model(**inputs)
107
+ prediction = outputs.logits.argmax().item()
108
+ class_name = LABEL_TO_CLASS[prediction]
 
 
 
 
 
 
 
109
 
110
+ return class_name
 
111
 
112
  # Display chat messages
113
  for message in st.session_state.messages:
 
121
 
122
  # Chat input handling
123
  if prompt := st.chat_input("Ask your inspection question..."):
124
+ # Load model if not already loaded
125
+ if "model" not in st.session_state:
126
+ model_data = load_model(HF_TOKEN)
127
  if model_data is None:
128
  st.error("Failed to load model. Please ensure HF_TOKEN is set correctly.")
129
  st.stop()
130
 
131
  st.session_state.model, st.session_state.tokenizer = model_data
 
132
 
133
  model = st.session_state.model
134
  tokenizer = st.session_state.tokenizer
 
138
  st.markdown(prompt)
139
  st.session_state.messages.append({"role": "user", "content": prompt})
140
 
141
+ # Process file context
142
  file_context = process_file(uploaded_file)
143
 
144
+ # Classify the instruction
145
  if model and tokenizer:
146
  try:
147
  with st.chat_message("assistant", avatar=BOT_AVATAR):
148
+ predicted_class = classify_instruction(prompt, file_context, model, tokenizer)
149
+ response = f"Predicted class: {predicted_class}"
150
+ st.markdown(response)
151
+ st.session_state.messages.append({"role": "assistant", "content": response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  except Exception as e:
154
+ st.error(f"⚑ Classification error: {str(e)}")
155
  else:
156
  st.error("πŸ€– Model not loaded!")