import gradio as gr import pandas as pd import matplotlib.pyplot as plt import io import ast from PIL import Image, ImageDraw import google.generativeai as genai import traceback def process_file(file, instructions, api_key): try: # Initialize Gemini genai.configure(api_key=api_key) model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25') # Read uploaded file file_path = file.name df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path) # Generate visualization code response = model.generate_content(f""" Analyze the following dataset and instructions: Data columns: {list(df.columns)} Instructions: {instructions} Based on this, create 3 appropriate visualizations. For each visualization, provide: 1. A title 2. The most suitable plot type (choose from: bar, line, scatter, hist) 3. The column to use for the x-axis 4. The column(s) to use for the y-axis (can be a list for multiple columns, or None for histograms) 5. Any necessary data preprocessing steps (e.g., grouping, sorting, etc.) Return your response as a Python list of dictionaries: [ {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "preprocessing": "..."}}, {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "preprocessing": "..."}}, {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "preprocessing": "..."}} ] """) # Extract code block safely code_block = response.text if '```python' in code_block: code_block = code_block.split('```python')[1].split('```')[0].strip() elif '```' in code_block: code_block = code_block.split('```')[1].strip() print("Generated code block:") print(code_block) plots = ast.literal_eval(code_block) # Generate visualizations images = [] for plot in plots[:3]: # Ensure max 3 plots fig, ax = plt.subplots(figsize=(10, 6)) # Apply preprocessing if any if plot['preprocessing']: exec(plot['preprocessing']) if plot['plot_type'] == 'bar': df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax) elif plot['plot_type'] == 'line': df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax) elif plot['plot_type'] == 'scatter': df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax) elif plot['plot_type'] == 'hist': df[plot['x']].hist(ax=ax) ax.set_title(plot['title']) ax.set_xlabel(plot['x']) ax.set_ylabel(plot['y'] if plot['y'] else 'Frequency') plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) img = Image.open(buf) images.append(img) plt.close(fig) return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images)) except Exception as e: error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" print(error_message) # Print to console for debugging error_image = Image.new('RGB', (800, 400), (255, 255, 255)) draw = ImageDraw.Draw(error_image) draw.text((10, 10), error_message, fill=(255, 0, 0)) return [error_image] * 3 with gr.Blocks(theme=gr.themes.Default()) as demo: gr.Markdown("# Data Analysis Dashboard") with gr.Row(): file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"]) instructions = gr.Textbox(label="Analysis Instructions", placeholder="Describe the analysis you want...") api_key = gr.Textbox(label="Gemini API Key", type="password") submit = gr.Button("Generate Insights", variant="primary") output_images = [gr.Image(label=f"Visualization {i+1}") for i in range(3)] submit.click( process_file, inputs=[file, instructions, api_key], outputs=output_images ) if __name__ == "__main__": demo.launch()