Refactor functions and improve llm accuracy (#14)
Browse files- refactor functions and improve llm accuracy (3d660e2779859c6a266e023b0f93f48de373bb3a)
- app.py +99 -2
- functions/__init__.py +2 -2
- functions/chart_functions.py +77 -26
- functions/chat_functions.py +3 -111
- pipelines/__init__.py +0 -3
- pipelines/pipelines.py +0 -91
- requirements.txt +1 -1
- tools.py +12 -7
- utils.py +3 -1
app.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
| 1 |
-
from
|
|
|
|
|
|
|
|
|
|
| 2 |
|
|
|
|
| 3 |
import os
|
| 4 |
from getpass import getpass
|
| 5 |
from dotenv import load_dotenv
|
|
@@ -9,5 +13,98 @@ load_dotenv()
|
|
| 9 |
if "OPENAI_API_KEY" not in os.environ:
|
| 10 |
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
## Uncomment the line below to launch the chat app with UI
|
| 13 |
-
demo.launch(debug=True, allowed_paths=["temp/"])
|
|
|
|
| 1 |
+
from data_sources import process_data_upload
|
| 2 |
+
from functions import example_question_generator, chatbot_with_fc
|
| 3 |
+
from utils import TEMP_DIR, message_dict
|
| 4 |
+
import gradio as gr
|
| 5 |
|
| 6 |
+
import ast
|
| 7 |
import os
|
| 8 |
from getpass import getpass
|
| 9 |
from dotenv import load_dotenv
|
|
|
|
| 13 |
if "OPENAI_API_KEY" not in os.environ:
|
| 14 |
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
|
| 15 |
|
| 16 |
+
def delete_db(req: gr.Request):
|
| 17 |
+
import shutil
|
| 18 |
+
dir_path = TEMP_DIR / str(req.session_hash)
|
| 19 |
+
if os.path.exists(dir_path):
|
| 20 |
+
shutil.rmtree(dir_path)
|
| 21 |
+
message_dict[req.session_hash] = None
|
| 22 |
+
|
| 23 |
+
def run_example(input):
|
| 24 |
+
return input
|
| 25 |
+
|
| 26 |
+
def example_display(input):
|
| 27 |
+
if input == None:
|
| 28 |
+
display = True
|
| 29 |
+
else:
|
| 30 |
+
display = False
|
| 31 |
+
return [gr.update(visible=display),gr.update(visible=display)]
|
| 32 |
+
|
| 33 |
+
css= ".file_marker .large{min-height:50px !important;} .example_btn{max-width:300px;}"
|
| 34 |
+
|
| 35 |
+
with gr.Blocks(css=css, delete_cache=(3600,3600)) as demo:
|
| 36 |
+
title = gr.HTML("<h1 style='text-align:center;'>Virtual Data Analyst</h1>")
|
| 37 |
+
description = gr.HTML("""<p style='text-align:center;'>Upload a data file and chat with our virtual data analyst
|
| 38 |
+
to get insights on your data set. Currently accepts CSV, TSV, TXT, XLS, XLSX, XML, and JSON files.
|
| 39 |
+
Can now generate charts and graphs!
|
| 40 |
+
Try a sample file to get started!</p>
|
| 41 |
+
<p style='text-align:center;'>This tool is under active development. If you experience bugs with use,
|
| 42 |
+
open a discussion in the community tab and I will respond.</p>""")
|
| 43 |
+
example_file_1 = gr.File(visible=False, value="samples/bank_marketing_campaign.csv")
|
| 44 |
+
example_file_2 = gr.File(visible=False, value="samples/online_retail_data.csv")
|
| 45 |
+
with gr.Row():
|
| 46 |
+
example_btn_1 = gr.Button(value="Try Me: bank_marketing_campaign.csv", elem_classes="example_btn", size="md", variant="primary")
|
| 47 |
+
example_btn_2 = gr.Button(value="Try Me: online_retail_data.csv", elem_classes="example_btn", size="md", variant="primary")
|
| 48 |
+
|
| 49 |
+
file_output = gr.File(label="Data File (CSV, TSV, TXT, XLS, XLSX, XML, JSON)", show_label=True, elem_classes="file_marker", file_types=['.csv','.xlsx','.txt','.json','.ndjson','.xml','.xls','.tsv'])
|
| 50 |
+
example_btn_1.click(fn=run_example, inputs=example_file_1, outputs=file_output)
|
| 51 |
+
example_btn_2.click(fn=run_example, inputs=example_file_2, outputs=file_output)
|
| 52 |
+
file_output.change(fn=example_display, inputs=file_output, outputs=[example_btn_1, example_btn_2])
|
| 53 |
+
|
| 54 |
+
@gr.render(inputs=file_output)
|
| 55 |
+
def data_options(filename, request: gr.Request):
|
| 56 |
+
print(filename)
|
| 57 |
+
message_dict[request.session_hash] = None
|
| 58 |
+
if filename:
|
| 59 |
+
process_upload(filename, request.session_hash)
|
| 60 |
+
if "bank_marketing_campaign" in filename:
|
| 61 |
+
example_questions = [
|
| 62 |
+
["Describe the dataset"],
|
| 63 |
+
["What levels of education have the highest and lowest average balance?"],
|
| 64 |
+
["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
|
| 65 |
+
["Can you generate a bar chart of education vs. average balance?"],
|
| 66 |
+
["Can you generate a table of levels of education versus average balance, percent married, percent with a loan, and percent in default?"]
|
| 67 |
+
]
|
| 68 |
+
elif "online_retail_data" in filename:
|
| 69 |
+
example_questions = [
|
| 70 |
+
["Describe the dataset"],
|
| 71 |
+
["What month had the highest revenue?"],
|
| 72 |
+
["Is revenue higher in the morning or afternoon?"],
|
| 73 |
+
["Can you generate a line graph of revenue per month?"],
|
| 74 |
+
["Can you generate a table of revenue per month?"]
|
| 75 |
+
]
|
| 76 |
+
else:
|
| 77 |
+
try:
|
| 78 |
+
generated_examples = ast.literal_eval(example_question_generator(request.session_hash))
|
| 79 |
+
example_questions = [
|
| 80 |
+
["Describe the dataset"]
|
| 81 |
+
]
|
| 82 |
+
for example in generated_examples:
|
| 83 |
+
example_questions.append([example])
|
| 84 |
+
except:
|
| 85 |
+
example_questions = [
|
| 86 |
+
["Describe the dataset"],
|
| 87 |
+
["List the columns in the dataset"],
|
| 88 |
+
["What could this data be used for?"],
|
| 89 |
+
]
|
| 90 |
+
parameters = gr.Textbox(visible=False, value=request.session_hash)
|
| 91 |
+
bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
|
| 92 |
+
chat = gr.ChatInterface(
|
| 93 |
+
fn=chatbot_with_fc,
|
| 94 |
+
type='messages',
|
| 95 |
+
chatbot=bot,
|
| 96 |
+
title="Chat with your data file",
|
| 97 |
+
concurrency_limit=None,
|
| 98 |
+
examples=example_questions,
|
| 99 |
+
additional_inputs=parameters
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def process_upload(upload_value, session_hash):
|
| 103 |
+
if upload_value:
|
| 104 |
+
process_data_upload(upload_value, session_hash)
|
| 105 |
+
return [], []
|
| 106 |
+
|
| 107 |
+
demo.unload(delete_db)
|
| 108 |
+
|
| 109 |
## Uncomment the line below to launch the chat app with UI
|
| 110 |
+
demo.launch(debug=True, allowed_paths=["temp/"])
|
functions/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from .sqlite_functions import SQLiteQuery, sqlite_query_func
|
| 2 |
from .chart_functions import chart_generation_func, table_generation_func
|
| 3 |
-
from .chat_functions import
|
| 4 |
|
| 5 |
-
__all__ = ["SQLiteQuery","sqlite_query_func","chart_generation_func","table_generation_func","
|
|
|
|
| 1 |
from .sqlite_functions import SQLiteQuery, sqlite_query_func
|
| 2 |
from .chart_functions import chart_generation_func, table_generation_func
|
| 3 |
+
from .chat_functions import example_question_generator, chatbot_with_fc
|
| 4 |
|
| 5 |
+
__all__ = ["SQLiteQuery","sqlite_query_func","chart_generation_func","table_generation_func","example_question_generator","chatbot_with_fc"]
|
functions/chart_functions.py
CHANGED
|
@@ -1,45 +1,96 @@
|
|
| 1 |
from typing import List
|
| 2 |
-
from
|
|
|
|
| 3 |
import pandas as pd
|
| 4 |
from utils import TEMP_DIR
|
| 5 |
import os
|
|
|
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
|
| 8 |
load_dotenv()
|
| 9 |
|
| 10 |
root_url = os.getenv("ROOT_URL")
|
| 11 |
|
| 12 |
-
def chart_generation_func(
|
| 13 |
print("CHART GENERATION")
|
| 14 |
-
|
| 15 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
interactive_url = url_base + '/chart-maker/view/' + url_id
|
| 28 |
-
edit_url = url_base + '/chart-maker/edit/' + url_id
|
| 29 |
|
| 30 |
-
|
| 31 |
|
| 32 |
-
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
print("TABLE GENERATION")
|
| 37 |
print(data)
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import List
|
| 2 |
+
from typing import Dict
|
| 3 |
+
import plotly.io as pio
|
| 4 |
import pandas as pd
|
| 5 |
from utils import TEMP_DIR
|
| 6 |
import os
|
| 7 |
+
import ast
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
|
| 10 |
load_dotenv()
|
| 11 |
|
| 12 |
root_url = os.getenv("ROOT_URL")
|
| 13 |
|
| 14 |
+
def chart_generation_func(data: List[dict], session_hash: str, layout: Dict[str,str]={}):
|
| 15 |
print("CHART GENERATION")
|
| 16 |
+
print(data)
|
| 17 |
+
print(layout)
|
| 18 |
+
try:
|
| 19 |
+
dir_path = TEMP_DIR / str(session_hash)
|
| 20 |
+
chart_path = f'{dir_path}/chart.html'
|
| 21 |
+
|
| 22 |
+
#Processing data to account for variation from LLM
|
| 23 |
+
data_list = []
|
| 24 |
+
layout_dict = {}
|
| 25 |
+
if isinstance(data, list):
|
| 26 |
+
data_list = data
|
| 27 |
+
else:
|
| 28 |
+
data_list.append(data)
|
| 29 |
+
|
| 30 |
+
if isinstance(data[0], str):
|
| 31 |
+
data_list[0] = ast.literal_eval(data_list[0])
|
| 32 |
|
| 33 |
+
if isinstance(layout, list):
|
| 34 |
+
layout_obj = layout[0]
|
| 35 |
+
else:
|
| 36 |
+
layout_obj = layout
|
| 37 |
|
| 38 |
+
if isinstance(layout_obj, str):
|
| 39 |
+
layout_dict = ast.literal_eval(layout_obj)
|
| 40 |
+
else:
|
| 41 |
+
layout_dict = layout_obj
|
| 42 |
+
|
| 43 |
|
| 44 |
+
fig = dict({"data": data_list,
|
| 45 |
+
"layout": layout_dict})
|
| 46 |
+
pio.write_html(fig, chart_path, full_html=False)
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
|
| 49 |
|
| 50 |
+
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
|
| 51 |
|
| 52 |
+
return {"reply": iframe}
|
| 53 |
+
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print("CHART ERROR")
|
| 56 |
+
reply = f"""There was an error generating the Plotly Chart from {data} and {layout}
|
| 57 |
+
The error is {e},
|
| 58 |
+
You should probably try again.
|
| 59 |
+
"""
|
| 60 |
+
return {"reply": reply}
|
| 61 |
+
|
| 62 |
+
def table_generation_func(data: List[dict], session_hash):
|
| 63 |
print("TABLE GENERATION")
|
| 64 |
print(data)
|
| 65 |
+
try:
|
| 66 |
+
dir_path = TEMP_DIR / str(session_hash)
|
| 67 |
+
csv_path = f'{dir_path}/data.csv'
|
| 68 |
+
|
| 69 |
+
#Processing data to account for variation from LLM
|
| 70 |
+
if isinstance(data, list):
|
| 71 |
+
data_obj = data[0]
|
| 72 |
+
else:
|
| 73 |
+
data_obj = data
|
| 74 |
+
|
| 75 |
+
if isinstance(data_obj, str):
|
| 76 |
+
data_dict = ast.literal_eval(data_obj)
|
| 77 |
+
else:
|
| 78 |
+
data_dict = data_obj
|
| 79 |
+
|
| 80 |
+
df = pd.DataFrame.from_dict(data_dict)
|
| 81 |
+
print(df)
|
| 82 |
+
df.to_csv(csv_path)
|
| 83 |
+
|
| 84 |
+
download_path = f'{root_url}/gradio_api/file/temp/{session_hash}/data.csv'
|
| 85 |
+
html_table = df.to_html() + f'<p>Download as a <a href="{download_path}">CSV file</a></p>'
|
| 86 |
+
print(html_table)
|
| 87 |
+
|
| 88 |
+
return {"reply": html_table}
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print("TABLE ERROR")
|
| 92 |
+
reply = f"""There was an error generating the Pandas DataFrame table from {data}
|
| 93 |
+
The error is {e},
|
| 94 |
+
You should probably try again.
|
| 95 |
+
"""
|
| 96 |
+
return {"reply": reply}
|
functions/chat_functions.py
CHANGED
|
@@ -1,24 +1,10 @@
|
|
| 1 |
-
from
|
| 2 |
-
from utils import TEMP_DIR
|
| 3 |
-
|
| 4 |
-
import gradio as gr
|
| 5 |
|
| 6 |
from haystack.dataclasses import ChatMessage
|
| 7 |
from haystack.components.generators.chat import OpenAIChatGenerator
|
| 8 |
|
| 9 |
-
import os
|
| 10 |
-
import ast
|
| 11 |
-
from getpass import getpass
|
| 12 |
-
from dotenv import load_dotenv
|
| 13 |
-
|
| 14 |
-
load_dotenv()
|
| 15 |
-
|
| 16 |
-
if "OPENAI_API_KEY" not in os.environ:
|
| 17 |
-
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
|
| 18 |
-
|
| 19 |
chat_generator = OpenAIChatGenerator(model="gpt-4o")
|
| 20 |
response = None
|
| 21 |
-
message_dict = {}
|
| 22 |
|
| 23 |
def example_question_generator(session_hash):
|
| 24 |
import sqlite3
|
|
@@ -51,10 +37,9 @@ def example_question_generator(session_hash):
|
|
| 51 |
|
| 52 |
def chatbot_with_fc(message, history, session_hash):
|
| 53 |
from functions import sqlite_query_func, chart_generation_func, table_generation_func
|
| 54 |
-
from pipelines import rag_pipeline_func
|
| 55 |
import tools
|
| 56 |
|
| 57 |
-
available_functions = {"sql_query_func": sqlite_query_func, "
|
| 58 |
|
| 59 |
if message_dict[session_hash] != None:
|
| 60 |
message_dict[session_hash].append(ChatMessage.from_user(message))
|
|
@@ -62,7 +47,7 @@ def chatbot_with_fc(message, history, session_hash):
|
|
| 62 |
messages = [
|
| 63 |
ChatMessage.from_system(
|
| 64 |
"""You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'.
|
| 65 |
-
You also have access to a chart
|
| 66 |
You also have access to a function, called table_generation_func, that builds table formatted html and generates a link to download as CSV."""
|
| 67 |
)
|
| 68 |
]
|
|
@@ -95,97 +80,4 @@ def chatbot_with_fc(message, history, session_hash):
|
|
| 95 |
break
|
| 96 |
return response["replies"][0].text
|
| 97 |
|
| 98 |
-
def delete_db(req: gr.Request):
|
| 99 |
-
import shutil
|
| 100 |
-
dir_path = TEMP_DIR / str(req.session_hash)
|
| 101 |
-
if os.path.exists(dir_path):
|
| 102 |
-
shutil.rmtree(dir_path)
|
| 103 |
-
message_dict[req.session_hash] = None
|
| 104 |
-
|
| 105 |
-
def run_example(input):
|
| 106 |
-
return input
|
| 107 |
-
|
| 108 |
-
def example_display(input):
|
| 109 |
-
if input == None:
|
| 110 |
-
display = True
|
| 111 |
-
else:
|
| 112 |
-
display = False
|
| 113 |
-
return [gr.update(visible=display),gr.update(visible=display)]
|
| 114 |
-
|
| 115 |
-
css= ".file_marker .large{min-height:50px !important;} .example_btn{max-width:300px;}"
|
| 116 |
-
|
| 117 |
-
with gr.Blocks(css=css, delete_cache=(3600,3600)) as demo:
|
| 118 |
-
title = gr.HTML("<h1 style='text-align:center;'>Virtual Data Analyst</h1>")
|
| 119 |
-
description = gr.HTML("""<p style='text-align:center;'>Upload a data file and chat with our virtual data analyst
|
| 120 |
-
to get insights on your data set. Currently accepts CSV, TSV, TXT, XLS, XLSX, XML, and JSON files.
|
| 121 |
-
Can now generate charts and graphs!
|
| 122 |
-
Try a sample file to get started!</p>
|
| 123 |
-
<p style='text-align:center;'>This tool is under active development. If you experience bugs with use,
|
| 124 |
-
open a discussion in the community tab and I will respond.</p>""")
|
| 125 |
-
example_file_1 = gr.File(visible=False, value="samples/bank_marketing_campaign.csv")
|
| 126 |
-
example_file_2 = gr.File(visible=False, value="samples/online_retail_data.csv")
|
| 127 |
-
with gr.Row():
|
| 128 |
-
example_btn_1 = gr.Button(value="Try Me: bank_marketing_campaign.csv", elem_classes="example_btn", size="md", variant="primary")
|
| 129 |
-
example_btn_2 = gr.Button(value="Try Me: online_retail_data.csv", elem_classes="example_btn", size="md", variant="primary")
|
| 130 |
-
|
| 131 |
-
file_output = gr.File(label="Data File (CSV, TSV, TXT, XLS, XLSX, XML, JSON)", show_label=True, elem_classes="file_marker", file_types=['.csv','.xlsx','.txt','.json','.ndjson','.xml','.xls','.tsv'])
|
| 132 |
-
example_btn_1.click(fn=run_example, inputs=example_file_1, outputs=file_output)
|
| 133 |
-
example_btn_2.click(fn=run_example, inputs=example_file_2, outputs=file_output)
|
| 134 |
-
file_output.change(fn=example_display, inputs=file_output, outputs=[example_btn_1, example_btn_2])
|
| 135 |
-
|
| 136 |
-
@gr.render(inputs=file_output)
|
| 137 |
-
def data_options(filename, request: gr.Request):
|
| 138 |
-
print(filename)
|
| 139 |
-
message_dict[request.session_hash] = None
|
| 140 |
-
if filename:
|
| 141 |
-
process_upload(filename, request.session_hash)
|
| 142 |
-
if "bank_marketing_campaign" in filename:
|
| 143 |
-
example_questions = [
|
| 144 |
-
["Describe the dataset"],
|
| 145 |
-
["What levels of education have the highest and lowest average balance?"],
|
| 146 |
-
["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
|
| 147 |
-
["Can you generate a bar chart of education vs. average balance?"],
|
| 148 |
-
["Can you generate a table of levels of education versus average balance, percent married, percent with a loan, and percent in default?"]
|
| 149 |
-
]
|
| 150 |
-
elif "online_retail_data" in filename:
|
| 151 |
-
example_questions = [
|
| 152 |
-
["Describe the dataset"],
|
| 153 |
-
["What month had the highest revenue?"],
|
| 154 |
-
["Is revenue higher in the morning or afternoon?"],
|
| 155 |
-
["Can you generate a line graph of revenue per month?"],
|
| 156 |
-
["Can you generate a table of revenue per month?"]
|
| 157 |
-
]
|
| 158 |
-
else:
|
| 159 |
-
try:
|
| 160 |
-
generated_examples = ast.literal_eval(example_question_generator(request.session_hash))
|
| 161 |
-
example_questions = [
|
| 162 |
-
["Describe the dataset"]
|
| 163 |
-
]
|
| 164 |
-
for example in generated_examples:
|
| 165 |
-
example_questions.append([example])
|
| 166 |
-
except:
|
| 167 |
-
example_questions = [
|
| 168 |
-
["Describe the dataset"],
|
| 169 |
-
["List the columns in the dataset"],
|
| 170 |
-
["What could this data be used for?"],
|
| 171 |
-
]
|
| 172 |
-
parameters = gr.Textbox(visible=False, value=request.session_hash)
|
| 173 |
-
bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
|
| 174 |
-
chat = gr.ChatInterface(
|
| 175 |
-
fn=chatbot_with_fc,
|
| 176 |
-
type='messages',
|
| 177 |
-
chatbot=bot,
|
| 178 |
-
title="Chat with your data file",
|
| 179 |
-
concurrency_limit=None,
|
| 180 |
-
examples=example_questions,
|
| 181 |
-
additional_inputs=parameters
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
def process_upload(upload_value, session_hash):
|
| 185 |
-
if upload_value:
|
| 186 |
-
process_data_upload(upload_value, session_hash)
|
| 187 |
-
return [], []
|
| 188 |
-
|
| 189 |
-
demo.unload(delete_db)
|
| 190 |
-
|
| 191 |
|
|
|
|
| 1 |
+
from utils import TEMP_DIR, message_dict
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from haystack.dataclasses import ChatMessage
|
| 4 |
from haystack.components.generators.chat import OpenAIChatGenerator
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
chat_generator = OpenAIChatGenerator(model="gpt-4o")
|
| 7 |
response = None
|
|
|
|
| 8 |
|
| 9 |
def example_question_generator(session_hash):
|
| 10 |
import sqlite3
|
|
|
|
| 37 |
|
| 38 |
def chatbot_with_fc(message, history, session_hash):
|
| 39 |
from functions import sqlite_query_func, chart_generation_func, table_generation_func
|
|
|
|
| 40 |
import tools
|
| 41 |
|
| 42 |
+
available_functions = {"sql_query_func": sqlite_query_func, "chart_generation_func": chart_generation_func, "table_generation_func":table_generation_func }
|
| 43 |
|
| 44 |
if message_dict[session_hash] != None:
|
| 45 |
message_dict[session_hash].append(ChatMessage.from_user(message))
|
|
|
|
| 47 |
messages = [
|
| 48 |
ChatMessage.from_system(
|
| 49 |
"""You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'.
|
| 50 |
+
You also have access to a chart function that uses plotly dictionaries to generate charts and graphs.
|
| 51 |
You also have access to a function, called table_generation_func, that builds table formatted html and generates a link to download as CSV."""
|
| 52 |
)
|
| 53 |
]
|
|
|
|
| 80 |
break
|
| 81 |
return response["replies"][0].text
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
pipelines/__init__.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
from .pipelines import rag_pipeline_func
|
| 2 |
-
|
| 3 |
-
__all__ = ["rag_pipeline_func"]
|
|
|
|
|
|
|
|
|
|
|
|
pipelines/pipelines.py
DELETED
|
@@ -1,91 +0,0 @@
|
|
| 1 |
-
from haystack import Pipeline
|
| 2 |
-
from haystack.components.builders import PromptBuilder
|
| 3 |
-
from haystack.components.generators.openai import OpenAIGenerator
|
| 4 |
-
from haystack.components.routers import ConditionalRouter
|
| 5 |
-
|
| 6 |
-
from functions import SQLiteQuery
|
| 7 |
-
|
| 8 |
-
from typing import List
|
| 9 |
-
import sqlite3
|
| 10 |
-
|
| 11 |
-
import os
|
| 12 |
-
from getpass import getpass
|
| 13 |
-
from dotenv import load_dotenv
|
| 14 |
-
|
| 15 |
-
load_dotenv()
|
| 16 |
-
|
| 17 |
-
if "OPENAI_API_KEY" not in os.environ:
|
| 18 |
-
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
|
| 19 |
-
|
| 20 |
-
from haystack.components.builders import PromptBuilder
|
| 21 |
-
from haystack.components.generators import OpenAIGenerator
|
| 22 |
-
|
| 23 |
-
llm = OpenAIGenerator(model="gpt-4o")
|
| 24 |
-
def rag_pipeline_func(queries: str, session_hash):
|
| 25 |
-
sql_query = SQLiteQuery(f'data_source_{session_hash}.db')
|
| 26 |
-
|
| 27 |
-
connection = sqlite3.connect(f'data_source_{session_hash}.db')
|
| 28 |
-
cur=connection.execute('select * from data_source')
|
| 29 |
-
columns = [i[0] for i in cur.description]
|
| 30 |
-
cur.close()
|
| 31 |
-
|
| 32 |
-
#Rag Pipeline
|
| 33 |
-
prompt = PromptBuilder(template="""Please generate an SQL query. The query should answer the following Question: {{question}};
|
| 34 |
-
If the question cannot be answered given the provided table and columns, return 'no_answer'
|
| 35 |
-
The query is to be answered for the table is called 'data_source' with the following
|
| 36 |
-
Columns: {{columns}};
|
| 37 |
-
Answer:""")
|
| 38 |
-
|
| 39 |
-
routes = [
|
| 40 |
-
{
|
| 41 |
-
"condition": "{{'no_answer' not in replies[0]}}",
|
| 42 |
-
"output": "{{replies}}",
|
| 43 |
-
"output_name": "sql",
|
| 44 |
-
"output_type": List[str],
|
| 45 |
-
},
|
| 46 |
-
{
|
| 47 |
-
"condition": "{{'no_answer' in replies[0]}}",
|
| 48 |
-
"output": "{{question}}",
|
| 49 |
-
"output_name": "go_to_fallback",
|
| 50 |
-
"output_type": str,
|
| 51 |
-
},
|
| 52 |
-
]
|
| 53 |
-
|
| 54 |
-
router = ConditionalRouter(routes)
|
| 55 |
-
|
| 56 |
-
fallback_prompt = PromptBuilder(template="""User entered a query that cannot be answered with the given table.
|
| 57 |
-
The query was: {{question}} and the table had columns: {{columns}}.
|
| 58 |
-
Let the user know why the question cannot be answered""")
|
| 59 |
-
fallback_llm = OpenAIGenerator(model="gpt-4")
|
| 60 |
-
|
| 61 |
-
conditional_sql_pipeline = Pipeline()
|
| 62 |
-
conditional_sql_pipeline.add_component("prompt", prompt)
|
| 63 |
-
conditional_sql_pipeline.add_component("llm", llm)
|
| 64 |
-
conditional_sql_pipeline.add_component("router", router)
|
| 65 |
-
conditional_sql_pipeline.add_component("fallback_prompt", fallback_prompt)
|
| 66 |
-
conditional_sql_pipeline.add_component("fallback_llm", fallback_llm)
|
| 67 |
-
conditional_sql_pipeline.add_component("sql_querier", sql_query)
|
| 68 |
-
|
| 69 |
-
conditional_sql_pipeline.connect("prompt", "llm")
|
| 70 |
-
conditional_sql_pipeline.connect("llm.replies", "router.replies")
|
| 71 |
-
conditional_sql_pipeline.connect("router.sql", "sql_querier.queries")
|
| 72 |
-
conditional_sql_pipeline.connect("router.go_to_fallback", "fallback_prompt.question")
|
| 73 |
-
conditional_sql_pipeline.connect("fallback_prompt", "fallback_llm")
|
| 74 |
-
|
| 75 |
-
print("RAG PIPELINE FUNCTION")
|
| 76 |
-
result = conditional_sql_pipeline.run({"prompt": {"question": queries,
|
| 77 |
-
"columns": columns},
|
| 78 |
-
"router": {"question": queries},
|
| 79 |
-
"fallback_prompt": {"columns": columns}})
|
| 80 |
-
|
| 81 |
-
if 'sql_querier' in result:
|
| 82 |
-
reply = result['sql_querier']['results'][0]
|
| 83 |
-
elif 'fallback_llm' in result:
|
| 84 |
-
reply = result['fallback_llm']['replies'][0]
|
| 85 |
-
else:
|
| 86 |
-
reply = result["llm"]["replies"][0]
|
| 87 |
-
|
| 88 |
-
print("reply content")
|
| 89 |
-
print(reply.content)
|
| 90 |
-
|
| 91 |
-
return {"reply": reply.content}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -2,5 +2,5 @@ haystack-ai
|
|
| 2 |
python-dotenv
|
| 3 |
gradio
|
| 4 |
pandas
|
| 5 |
-
|
| 6 |
openpyxl
|
|
|
|
| 2 |
python-dotenv
|
| 3 |
gradio
|
| 4 |
pandas
|
| 5 |
+
plotly
|
| 6 |
openpyxl
|
tools.py
CHANGED
|
@@ -29,7 +29,7 @@ def tools_call(session_hash):
|
|
| 29 |
}
|
| 30 |
}
|
| 31 |
},
|
| 32 |
-
"required": ["
|
| 33 |
},
|
| 34 |
},
|
| 35 |
},
|
|
@@ -44,17 +44,22 @@ def tools_call(session_hash):
|
|
| 44 |
"parameters": {
|
| 45 |
"type": "object",
|
| 46 |
"properties": {
|
| 47 |
-
"
|
| 48 |
"type": "array",
|
| 49 |
-
"description": """The
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
"items": {
|
| 53 |
"type": "string",
|
| 54 |
}
|
| 55 |
}
|
| 56 |
},
|
| 57 |
-
"required": ["
|
| 58 |
},
|
| 59 |
},
|
| 60 |
},
|
|
@@ -82,7 +87,7 @@ def tools_call(session_hash):
|
|
| 82 |
}
|
| 83 |
}
|
| 84 |
},
|
| 85 |
-
"required": ["
|
| 86 |
},
|
| 87 |
},
|
| 88 |
}
|
|
|
|
| 29 |
}
|
| 30 |
}
|
| 31 |
},
|
| 32 |
+
"required": ["queries"],
|
| 33 |
},
|
| 34 |
},
|
| 35 |
},
|
|
|
|
| 44 |
"parameters": {
|
| 45 |
"type": "object",
|
| 46 |
"properties": {
|
| 47 |
+
"data": {
|
| 48 |
"type": "array",
|
| 49 |
+
"description": """The list containing a dictionary that contains the 'data' portion of the plotly chart generation. Infer this from the user's message.""",
|
| 50 |
+
"items": {
|
| 51 |
+
"type": "string",
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
"layout": {
|
| 55 |
+
"type": "array",
|
| 56 |
+
"description": """The dictionary that contains the 'layout' portion of the plotly chart generation""",
|
| 57 |
"items": {
|
| 58 |
"type": "string",
|
| 59 |
}
|
| 60 |
}
|
| 61 |
},
|
| 62 |
+
"required": ["data"],
|
| 63 |
},
|
| 64 |
},
|
| 65 |
},
|
|
|
|
| 87 |
}
|
| 88 |
}
|
| 89 |
},
|
| 90 |
+
"required": ["data"],
|
| 91 |
},
|
| 92 |
},
|
| 93 |
}
|
utils.py
CHANGED
|
@@ -2,4 +2,6 @@ from pathlib import Path
|
|
| 2 |
|
| 3 |
current_dir = Path(__file__).parent
|
| 4 |
|
| 5 |
-
TEMP_DIR = current_dir / 'temp'
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
current_dir = Path(__file__).parent
|
| 4 |
|
| 5 |
+
TEMP_DIR = current_dir / 'temp'
|
| 6 |
+
|
| 7 |
+
message_dict = {}
|