|
import streamlit as st |
|
import pandas as pd |
|
import plotly.express as px |
|
import json |
|
|
|
|
|
|
|
def ask_gpt4o_for_viz(query, df, llm): |
|
columns = ', '.join(df.columns) |
|
prompt = f""" |
|
Analyze the user query and suggest the best way to visualize the data. |
|
Query: "{query}" |
|
Available Columns: {columns} |
|
Respond in this JSON format: |
|
{{ |
|
"chart_type": "bar/box/line/scatter", |
|
"x_axis": "column_name", |
|
"y_axis": "column_name", |
|
"group_by": "optional_column_name" |
|
}} |
|
""" |
|
|
|
response = llm.generate(prompt) |
|
try: |
|
suggestion = json.loads(response) |
|
return suggestion |
|
except json.JSONDecodeError: |
|
st.error("โ ๏ธ Failed to interpret AI response. Please refine your query.") |
|
return None |
|
|
|
|
|
def generate_viz(suggestion, df): |
|
chart_type = suggestion.get("chart_type", "bar") |
|
x_axis = suggestion.get("x_axis") |
|
y_axis = suggestion.get("y_axis", "salary_in_usd") |
|
group_by = suggestion.get("group_by") |
|
|
|
if not x_axis or not y_axis: |
|
st.warning("โ ๏ธ Could not identify required columns.") |
|
return None |
|
|
|
|
|
if chart_type == "bar": |
|
fig = px.bar(df, x=x_axis, y=y_axis, color=group_by) |
|
elif chart_type == "box": |
|
fig = px.box(df, x=x_axis, y=y_axis, color=group_by) |
|
elif chart_type == "line": |
|
fig = px.line(df, x=x_axis, y=y_axis, color=group_by) |
|
elif chart_type == "scatter": |
|
fig = px.scatter(df, x=x_axis, y=y_axis, color=group_by) |
|
else: |
|
st.warning("โ ๏ธ Unsupported chart type suggested.") |
|
return None |
|
|
|
fig.update_layout(title=f"{chart_type.title()} Plot of {y_axis} by {x_axis}") |
|
return fig |
|
|
|
|
|
st.title("๐ GPT-4o Powered Data Visualization") |
|
uploaded_file = st.file_uploader("Upload CSV File", type=["csv"]) |
|
query = st.text_input("Ask a question about the data:") |
|
|
|
if uploaded_file: |
|
df = load_data(uploaded_file) |
|
st.write("### Dataset Preview", df.head()) |
|
|
|
if query and st.button("Generate Visualization"): |
|
llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") |
|
suggestion = ask_gpt4o_for_viz(query, df, llm) |
|
if suggestion: |
|
fig = generate_viz(suggestion, df) |
|
if fig: |
|
st.plotly_chart(fig, use_container_width=True) |
|
else: |
|
st.info("Upload a CSV file to get started.") |