File size: 14,120 Bytes
4858ba5
 
 
 
 
 
2a1d061
4858ba5
 
 
 
 
 
2a1d061
4858ba5
 
 
 
 
2a1d061
4858ba5
 
 
2a1d061
4858ba5
 
 
 
 
 
2a1d061
4858ba5
2a1d061
e44c00b
 
2a1d061
4858ba5
e44c00b
 
 
 
2a1d061
e44c00b
4858ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e44c00b
4858ba5
e44c00b
4858ba5
 
 
 
 
e44c00b
 
 
4858ba5
 
 
 
 
e44c00b
4858ba5
 
 
 
 
e44c00b
4858ba5
 
 
 
1734c0d
4858ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e44c00b
4858ba5
 
 
 
 
 
 
 
 
 
 
 
1734c0d
4858ba5
 
1734c0d
bf6bb3c
1734c0d
 
4858ba5
 
1734c0d
4858ba5
 
 
 
 
 
1734c0d
4858ba5
 
 
1734c0d
bf6bb3c
 
 
4858ba5
bf6bb3c
 
1734c0d
bf6bb3c
0eb72dc
bf6bb3c
 
 
 
4858ba5
1734c0d
4858ba5
0eb72dc
 
4858ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a1d061
 
4858ba5
 
 
 
 
b2bfa4a
4858ba5
b2bfa4a
4858ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1734c0d
 
4858ba5
1734c0d
4858ba5
 
 
 
 
 
 
 
1734c0d
2a1d061
 
4858ba5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import os
import json
import torch
import duckdb
import spaces
import pandas as pd
import gradio as gr
import matplotlib.pyplot as plt
from visualization_prompt import graph_instructions
from langchain_core.prompts import ChatPromptTemplate
from langchain_huggingface.llms import HuggingFacePipeline
from plot_utils import plot_bar_chart, plot_horizontal_bar_chart, plot_line_graph, plot_scatter, plot_pie
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline

# Height of the Tabs Text Area
TAB_LINES = 8
# Load Token
md_token = os.getenv('MD_TOKEN')
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')

print('Connecting to DB...')
# Connect to DB
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

print('Loading Model...')


tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type= "nf4")

model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", quantization_config=quantization_config,
                                             device_map="auto", torch_dtype=torch.bfloat16)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, return_full_text=False, max_new_tokens=512)
llm = HuggingFacePipeline(pipeline=pipe)


print('Model Loaded...')
print(f'Model Device: {model.device}')

def get_schemas():
    schemas = conn.execute("""
    SELECT DISTINCT schema_name
    FROM information_schema.schemata
    WHERE schema_name NOT IN ('information_schema', 'pg_catalog')
    """).fetchall()
    return [item[0] for item in schemas]

# Get Tables
def get_tables(schema_name):
    tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall()
    return [table[0] for table in tables]

# Update Tables
def update_tables(schema_name):
    tables = get_tables(schema_name)
    return gr.update(choices=tables)

# Get Schema
def get_table_schema(table):
    result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df()
    ddl_create = result.iloc[0,0]
    parent_database = result.iloc[0,1]
    schema_name = result.iloc[0,2]
    full_path = f"{parent_database}.{schema_name}.{table}"
    if schema_name != "main":
        old_path = f"{schema_name}.{table}"
    else:
        old_path = table
    ddl_create = ddl_create.replace(old_path, full_path)
    return ddl_create

# Get Prompt
def get_prompt(schema, query_input):
    text = f"""
    ### Instruction:
    Your task is to generate valid duckdb SQL query to answer the following question. Only Respond with SQL query do not inlcude anything else.
    ### Input:
    Here is the database schema that the SQL query will run on:
    {schema}

    ### Question:
    {query_input}
    ### Response (use duckdb shorthand if possible):
    """
    return text

@spaces.GPU(duration=60)
def run_llm(prompt):
    result = llm.invoke(prompt)
    return result


def get_visualization_type(text_query, sql_query, sql_result):
    system_message = '''
You are an AI assistant that recommends appropriate data visualizations. Based on the user's question, SQL query, and query results, suggest the most suitable type of graph or chart to visualize the data. If no visualization is appropriate, indicate that.

Available chart types and their use cases:
- Bar Graphs: Best for comparing categorical data or showing changes over time when categories are discrete and the number of categories is more than 2. Use for questions like "What are the sales figures for each product?" or "How does the population of cities compare? or "What percentage of each city is male?"
- Horizontal Bar Graphs: Best for comparing categorical data or showing changes over time when the number of categories is small or the disparity between categories is large. Use for questions like "Show the revenue of A and B?" or "How does the population of 2 cities compare?" or "How many men and women got promoted?" or "What percentage of men and what percentage of women got promoted?" when the disparity between categories is large.
- Scatter Plots: Useful for identifying relationships or correlations between two numerical variables or plotting distributions of data. Best used when both x axis and y axis are continuous. Use for questions like "Plot a distribution of the fares (where the x axis is the fare and the y axis is the count of people who paid that fare)" or "Is there a relationship between advertising spend and sales?" or "How do height and weight correlate in the dataset? Do not use it for questions that do not have a continuous x axis."
- Pie Charts: Ideal for showing proportions or percentages within a whole. Use for questions like "What is the market share distribution among different companies?" or "What percentage of the total revenue comes from each product?"
- Line Graphs: Best for showing trends and distributionsover time. Best used when both x axis and y axis are continuous. Used for questions like "How have website visits changed over the year?" or "What is the trend in temperature over the past decade?". Do not use it for questions that do not have a continuous x axis or a time based x axis.

Consider these types of questions when recommending a visualization:
1. Aggregations and Summarizations (e.g., "What is the average revenue by month?" - Line Graph)
2. Comparisons (e.g., "Compare the sales figures of Product A and Product B over the last year." - Line or Column Graph)
3. Plotting Distributions (e.g., "Plot a distribution of the age of users" - Scatter Plot)
4. Trends Over Time (e.g., "What is the trend in the number of active users over the past year?" - Line Graph)
5. Proportions (e.g., "What is the market share of the products?" - Pie Chart)
6. Correlations (e.g., "Is there a correlation between marketing spend and revenue?" - Scatter Plot)

Provide your response in the following format:
Recommended Visualization: [Chart type or "None"]. ONLY use the following names: bar, horizontal_bar, line, pie, scatter, none
Reason: [Brief explanation for your recommendation]
'''
    human_message = '''
User question: {question}
SQL query: {sql_query}
Query results: {results}

Recommend a visualization:
'''

    prompt = ChatPromptTemplate.from_messages([
                    ("system", system_message),
                    ("human", human_message),
                ])

    final_prompt = prompt.format_prompt(question=text_query,
                                        sql_query=sql_query, results=sql_result)
    response = run_llm(final_prompt)
    response = response.replace('```', '')
    lines = response.strip().split('\n')
    print(lines)
    visualization = lines[0].split(': ')[1]
    reason = lines[1].split(': ')[1]

    return visualization, reason




def format_data(text_query, sql_query, sql_result, visualization_type):
    instruction = graph_instructions[visualization_type]

    template = ChatPromptTemplate.from_messages([
    ("system", "You are a Data expert who formats data according to the required needs. You are given the question asked by the user, it's sql query, the result of the query and the format you need to format it in."),
    ("human", "For the given question: {question}\n\nSQL query: {sql_query}\n\Result: {results}\n\nUse the following example to structure the data: {instructions}. If there is None in Result please change it to '0'. Just give the json string. Do not format it. Do not give backticks."),
    ])


    prompt = template.format_prompt(question=text_query, sql_query=sql_query,
                                    results=sql_result, instructions=instruction)
    print(prompt)
    formatted_data = run_llm(prompt)
    print(f'Formatted Data {formatted_data}')
    return json.loads(formatted_data.replace('.', '').strip())

def visualize_result(text_query, visualization_type, sql_query,
                      sql_result):

    if visualization_type == 'bar':
        data = format_data(text_query=text_query, sql_query=sql_query,
                           sql_result=sql_result, visualization_type=visualization_type)
        return plot_bar_chart(data)

    elif visualization_type == 'horizontal_bar':
        data = format_data(text_query=text_query, sql_query=sql_query,
                           sql_result=sql_result, visualization_type=visualization_type)
        return plot_horizontal_bar_chart(data)

    elif visualization_type == 'line':
        data = format_data(text_query=text_query, sql_query=sql_query,
                            sql_result=sql_result, visualization_type=visualization_type)
        return plot_line_graph(data)

    elif visualization_type == 'pie':
        data = format_data(text_query=text_query, sql_query=sql_query,
                           sql_result=sql_result, visualization_type=visualization_type)
        return plot_pie(data)

    elif visualization_type == 'scatter':
        data = format_data(text_query=text_query, sql_query=sql_query,
                           sql_result=sql_result, visualization_type=visualization_type)
        return plot_scatter(data)

    elif visualization_type == 'none':
        fig, ax = plt.subplots()
        ax.set_visible(False) 
        return fig



def main(table, text_query):
    if table is None:
        return ["", "", "", pd.DataFrame([{"error": "❌ Table is None."}])]
    fig, ax = plt.subplots()
    ax.set_visible(False)

    schema = get_table_schema(table)
    prompt = get_prompt(schema, text_query)
    try:
        generated_sql_query = run_llm(prompt)
        print(f'Generated SQL Query: {generated_sql_query}')
    except Exception as e:
        return generate_output(schema, prompt, '', fig, pd.DataFrame([{"error": f"❌ Unable to generate the SQL query. {e}"}]))

    try:
        sql_query_result_raw = conn.sql(generated_sql_query)
        sql_query_result  = sql_query_result_raw.fetchall()
        sql_query_df = sql_query_result_raw.df()
        
        print(f"SQL Query Result: {sql_query_result}")
    except Exception as e:
        return generate_output(schema, prompt, generated_sql_query, fig, pd.DataFrame([{"error": f"❌ Unable to execute the SQL query. {e}"}]))

    try:
      visualization_type, reason = get_visualization_type(text_query=text_query,
                                                  sql_query=generated_sql_query, sql_result=sql_query_result)

      print(f"Visualization Type: {visualization_type}")
      print(f"Reason: {reason}")

    except Exception as e:
        return generate_output(schema, prompt, generated_sql_query,
                               fig, sql_query_df)
    
    if visualization_type != 'none':
        try:

            plot = visualize_result(text_query=text_query, sql_query=generated_sql_query,
                        sql_result=sql_query_result, visualization_type=visualization_type)

        except Exception as e:
            gr.Warning(f"⚠️ {e}")
            return generate_output(schema, prompt, generated_sql_query, fig, sql_query_df)
    else:
      gr.Warning(f"⚠️ {reason}")
      return generate_output(schema, prompt, generated_sql_query, fig, sql_query_df)

    return generate_output(schema, prompt, generated_sql_query, plot, sql_query_df)

def generate_output(schema, prompt, generated_sql_query, result_plot, sql_query_df):
    return [schema, prompt, generated_sql_query, result_plot, sql_query_df]

custom_css = """
.gradio-container {
    background-color: #f0f4f8;
}
.logo {
    max-width: 200px;
    margin: 20px auto;
    display: block;
}
.gr-button {
    background-color: #4a90e2 !important;
}
.gr-button:hover {
    background-color: #3a7bc8 !important;
}
"""

with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo:
    gr.Image("logo.png", label=None, show_label=False, container=False, height=100)

    gr.Markdown("""
    <div style='text-align: center;'>
    <strong style='font-size: 36px;'>DataViz Agent</strong>
    <br>
    <span style='font-size: 20px;'>Visualize SQL queries based on a given text for the dataset.</span>
    </div>
    """)

    with gr.Row():

        with gr.Column(scale=1, variant='panel'):
            schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
            tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)

        with gr.Column(scale=2):
            query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter your text query here...")
            with gr.Row():
                with gr.Column(scale=7):
                    pass
                with gr.Column(scale=1):
                    generate_query_button = gr.Button("Run Query", variant="primary")

            with gr.Tabs():
                with gr.Tab("Result"):
                    query_result_output = gr.DataFrame(label="Query Results", value=[], interactive=False)
                with gr.Tab("Plot"):
                    result_plot = gr.Plot()
                with gr.Tab("SQL Query"):
                    generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
                with gr.Tab("Prompt"):
                    input_prompt = gr.Textbox(lines=TAB_LINES, label="Input Prompt", value="", interactive=False)
                with gr.Tab("Schema"):
                    table_schema = gr.Textbox(lines=TAB_LINES, label="Table Schema", value="", interactive=False)

        schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
        generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_plot, query_result_output])

if __name__ == "__main__":
    demo.launch(debug=True)