Spaces:
Runtime error
Runtime error
Mustehson
commited on
Commit
·
fb83515
1
Parent(s):
5862423
Changed SQL Prompt
Browse files- app.py +46 -8
- plot_utils.py +3 -3
app.py
CHANGED
@@ -85,14 +85,16 @@ def get_table_schema(table):
|
|
85 |
def get_prompt(schema, query_input):
|
86 |
text = f"""
|
87 |
### Instruction:
|
88 |
-
Your task is to generate valid
|
|
|
89 |
### Input:
|
90 |
Here is the database schema that the SQL query will run on:
|
91 |
{schema}
|
92 |
-
|
93 |
### Question:
|
94 |
{query_input}
|
95 |
-
|
|
|
96 |
"""
|
97 |
return text
|
98 |
|
@@ -151,19 +153,52 @@ Recommend a visualization:
|
|
151 |
print(visualization, reason)
|
152 |
return visualization, reason
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
def format_data(text_query, sql_query, sql_result, visualization_type):
|
156 |
instruction = graph_instructions[visualization_type]
|
157 |
|
158 |
template = ChatPromptTemplate.from_messages([
|
159 |
-
("system", "You are a Data expert who formats data according to the required needs. You are given the question asked by the user,
|
160 |
-
("human", "For the given question: {question}\n\nSQL query: {sql_query}\n\Result: {results}\n\nUse the following example to structure the data: {instructions}. If there is None in Result please change it to '0'. Just give the json string. Do not format it."),
|
161 |
])
|
162 |
|
163 |
|
164 |
prompt = template.format_prompt(question=text_query, sql_query=sql_query,
|
165 |
results=sql_result, instructions=instruction)
|
166 |
-
print(prompt)
|
167 |
formatted_data = run_llm(prompt)
|
168 |
formatted_data = formatted_data.replace('```', '')
|
169 |
formatted_data = formatted_data.replace('json', '')
|
@@ -174,8 +209,11 @@ def visualize_result(text_query, visualization_type, sql_query,
|
|
174 |
sql_result):
|
175 |
|
176 |
if visualization_type == 'bar':
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
179 |
return plot_bar_chart(data)
|
180 |
|
181 |
elif visualization_type == 'horizontal_bar':
|
|
|
85 |
def get_prompt(schema, query_input):
|
86 |
text = f"""
|
87 |
### Instruction:
|
88 |
+
Your task is to generate a valid DuckDB SQL query to answer the following question. Select only the columns that are necessary for visualization or that will make visualization easier (e.g., numeric or categorical data). Respond only with the SQL query and do not include anything else.
|
89 |
+
|
90 |
### Input:
|
91 |
Here is the database schema that the SQL query will run on:
|
92 |
{schema}
|
93 |
+
|
94 |
### Question:
|
95 |
{query_input}
|
96 |
+
|
97 |
+
### Response (use DuckDB shorthand if possible):
|
98 |
"""
|
99 |
return text
|
100 |
|
|
|
153 |
print(visualization, reason)
|
154 |
return visualization, reason
|
155 |
|
156 |
+
def format_bar_data(results, question, sql_query):
|
157 |
+
if isinstance(results, str):
|
158 |
+
results = eval(results)
|
159 |
+
|
160 |
+
if len(results[0]) == 2:
|
161 |
+
labels = [str(row[0]) for row in results]
|
162 |
+
data = [float(row[1]) if row[1] is not None else 0 for row in results]
|
163 |
+
|
164 |
+
prompt = ChatPromptTemplate.from_messages([
|
165 |
+
("system", "You are a data labeling expert. Given a question, SQL Query used, and some data, provide a concise and relevant label for the data series."),
|
166 |
+
("human", "Question: {question}\n SQL Query: {sql_query} \nData (first few rows): {data}\n\n Provide a concise label name for this y axis. Just give me the name."),
|
167 |
+
])
|
168 |
+
prompt = prompt.format_prompt(question=question, data=str(results[:2]),
|
169 |
+
sql_query=sql_query)
|
170 |
+
label = run_llm(prompt)
|
171 |
+
label = label.replace('**Answer:**', '')
|
172 |
+
values = [{"data": data, "label": label}]
|
173 |
+
elif len(results[0]) == 3:
|
174 |
+
categories = set(row[1] for row in results)
|
175 |
+
labels = list(categories)
|
176 |
+
entities = set(row[0] for row in results)
|
177 |
+
values = []
|
178 |
+
for entity in entities:
|
179 |
+
entity_data = [float(row[2]) for row in results if row[0] == entity]
|
180 |
+
values.append({"data": entity_data, "label": str(entity)})
|
181 |
+
else:
|
182 |
+
raise ValueError("Unexpected data format in results")
|
183 |
+
|
184 |
+
formatted_data = {
|
185 |
+
"labels": labels,
|
186 |
+
"values": values
|
187 |
+
}
|
188 |
+
|
189 |
+
return formatted_data
|
190 |
|
191 |
def format_data(text_query, sql_query, sql_result, visualization_type):
|
192 |
instruction = graph_instructions[visualization_type]
|
193 |
|
194 |
template = ChatPromptTemplate.from_messages([
|
195 |
+
("system", "You are a Data expert who formats data according to the required needs. You are given the question asked by the user,Given the question, SQL query, and results, your job is to Understand the question and SQL query. Ensure the result matches the query and can be easily visualized."),
|
196 |
+
("human", "For the given question: {question}\n\nSQL query: {sql_query}\n\Result: {results}\n\nUse the following example to structure the data: {instructions}. If there is None in Result please change it to '0'. Just give the json string. Do not format it. Do not give backticks."),
|
197 |
])
|
198 |
|
199 |
|
200 |
prompt = template.format_prompt(question=text_query, sql_query=sql_query,
|
201 |
results=sql_result, instructions=instruction)
|
|
|
202 |
formatted_data = run_llm(prompt)
|
203 |
formatted_data = formatted_data.replace('```', '')
|
204 |
formatted_data = formatted_data.replace('json', '')
|
|
|
209 |
sql_result):
|
210 |
|
211 |
if visualization_type == 'bar':
|
212 |
+
try:
|
213 |
+
data = format_bar_data(sql_result, text_query, sql_query)
|
214 |
+
except Exception as e:
|
215 |
+
data = format_data(text_query=text_query, sql_query=sql_query,
|
216 |
+
sql_result=sql_result, visualization_type=visualization_type)
|
217 |
return plot_bar_chart(data)
|
218 |
|
219 |
elif visualization_type == 'horizontal_bar':
|
plot_utils.py
CHANGED
@@ -6,9 +6,9 @@ def plot_bar_chart(data):
|
|
6 |
labels = data['labels']
|
7 |
values = data['values']
|
8 |
|
9 |
-
colors = ['darkviolet', 'indigo', 'blueviolet']
|
10 |
-
width = 0.2
|
11 |
-
x = np.arange(len(labels))
|
12 |
|
13 |
for i, value in enumerate(values):
|
14 |
ax.bar(x + i * width, value['data'], width, label=value['label'], color=colors[i])
|
|
|
6 |
labels = data['labels']
|
7 |
values = data['values']
|
8 |
|
9 |
+
colors = ['darkviolet', 'indigo', 'blueviolet']
|
10 |
+
width = 0.2
|
11 |
+
x = np.arange(len(labels))
|
12 |
|
13 |
for i, value in enumerate(values):
|
14 |
ax.bar(x + i * width, value['data'], width, label=value['label'], color=colors[i])
|