Mustehson commited on
Commit
1734c0d
Β·
1 Parent(s): d6b332d

Added Dataframe for Result

Browse files
Files changed (1) hide show
  1. app.py +28 -34
app.py CHANGED
@@ -34,10 +34,7 @@ print('Loading Model...')
34
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
35
 
36
  quantization_config = BitsAndBytesConfig(
37
- load_in_4bit=True,
38
- bnb_4bit_compute_dtype=torch.bfloat16,
39
- bnb_4bit_use_double_quant=True,
40
- bnb_4bit_quant_type= "nf4")
41
 
42
  model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", quantization_config=quantization_config,
43
  device_map="auto", torch_dtype=torch.bfloat16)
@@ -49,13 +46,14 @@ print('Model Loaded...')
49
  print(f'Model Device: {model.device}')
50
 
51
 
52
- def generate_output(schema, prompt, generated_sql_query, result_output):
53
- return {
54
- table_schema: schema,
55
- input_prompt: prompt,
56
- generated_query: generated_sql_query,
57
- result_output: result_output,
58
- }
 
59
 
60
  # Get Databases
61
  def get_schemas():
@@ -149,17 +147,12 @@ Recommend a visualization:
149
 
150
  final_prompt = prompt.format_prompt(question=text_query,
151
  sql_query=sql_query, results=sql_result)
152
- print(final_prompt)
153
  response = run_llm(final_prompt)
154
 
155
- print(f"Visualization Type: {response}")
156
-
157
  lines = response.strip().split('\n')
158
  visualization = lines[0].split(': ')[1]
159
  reason = lines[1].split(': ')[1]
160
 
161
-
162
- print(visualization, reason)
163
  return visualization, reason
164
 
165
  def format_data(text_query, sql_query, sql_result, visualization_type):
@@ -173,20 +166,17 @@ def format_data(text_query, sql_query, sql_result, visualization_type):
173
 
174
  prompt = template.format_prompt(question=text_query, sql_query=sql_query,
175
  results=sql_result, instructions=instruction)
176
- print(f'New Template for Data fromat {prompt}')
177
 
178
  formatted_data = run_llm(prompt)
179
  print(f'Formatted Data {formatted_data}')
180
-
181
  return json.loads(formatted_data.replace('.', '').strip())
182
 
183
- def visualize_result(text_query, visualization_type, reason, sql_query,
184
  sql_result):
185
 
186
  if visualization_type == 'bar':
187
  data = format_data(text_query=text_query, sql_query=sql_query,
188
  sql_result=sql_result, visualization_type=visualization_type)
189
- print(data)
190
  return plot_bar_chart(data)
191
 
192
  elif visualization_type == 'horizontal_bar':
@@ -227,36 +217,39 @@ def main(table, text_query):
227
  generated_sql_query = run_llm(prompt)
228
  print(f'Generated SQL Query: {generated_sql_query}')
229
  except Exception as e:
230
- return generate_output(schema, prompt, '', fig)
231
 
232
  try:
233
- sql_query_result = conn.sql(generated_sql_query).fetchall()
 
 
 
234
  print(f"SQL Query Result: {sql_query_result}")
235
  except Exception as e:
236
- return generate_output(schema, prompt, generated_sql_query, fig)
237
 
238
  try:
239
- print('Generating Visualization Type')
240
  visualization_type, reason = get_visualization_type(text_query=text_query,
241
  sql_query=generated_sql_query, sql_result=sql_query_result)
242
 
243
  print(f"Visualization Type: {visualization_type}")
 
244
 
245
  except Exception as e:
246
  return generate_output(schema, prompt, generated_sql_query,
247
- fig)
248
  try:
249
 
250
  plot = visualize_result(text_query=text_query, sql_query=generated_sql_query,
251
- sql_result=sql_query_result, visualization_type=visualization_type, reason=reason)
252
- print(f"Plot: {plot}")
253
  except Exception as e:
254
- return generate_output(schema, prompt, generated_sql_query, fig)
255
 
256
- return generate_output(schema, prompt, generated_sql_query, plot)
257
 
258
- def generate_output(schema, prompt, generated_sql_query, result_output):
259
- return [schema, prompt, generated_sql_query, result_output]
260
 
261
  custom_css = """
262
  .gradio-container {
@@ -301,9 +294,10 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
301
  generate_query_button = gr.Button("Run Query", variant="primary")
302
 
303
  with gr.Tabs():
 
 
304
  with gr.Tab("Plot"):
305
- result_output = gr.Plot()
306
- # result_output = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
307
  with gr.Tab("SQL Query"):
308
  generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
309
  with gr.Tab("Prompt"):
@@ -312,7 +306,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
312
  table_schema = gr.Textbox(lines=TAB_LINES, label="Table Schema", value="", interactive=False)
313
 
314
  schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
315
- generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_output])
316
 
317
  if __name__ == "__main__":
318
  demo.launch(debug=True)
 
34
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
35
 
36
  quantization_config = BitsAndBytesConfig(
37
+ load_in_4bit=True)
 
 
 
38
 
39
  model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", quantization_config=quantization_config,
40
  device_map="auto", torch_dtype=torch.bfloat16)
 
46
  print(f'Model Device: {model.device}')
47
 
48
 
49
+ # def generate_output(schema, prompt, generated_sql_query, result_output, result_visulaized):
50
+ # return {
51
+ # table_schema: schema,
52
+ # input_prompt: prompt,
53
+ # generated_query: generated_sql_query,
54
+ # result_output: result_output,
55
+ # result_visulaized: result_visulaized
56
+ # }
57
 
58
  # Get Databases
59
  def get_schemas():
 
147
 
148
  final_prompt = prompt.format_prompt(question=text_query,
149
  sql_query=sql_query, results=sql_result)
 
150
  response = run_llm(final_prompt)
151
 
 
 
152
  lines = response.strip().split('\n')
153
  visualization = lines[0].split(': ')[1]
154
  reason = lines[1].split(': ')[1]
155
 
 
 
156
  return visualization, reason
157
 
158
  def format_data(text_query, sql_query, sql_result, visualization_type):
 
166
 
167
  prompt = template.format_prompt(question=text_query, sql_query=sql_query,
168
  results=sql_result, instructions=instruction)
 
169
 
170
  formatted_data = run_llm(prompt)
171
  print(f'Formatted Data {formatted_data}')
 
172
  return json.loads(formatted_data.replace('.', '').strip())
173
 
174
+ def visualize_result(text_query, visualization_type, sql_query,
175
  sql_result):
176
 
177
  if visualization_type == 'bar':
178
  data = format_data(text_query=text_query, sql_query=sql_query,
179
  sql_result=sql_result, visualization_type=visualization_type)
 
180
  return plot_bar_chart(data)
181
 
182
  elif visualization_type == 'horizontal_bar':
 
217
  generated_sql_query = run_llm(prompt)
218
  print(f'Generated SQL Query: {generated_sql_query}')
219
  except Exception as e:
220
+ return generate_output(schema, prompt, '', fig, pd.DataFrame([{"error": f"❌ Unable to generate the SQL query. {e}"}]))
221
 
222
  try:
223
+ sql_query_result_raw = conn.sql(generated_sql_query)
224
+ sql_query_result, = sql_query_result_raw.fetchall()
225
+ sql_query_df = sql_query_result_raw.df()
226
+
227
  print(f"SQL Query Result: {sql_query_result}")
228
  except Exception as e:
229
+ return generate_output(schema, prompt, generated_sql_query, fig, pd.DataFrame([{"error": f"❌ Unable to execute the SQL query. {e}"}]))
230
 
231
  try:
 
232
  visualization_type, reason = get_visualization_type(text_query=text_query,
233
  sql_query=generated_sql_query, sql_result=sql_query_result)
234
 
235
  print(f"Visualization Type: {visualization_type}")
236
+ print(f"Reason: {reason}")
237
 
238
  except Exception as e:
239
  return generate_output(schema, prompt, generated_sql_query,
240
+ fig, sql_query_df)
241
  try:
242
 
243
  plot = visualize_result(text_query=text_query, sql_query=generated_sql_query,
244
+ sql_result=sql_query_result, visualization_type=visualization_type)
245
+
246
  except Exception as e:
247
+ return generate_output(schema, prompt, generated_sql_query, fig, sql_query_df)
248
 
249
+ return generate_output(schema, prompt, generated_sql_query, plot, sql_query_df)
250
 
251
+ def generate_output(schema, prompt, generated_sql_query, fig, sql_query_df):
252
+ return [schema, prompt, generated_sql_query, fig, sql_query_df]
253
 
254
  custom_css = """
255
  .gradio-container {
 
294
  generate_query_button = gr.Button("Run Query", variant="primary")
295
 
296
  with gr.Tabs():
297
+ with gr.Tab("Result"):
298
+ query_result_output = gr.DataFrame(label="Query Results", value=[], interactive=False)
299
  with gr.Tab("Plot"):
300
+ result_plot = gr.Plot()
 
301
  with gr.Tab("SQL Query"):
302
  generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
303
  with gr.Tab("Prompt"):
 
306
  table_schema = gr.Textbox(lines=TAB_LINES, label="Table Schema", value="", interactive=False)
307
 
308
  schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
309
+ generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_plot, query_result_output])
310
 
311
  if __name__ == "__main__":
312
  demo.launch(debug=True)