import gradio as gr import pandas as pd import matplotlib.pyplot as plt import io import ast from PIL import Image, ImageDraw import google.generativeai as genai import traceback def process_file(file, instructions, api_key): try: # Initialize Gemini genai.configure(api_key=api_key) model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25') # Read uploaded file file_path = file.name df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path) # Generate visualization code response = model.generate_content(f""" Analyze the following dataset and instructions: Data columns: {list(df.columns)} Data shape: {df.shape} Instructions: {instructions} Based on this, create 3 appropriate visualizations that provide meaningful insights. For each visualization: 1. Choose the most suitable plot type (bar, line, scatter, hist, pie, heatmap) 2. Determine appropriate data aggregation (e.g., top 5 categories, monthly averages) 3. Select relevant columns for x-axis, y-axis, and any additional dimensions (color, size) 4. Provide a clear, concise title that explains the insight Consider data density and choose visualizations that simplify and clarify the information. Limit the number of data points displayed to ensure readability (e.g., top 5, top 10). Return your response as a Python list of dictionaries: [ {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}}, {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}}, {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}} ] """) # Extract code block safely code_block = response.text if '```python' in code_block: code_block = code_block.split('```python')[1].split('```')[0].strip() elif '```' in code_block: code_block = code_block.split('```')[1].strip() print("Generated code block:") print(code_block) plots = ast.literal_eval(code_block) # Generate visualizations images = [] for plot in plots[:3]: # Ensure max 3 plots fig, ax = plt.subplots(figsize=(10, 6)) # Apply preprocessing and aggregation plot_df = df.copy() if plot['agg_func'] == 'sum': plot_df = plot_df.groupby(plot['x'])[plot['y']].sum().reset_index() elif plot['agg_func'] == 'mean': plot_df = plot_df.groupby(plot['x'])[plot['y']].mean().reset_index() elif plot['agg_func'] == 'count': plot_df = plot_df.groupby(plot['x']).size().reset_index(name=plot['y']) if 'top_n' in plot and plot['top_n']: plot_df = plot_df.nlargest(plot['top_n'], plot['y']) if plot['plot_type'] == 'bar': plot_df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax) elif plot['plot_type'] == 'line': plot_df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax) elif plot['plot_type'] == 'scatter': plot_df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax, c=plot['additional'].get('color'), s=plot_df[plot['additional'].get('size', 'y')]) elif plot['plot_type'] == 'hist': plot_df[plot['x']].hist(ax=ax, bins=20) elif plot['plot_type'] == 'pie': plot_df.plot(kind='pie', y=plot['y'], labels=plot_df[plot['x']], ax=ax, autopct='%1.1f%%') elif plot['plot_type'] == 'heatmap': pivot_df = plot_df.pivot(index=plot['x'], columns=plot['additional']['color'], values=plot['y']) ax.imshow(pivot_df, cmap='YlOrRd') ax.set_xticks(range(len(pivot_df.columns))) ax.set_yticks(range(len(pivot_df.index))) ax.set_xticklabels(pivot_df.columns) ax.set_yticklabels(pivot_df.index) ax.set_title(plot['title']) if plot['plot_type'] != 'pie': ax.set_xlabel(plot['x']) ax.set_ylabel(plot['y']) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) img = Image.open(buf) images.append(img) plt.close(fig) return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images)) except Exception as e: error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" print(error_message) # Print to console for debugging error_image = Image.new('RGB', (800, 400), (255, 255, 255)) draw = ImageDraw.Draw(error_image) draw.text((10, 10), error_message, fill=(255, 0, 0)) return [error_image] * 3 with gr.Blocks(theme=gr.themes.Default()) as demo: gr.Markdown("# Data Analysis Dashboard") with gr.Row(): file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"]) instructions = gr.Textbox(label="Analysis Instructions", placeholder="Describe the analysis you want...") api_key = gr.Textbox(label="Gemini API Key", type="password") submit = gr.Button("Generate Insights", variant="primary") output_images = [gr.Image(label=f"Visualization {i+1}") for i in range(3)] submit.click( process_file, inputs=[file, instructions, api_key], outputs=output_images ) if __name__ == "__main__": demo.launch()