garyd1 commited on
Commit
1d1e331
·
verified ·
1 Parent(s): 5324d4e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import openai
5
+ import torch
6
+ import matplotlib.pyplot as plt
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
+ from dotenv import load_dotenv
9
+ import anthropic
10
+ import ast
11
+ import re
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
+ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
16
+ os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")
17
+
18
+ # UI Styling
19
+ st.markdown(
20
+ """
21
+ <style>
22
+ .stButton button {
23
+ background-color: #1F6FEB;
24
+ color: white;
25
+ border-radius: 8px;
26
+ border: none;
27
+ padding: 10px 20px;
28
+ font-weight: bold;
29
+ }
30
+ .stButton button:hover {
31
+ background-color: #1A4FC5;
32
+ }
33
+ .stTextInput > div > input {
34
+ border: 1px solid #30363D;
35
+ background-color: #161B22;
36
+ color: #C9D1D9;
37
+ border-radius: 6px;
38
+ padding: 10px;
39
+ }
40
+ .stFileUploader > div {
41
+ border: 2px dashed #30363D;
42
+ background-color: #161B22;
43
+ color: #C9D1D9;
44
+ border-radius: 6px;
45
+ padding: 10px;
46
+ }
47
+ .response-box {
48
+ background-color: #161B22;
49
+ padding: 10px;
50
+ border-radius: 6px;
51
+ margin-bottom: 10px;
52
+ color: #FFFFFF;
53
+ }
54
+ </style>
55
+ """,
56
+ unsafe_allow_html=True
57
+ )
58
+
59
+ st.title("Excel Q&A Chatbot 📊")
60
+
61
+ # Model Selection
62
+ model_choice = st.selectbox("Select LLM Model", ["OpenAI GPT-3.5", "Claude 3 Haiku", "Mistral-7B"])
63
+
64
+ # Load appropriate model based on selection
65
+ if model_choice == "Mistral-7B":
66
+ model_name = "mistralai/Mistral-7B-Instruct"
67
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
68
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
69
+ def ask_mistral(query):
70
+ inputs = tokenizer(query, return_tensors="pt").to("cuda")
71
+ output = model.generate(**inputs)
72
+ return tokenizer.decode(output[0])
73
+
74
+ elif model_choice == "Claude 3 Haiku":
75
+ client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
76
+ def ask_claude(query):
77
+ response = client.messages.create(
78
+ model="claude-3-haiku",
79
+ max_tokens=512,
80
+ messages=[{"role": "user", "content": query}]
81
+ )
82
+ return response.content[0]["text"]
83
+
84
+ else:
85
+ client = openai.OpenAI()
86
+ def ask_gpt(query):
87
+ response = client.chat.completions.create(
88
+ model="gpt-3.5-turbo",
89
+ messages=[{"role": "user", "content": query}]
90
+ )
91
+ return response.choices[0].message.content
92
+
93
+ # File Upload with validation
94
+ uploaded_file = st.file_uploader("Upload a file", type=["csv", "xlsx", "xls", "json", "tsv"])
95
+
96
+ if uploaded_file is not None:
97
+ file_extension = uploaded_file.name.split(".")[-1].lower()
98
+
99
+ try:
100
+ if file_extension == "csv":
101
+ df = pd.read_csv(uploaded_file)
102
+ elif file_extension in ["xlsx", "xls"]:
103
+ df = pd.read_excel(uploaded_file, engine="openpyxl")
104
+ elif file_extension == "json":
105
+ df = pd.read_json(uploaded_file)
106
+ elif file_extension == "tsv":
107
+ df = pd.read_csv(uploaded_file, sep="\t")
108
+ else:
109
+ st.error("Unsupported file format. Please upload a CSV, Excel, JSON, or TSV file.")
110
+ st.stop()
111
+
112
+ st.write("### Preview of Data:")
113
+ st.write(df.head())
114
+
115
+ # Extract metadata
116
+ column_names = df.columns.tolist()
117
+ data_types = df.dtypes.apply(lambda x: x.name).to_dict()
118
+ missing_values = df.isnull().sum().to_dict()
119
+
120
+ # Display metadata
121
+ st.write("### Column Details:")
122
+ st.write(pd.DataFrame({"Column": column_names, "Type": data_types.values(), "Missing Values": missing_values.values()}))
123
+
124
+ except Exception as e:
125
+ st.error(f"Error loading file: {str(e)}")
126
+ st.stop()
127
+
128
+ # User Query
129
+ query = st.text_input("Ask a question about this data:")
130
+
131
+ if st.button("Submit Query"):
132
+ if query:
133
+ # Interpret the query using selected LLM
134
+ if model_choice == "Mistral-7B":
135
+ parsed_query = ask_mistral(f"Convert this question into a structured data operation: {query}")
136
+ elif model_choice == "Claude 3 Haiku":
137
+ parsed_query = ask_claude(f"Convert this question into a structured data operation: {query}")
138
+ else:
139
+ parsed_query = ask_gpt(f"Convert this question into a structured data operation: {query}")
140
+
141
+ # Validate and clean query
142
+ parsed_query = re.sub(r"[^a-zA-Z0-9_()\[\]"'., ]", "", parsed_query.strip())
143
+ st.write(f"Parsed Query: `{parsed_query}`")
144
+
145
+ # Predefined Safe Execution Methods
146
+ SAFE_OPERATIONS = {
147
+ "sum": lambda col: df[col].sum(),
148
+ "mean": lambda col: df[col].mean(),
149
+ "max": lambda col: df[col].max(),
150
+ "groupby_sum": lambda col, group_by: df.groupby(group_by)[col].sum()
151
+ }
152
+
153
+ # Safe Execution
154
+ try:
155
+ exec_result = eval(parsed_query, {"df": df, "pd": pd, "SAFE_OPERATIONS": SAFE_OPERATIONS})
156
+ st.write("### Result:")
157
+ st.write(exec_result if isinstance(exec_result, pd.DataFrame) else str(exec_result))
158
+
159
+ # If numerical data, show a visualization dynamically
160
+ if isinstance(exec_result, pd.Series):
161
+ fig, ax = plt.subplots()
162
+ if exec_result.dtype in ["int64", "float64"]:
163
+ exec_result.plot(kind="bar", ax=ax)
164
+ elif exec_result.dtype == "object":
165
+ exec_result.value_counts().plot(kind="bar", ax=ax)
166
+ st.pyplot(fig)
167
+
168
+ except SyntaxError as e:
169
+ st.error(f"Syntax Error in parsed query: {str(e)}")
170
+ except Exception as e:
171
+ st.error(f"Error executing query: {str(e)}")
172
+
173
+ # Memory for context retention
174
+ if "query_history" not in st.session_state:
175
+ st.session_state.query_history = []
176
+ st.session_state.query_history.append(query)