amiguel commited on
Commit
5633122
Β·
verified Β·
1 Parent(s): 1812c1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -76
app.py CHANGED
@@ -1,20 +1,22 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
- from huggingface_hub import login
4
  import PyPDF2
5
  import pandas as pd
6
  import torch
7
  import os
8
  import re
9
 
 
 
 
 
 
 
 
 
10
  # Set page configuration
11
- st.set_page_config(
12
- page_title="WizNerd Insp",
13
- page_icon="πŸš€",
14
- layout="centered"
15
- )
16
 
17
- # Load Hugging Face token from environment variable
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
 
20
  # Model name
@@ -28,14 +30,14 @@ LABEL_TO_CLASS = {
28
  11: "Pressure Vessel (VII)", 12: "Structure", 13: "Flame Arrestor"
29
  }
30
 
31
- # Title with rocket emojis
32
  st.title("πŸš€ WizNerd Insp πŸš€")
33
 
34
- # Configure Avatars
35
  USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
36
  BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
37
 
38
- # Sidebar configuration
39
  with st.sidebar:
40
  st.header("Upload Documents πŸ“‚")
41
  uploaded_file = st.file_uploader(
@@ -44,11 +46,15 @@ with st.sidebar:
44
  label_visibility="collapsed"
45
  )
46
 
47
- # Initialize chat history
48
  if "messages" not in st.session_state:
49
  st.session_state.messages = []
 
 
 
 
50
 
51
- # File processing function with pre-processing
52
  @st.cache_data
53
  def process_file(uploaded_file):
54
  if uploaded_file is None:
@@ -58,24 +64,24 @@ def process_file(uploaded_file):
58
  if uploaded_file.type == "application/pdf":
59
  pdf_reader = PyPDF2.PdfReader(uploaded_file)
60
  text = "\n".join([page.extract_text() for page in pdf_reader.pages])
61
- # Basic pre-processing
62
  text = re.sub(r'\s+', ' ', text.lower().strip())
63
  return {"type": "text", "content": text}
64
 
65
- elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
66
- df = pd.read_excel(uploaded_file)
67
- elif uploaded_file.type == "text/csv":
68
- df = pd.read_csv(uploaded_file)
69
-
70
- # For tabular data (xlsx, csv), detect scope columns
71
- if 'df' in locals():
72
- scope_cols = [col for col in df.columns if "scope" in col.lower()]
73
- if not scope_cols:
74
- st.warning("No 'scope' column found in the file. Using all data as context.")
75
- return {"type": "table", "content": df.to_markdown()}
76
- # Pre-process scope data
77
- scope_data = df[scope_cols].dropna().astype(str).apply(lambda x: re.sub(r'\s+', ' ', x.lower().strip()))
78
- return {"type": "scope", "content": scope_data}
 
79
 
80
  except Exception as e:
81
  st.error(f"πŸ“„ Error processing file: {str(e)}")
@@ -84,36 +90,31 @@ def process_file(uploaded_file):
84
  # Model loading function
85
  @st.cache_resource
86
  def load_model(hf_token):
 
 
87
  try:
88
  if not hf_token:
89
- st.error("πŸ” Authentication required! Please set the HF_TOKEN environment variable.")
90
  return None
91
-
92
  login(token=hf_token)
93
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
94
- model = AutoModelForSequenceClassification.from_pretrained(
95
- MODEL_NAME,
96
- num_labels=len(LABEL_TO_CLASS),
97
- token=hf_token
98
- )
99
  device = "cuda" if torch.cuda.is_available() else "cpu"
100
  model.to(device)
101
  return model, tokenizer
102
-
103
  except Exception as e:
104
  st.error(f"πŸ€– Model loading failed: {str(e)}")
105
  return None
106
 
107
  # Classification function
108
- def classify_instruction(prompt, file_context, model, tokenizer):
109
  model.eval()
110
  device = model.device
111
 
112
- if file_context["type"] == "scope":
113
- # Batch prediction for multiple scope entries
114
  predictions = []
115
- for scope in file_context["content"].values.flatten():
116
- full_prompt = f"Context:\n{scope}\n\nInstruction: {prompt}"
117
  inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
118
  inputs = {k: v.to(device) for k, v in inputs.items()}
119
  with torch.no_grad():
@@ -122,8 +123,7 @@ def classify_instruction(prompt, file_context, model, tokenizer):
122
  predictions.append(LABEL_TO_CLASS[prediction])
123
  return predictions
124
  else:
125
- # Single prediction for text or table context
126
- full_prompt = f"Context:\n{file_context['content']}\n\nInstruction: {prompt}"
127
  inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
128
  inputs = {k: v.to(device) for k, v in inputs.items()}
129
  with torch.no_grad():
@@ -131,6 +131,29 @@ def classify_instruction(prompt, file_context, model, tokenizer):
131
  prediction = outputs.logits.argmax().item()
132
  return LABEL_TO_CLASS[prediction]
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  # Display chat messages
135
  for message in st.session_state.messages:
136
  avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
@@ -139,49 +162,38 @@ for message in st.session_state.messages:
139
 
140
  # Chat input handling
141
  if prompt := st.chat_input("Ask your inspection question..."):
142
- # Load model if not already loaded
143
- if "model" not in st.session_state:
144
- model_data = load_model(HF_TOKEN)
145
- if model_data is None:
146
- st.error("Failed to load model. Please ensure HF_TOKEN is set correctly.")
147
- st.stop()
148
- st.session_state.model, st.session_state.tokenizer = model_data
149
-
150
- model = st.session_state.model
151
- tokenizer = st.session_state.tokenizer
152
-
153
  # Add user message
154
  with st.chat_message("user", avatar=USER_AVATAR):
155
  st.markdown(prompt)
156
  st.session_state.messages.append({"role": "user", "content": prompt})
157
 
158
- # Process file context
159
- file_context = process_file(uploaded_file)
160
- if file_context is None:
161
- st.error("No file uploaded or file processing failed.")
162
- st.stop()
163
-
164
- # Classify the instruction
165
  if model and tokenizer:
166
  try:
167
  with st.chat_message("assistant", avatar=BOT_AVATAR):
168
- predicted_output = classify_instruction(prompt, file_context, model, tokenizer)
169
- if file_context["type"] == "scope":
170
- # Display multiple predictions in a table
171
- scope_values = file_context["content"].values.flatten()
172
- result_df = pd.DataFrame({
173
- "Scope": scope_values,
174
- "Predicted Class": predicted_output
175
- })
176
- st.write("Predicted Classes:")
177
- st.table(result_df)
178
- response = "Predictions completed for multiple scope entries."
 
179
  else:
180
- # Single prediction
181
- response = f"The Item Class is: {predicted_output}"
182
- st.markdown(response)
183
- st.session_state.messages.append({"role": "assistant", "content": response})
184
 
 
 
185
  except Exception as e:
186
  st.error(f"⚑ Classification error: {str(e)}")
187
  else:
 
1
  import streamlit as st
 
 
2
  import PyPDF2
3
  import pandas as pd
4
  import torch
5
  import os
6
  import re
7
 
8
+ try:
9
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
+ from huggingface_hub import login
11
+ TRANSFORMERS_AVAILABLE = True
12
+ except ImportError as e:
13
+ st.error(f"Failed to import transformers: {str(e)}. Please install it with `pip install transformers`.")
14
+ TRANSFORMERS_AVAILABLE = False
15
+
16
  # Set page configuration
17
+ st.set_page_config(page_title="WizNerd Insp", page_icon="πŸš€", layout="centered")
 
 
 
 
18
 
19
+ # Load Hugging Face token
20
  HF_TOKEN = os.getenv("HF_TOKEN")
21
 
22
  # Model name
 
30
  11: "Pressure Vessel (VII)", 12: "Structure", 13: "Flame Arrestor"
31
  }
32
 
33
+ # Title
34
  st.title("πŸš€ WizNerd Insp πŸš€")
35
 
36
+ # Avatars
37
  USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
38
  BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
39
 
40
+ # Sidebar
41
  with st.sidebar:
42
  st.header("Upload Documents πŸ“‚")
43
  uploaded_file = st.file_uploader(
 
46
  label_visibility="collapsed"
47
  )
48
 
49
+ # Initialize session state
50
  if "messages" not in st.session_state:
51
  st.session_state.messages = []
52
+ if "file_processed" not in st.session_state:
53
+ st.session_state.file_processed = False
54
+ if "file_data" not in st.session_state:
55
+ st.session_state.file_data = None
56
 
57
+ # File processing function
58
  @st.cache_data
59
  def process_file(uploaded_file):
60
  if uploaded_file is None:
 
64
  if uploaded_file.type == "application/pdf":
65
  pdf_reader = PyPDF2.PdfReader(uploaded_file)
66
  text = "\n".join([page.extract_text() for page in pdf_reader.pages])
 
67
  text = re.sub(r'\s+', ' ', text.lower().strip())
68
  return {"type": "text", "content": text}
69
 
70
+ elif uploaded_file.type in ["application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", "text/csv"]:
71
+ df = pd.read_excel(uploaded_file) if uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" else pd.read_csv(uploaded_file)
72
+ required_cols = ["Scope", "Functional Location"]
73
+ available_cols = [col for col in required_cols if col in df.columns]
74
+
75
+ if not available_cols:
76
+ st.warning("No 'Scope' or 'Functional Location' columns found. Treating as plain text.")
77
+ return {"type": "text", "content": df.to_string()}
78
+
79
+ # Pre-process and concatenate Scope and Functional Location
80
+ df = df.dropna(subset=available_cols)
81
+ df["input_text"] = df[available_cols].apply(
82
+ lambda row: " ".join([re.sub(r'\s+', ' ', str(val).lower().strip()) for val in row]), axis=1
83
+ )
84
+ return {"type": "table", "content": df[["input_text"] + available_cols]}
85
 
86
  except Exception as e:
87
  st.error(f"πŸ“„ Error processing file: {str(e)}")
 
90
  # Model loading function
91
  @st.cache_resource
92
  def load_model(hf_token):
93
+ if not TRANSFORMERS_AVAILABLE:
94
+ return None
95
  try:
96
  if not hf_token:
97
+ st.error("πŸ” Please set the HF_TOKEN environment variable.")
98
  return None
 
99
  login(token=hf_token)
100
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
101
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(LABEL_TO_CLASS), token=hf_token)
 
 
 
 
102
  device = "cuda" if torch.cuda.is_available() else "cpu"
103
  model.to(device)
104
  return model, tokenizer
 
105
  except Exception as e:
106
  st.error(f"πŸ€– Model loading failed: {str(e)}")
107
  return None
108
 
109
  # Classification function
110
+ def classify_instruction(prompt, context, model, tokenizer):
111
  model.eval()
112
  device = model.device
113
 
114
+ if isinstance(context, pd.DataFrame):
 
115
  predictions = []
116
+ for text in context["input_text"]:
117
+ full_prompt = f"Context:\n{text}\n\nInstruction: {prompt}"
118
  inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
119
  inputs = {k: v.to(device) for k, v in inputs.items()}
120
  with torch.no_grad():
 
123
  predictions.append(LABEL_TO_CLASS[prediction])
124
  return predictions
125
  else:
126
+ full_prompt = f"Context:\n{context}\n\nInstruction: {prompt}"
 
127
  inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
128
  inputs = {k: v.to(device) for k, v in inputs.items()}
129
  with torch.no_grad():
 
131
  prediction = outputs.logits.argmax().item()
132
  return LABEL_TO_CLASS[prediction]
133
 
134
+ # Load model
135
+ if "model" not in st.session_state:
136
+ model_data = load_model(HF_TOKEN)
137
+ if model_data is None and TRANSFORMERS_AVAILABLE:
138
+ st.error("Failed to load model. Check HF_TOKEN.")
139
+ st.stop()
140
+ elif TRANSFORMERS_AVAILABLE:
141
+ st.session_state.model, st.session_state.tokenizer = model_data
142
+
143
+ model = st.session_state.get("model")
144
+ tokenizer = st.session_state.get("tokenizer")
145
+
146
+ # Process uploaded file once
147
+ if uploaded_file and not st.session_state.file_processed:
148
+ file_data = process_file(uploaded_file)
149
+ if file_data:
150
+ st.session_state.file_data = file_data
151
+ st.session_state.file_processed = True
152
+ if file_data["type"] == "table":
153
+ st.write("File uploaded with Scope and Functional Location data. Please provide an instruction.")
154
+ else:
155
+ st.write("File uploaded as text context. Please provide an instruction.")
156
+
157
  # Display chat messages
158
  for message in st.session_state.messages:
159
  avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
 
162
 
163
  # Chat input handling
164
  if prompt := st.chat_input("Ask your inspection question..."):
165
+ if not TRANSFORMERS_AVAILABLE:
166
+ st.error("Transformers library not available.")
167
+ st.stop()
168
+
 
 
 
 
 
 
 
169
  # Add user message
170
  with st.chat_message("user", avatar=USER_AVATAR):
171
  st.markdown(prompt)
172
  st.session_state.messages.append({"role": "user", "content": prompt})
173
 
174
+ # Handle response
 
 
 
 
 
 
175
  if model and tokenizer:
176
  try:
177
  with st.chat_message("assistant", avatar=BOT_AVATAR):
178
+ if st.session_state.file_data:
179
+ file_data = st.session_state.file_data
180
+ if file_data["type"] == "table":
181
+ predictions = classify_instruction(prompt, file_data["content"], model, tokenizer)
182
+ result_df = file_data["content"].copy()
183
+ result_df["Predicted Class"] = predictions
184
+ st.write("Predicted Item Classes:")
185
+ st.table(result_df)
186
+ response = "Predictions completed for uploaded file."
187
+ else:
188
+ predicted_class = classify_instruction(prompt, file_data["content"], model, tokenizer)
189
+ response = f"The Item Class is: {predicted_class}"
190
  else:
191
+ # Handle single prompt without file
192
+ predicted_class = classify_instruction(prompt, "", model, tokenizer)
193
+ response = f"The Item Class is: {predicted_class}"
 
194
 
195
+ st.markdown(response)
196
+ st.session_state.messages.append({"role": "assistant", "content": response})
197
  except Exception as e:
198
  st.error(f"⚑ Classification error: {str(e)}")
199
  else: