Mustehson commited on
Commit
4858ba5
Β·
1 Parent(s): 72bb4d1

Initial Commit

Browse files
Files changed (4) hide show
  1. README.md +2 -2
  2. app.py +309 -49
  3. plot_utils.py +78 -0
  4. visualization_prompt.py +151 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: DataViz Agent
3
- emoji: πŸ’¬
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
  title: DataViz Agent
3
+ emoji: πŸ“Š
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.42.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -1,63 +1,323 @@
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
27
 
28
- response = ""
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import duckdb
5
+ import spaces
6
+ import pandas as pd
7
  import gradio as gr
8
+ import matplotlib.pyplot as plt
9
+ from visualization_prompt import graph_instructions
10
+ from langchain_core.prompts import ChatPromptTemplate
11
+ from langchain_huggingface.llms import HuggingFacePipeline
12
+ from plot_utils import plot_bar_chart, plot_horizontal_bar_chart, plot_line_graph, plot_scatter, plot_pie
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
14
+ from dotenv import load_dotenv
15
 
16
+ load_dotenv()
 
 
 
17
 
18
 
19
+ # Height of the Tabs Text Area
20
+ TAB_LINES = 8
21
+ # Load Token
22
+ md_token = os.getenv('MD_TOKEN')
23
+ os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
 
 
 
 
24
 
25
+ print('Connecting to DB...')
26
+ # Connect to DB
27
+ conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
 
 
28
 
29
+ if torch.cuda.is_available():
30
+ device = torch.device("cuda")
31
+ print(f"Using GPU: {torch.cuda.get_device_name(device)}")
32
+ else:
33
+ device = torch.device("cpu")
34
+ print("Using CPU")
35
 
36
+ print('Loading Model...')
37
 
38
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
 
 
 
 
 
 
 
39
 
40
+ quantization_config = BitsAndBytesConfig(
41
+ load_in_4bit=True,
42
+ bnb_4bit_compute_dtype=torch.bfloat16,
43
+ bnb_4bit_use_double_quant=True,
44
+ bnb_4bit_quant_type= "nf4")
45
 
46
+ model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", quantization_config=quantization_config,
47
+ device_map="auto", torch_dtype=torch.bfloat16)
48
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, return_full_text=False, max_new_tokens=512)
49
+ llm = HuggingFacePipeline(pipeline=pipe)
50
+
51
+
52
+ print('Model Loaded...')
53
+ print(f'Model Device: {model.device}')
54
+
55
+
56
+ def generate_output(schema, prompt, generated_sql_query, result_output):
57
+ return {
58
+ table_schema: schema,
59
+ input_prompt: prompt,
60
+ generated_query: generated_sql_query,
61
+ result_output: result_output,
62
+ }
63
+
64
+ # Get Databases
65
+ def get_schemas():
66
+ schemas = conn.execute("""
67
+ SELECT DISTINCT schema_name
68
+ FROM information_schema.schemata
69
+ WHERE schema_name NOT IN ('information_schema', 'pg_catalog')
70
+ """).fetchall()
71
+ return [item[0] for item in schemas]
72
+
73
+ # Get Tables
74
+ def get_tables(schema_name):
75
+ tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall()
76
+ return [table[0] for table in tables]
77
+
78
+ # Update Tables
79
+ def update_tables(schema_name):
80
+ tables = get_tables(schema_name)
81
+ return gr.update(choices=tables)
82
+
83
+ # Get Schema
84
+ def get_table_schema(table):
85
+ result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df()
86
+ ddl_create = result.iloc[0,0]
87
+ parent_database = result.iloc[0,1]
88
+ schema_name = result.iloc[0,2]
89
+ full_path = f"{parent_database}.{schema_name}.{table}"
90
+ if schema_name != "main":
91
+ old_path = f"{schema_name}.{table}"
92
+ else:
93
+ old_path = table
94
+ ddl_create = ddl_create.replace(old_path, full_path)
95
+ return ddl_create
96
+
97
+ # Get Prompt
98
+ def get_prompt(schema, query_input):
99
+ text = f"""
100
+ ### Instruction:
101
+ Your task is to generate valid duckdb SQL query to answer the following question. Only Respond with SQL query do not inlcude anything else.
102
+ ### Input:
103
+ Here is the database schema that the SQL query will run on:
104
+ {schema}
105
+
106
+ ### Question:
107
+ {query_input}
108
+ ### Response (use duckdb shorthand if possible):
109
+ """
110
+ return text
111
+
112
+ @spaces.GPU(duration=60)
113
+ def run_llm(prompt):
114
+ result = llm.invoke(prompt)
115
+ return result
116
+
117
+
118
+ def get_visualization_type(text_query, sql_query, sql_result):
119
+ system_message = '''
120
+ You are an AI assistant that recommends appropriate data visualizations. Based on the user's question, SQL query, and query results, suggest the most suitable type of graph or chart to visualize the data. If no visualization is appropriate, indicate that.
121
+
122
+ Available chart types and their use cases:
123
+ - Bar Graphs: Best for comparing categorical data or showing changes over time when categories are discrete and the number of categories is more than 2. Use for questions like "What are the sales figures for each product?" or "How does the population of cities compare? or "What percentage of each city is male?"
124
+ - Horizontal Bar Graphs: Best for comparing categorical data or showing changes over time when the number of categories is small or the disparity between categories is large. Use for questions like "Show the revenue of A and B?" or "How does the population of 2 cities compare?" or "How many men and women got promoted?" or "What percentage of men and what percentage of women got promoted?" when the disparity between categories is large.
125
+ - Scatter Plots: Useful for identifying relationships or correlations between two numerical variables or plotting distributions of data. Best used when both x axis and y axis are continuous. Use for questions like "Plot a distribution of the fares (where the x axis is the fare and the y axis is the count of people who paid that fare)" or "Is there a relationship between advertising spend and sales?" or "How do height and weight correlate in the dataset? Do not use it for questions that do not have a continuous x axis."
126
+ - Pie Charts: Ideal for showing proportions or percentages within a whole. Use for questions like "What is the market share distribution among different companies?" or "What percentage of the total revenue comes from each product?"
127
+ - Line Graphs: Best for showing trends and distributionsover time. Best used when both x axis and y axis are continuous. Used for questions like "How have website visits changed over the year?" or "What is the trend in temperature over the past decade?". Do not use it for questions that do not have a continuous x axis or a time based x axis.
128
+
129
+ Consider these types of questions when recommending a visualization:
130
+ 1. Aggregations and Summarizations (e.g., "What is the average revenue by month?" - Line Graph)
131
+ 2. Comparisons (e.g., "Compare the sales figures of Product A and Product B over the last year." - Line or Column Graph)
132
+ 3. Plotting Distributions (e.g., "Plot a distribution of the age of users" - Scatter Plot)
133
+ 4. Trends Over Time (e.g., "What is the trend in the number of active users over the past year?" - Line Graph)
134
+ 5. Proportions (e.g., "What is the market share of the products?" - Pie Chart)
135
+ 6. Correlations (e.g., "Is there a correlation between marketing spend and revenue?" - Scatter Plot)
136
+
137
+ Provide your response in the following format:
138
+ Recommended Visualization: [Chart type or "None"]. ONLY use the following names: bar, horizontal_bar, line, pie, scatter, none
139
+ Reason: [Brief explanation for your recommendation]
140
+ '''
141
+ human_message = '''
142
+ User question: {question}
143
+ SQL query: {sql_query}
144
+ Query results: {results}
145
+
146
+ Recommend a visualization:
147
+ '''
148
+
149
+ prompt = ChatPromptTemplate.from_messages([
150
+ ("system", system_message),
151
+ ("human", human_message),
152
+ ])
153
+
154
+ final_prompt = prompt.format_prompt(question=text_query,
155
+ sql_query=sql_query, results=sql_result)
156
+ print(final_prompt)
157
+ response = run_llm(final_prompt)
158
+
159
+ print(f"Visualization Type: {response}")
160
+
161
+ lines = response.strip().split('\n')
162
+ visualization = lines[0].split(': ')[1]
163
+ reason = lines[1].split(': ')[1]
164
+
165
+
166
+ print(visualization, reason)
167
+ return visualization, reason
168
+
169
+ def format_data(text_query, sql_query, sql_result, visualization_type):
170
+ instruction = graph_instructions[visualization_type]
171
+
172
+ template = ChatPromptTemplate.from_messages([
173
+ ("system", "You are a Data expert who formats data according to the required needs. You are given the question asked by the user, it's sql query, the result of the query and the format you need to format it in."),
174
+ ("human", "For the given question: {question}\n\nSQL query: {sql_query}\n\Result: {results}\n\nUse the following example to structure the data: {instructions}. If there is None in Result please change it to '0'. Just give the json string. Do not format it or add any label or text."),
175
+ ])
176
+
177
+
178
+ prompt = template.format_prompt(question=text_query, sql_query=sql_query,
179
+ results=sql_result, instructions=instruction)
180
+ print(f'New Template for Data fromat {prompt}')
181
+
182
+ formatted_data = run_llm(prompt)
183
+ print(f'Formatted Data {formatted_data}')
184
+
185
+ return json.loads(formatted_data.replace('.', '').strip())
186
+
187
+ def visualize_result(text_query, visualization_type, reason, sql_query,
188
+ sql_result):
189
+
190
+ if visualization_type == 'bar':
191
+ data = format_data(text_query=text_query, sql_query=sql_query,
192
+ sql_result=sql_result, visualization_type=visualization_type)
193
+ print(data)
194
+ return plot_bar_chart(data)
195
+
196
+ elif visualization_type == 'horizontal_bar':
197
+ data = format_data(text_query=text_query, sql_query=sql_query,
198
+ sql_result=sql_result, visualization_type=visualization_type)
199
+ return plot_horizontal_bar_chart(data)
200
+
201
+ elif visualization_type == 'line':
202
+ data = format_data(text_query=text_query, sql_query=sql_query,
203
+ sql_result=sql_result, visualization_type=visualization_type)
204
+ return plot_line_graph(data)
205
+
206
+ elif visualization_type == 'pie':
207
+ data = format_data(text_query=text_query, sql_query=sql_query,
208
+ sql_result=sql_result, visualization_type=visualization_type)
209
+ return plot_pie(data)
210
+
211
+ elif visualization_type == 'scatter':
212
+ data = format_data(text_query=text_query, sql_query=sql_query,
213
+ sql_result=sql_result, visualization_type=visualization_type)
214
+ return plot_scatter(data)
215
+
216
+ elif visualization_type == 'none':
217
+ fig, ax = plt.subplots()
218
+ ax.set_visible(False)
219
+ return fig
220
+
221
+
222
+ def main(table, text_query):
223
+ if table is None:
224
+ return ["", "", "", pd.DataFrame([{"error": "❌ Table is None."}])]
225
+ fig, ax = plt.subplots()
226
+ ax.set_visible(False)
227
+
228
+ schema = get_table_schema(table)
229
+ prompt = get_prompt(schema, text_query)
230
+ try:
231
+ generated_sql_query = run_llm(prompt)
232
+ print(f'Generated SQL Query: {generated_sql_query}')
233
+ except Exception as e:
234
+ return generate_output(schema, prompt, '', fig)
235
+
236
+ try:
237
+ sql_query_result = conn.sql(generated_sql_query).fetchall()
238
+ print(f"SQL Query Result: {sql_query_result}")
239
+ except Exception as e:
240
+ return generate_output(schema, prompt, generated_sql_query, fig)
241
+
242
+ try:
243
+ print('Generating Visualization Type')
244
+ visualization_type, reason = get_visualization_type(text_query=text_query,
245
+ sql_query=generated_sql_query, sql_result=sql_query_result)
246
+
247
+ print(f"Visualization Type: {visualization_type}")
248
+
249
+ except Exception as e:
250
+ return generate_output(schema, prompt, generated_sql_query,
251
+ fig)
252
+ try:
253
+
254
+ plot = visualize_result(text_query=text_query, sql_query=generated_sql_query,
255
+ sql_result=sql_query_result, visualization_type=visualization_type, reason=reason)
256
+ print(f"Plot: {plot}")
257
+ except Exception as e:
258
+ return generate_output(schema, prompt, generated_sql_query, fig)
259
+
260
+ return generate_output(schema, prompt, generated_sql_query, plot)
261
+
262
+ def generate_output(schema, prompt, generated_sql_query, result_output):
263
+ # Returning values as a list in the order expected by Gradio
264
+ return [schema, prompt, generated_sql_query, result_output]
265
+
266
+ custom_css = """
267
+ .gradio-container {
268
+ background-color: #f0f4f8;
269
+ }
270
+ .logo {
271
+ max-width: 200px;
272
+ margin: 20px auto;
273
+ display: block;
274
+ }
275
+ .gr-button {
276
+ background-color: #4a90e2 !important;
277
+ }
278
+ .gr-button:hover {
279
+ background-color: #3a7bc8 !important;
280
+ }
281
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo:
284
+ gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
285
+
286
+ gr.Markdown("""
287
+ <div style='text-align: center;'>
288
+ <strong style='font-size: 36px;'>Datajoi SQL Agent</strong>
289
+ <br>
290
+ <span style='font-size: 20px;'>Generate and Run SQL queries based on a given text for the dataset.</span>
291
+ </div>
292
+ """)
293
+
294
+ with gr.Row():
295
+
296
+ with gr.Column(scale=1, variant='panel'):
297
+ schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
298
+ tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
299
+
300
+ with gr.Column(scale=2):
301
+ query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter your text query here...")
302
+ with gr.Row():
303
+ with gr.Column(scale=7):
304
+ pass
305
+ with gr.Column(scale=1):
306
+ generate_query_button = gr.Button("Run Query", variant="primary")
307
+
308
+ with gr.Tabs():
309
+ with gr.Tab("Plot"):
310
+ result_output = gr.Plot()
311
+ # result_output = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
312
+ with gr.Tab("SQL Query"):
313
+ generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
314
+ with gr.Tab("Prompt"):
315
+ input_prompt = gr.Textbox(lines=TAB_LINES, label="Input Prompt", value="", interactive=False)
316
+ with gr.Tab("Schema"):
317
+ table_schema = gr.Textbox(lines=TAB_LINES, label="Table Schema", value="", interactive=False)
318
+
319
+ schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
320
+ generate_query_button.click(main, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_output])
321
 
322
  if __name__ == "__main__":
323
+ demo.launch(debug=True)
plot_utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+
3
+ def plot_bar_chart(data):
4
+
5
+ fig, ax = plt.subplots()
6
+ labels = data['labels']
7
+ values = data['values']
8
+
9
+ for value in values:
10
+ ax.bar(labels, value['data'], label=value['label'])
11
+
12
+ ax.set_title('Bar Chart')
13
+ ax.set_xlabel('Labels')
14
+ ax.set_ylabel('Values')
15
+ ax.legend()
16
+ return fig
17
+
18
+ def plot_horizontal_bar_chart(data):
19
+ fig, ax = plt.subplots()
20
+ labels = data['labels']
21
+ values = data['values']
22
+
23
+ for value in values:
24
+ ax.barh(labels, value['data'], label=value['label'])
25
+
26
+ ax.set_title('Horizontal Bar Chart')
27
+ ax.set_xlabel('Values')
28
+ ax.set_ylabel('Labels')
29
+ ax.legend()
30
+ return fig
31
+
32
+ def plot_line_graph(data):
33
+ fig, ax = plt.subplots()
34
+ x_values = data['xValues']
35
+ y_values_list = data['yValues']
36
+
37
+ for y_values in y_values_list:
38
+ label = y_values.get('label', None)
39
+ ax.plot(x_values, y_values['data'], label=label)
40
+
41
+ ax.set_title('Line Graph')
42
+ ax.set_xlabel('X Values')
43
+ ax.set_ylabel('Y Values')
44
+
45
+ if any('label' in y_values for y_values in y_values_list):
46
+ ax.legend()
47
+
48
+ return fig
49
+
50
+ def plot_scatter(data):
51
+ fig, ax = plt.subplots()
52
+ for series in data['series']:
53
+ x_values = [point['x'] for point in series['data']]
54
+ y_values = [point['y'] for point in series['data']]
55
+ label = series.get('label', None)
56
+ ax.scatter(x_values, y_values, label=label)
57
+
58
+ ax.set_title('Scatter Plot')
59
+ ax.set_xlabel('X Values')
60
+ ax.set_ylabel('Y Values')
61
+
62
+ if any('label' in series for series in data['series']):
63
+ ax.legend()
64
+
65
+ return fig
66
+
67
+ def plot_pie(data):
68
+ values = [item['value'] for item in data]
69
+ labels = [item['label'] for item in data]
70
+
71
+ fig, ax = plt.subplots()
72
+ wedges, texts, autotexts = ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=140)
73
+ ax.set_title('Pie Chart')
74
+
75
+ for text in texts + autotexts:
76
+ text.set_fontsize(10)
77
+
78
+ return fig
visualization_prompt.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ barGraphIntstruction = '''
2
+
3
+ Where data is: {
4
+ labels: string[]
5
+ values: {\data: number[], label: string}[]
6
+ }
7
+
8
+ // Examples of usage:
9
+ Each label represents a column on the x axis.
10
+ Each array in values represents a different entity.
11
+
12
+ Here we are looking at average income for each month.
13
+ {
14
+ labels: ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'],
15
+ values: [{data:[21.5, 25.0, 47.5, 64.8, 105.5, 133.2], label: 'Income'}],
16
+ }
17
+
18
+ Here we are looking at the performance of american and european players for each series. Since there are two entities, we have two arrays in values.
19
+ {
20
+ labels: ['series A', 'series B', 'series C'],
21
+ values: [{data:[10, 15, 20], label: 'American'}, {data:[20, 25, 30], label: 'European'}],
22
+ }
23
+ '''
24
+
25
+ horizontalBarGraphIntstruction = '''
26
+
27
+ Where data is: {
28
+ labels: string[]
29
+ values: {\data: number[], label: string}[]
30
+ }
31
+
32
+ // Examples of usage:
33
+ Each label represents a column on the x axis.
34
+ Each array in values represents a different entity.
35
+
36
+ Here we are looking at average income for each month.
37
+ {
38
+ labels: ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'],
39
+ values: [{data:[21.5, 25.0, 47.5, 64.8, 105.5, 133.2], label: 'Income'}],
40
+ }
41
+
42
+ Here we are looking at the performance of american and european players for each series. Since there are two entities, we have two arrays in values.
43
+ {
44
+ labels: ['series A', 'series B', 'series C'],
45
+ values: [{data:[10, 15, 20], label: 'American'}, {data:[20, 25, 30], label: 'European'}],
46
+ }
47
+
48
+ '''
49
+
50
+
51
+ lineGraphIntstruction = '''
52
+
53
+ Where data is: {
54
+ xValues: number[] | string[]
55
+ yValues: { data: number[]; label: string }[]
56
+ }
57
+
58
+ // Examples of usage:
59
+
60
+ Here we are looking at the momentum of a body as a function of mass.
61
+ {
62
+ xValues: ['2020', '2021', '2022', '2023', '2024'],
63
+ yValues: [
64
+ { data: [2, 5.5, 2, 8.5, 1.5]},
65
+ ],
66
+ }
67
+
68
+ Here we are looking at the performance of american and european players for each year. Since there are two entities, we have two arrays in yValues.
69
+ {
70
+ xValues: ['2020', '2021', '2022', '2023', '2024'],
71
+ yValues: [
72
+ { data: [2, 5.5, 2, 8.5, 1.5], label: 'American' },
73
+ { data: [2, 5.5, 2, 8.5, 1.5], label: 'European' },
74
+ ],
75
+ }
76
+ '''
77
+
78
+ pieChartIntstruction = '''
79
+
80
+ Where data is: {
81
+ labels: string
82
+ values: number
83
+ }[]
84
+
85
+ // Example usage:
86
+ [
87
+ { id: 0, value: 10, label: 'series A' },
88
+ { id: 1, value: 15, label: 'series B' },
89
+ { id: 2, value: 20, label: 'series C' },
90
+ ],
91
+ '''
92
+
93
+ scatterPlotIntstruction = '''
94
+ Where data is: {
95
+ series: {
96
+ data: { x: number; y: number; id: number }[]
97
+ label: string
98
+ }[]
99
+ }
100
+
101
+ // Examples of usage:
102
+ 1. Here each data array represents the points for a different entity.
103
+ We are looking for correlation between amount spent and quantity bought for men and women.
104
+ {
105
+ series: [
106
+ {
107
+ data: [
108
+ { x: 100, y: 200, id: 1 },
109
+ { x: 120, y: 100, id: 2 },
110
+ { x: 170, y: 300, id: 3 },
111
+ ],
112
+ label: 'Men',
113
+ },
114
+ {
115
+ data: [
116
+ { x: 300, y: 300, id: 1 },
117
+ { x: 400, y: 500, id: 2 },
118
+ { x: 200, y: 700, id: 3 },
119
+ ],
120
+ label: 'Women',
121
+ }
122
+ ],
123
+ }
124
+
125
+ 2. Here we are looking for correlation between the height and weight of players.
126
+ {
127
+ series: [
128
+ {
129
+ data: [
130
+ { x: 180, y: 80, id: 1 },
131
+ { x: 170, y: 70, id: 2 },
132
+ { x: 160, y: 60, id: 3 },
133
+ ],
134
+ label: 'Players',
135
+ },
136
+ ],
137
+ }
138
+
139
+ // Note: Each object in the 'data' array represents a point on the scatter plot.
140
+ // The 'x' and 'y' values determine the position of the point, and 'id' is a unique identifier.
141
+ // Multiple series can be represented, each as an object in the outer array.
142
+ '''
143
+
144
+
145
+ graph_instructions = {
146
+ "bar": barGraphIntstruction,
147
+ "horizontal_bar": horizontalBarGraphIntstruction,
148
+ "line": lineGraphIntstruction,
149
+ "pie": pieChartIntstruction,
150
+ "scatter": scatterPlotIntstruction
151
+ }