import gradio as gr import pandas as pd import matplotlib.pyplot as plt import io import google.generativeai as genai from PIL import Image import ast def process_file(api_key, file, instructions): # Configure Gemini API with correct model genai.configure(api_key=api_key) model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25') try: # Read uploaded file if file.name.endswith('.csv'): df = pd.read_csv(file.name) else: df = pd.read_excel(file.name) except Exception as e: print(f"File Error: {str(e)}") return [None, None, None] # Enhanced prompt with strict coding rules prompt = f"""Generate 3 distinct matplotlib visualization codes with these rules: 1. Use ONLY these variables: df (existing DataFrame), plt 2. No imports or additional data loading 3. Each visualization must: - Start with: plt.figure(figsize=(16,9), dpi=120) - Use plt.style.use('ggplot') - Include title, axis labels, and grid - End with plt.tight_layout() 4. Different chart types (bar, line, scatter, etc) Dataset columns: {list(df.columns)} Sample data: {df.head(3).to_dict()} User instructions: {instructions or 'None'} Format EXACTLY as: # Visualization 1 plt.figure(figsize=(16,9), dpi=120) plt.style.use('ggplot') # Visualization code using df plt.tight_layout() """ try: response = model.generate_content(prompt) code_blocks = response.text.split("# Visualization ")[1:4] except Exception as e: print(f"Gemini Error: {str(e)}") return [None, None, None] visualizations = [] for i, block in enumerate(code_blocks, 1): buf = io.BytesIO() try: # Clean and validate generated code cleaned_code = '\n'.join([ line.replace('data', 'df').split('#')[0].strip() for line in block.split('\n')[1:] if line.strip() and not any(s in line.lower() for s in ['import', 'data=', 'data =']) ]) # Syntax check ast.parse(cleaned_code) # Execute code in controlled environment exec_env = {'df': df, 'plt': plt} plt.figure(figsize=(16, 9), dpi=120) exec(cleaned_code, exec_env) # Save HD image plt.savefig(buf, format='png', bbox_inches='tight') plt.close() buf.seek(0) visualizations.append(Image.open(buf)) except Exception as e: print(f"Visualization {i} Error: {str(e)}") print(f"Problematic Code:\n{cleaned_code}") visualizations.append(None) # Ensure exactly 3 outputs return visualizations + [None]*(3-len(visualizations)) # Gradio interface with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo: gr.Markdown("# 🔍 Data Visualization Generator") with gr.Row(): api_key = gr.Textbox( label="🔑 Gemini API Key", type="password", placeholder="Enter your API key here" ) file = gr.File( label="📁 Upload Dataset", file_types=[".csv", ".xlsx"], type="filepath" ) instructions = gr.Textbox( label="💡 Custom Instructions", placeholder="E.g.: Compare sales trends across regions..." ) submit = gr.Button("🚀 Generate Visualizations", variant="primary") with gr.Row(): outputs = [ gr.Image( label=f"Visualization {i+1}", width=600, height=400 ) for i in range(3) ] submit.click( process_file, inputs=[api_key, file, instructions], outputs=outputs ) if __name__ == "__main__": demo.launch()