Mustehson commited on
Commit
fb83515
·
1 Parent(s): 5862423

Changed SQL Prompt

Browse files
Files changed (2) hide show
  1. app.py +46 -8
  2. plot_utils.py +3 -3
app.py CHANGED
@@ -85,14 +85,16 @@ def get_table_schema(table):
85
  def get_prompt(schema, query_input):
86
  text = f"""
87
  ### Instruction:
88
- Your task is to generate valid duckdb SQL query to answer the following question. Only Respond with SQL query do not inlcude anything else.
 
89
  ### Input:
90
  Here is the database schema that the SQL query will run on:
91
  {schema}
92
-
93
  ### Question:
94
  {query_input}
95
- ### Response (use duckdb shorthand if possible):
 
96
  """
97
  return text
98
 
@@ -151,19 +153,52 @@ Recommend a visualization:
151
  print(visualization, reason)
152
  return visualization, reason
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  def format_data(text_query, sql_query, sql_result, visualization_type):
156
  instruction = graph_instructions[visualization_type]
157
 
158
  template = ChatPromptTemplate.from_messages([
159
- ("system", "You are a Data expert who formats data according to the required needs. You are given the question asked by the user, it's sql query, the result of the query and the format you need to format it in."),
160
- ("human", "For the given question: {question}\n\nSQL query: {sql_query}\n\Result: {results}\n\nUse the following example to structure the data: {instructions}. If there is None in Result please change it to '0'. Just give the json string. Do not format it."),
161
  ])
162
 
163
 
164
  prompt = template.format_prompt(question=text_query, sql_query=sql_query,
165
  results=sql_result, instructions=instruction)
166
- print(prompt)
167
  formatted_data = run_llm(prompt)
168
  formatted_data = formatted_data.replace('```', '')
169
  formatted_data = formatted_data.replace('json', '')
@@ -174,8 +209,11 @@ def visualize_result(text_query, visualization_type, sql_query,
174
  sql_result):
175
 
176
  if visualization_type == 'bar':
177
- data = format_data(text_query=text_query, sql_query=sql_query,
178
- sql_result=sql_result, visualization_type=visualization_type)
 
 
 
179
  return plot_bar_chart(data)
180
 
181
  elif visualization_type == 'horizontal_bar':
 
85
  def get_prompt(schema, query_input):
86
  text = f"""
87
  ### Instruction:
88
+ Your task is to generate a valid DuckDB SQL query to answer the following question. Select only the columns that are necessary for visualization or that will make visualization easier (e.g., numeric or categorical data). Respond only with the SQL query and do not include anything else.
89
+
90
  ### Input:
91
  Here is the database schema that the SQL query will run on:
92
  {schema}
93
+
94
  ### Question:
95
  {query_input}
96
+
97
+ ### Response (use DuckDB shorthand if possible):
98
  """
99
  return text
100
 
 
153
  print(visualization, reason)
154
  return visualization, reason
155
 
156
+ def format_bar_data(results, question, sql_query):
157
+ if isinstance(results, str):
158
+ results = eval(results)
159
+
160
+ if len(results[0]) == 2:
161
+ labels = [str(row[0]) for row in results]
162
+ data = [float(row[1]) if row[1] is not None else 0 for row in results]
163
+
164
+ prompt = ChatPromptTemplate.from_messages([
165
+ ("system", "You are a data labeling expert. Given a question, SQL Query used, and some data, provide a concise and relevant label for the data series."),
166
+ ("human", "Question: {question}\n SQL Query: {sql_query} \nData (first few rows): {data}\n\n Provide a concise label name for this y axis. Just give me the name."),
167
+ ])
168
+ prompt = prompt.format_prompt(question=question, data=str(results[:2]),
169
+ sql_query=sql_query)
170
+ label = run_llm(prompt)
171
+ label = label.replace('**Answer:**', '')
172
+ values = [{"data": data, "label": label}]
173
+ elif len(results[0]) == 3:
174
+ categories = set(row[1] for row in results)
175
+ labels = list(categories)
176
+ entities = set(row[0] for row in results)
177
+ values = []
178
+ for entity in entities:
179
+ entity_data = [float(row[2]) for row in results if row[0] == entity]
180
+ values.append({"data": entity_data, "label": str(entity)})
181
+ else:
182
+ raise ValueError("Unexpected data format in results")
183
+
184
+ formatted_data = {
185
+ "labels": labels,
186
+ "values": values
187
+ }
188
+
189
+ return formatted_data
190
 
191
  def format_data(text_query, sql_query, sql_result, visualization_type):
192
  instruction = graph_instructions[visualization_type]
193
 
194
  template = ChatPromptTemplate.from_messages([
195
+ ("system", "You are a Data expert who formats data according to the required needs. You are given the question asked by the user,Given the question, SQL query, and results, your job is to Understand the question and SQL query. Ensure the result matches the query and can be easily visualized."),
196
+ ("human", "For the given question: {question}\n\nSQL query: {sql_query}\n\Result: {results}\n\nUse the following example to structure the data: {instructions}. If there is None in Result please change it to '0'. Just give the json string. Do not format it. Do not give backticks."),
197
  ])
198
 
199
 
200
  prompt = template.format_prompt(question=text_query, sql_query=sql_query,
201
  results=sql_result, instructions=instruction)
 
202
  formatted_data = run_llm(prompt)
203
  formatted_data = formatted_data.replace('```', '')
204
  formatted_data = formatted_data.replace('json', '')
 
209
  sql_result):
210
 
211
  if visualization_type == 'bar':
212
+ try:
213
+ data = format_bar_data(sql_result, text_query, sql_query)
214
+ except Exception as e:
215
+ data = format_data(text_query=text_query, sql_query=sql_query,
216
+ sql_result=sql_result, visualization_type=visualization_type)
217
  return plot_bar_chart(data)
218
 
219
  elif visualization_type == 'horizontal_bar':
plot_utils.py CHANGED
@@ -6,9 +6,9 @@ def plot_bar_chart(data):
6
  labels = data['labels']
7
  values = data['values']
8
 
9
- colors = ['darkviolet', 'indigo', 'blueviolet'] # Define the colors for each group
10
- width = 0.2 # Width of the bars
11
- x = np.arange(len(labels)) # Label locations
12
 
13
  for i, value in enumerate(values):
14
  ax.bar(x + i * width, value['data'], width, label=value['label'], color=colors[i])
 
6
  labels = data['labels']
7
  values = data['values']
8
 
9
+ colors = ['darkviolet', 'indigo', 'blueviolet']
10
+ width = 0.2
11
+ x = np.arange(len(labels))
12
 
13
  for i, value in enumerate(values):
14
  ax.bar(x + i * width, value['data'], width, label=value['label'], color=colors[i])