import gradio as gr import pandas as pd import matplotlib.pyplot as plt import io import google.generativeai as genai from PIL import Image import ast import re def process_file(api_key, file, instructions): genai.configure(api_key=api_key) model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25') try: df = pd.read_csv(file.name) if file.name.endswith('.csv') else pd.read_excel(file.name) if df.empty: raise ValueError("Uploaded file contains no data") except Exception as e: print(f"Data Error: {str(e)}") return [generate_error_image(str(e))]*3 # Enhanced prompt with strict plotting requirements prompt = f"""Generate 3 matplotlib codes with these rules: 1. Use ONLY these variables: df (DataFrame), plt 2. Each visualization MUST: - Plot actual data from df - Include title, axis labels, and data labels if needed - Use clear color schemes - Avoid empty plots 3. Code structure: plt.figure(figsize=(16,9), dpi=120) plt.style.use('ggplot') # Plotting code using df columns: {list(df.columns)} plt.tight_layout() Sample data: {df.head(3).to_dict()} User instructions: {instructions or 'General insights'} Format EXACTLY as: # Visualization 1 [complete code] """ try: response = model.generate_content(prompt) code_blocks = re.split(r'# Visualization \d+', response.text)[1:4] except Exception as e: return [generate_error_image("API Error")]*3 visualizations = [] for i, block in enumerate(code_blocks, 1): try: # Advanced code sanitization cleaned_code = sanitize_code(block, df.columns) # Validate and execute ast.parse(cleaned_code) img = execute_plot_code(cleaned_code, df) visualizations.append(img) except Exception as e: print(f"Visualization {i} Error: {str(e)}") visualizations.append(generate_error_image(f"Plot {i} Error")) return visualizations + [generate_error_image("Not Generated")]*(3-len(visualizations)) def sanitize_code(code_block, columns): """Clean and validate generated code""" replacements = { r"'y_axis'": f"'{columns[1]}'" if len(columns) > 1 else "'Value'", r"'x_axis'": f"'{columns[0]}'", r"data": "df", r"plt.legend\(\)": "" # Remove empty legend calls } cleaned = [] for line in code_block.split('\n'): line = line.strip() if not line or line.startswith('`'): continue # Apply replacements for pattern, replacement in replacements.items(): line = re.sub(pattern, replacement, line) cleaned.append(line) return '\n'.join(cleaned) def execute_plot_code(code, df): """Safely execute plotting code""" buf = io.BytesIO() plt.figure(figsize=(16, 9), dpi=120) plt.style.use('ggplot') try: exec(code, {'df': df, 'plt': plt}) plt.tight_layout() plt.savefig(buf, format='png', bbox_inches='tight') buf.seek(0) return Image.open(buf) finally: plt.close() def generate_error_image(message): """Create error indication image""" img = Image.new('RGB', (1920, 1080), color=(73, 109, 137)) return img # Gradio interface with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo: gr.Markdown("# Professional Data Visualizer") with gr.Row(): api_key = gr.Textbox(label="Gemini API Key", type="password") file = gr.File(label="Upload Data File", file_types=[".csv", ".xlsx"]) instructions = gr.Textbox(label="Visualization Instructions") submit = gr.Button("Generate Insights", variant="primary") with gr.Row(): outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)] submit.click( process_file, inputs=[api_key, file, instructions], outputs=outputs ) demo.launch()