import os
import tempfile
import json
import pandas as pd
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from sqlalchemy import create_engine
from pandasai import SmartDataframe
from pandasai.llm import OpenAI
import sqlite3
from dotenv import load_dotenv
import atexit
import base64
import io
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
app_instance = None
class DataChatApp:
def __init__(self):
self.df = None
self.data_source = None
self.llm = OpenAI(api_token=OPENAI_API_KEY)
self.smart_df = None
self.chat_history = []
self.temp_files = []
self.db_connection = None
global app_instance
app_instance = self
def load_file(self, file):
"""Load data from uploaded file"""
if file is None:
return "No file uploaded", None, None
file_path = file.name
file_name = os.path.basename(file_path)
file_ext = os.path.splitext(file_name)[1].lower()
try:
if file_ext == '.csv':
self.df = pd.read_csv(file_path)
elif file_ext == '.xlsx' or file_ext == '.xls':
self.df = pd.read_excel(file_path)
elif file_ext == '.json':
self.df = pd.read_json(file_path)
else:
return f"Unsupported file format: {file_ext}", None, None
# Initialize the SmartDataframe
self.smart_df = SmartDataframe(self.df, config={"llm": self.llm})
self.data_source = f"File: {file_name}"
preview = self.df.head().to_html()
info = self._get_dataframe_info()
return f"Loaded successfully: {file_name}", preview, info
except Exception as e:
return f"Error loading file: {str(e)}", None, None
return self.df
def connect_database(self, connection_string, query):
"""Connect to database using connection string"""
try:
if connection_string.startswith('sqlite:'):
if 'memory' in connection_string:
self.db_connection = sqlite3.connect(':memory:')
else:
db_path = connection_string.replace('sqlite:///', '')
self.db_connection = sqlite3.connect(db_path)
else:
self.db_connection = create_engine(connection_string)
if not query:
return "Please provide a SQL query", None, None
self.df = pd.read_sql(query, self.db_connection)
self.smart_df = SmartDataframe(self.df, config={"llm": self.llm})
self.data_source = f"Database: {connection_string.split('://')[0]}"
preview = self.df.head().to_html()
info = self._get_dataframe_info()
return "Database connected successfully", preview, info
except Exception as e:
return f"Database connection error: {str(e)}", None, None
return self.df
def _get_dataframe_info(self):
"""Get information about the dataframe"""
if self.df is None:
return None
info = {
"Shape": self.df.shape,
"Columns": list(self.df.columns),
"Data Types": {col: str(dtype) for col, dtype in self.df.dtypes.items()},
"Missing Values": self.df.isnull().sum().to_dict()
}
return json.dumps(info, indent=2)
def chat_with_data(self, query, history):
"""Process natural language query against the loaded data"""
if self.df is None or self.smart_df is None:
return "Please load data first before querying.", history
if not query:
return "Please enter a query.", history
try:
if history is None:
history = []
response = self.smart_df.chat(query)
if isinstance(response, plt.Figure):
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
response.savefig(temp_file.name)
temp_file.close()
self.temp_files.append(temp_file.name)
response_text = f"
"
elif isinstance(response, pd.DataFrame):
response_text = f"
{response.to_html(index=False)}
"
else:
response_text = str(response)
history.append({"role": "user", "content": query})
history.append({"role": "assistant", "content": response_text})
return "", history
except Exception as e:
if not history:
history = []
history.append({"role": "user", "content": query})
history.append({"role": "assistant", "content": f"Error processing query: {str(e)}"})
return "", history
def create_visualization(self, viz_type, x_axis, y_axis, title):
"""Create visualization based on user selection"""
if self.df is None:
return "Please load data first before creating visualizations."
if not x_axis or (viz_type != 'pie' and viz_type != 'histogram' and not y_axis):
return "Please select both X and Y axis for the visualization."
try:
if x_axis not in self.df.columns:
return f"Column '{x_axis}' not found in the data."
if viz_type != 'pie' and viz_type != 'histogram' and y_axis not in self.df.columns:
return f"Column '{y_axis}' not found in the data."
plt.figure(figsize=(10, 6))
if viz_type == 'bar':
plt.bar(self.df[x_axis], self.df[y_axis])
plt.xlabel(x_axis)
plt.ylabel(y_axis)
plt.title(title or f"Bar Chart: {y_axis} by {x_axis}")
elif viz_type == 'line':
plt.plot(self.df[x_axis], self.df[y_axis])
plt.xlabel(x_axis)
plt.ylabel(y_axis)
plt.title(title or f"Line Chart: {y_axis} over {x_axis}")
elif viz_type == 'scatter':
plt.scatter(self.df[x_axis], self.df[y_axis])
plt.xlabel(x_axis)
plt.ylabel(y_axis)
plt.title(title or f"Scatter Plot: {y_axis} vs {x_axis}")
elif viz_type == 'pie':
if y_axis and y_axis in self.df.columns:
pie_data = self.df.groupby(x_axis)[y_axis].sum()
plt.pie(pie_data, labels=pie_data.index, autopct='%1.1f%%')
else:
counts = self.df[x_axis].value_counts()
plt.pie(counts, labels=counts.index, autopct='%1.1f%%')
plt.title(title or f"Pie Chart: Distribution of {x_axis}")
elif viz_type == 'histogram':
plt.hist(self.df[x_axis], bins=20)
plt.xlabel(x_axis)
plt.ylabel('Frequency')
plt.title(title or f"Histogram: Distribution of {x_axis}")
plt.tight_layout()
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
plt.savefig(temp_file.name, dpi=100, bbox_inches='tight')
temp_file.close()
self.temp_files.append(temp_file.name)
with open(temp_file.name, 'rb') as img_file:
img_data = base64.b64encode(img_file.read()).decode('utf-8')
html_content = f"""
"""
plt.close()
return html_content
except Exception as e:
plt.close()
return f"Error creating visualization: {str(e)}"
def generate_summary_cards(self):
"""Generate summary cards (KPIs) for numerical columns"""
if self.df is None:
return "Please load data first before generating summary cards."
try:
num_cols = self.df.select_dtypes(include=[np.number]).columns.tolist()
if not num_cols:
return "No numerical columns found for summary cards."
cards_html = """
"""
for col in num_cols:
mean_val = self.df[col].mean()
median_val = self.df[col].median()
min_val = self.df[col].min()
max_val = self.df[col].max()
card_html = f"""
{col}
Mean: {mean_val:.2f}
Median: {median_val:.2f}
Min: {min_val:.2f}
Max: {max_val:.2f}
"""
cards_html += card_html
cards_html += "
"
return cards_html
except Exception as e:
return f"Error generating summary cards: {str(e)}"
def cleanup(self):
"""Clean up temporary files"""
for file in self.temp_files:
try:
if os.path.exists(file):
os.unlink(file)
except Exception:
pass
if self.db_connection is not None:
try:
if hasattr(self.db_connection, 'close'):
self.db_connection.close()
elif hasattr(self.db_connection, 'dispose'):
self.db_connection.dispose()
except Exception:
pass
def create_interface():
app = DataChatApp()
def update_column_options():
if app_instance and app_instance.df is not None:
return gr.update(choices=list(app_instance.df.columns))
return gr.update(choices=[])
with gr.Blocks(theme=gr.themes.Soft(), title="Data Chat App", css="""
.plot-container {width: 100% !important; height: 100% !important;}
.js-plotly-plot {min-height: 500px;}
.plotly {min-height: 500px;}
""") as interface:
gr.Markdown("""
# GIN Data Chat Application
Upload your data file or connect to a database, then chat with your data using natural language!
""")
with gr.Tabs():
with gr.TabItem("Load Data"):
with gr.Tab("File Upload"):
file_input = gr.File(label="Upload CSV, Excel, or JSON file")
file_upload_button = gr.Button("Load File")
file_result = gr.Textbox(label="Result")
with gr.Tab("Database Connection"):
conn_str = gr.Textbox(
label="Connection String",
placeholder="E.g., sqlite:///data.db, postgresql://user:pass@localhost/db"
)
query = gr.Textbox(
label="SQL Query",
placeholder="SELECT * FROM your_table LIMIT 1000"
)
db_connect_button = gr.Button("Connect to Database")
db_result = gr.Textbox(label="Result")
preview = gr.HTML(label="Data Preview")
info = gr.JSON(label="Data Information")
with gr.TabItem("Chat with Data"):
chat_interface = gr.Chatbot(height=400, type="messages")
query_input = gr.Textbox(
label="Ask a question about your data",
placeholder="E.g., Show me the trend of sales over time",
lines=2
)
chat_button = gr.Button("Ask")
with gr.TabItem("Visualize Data"):
with gr.Row():
with gr.Column(scale=1):
viz_type = gr.Dropdown(
choices=["bar", "line", "scatter", "pie", "histogram"],
label="Visualization Type",
value="bar" # Set a default value
)
x_axis = gr.Dropdown(label="X-Axis / Category")
y_axis = gr.Dropdown(label="Y-Axis / Values (Optional for Pie & Histogram)")
viz_title = gr.Textbox(label="Chart Title (Optional)")
viz_button = gr.Button("Generate Visualization", variant="primary")
with gr.Column(scale=2):
viz_output = gr.HTML(label="Visualization", value="Your visualization will appear here
")
with gr.TabItem("Summary Stats"):
summary_button = gr.Button("Generate Summary Cards")
summary_output = gr.HTML(label="Summary Statistics")
# Set up event handlers
file_upload_button.click(
app.load_file,
inputs=[file_input],
outputs=[file_result, preview, info]
).then(
update_column_options,
inputs=None,
outputs=[x_axis]
).then(
update_column_options,
inputs=None,
outputs=[y_axis]
)
db_connect_button.click(
app.connect_database,
inputs=[conn_str, query],
outputs=[db_result, preview, info]
).then(
update_column_options,
inputs=None,
outputs=[x_axis]
).then(
update_column_options,
inputs=None,
outputs=[y_axis]
)
chat_button.click(
app.chat_with_data,
inputs=[query_input, chat_interface],
outputs=[query_input, chat_interface]
)
query_input.submit(
app.chat_with_data,
inputs=[query_input, chat_interface],
outputs=[query_input, chat_interface]
)
viz_button.click(
app.create_visualization,
inputs=[viz_type, x_axis, y_axis, viz_title],
outputs=[viz_output]
)
summary_button.click(
app.generate_summary_cards,
outputs=[summary_output]
)
# Register cleanup function for when the app closes
# The on_close method is no longer available in newer Gradio versions
# Instead, we'll clean up temp files when the server restarts
app.cleanup() # Clean up any previous temp files
return interface
if __name__ == "__main__":
import atexit
app = DataChatApp()
atexit.register(app.cleanup)
interface = create_interface()
interface.launch(share=True)