Spaces:
Runtime error
Runtime error
Mustehson
commited on
Commit
Β·
1734c0d
1
Parent(s):
d6b332d
Added Dataframe for Result
Browse files
app.py
CHANGED
@@ -34,10 +34,7 @@ print('Loading Model...')
|
|
34 |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
|
35 |
|
36 |
quantization_config = BitsAndBytesConfig(
|
37 |
-
load_in_4bit=True
|
38 |
-
bnb_4bit_compute_dtype=torch.bfloat16,
|
39 |
-
bnb_4bit_use_double_quant=True,
|
40 |
-
bnb_4bit_quant_type= "nf4")
|
41 |
|
42 |
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", quantization_config=quantization_config,
|
43 |
device_map="auto", torch_dtype=torch.bfloat16)
|
@@ -49,13 +46,14 @@ print('Model Loaded...')
|
|
49 |
print(f'Model Device: {model.device}')
|
50 |
|
51 |
|
52 |
-
def generate_output(schema, prompt, generated_sql_query, result_output):
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
59 |
|
60 |
# Get Databases
|
61 |
def get_schemas():
|
@@ -149,17 +147,12 @@ Recommend a visualization:
|
|
149 |
|
150 |
final_prompt = prompt.format_prompt(question=text_query,
|
151 |
sql_query=sql_query, results=sql_result)
|
152 |
-
print(final_prompt)
|
153 |
response = run_llm(final_prompt)
|
154 |
|
155 |
-
print(f"Visualization Type: {response}")
|
156 |
-
|
157 |
lines = response.strip().split('\n')
|
158 |
visualization = lines[0].split(': ')[1]
|
159 |
reason = lines[1].split(': ')[1]
|
160 |
|
161 |
-
|
162 |
-
print(visualization, reason)
|
163 |
return visualization, reason
|
164 |
|
165 |
def format_data(text_query, sql_query, sql_result, visualization_type):
|
@@ -173,20 +166,17 @@ def format_data(text_query, sql_query, sql_result, visualization_type):
|
|
173 |
|
174 |
prompt = template.format_prompt(question=text_query, sql_query=sql_query,
|
175 |
results=sql_result, instructions=instruction)
|
176 |
-
print(f'New Template for Data fromat {prompt}')
|
177 |
|
178 |
formatted_data = run_llm(prompt)
|
179 |
print(f'Formatted Data {formatted_data}')
|
180 |
-
|
181 |
return json.loads(formatted_data.replace('.', '').strip())
|
182 |
|
183 |
-
def visualize_result(text_query, visualization_type,
|
184 |
sql_result):
|
185 |
|
186 |
if visualization_type == 'bar':
|
187 |
data = format_data(text_query=text_query, sql_query=sql_query,
|
188 |
sql_result=sql_result, visualization_type=visualization_type)
|
189 |
-
print(data)
|
190 |
return plot_bar_chart(data)
|
191 |
|
192 |
elif visualization_type == 'horizontal_bar':
|
@@ -227,36 +217,39 @@ def main(table, text_query):
|
|
227 |
generated_sql_query = run_llm(prompt)
|
228 |
print(f'Generated SQL Query: {generated_sql_query}')
|
229 |
except Exception as e:
|
230 |
-
return generate_output(schema, prompt, '', fig)
|
231 |
|
232 |
try:
|
233 |
-
|
|
|
|
|
|
|
234 |
print(f"SQL Query Result: {sql_query_result}")
|
235 |
except Exception as e:
|
236 |
-
return generate_output(schema, prompt, generated_sql_query, fig)
|
237 |
|
238 |
try:
|
239 |
-
print('Generating Visualization Type')
|
240 |
visualization_type, reason = get_visualization_type(text_query=text_query,
|
241 |
sql_query=generated_sql_query, sql_result=sql_query_result)
|
242 |
|
243 |
print(f"Visualization Type: {visualization_type}")
|
|
|
244 |
|
245 |
except Exception as e:
|
246 |
return generate_output(schema, prompt, generated_sql_query,
|
247 |
-
fig)
|
248 |
try:
|
249 |
|
250 |
plot = visualize_result(text_query=text_query, sql_query=generated_sql_query,
|
251 |
-
sql_result=sql_query_result, visualization_type=visualization_type
|
252 |
-
|
253 |
except Exception as e:
|
254 |
-
return generate_output(schema, prompt, generated_sql_query, fig)
|
255 |
|
256 |
-
return generate_output(schema, prompt, generated_sql_query, plot)
|
257 |
|
258 |
-
def generate_output(schema, prompt, generated_sql_query,
|
259 |
-
return [schema, prompt, generated_sql_query,
|
260 |
|
261 |
custom_css = """
|
262 |
.gradio-container {
|
@@ -301,9 +294,10 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
|
|
301 |
generate_query_button = gr.Button("Run Query", variant="primary")
|
302 |
|
303 |
with gr.Tabs():
|
|
|
|
|
304 |
with gr.Tab("Plot"):
|
305 |
-
|
306 |
-
# result_output = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
|
307 |
with gr.Tab("SQL Query"):
|
308 |
generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
|
309 |
with gr.Tab("Prompt"):
|
@@ -312,7 +306,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
|
|
312 |
table_schema = gr.Textbox(lines=TAB_LINES, label="Table Schema", value="", interactive=False)
|
313 |
|
314 |
schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
|
315 |
-
generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query,
|
316 |
|
317 |
if __name__ == "__main__":
|
318 |
demo.launch(debug=True)
|
|
|
34 |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
|
35 |
|
36 |
quantization_config = BitsAndBytesConfig(
|
37 |
+
load_in_4bit=True)
|
|
|
|
|
|
|
38 |
|
39 |
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", quantization_config=quantization_config,
|
40 |
device_map="auto", torch_dtype=torch.bfloat16)
|
|
|
46 |
print(f'Model Device: {model.device}')
|
47 |
|
48 |
|
49 |
+
# def generate_output(schema, prompt, generated_sql_query, result_output, result_visulaized):
|
50 |
+
# return {
|
51 |
+
# table_schema: schema,
|
52 |
+
# input_prompt: prompt,
|
53 |
+
# generated_query: generated_sql_query,
|
54 |
+
# result_output: result_output,
|
55 |
+
# result_visulaized: result_visulaized
|
56 |
+
# }
|
57 |
|
58 |
# Get Databases
|
59 |
def get_schemas():
|
|
|
147 |
|
148 |
final_prompt = prompt.format_prompt(question=text_query,
|
149 |
sql_query=sql_query, results=sql_result)
|
|
|
150 |
response = run_llm(final_prompt)
|
151 |
|
|
|
|
|
152 |
lines = response.strip().split('\n')
|
153 |
visualization = lines[0].split(': ')[1]
|
154 |
reason = lines[1].split(': ')[1]
|
155 |
|
|
|
|
|
156 |
return visualization, reason
|
157 |
|
158 |
def format_data(text_query, sql_query, sql_result, visualization_type):
|
|
|
166 |
|
167 |
prompt = template.format_prompt(question=text_query, sql_query=sql_query,
|
168 |
results=sql_result, instructions=instruction)
|
|
|
169 |
|
170 |
formatted_data = run_llm(prompt)
|
171 |
print(f'Formatted Data {formatted_data}')
|
|
|
172 |
return json.loads(formatted_data.replace('.', '').strip())
|
173 |
|
174 |
+
def visualize_result(text_query, visualization_type, sql_query,
|
175 |
sql_result):
|
176 |
|
177 |
if visualization_type == 'bar':
|
178 |
data = format_data(text_query=text_query, sql_query=sql_query,
|
179 |
sql_result=sql_result, visualization_type=visualization_type)
|
|
|
180 |
return plot_bar_chart(data)
|
181 |
|
182 |
elif visualization_type == 'horizontal_bar':
|
|
|
217 |
generated_sql_query = run_llm(prompt)
|
218 |
print(f'Generated SQL Query: {generated_sql_query}')
|
219 |
except Exception as e:
|
220 |
+
return generate_output(schema, prompt, '', fig, pd.DataFrame([{"error": f"β Unable to generate the SQL query. {e}"}]))
|
221 |
|
222 |
try:
|
223 |
+
sql_query_result_raw = conn.sql(generated_sql_query)
|
224 |
+
sql_query_result, = sql_query_result_raw.fetchall()
|
225 |
+
sql_query_df = sql_query_result_raw.df()
|
226 |
+
|
227 |
print(f"SQL Query Result: {sql_query_result}")
|
228 |
except Exception as e:
|
229 |
+
return generate_output(schema, prompt, generated_sql_query, fig, pd.DataFrame([{"error": f"β Unable to execute the SQL query. {e}"}]))
|
230 |
|
231 |
try:
|
|
|
232 |
visualization_type, reason = get_visualization_type(text_query=text_query,
|
233 |
sql_query=generated_sql_query, sql_result=sql_query_result)
|
234 |
|
235 |
print(f"Visualization Type: {visualization_type}")
|
236 |
+
print(f"Reason: {reason}")
|
237 |
|
238 |
except Exception as e:
|
239 |
return generate_output(schema, prompt, generated_sql_query,
|
240 |
+
fig, sql_query_df)
|
241 |
try:
|
242 |
|
243 |
plot = visualize_result(text_query=text_query, sql_query=generated_sql_query,
|
244 |
+
sql_result=sql_query_result, visualization_type=visualization_type)
|
245 |
+
|
246 |
except Exception as e:
|
247 |
+
return generate_output(schema, prompt, generated_sql_query, fig, sql_query_df)
|
248 |
|
249 |
+
return generate_output(schema, prompt, generated_sql_query, plot, sql_query_df)
|
250 |
|
251 |
+
def generate_output(schema, prompt, generated_sql_query, fig, sql_query_df):
|
252 |
+
return [schema, prompt, generated_sql_query, fig, sql_query_df]
|
253 |
|
254 |
custom_css = """
|
255 |
.gradio-container {
|
|
|
294 |
generate_query_button = gr.Button("Run Query", variant="primary")
|
295 |
|
296 |
with gr.Tabs():
|
297 |
+
with gr.Tab("Result"):
|
298 |
+
query_result_output = gr.DataFrame(label="Query Results", value=[], interactive=False)
|
299 |
with gr.Tab("Plot"):
|
300 |
+
result_plot = gr.Plot()
|
|
|
301 |
with gr.Tab("SQL Query"):
|
302 |
generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
|
303 |
with gr.Tab("Prompt"):
|
|
|
306 |
table_schema = gr.Textbox(lines=TAB_LINES, label="Table Schema", value="", interactive=False)
|
307 |
|
308 |
schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
|
309 |
+
generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_plot, query_result_output])
|
310 |
|
311 |
if __name__ == "__main__":
|
312 |
demo.launch(debug=True)
|