bluenevus's picture
Update app.py
ceb21fe verified
raw
history blame
7.76 kB
import base64
import io
import ast
import traceback
from threading import Thread
import dash
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
import pandas as pd
import plotly.graph_objs as go
import google.generativeai as genai
# Initialize Dash app
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
# Layout
app.layout = dbc.Container([
html.H1("Data Analysis Dashboard", className="my-4"),
dbc.Card([
dbc.CardBody([
dbc.Input(id="api-key", placeholder="Enter your Gemini API key", type="password", className="mb-3"),
dcc.Upload(
id='upload-data',
children=html.Div([
'Drag and Drop or ',
html.A('Select Files')
]),
style={
'width': '100%',
'height': '60px',
'lineHeight': '60px',
'borderWidth': '1px',
'borderStyle': 'dashed',
'borderRadius': '5px',
'textAlign': 'center',
'margin': '10px'
},
multiple=False
),
dbc.Input(id="instructions", placeholder="Describe the analysis you want...", type="text"),
dbc.Button("Generate Insights", id="submit-button", color="primary", className="mt-3"),
])
], className="mb-4"),
html.Div(id="error-message", className="text-danger mb-3"),
dcc.Loading(
id="loading-visualizations",
type="default",
children=[
dbc.Card([
dbc.CardBody([
dcc.Graph(id='visualization-1'),
dcc.Graph(id='visualization-2'),
dcc.Graph(id='visualization-3'),
])
])
]
)
], fluid=True)
def parse_contents(contents, filename):
content_type, content_string = contents.split(',')
decoded = base64.b64decode(content_string)
try:
if 'csv' in filename:
df = pd.read_csv(io.StringIO(decoded.decode('utf-8')))
elif 'xls' in filename:
df = pd.read_excel(io.BytesIO(decoded))
else:
return None
return df
except Exception as e:
print(e)
return None
def process_data(df, instructions, api_key):
try:
# Initialize Gemini with provided API key
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
# Generate visualization code
response = model.generate_content(f"""
Analyze the following dataset and instructions:
Data columns: {list(df.columns)}
Data shape: {df.shape}
Instructions: {instructions}
Based on this, create 3 appropriate visualizations that provide meaningful insights. For each visualization:
1. Choose the most suitable plot type (bar, line, scatter, hist, pie, heatmap)
2. Determine appropriate data aggregation (e.g., top 5 categories, yearly averages)
3. Select relevant columns for x-axis, y-axis, and any additional dimensions (color, size)
4. Provide a clear, concise title that explains the insight
Consider data density and choose visualizations that simplify and clarify the information.
Limit the number of data points displayed to ensure readability (e.g., top 5, top 10, yearly).
Return your response as a Python list of dictionaries:
[
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}},
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}},
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}}
]
""")
# Extract code block safely
code_block = response.text
if '```python' in code_block:
code_block = code_block.split('```python')[1].split('```')[0].strip()
elif '```' in code_block:
code_block = code_block.split('```')[1].strip()
plots = ast.literal_eval(code_block)
return plots
except Exception as e:
print(f"Error in process_data: {str(e)}")
return None
def generate_plot(df, plot_info):
plot_df = df.copy()
if plot_info['agg_func'] == 'sum':
plot_df = plot_df.groupby(plot_info['x'])[plot_info['y']].sum().reset_index()
elif plot_info['agg_func'] == 'mean':
plot_df = plot_df.groupby(plot_info['x'])[plot_info['y']].mean().reset_index()
elif plot_info['agg_func'] == 'count':
plot_df = plot_df.groupby(plot_info['x']).size().reset_index(name=plot_info['y'])
if 'top_n' in plot_info and plot_info['top_n']:
plot_df = plot_df.nlargest(plot_info['top_n'], plot_info['y'])
if plot_info['plot_type'] == 'bar':
fig = go.Figure(go.Bar(x=plot_df[plot_info['x']], y=plot_df[plot_info['y']]))
elif plot_info['plot_type'] == 'line':
fig = go.Figure(go.Scatter(x=plot_df[plot_info['x']], y=plot_df[plot_info['y']], mode='lines'))
elif plot_info['plot_type'] == 'scatter':
fig = go.Figure(go.Scatter(x=plot_df[plot_info['x']], y=plot_df[plot_info['y']], mode='markers'))
elif plot_info['plot_type'] == 'hist':
fig = go.Figure(go.Histogram(x=plot_df[plot_info['x']]))
elif plot_info['plot_type'] == 'pie':
fig = go.Figure(go.Pie(labels=plot_df[plot_info['x']], values=plot_df[plot_info['y']]))
elif plot_info['plot_type'] == 'heatmap':
pivot_df = plot_df.pivot(index=plot_info['x'], columns=plot_info['additional']['color'], values=plot_info['y'])
fig = go.Figure(go.Heatmap(z=pivot_df.values, x=pivot_df.columns, y=pivot_df.index))
fig.update_layout(title=plot_info['title'], xaxis_title=plot_info['x'], yaxis_title=plot_info['y'])
return fig
@app.callback(
[Output('visualization-1', 'figure'),
Output('visualization-2', 'figure'),
Output('visualization-3', 'figure'),
Output('error-message', 'children')],
[Input('submit-button', 'n_clicks')],
[State('upload-data', 'contents'),
State('upload-data', 'filename'),
State('instructions', 'value'),
State('api-key', 'value')]
)
def update_output(n_clicks, contents, filename, instructions, api_key):
if n_clicks is None or contents is None:
return dash.no_update, dash.no_update, dash.no_update, ""
if not api_key:
return dash.no_update, dash.no_update, dash.no_update, "Please enter a valid API key."
try:
df = parse_contents(contents, filename)
if df is None:
return dash.no_update, dash.no_update, dash.no_update, "Unable to parse the uploaded file."
plots = process_data(df, instructions, api_key)
if plots is None or len(plots) < 3:
return dash.no_update, dash.no_update, dash.no_update, "Unable to generate visualizations. Please check your instructions and try again."
figures = [generate_plot(df, plot_info) for plot_info in plots[:3]]
return figures[0], figures[1], figures[2], ""
except Exception as e:
error_message = f"An error occurred: {str(e)}"
return dash.no_update, dash.no_update, dash.no_update, error_message
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=7860, threaded=True)