DataViz-Agent / app.py
Mustehson
Added empty fig
28623de
raw
history blame
5 kB
import os
import duckdb
import gradio as gr
import matplotlib.pyplot as plt
from transformers import HfEngine, ReactCodeAgent
from transformers.agents import Tool
# Height of the Tabs Text Area
TAB_LINES = 8
# Load Token
md_token = os.getenv('MD_TOKEN')
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
print('Connecting to DB...')
# Connect to DB
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
llm_engine = HfEngine(model="meta-llama/Meta-Llama-3-70B-Instruct")
def get_schemas():
schemas = conn.execute("""
SELECT DISTINCT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('information_schema', 'pg_catalog')
""").fetchall()
return [item[0] for item in schemas]
# Get Tables
def get_tables(schema_name):
tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall()
return [table[0] for table in tables]
# Update Tables
def update_tables(schema_name):
tables = get_tables(schema_name)
return gr.update(choices=tables)
# Get Schema
def get_table_schema(table):
result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df()
ddl_create = result.iloc[0,0]
parent_database = result.iloc[0,1]
schema_name = result.iloc[0,2]
full_path = f"{parent_database}.{schema_name}.{table}"
if schema_name != "main":
old_path = f"{schema_name}.{table}"
else:
old_path = table
ddl_create = ddl_create.replace(old_path, full_path)
return ddl_create, full_path
def get_visualization(question, tool):
agent = ReactCodeAgent(tools=[tool], llm_engine=llm_engine, add_base_tools=True,
additional_authorized_imports=['matplotlib.pyplot',
'pandas', 'plotly.express',
'seaborn'], max_iterations=20)
fig = agent.run(
task=f'''
Use seaborn. Always
Question: {question}
Always use the right colors.
If the question is about showing n number of rows return empty figure.
In the end you have to return a final fig using the `final_answer` tool
''',
)
return fig
class SQLExecutorTool(Tool):
name = "sql_engine"
inputs = {
"query": {
"type": "text",
"description": f"The query to perform. This should be correct DuckDB SQL.",
}
}
output_type = "pandas.core.frame.DataFrame"
def forward(self, query: str) -> str:
with duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) as con:
output_df = conn.sql(query).df()
return output_df
tool = SQLExecutorTool()
def main(table, text_query):
# Empty Fig
fig, ax = plt.subplots()
ax.set_axis_off()
schema, _ = get_table_schema(table)
tool.description = f"""Allows you to perform SQL queries on the table. Returns a pandas dataframe representation of the result.
The table schema is as follows: \n{schema}"""
try:
fig = get_visualization(question=text_query, tool=tool)
except Exception as e:
gr.Warning(f"❌ Unable to generate the visualization. {e}")
return fig
custom_css = """
.gradio-container {
background-color: #f0f4f8;
}
.logo {
max-width: 200px;
margin: 20px auto;
display: block;
}
.gr-button {
background-color: #4a90e2 !important;
}
.gr-button:hover {
background-color: #3a7bc8 !important;
}
"""
with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo:
gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
gr.Markdown("""
<div style='text-align: center;'>
<strong style='font-size: 36px;'>DataViz Agent</strong>
<br>
<span style='font-size: 20px;'>Visualize SQL queries based on a given text for the dataset.</span>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
with gr.Column(scale=2):
query_input = gr.Textbox(lines=3, label="Text Query", placeholder="Enter your text query here...")
with gr.Row():
with gr.Column(scale=7):
pass
with gr.Column(scale=1):
generate_query_button = gr.Button("Run Query", variant="primary")
with gr.Tabs():
with gr.Tab("Plot"):
result_plot = gr.Plot()
schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[result_plot])
if __name__ == "__main__":
demo.launch(debug=True)