bluenevus's picture
Update app.py
12ce912 verified
raw
history blame
4.7 kB
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image
import ast
import re
import traceback
def process_file(api_key, file, instructions):
# Configure Gemini API with error handling
try:
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
except Exception as e:
return [generate_error_image(f"API Config Error: {str(e)}")] * 3
# Load data with validation
try:
if file.name.endswith('.csv'):
df = pd.read_csv(file.name)
else:
df = pd.read_excel(file.name)
if df.empty:
raise ValueError("Empty dataset uploaded")
if len(df.columns) < 2:
raise ValueError("Dataset needs at least 2 columns")
except Exception as e:
return [generate_error_image(f"Data Error: {str(e)}")] * 3
# Enhanced prompt with strict formatting
prompt = f"""Generate 3 Python matplotlib codes with these rules:
1. Perfect Python 3.10 syntax with 4-space indentation
2. Complete code blocks with proper indentation
3. Use ONLY these variables: df, plt
4. Each visualization must:
- Start with: plt.figure(figsize=(16,9), dpi=120)
- Use plt.style.use('ggplot')
- Include title, axis labels, and data visualization
- End with plt.tight_layout()
Dataset Columns: {list(df.columns)}
Data Sample: {df.head(3).to_dict()}
User Instructions: {instructions or 'Show general trends'}
Format EXACTLY as:
# Visualization 1
[properly indented code]
# Visualization 2
[properly indented code]
# Visualization 3
[properly indented code]
"""
# Get and process Gemini response
try:
response = model.generate_content(prompt)
code_blocks = re.split(r'# Visualization \d+\s*', response.text)[1:4]
except Exception as e:
return [generate_error_image("API Response Error")] * 3
visualizations = []
for i, block in enumerate(code_blocks, 1):
try:
# Clean and validate code
cleaned_code = sanitize_code(block, df.columns)
img = execute_plot_code(cleaned_code, df)
visualizations.append(img)
except Exception as e:
print(f"Visualization {i} Error:\n{traceback.format_exc()}")
visualizations.append(generate_error_image(f"Error in Viz {i}"))
# Ensure exactly 3 outputs
return visualizations + [generate_error_image("Not Generated")]*(3-len(visualizations))
def sanitize_code(code_block, columns):
"""Fix indentation and syntax issues in generated code"""
lines = []
indent_level = 0
indent_size = 4
stack = []
for line in code_block.split('\n'):
line = line.rstrip()
if not line:
continue
# Remove markdown artifacts
line = re.sub(r'^```python|```$', '', line)
# Handle indentation triggers
if re.match(r'^\s*(for|if|while|with|def|class|try|except|else|elif)\b', line):
stack.append(indent_level)
indent_level += 1
elif re.match(r'^\s*(return|pass|break|continue|raise)', line):
indent_level = max(0, indent_level - 1)
# Apply current indentation
current_indent = ' ' * (indent_level * indent_size)
cleaned_line = current_indent + line.lstrip()
# Check for dedent patterns
if re.match(r'^\s*(\}|\)|]|else:|elif |except)', line):
if stack:
indent_level = stack.pop()
lines.append(cleaned_line)
cleaned_code = '\n'.join(lines)
# Validate syntax
try:
ast.parse(cleaned_code)
except SyntaxError as e:
raise ValueError(f"Syntax Error: {str(e)}")
return cleaned_code
def execute_plot_code(code, df):
"""Safely execute plotting code with resource management"""
buf = io.BytesIO()
plt.figure(figsize=(16, 9), dpi=120)
plt.style.use('ggplot')
try:
exec(code, {'df': df, 'plt': plt})
plt.tight_layout()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
return Image.open(buf)
except Exception as e:
raise RuntimeError(f"Execution Error: {str(e)}")
finally:
plt.close()
def generate_error_image(message):
"""Create error indication image with message"""
from PIL import ImageDraw, ImageFont
img = Image.new('RGB', (1920, 1080), color=(255, 255, 255))
try:
draw = ImageDraw.Draw(img)
font = ImageFont