DataViz-Agent / app.py
Mustehson
Gemma Model
e44c00b
raw
history blame
14.1 kB
import os
import json
import torch
import duckdb
import spaces
import pandas as pd
import gradio as gr
import matplotlib.pyplot as plt
from visualization_prompt import graph_instructions
from langchain_core.prompts import ChatPromptTemplate
from langchain_huggingface.llms import HuggingFacePipeline
from plot_utils import plot_bar_chart, plot_horizontal_bar_chart, plot_line_graph, plot_scatter, plot_pie
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
# Height of the Tabs Text Area
TAB_LINES = 8
# Load Token
md_token = os.getenv('MD_TOKEN')
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
print('Connecting to DB...')
# Connect to DB
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
device = torch.device("cpu")
print("Using CPU")
print('Loading Model...')
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type= "nf4")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", quantization_config=quantization_config,
device_map="auto", torch_dtype=torch.bfloat16)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, return_full_text=False, max_new_tokens=512)
llm = HuggingFacePipeline(pipeline=pipe)
print('Model Loaded...')
print(f'Model Device: {model.device}')
def get_schemas():
schemas = conn.execute("""
SELECT DISTINCT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('information_schema', 'pg_catalog')
""").fetchall()
return [item[0] for item in schemas]
# Get Tables
def get_tables(schema_name):
tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall()
return [table[0] for table in tables]
# Update Tables
def update_tables(schema_name):
tables = get_tables(schema_name)
return gr.update(choices=tables)
# Get Schema
def get_table_schema(table):
result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df()
ddl_create = result.iloc[0,0]
parent_database = result.iloc[0,1]
schema_name = result.iloc[0,2]
full_path = f"{parent_database}.{schema_name}.{table}"
if schema_name != "main":
old_path = f"{schema_name}.{table}"
else:
old_path = table
ddl_create = ddl_create.replace(old_path, full_path)
return ddl_create
# Get Prompt
def get_prompt(schema, query_input):
text = f"""
### Instruction:
Your task is to generate valid duckdb SQL query to answer the following question. Only Respond with SQL query do not inlcude anything else.
### Input:
Here is the database schema that the SQL query will run on:
{schema}
### Question:
{query_input}
### Response (use duckdb shorthand if possible):
"""
return text
@spaces.GPU(duration=60)
def run_llm(prompt):
result = llm.invoke(prompt)
return result
def get_visualization_type(text_query, sql_query, sql_result):
system_message = '''
You are an AI assistant that recommends appropriate data visualizations. Based on the user's question, SQL query, and query results, suggest the most suitable type of graph or chart to visualize the data. If no visualization is appropriate, indicate that.
Available chart types and their use cases:
- Bar Graphs: Best for comparing categorical data or showing changes over time when categories are discrete and the number of categories is more than 2. Use for questions like "What are the sales figures for each product?" or "How does the population of cities compare? or "What percentage of each city is male?"
- Horizontal Bar Graphs: Best for comparing categorical data or showing changes over time when the number of categories is small or the disparity between categories is large. Use for questions like "Show the revenue of A and B?" or "How does the population of 2 cities compare?" or "How many men and women got promoted?" or "What percentage of men and what percentage of women got promoted?" when the disparity between categories is large.
- Scatter Plots: Useful for identifying relationships or correlations between two numerical variables or plotting distributions of data. Best used when both x axis and y axis are continuous. Use for questions like "Plot a distribution of the fares (where the x axis is the fare and the y axis is the count of people who paid that fare)" or "Is there a relationship between advertising spend and sales?" or "How do height and weight correlate in the dataset? Do not use it for questions that do not have a continuous x axis."
- Pie Charts: Ideal for showing proportions or percentages within a whole. Use for questions like "What is the market share distribution among different companies?" or "What percentage of the total revenue comes from each product?"
- Line Graphs: Best for showing trends and distributionsover time. Best used when both x axis and y axis are continuous. Used for questions like "How have website visits changed over the year?" or "What is the trend in temperature over the past decade?". Do not use it for questions that do not have a continuous x axis or a time based x axis.
Consider these types of questions when recommending a visualization:
1. Aggregations and Summarizations (e.g., "What is the average revenue by month?" - Line Graph)
2. Comparisons (e.g., "Compare the sales figures of Product A and Product B over the last year." - Line or Column Graph)
3. Plotting Distributions (e.g., "Plot a distribution of the age of users" - Scatter Plot)
4. Trends Over Time (e.g., "What is the trend in the number of active users over the past year?" - Line Graph)
5. Proportions (e.g., "What is the market share of the products?" - Pie Chart)
6. Correlations (e.g., "Is there a correlation between marketing spend and revenue?" - Scatter Plot)
Provide your response in the following format:
Recommended Visualization: [Chart type or "None"]. ONLY use the following names: bar, horizontal_bar, line, pie, scatter, none
Reason: [Brief explanation for your recommendation]
'''
human_message = '''
User question: {question}
SQL query: {sql_query}
Query results: {results}
Recommend a visualization:
'''
prompt = ChatPromptTemplate.from_messages([
("system", system_message),
("human", human_message),
])
final_prompt = prompt.format_prompt(question=text_query,
sql_query=sql_query, results=sql_result)
response = run_llm(final_prompt)
response = response.replace('```', '')
lines = response.strip().split('\n')
print(lines)
visualization = lines[0].split(': ')[1]
reason = lines[1].split(': ')[1]
return visualization, reason
def format_data(text_query, sql_query, sql_result, visualization_type):
instruction = graph_instructions[visualization_type]
template = ChatPromptTemplate.from_messages([
("system", "You are a Data expert who formats data according to the required needs. You are given the question asked by the user, it's sql query, the result of the query and the format you need to format it in."),
("human", "For the given question: {question}\n\nSQL query: {sql_query}\n\Result: {results}\n\nUse the following example to structure the data: {instructions}. If there is None in Result please change it to '0'. Just give the json string. Do not format it. Do not give backticks."),
])
prompt = template.format_prompt(question=text_query, sql_query=sql_query,
results=sql_result, instructions=instruction)
print(prompt)
formatted_data = run_llm(prompt)
print(f'Formatted Data {formatted_data}')
return json.loads(formatted_data.replace('.', '').strip())
def visualize_result(text_query, visualization_type, sql_query,
sql_result):
if visualization_type == 'bar':
data = format_data(text_query=text_query, sql_query=sql_query,
sql_result=sql_result, visualization_type=visualization_type)
return plot_bar_chart(data)
elif visualization_type == 'horizontal_bar':
data = format_data(text_query=text_query, sql_query=sql_query,
sql_result=sql_result, visualization_type=visualization_type)
return plot_horizontal_bar_chart(data)
elif visualization_type == 'line':
data = format_data(text_query=text_query, sql_query=sql_query,
sql_result=sql_result, visualization_type=visualization_type)
return plot_line_graph(data)
elif visualization_type == 'pie':
data = format_data(text_query=text_query, sql_query=sql_query,
sql_result=sql_result, visualization_type=visualization_type)
return plot_pie(data)
elif visualization_type == 'scatter':
data = format_data(text_query=text_query, sql_query=sql_query,
sql_result=sql_result, visualization_type=visualization_type)
return plot_scatter(data)
elif visualization_type == 'none':
fig, ax = plt.subplots()
ax.set_visible(False)
return fig
def main(table, text_query):
if table is None:
return ["", "", "", pd.DataFrame([{"error": "❌ Table is None."}])]
fig, ax = plt.subplots()
ax.set_visible(False)
schema = get_table_schema(table)
prompt = get_prompt(schema, text_query)
try:
generated_sql_query = run_llm(prompt)
print(f'Generated SQL Query: {generated_sql_query}')
except Exception as e:
return generate_output(schema, prompt, '', fig, pd.DataFrame([{"error": f"❌ Unable to generate the SQL query. {e}"}]))
try:
sql_query_result_raw = conn.sql(generated_sql_query)
sql_query_result = sql_query_result_raw.fetchall()
sql_query_df = sql_query_result_raw.df()
print(f"SQL Query Result: {sql_query_result}")
except Exception as e:
return generate_output(schema, prompt, generated_sql_query, fig, pd.DataFrame([{"error": f"❌ Unable to execute the SQL query. {e}"}]))
try:
visualization_type, reason = get_visualization_type(text_query=text_query,
sql_query=generated_sql_query, sql_result=sql_query_result)
print(f"Visualization Type: {visualization_type}")
print(f"Reason: {reason}")
except Exception as e:
return generate_output(schema, prompt, generated_sql_query,
fig, sql_query_df)
if visualization_type != 'none':
try:
plot = visualize_result(text_query=text_query, sql_query=generated_sql_query,
sql_result=sql_query_result, visualization_type=visualization_type)
except Exception as e:
gr.Warning(f"⚠️ {e}")
return generate_output(schema, prompt, generated_sql_query, fig, sql_query_df)
else:
gr.Warning(f"⚠️ {reason}")
return generate_output(schema, prompt, generated_sql_query, fig, sql_query_df)
return generate_output(schema, prompt, generated_sql_query, plot, sql_query_df)
def generate_output(schema, prompt, generated_sql_query, result_plot, sql_query_df):
return [schema, prompt, generated_sql_query, result_plot, sql_query_df]
custom_css = """
.gradio-container {
background-color: #f0f4f8;
}
.logo {
max-width: 200px;
margin: 20px auto;
display: block;
}
.gr-button {
background-color: #4a90e2 !important;
}
.gr-button:hover {
background-color: #3a7bc8 !important;
}
"""
with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo:
gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
gr.Markdown("""
<div style='text-align: center;'>
<strong style='font-size: 36px;'>DataViz Agent</strong>
<br>
<span style='font-size: 20px;'>Visualize SQL queries based on a given text for the dataset.</span>
</div>
""")
with gr.Row():
with gr.Column(scale=1, variant='panel'):
schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
with gr.Column(scale=2):
query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter your text query here...")
with gr.Row():
with gr.Column(scale=7):
pass
with gr.Column(scale=1):
generate_query_button = gr.Button("Run Query", variant="primary")
with gr.Tabs():
with gr.Tab("Result"):
query_result_output = gr.DataFrame(label="Query Results", value=[], interactive=False)
with gr.Tab("Plot"):
result_plot = gr.Plot()
with gr.Tab("SQL Query"):
generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
with gr.Tab("Prompt"):
input_prompt = gr.Textbox(lines=TAB_LINES, label="Input Prompt", value="", interactive=False)
with gr.Tab("Schema"):
table_schema = gr.Textbox(lines=TAB_LINES, label="Table Schema", value="", interactive=False)
schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_plot, query_result_output])
if __name__ == "__main__":
demo.launch(debug=True)