def ask_gpt4o_for_visualization(query, df, llm): columns = ', '.join(df.columns) prompt = f""" Analyze the query and suggest one or more relevant visualizations. Query: "{query}" Available Columns: {columns} Respond in this JSON format (as a list if multiple suggestions): [ {{ "chart_type": "bar/box/line/scatter", "x_axis": "column_name", "y_axis": "column_name", "group_by": "optional_column_name" }} ] """ response = llm.generate(prompt) try: return json.loads(response) except json.JSONDecodeError: st.error("⚠️ GPT-4o failed to generate a valid suggestion.") return None def add_stats_to_figure(fig, df, y_axis, chart_type): """ Add relevant statistical annotations to the visualization based on the chart type. """ # Check if the y-axis column is numeric if not pd.api.types.is_numeric_dtype(df[y_axis]): st.warning(f"⚠️ Cannot compute statistics for non-numeric column: {y_axis}") return fig # Compute statistics for numeric data min_val = df[y_axis].min() max_val = df[y_axis].max() avg_val = df[y_axis].mean() median_val = df[y_axis].median() std_dev_val = df[y_axis].std() # Format the stats for display stats_text = ( f"📊 **Statistics**\n\n" f"- **Min:** ${min_val:,.2f}\n" f"- **Max:** ${max_val:,.2f}\n" f"- **Average:** ${avg_val:,.2f}\n" f"- **Median:** ${median_val:,.2f}\n" f"- **Std Dev:** ${std_dev_val:,.2f}" ) # Apply stats only to relevant chart types if chart_type in ["bar", "line"]: # Add annotation box for bar and line charts fig.add_annotation( text=stats_text, xref="paper", yref="paper", x=1.02, y=1, showarrow=False, align="left", font=dict(size=12, color="black"), bordercolor="gray", borderwidth=1, bgcolor="rgba(255, 255, 255, 0.85)" ) # Add horizontal reference lines fig.add_hline(y=min_val, line_dash="dot", line_color="red", annotation_text="Min", annotation_position="bottom right") fig.add_hline(y=median_val, line_dash="dash", line_color="orange", annotation_text="Median", annotation_position="top right") fig.add_hline(y=avg_val, line_dash="dashdot", line_color="green", annotation_text="Avg", annotation_position="top right") fig.add_hline(y=max_val, line_dash="dot", line_color="blue", annotation_text="Max", annotation_position="top right") elif chart_type == "scatter": # Add stats annotation only, no lines for scatter plots fig.add_annotation( text=stats_text, xref="paper", yref="paper", x=1.02, y=1, showarrow=False, align="left", font=dict(size=12, color="black"), bordercolor="gray", borderwidth=1, bgcolor="rgba(255, 255, 255, 0.85)" ) elif chart_type == "box": # Box plots inherently show distribution; no extra stats needed pass elif chart_type == "pie": # Pie charts represent proportions, not suitable for stats st.info("📊 Pie charts represent proportions. Additional stats are not applicable.") elif chart_type == "heatmap": # Heatmaps already reflect data intensity st.info("📊 Heatmaps inherently reflect distribution. No additional stats added.") else: st.warning(f"⚠️ No statistical overlays applied for unsupported chart type: '{chart_type}'.") return fig # Dynamically generate Plotly visualizations based on GPT-4o suggestions def generate_visualization(suggestion, df): """ Generate a Plotly visualization based on GPT-4o's suggestion. If the Y-axis is missing, infer it intelligently. """ chart_type = suggestion.get("chart_type", "bar").lower() x_axis = suggestion.get("x_axis") y_axis = suggestion.get("y_axis") group_by = suggestion.get("group_by") # Step 1: Infer Y-axis if not provided if not y_axis: numeric_columns = df.select_dtypes(include='number').columns.tolist() # Avoid using the same column for both axes if x_axis in numeric_columns: numeric_columns.remove(x_axis) # Smart guess: prioritize salary or relevant metrics if available priority_columns = ["salary_in_usd", "income", "earnings", "revenue"] for col in priority_columns: if col in numeric_columns: y_axis = col break # Fallback to the first numeric column if no priority columns exist if not y_axis and numeric_columns: y_axis = numeric_columns[0] # Step 2: Validate axes if not x_axis or not y_axis: st.warning("⚠️ Unable to determine appropriate columns for visualization.") return None # Step 3: Dynamically select the Plotly function plotly_function = getattr(px, chart_type, None) if not plotly_function: st.warning(f"⚠️ Unsupported chart type '{chart_type}' suggested by GPT-4o.") return None # Step 4: Prepare dynamic plot arguments plot_args = {"data_frame": df, "x": x_axis, "y": y_axis} if group_by and group_by in df.columns: plot_args["color"] = group_by try: # Step 5: Generate the visualization fig = plotly_function(**plot_args) fig.update_layout( title=f"{chart_type.title()} Plot of {y_axis.replace('_', ' ').title()} by {x_axis.replace('_', ' ').title()}", xaxis_title=x_axis.replace('_', ' ').title(), yaxis_title=y_axis.replace('_', ' ').title(), ) # Step 6: Apply statistics intelligently fig = add_statistics_to_visualization(fig, df, y_axis, chart_type) return fig except Exception as e: st.error(f"⚠️ Failed to generate visualization: {e}") return None def generate_multiple_visualizations(suggestions, df): """ Generates one or more visualizations based on GPT-4o's suggestions. Handles both single and multiple suggestions. """ visualizations = [] for suggestion in suggestions: fig = generate_visualization(suggestion, df) if fig: # Apply chart-specific statistics fig = add_stats_to_figure(fig, df, suggestion["y_axis"], suggestion["chart_type"]) visualizations.append(fig) if not visualizations and suggestions: st.warning("⚠️ No valid visualization found. Displaying the most relevant one.") best_suggestion = suggestions[0] fig = generate_visualization(best_suggestion, df) fig = add_stats_to_figure(fig, df, best_suggestion["y_axis"], best_suggestion["chart_type"]) visualizations.append(fig) return visualizations def handle_visualization_suggestions(suggestions, df): """ Determines whether to generate a single or multiple visualizations. """ visualizations = [] # If multiple suggestions, generate multiple plots if isinstance(suggestions, list) and len(suggestions) > 1: visualizations = generate_multiple_visualizations(suggestions, df) # If only one suggestion, generate a single plot elif isinstance(suggestions, dict) or (isinstance(suggestions, list) and len(suggestions) == 1): suggestion = suggestions[0] if isinstance(suggestions, list) else suggestions fig = generate_visualization(suggestion, df) if fig: visualizations.append(fig) # Handle cases when no visualization could be generated if not visualizations: st.warning("⚠️ Unable to generate any visualization based on the suggestion.") # Display all generated visualizations for fig in visualizations: st.plotly_chart(fig, use_container_width=True)