amiguel commited on
Commit
fe1fad9
Β·
verified Β·
1 Parent(s): 5633122

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import pandas as pd
4
  import torch
5
  import os
6
  import re
 
7
 
8
  try:
9
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
@@ -16,6 +17,19 @@ except ImportError as e:
16
  # Set page configuration
17
  st.set_page_config(page_title="WizNerd Insp", page_icon="πŸš€", layout="centered")
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Load Hugging Face token
20
  HF_TOKEN = os.getenv("HF_TOKEN")
21
 
@@ -131,6 +145,13 @@ def classify_instruction(prompt, context, model, tokenizer):
131
  prediction = outputs.logits.argmax().item()
132
  return LABEL_TO_CLASS[prediction]
133
 
 
 
 
 
 
 
 
134
  # Load model
135
  if "model" not in st.session_state:
136
  model_data = load_model(HF_TOKEN)
@@ -179,10 +200,11 @@ if prompt := st.chat_input("Ask your inspection question..."):
179
  file_data = st.session_state.file_data
180
  if file_data["type"] == "table":
181
  predictions = classify_instruction(prompt, file_data["content"], model, tokenizer)
182
- result_df = file_data["content"].copy()
183
  result_df["Predicted Class"] = predictions
184
  st.write("Predicted Item Classes:")
185
  st.table(result_df)
 
186
  response = "Predictions completed for uploaded file."
187
  else:
188
  predicted_class = classify_instruction(prompt, file_data["content"], model, tokenizer)
 
4
  import torch
5
  import os
6
  import re
7
+ import base64 # For CSV download
8
 
9
  try:
10
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
17
  # Set page configuration
18
  st.set_page_config(page_title="WizNerd Insp", page_icon="πŸš€", layout="centered")
19
 
20
+ # Custom CSS for Tw Cen MT font
21
+ st.markdown("""
22
+ <style>
23
+ @import url('https://fonts.googleapis.com/css2?family=Tw+Cen+MT&display=swap');
24
+ html, body, [class*="css"] {
25
+ font-family: 'Tw Cen MT', sans-serif !important;
26
+ }
27
+ .stTable table {
28
+ font-family: 'Tw Cen MT', sans-serif !important;
29
+ }
30
+ </style>
31
+ """, unsafe_allow_html=True)
32
+
33
  # Load Hugging Face token
34
  HF_TOKEN = os.getenv("HF_TOKEN")
35
 
 
145
  prediction = outputs.logits.argmax().item()
146
  return LABEL_TO_CLASS[prediction]
147
 
148
+ # CSV download function
149
+ def get_csv_download_link(df, filename="predicted_classes.csv"):
150
+ csv = df.to_csv(index=False)
151
+ b64 = base64.b64encode(csv.encode()).decode() # Encode to base64
152
+ href = f'<a href="data:file/csv;base64,{b64}" download="{filename}">Download CSV</a>'
153
+ return href
154
+
155
  # Load model
156
  if "model" not in st.session_state:
157
  model_data = load_model(HF_TOKEN)
 
200
  file_data = st.session_state.file_data
201
  if file_data["type"] == "table":
202
  predictions = classify_instruction(prompt, file_data["content"], model, tokenizer)
203
+ result_df = file_data["content"][["Scope", "Functional Location"]].copy()
204
  result_df["Predicted Class"] = predictions
205
  st.write("Predicted Item Classes:")
206
  st.table(result_df)
207
+ st.markdown(get_csv_download_link(result_df), unsafe_allow_html=True)
208
  response = "Predictions completed for uploaded file."
209
  else:
210
  predicted_class = classify_instruction(prompt, file_data["content"], model, tokenizer)