bluenevus's picture
Update app.py
5be932a verified
raw
history blame
4.11 kB
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image
import ast
import re
def process_file(api_key, file, instructions):
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
try:
df = pd.read_csv(file.name) if file.name.endswith('.csv') else pd.read_excel(file.name)
if df.empty:
raise ValueError("Uploaded file contains no data")
except Exception as e:
print(f"Data Error: {str(e)}")
return [generate_error_image(str(e))]*3
# Enhanced prompt with strict plotting requirements
prompt = f"""Generate 3 matplotlib codes with these rules:
1. Use ONLY these variables: df (DataFrame), plt
2. Each visualization MUST:
- Plot actual data from df
- Include title, axis labels, and data labels if needed
- Use clear color schemes
- Avoid empty plots
3. Code structure:
plt.figure(figsize=(16,9), dpi=120)
plt.style.use('ggplot')
# Plotting code using df columns: {list(df.columns)}
plt.tight_layout()
Sample data: {df.head(3).to_dict()}
User instructions: {instructions or 'General insights'}
Format EXACTLY as:
# Visualization 1
[complete code]
"""
try:
response = model.generate_content(prompt)
code_blocks = re.split(r'# Visualization \d+', response.text)[1:4]
except Exception as e:
return [generate_error_image("API Error")]*3
visualizations = []
for i, block in enumerate(code_blocks, 1):
try:
# Advanced code sanitization
cleaned_code = sanitize_code(block, df.columns)
# Validate and execute
ast.parse(cleaned_code)
img = execute_plot_code(cleaned_code, df)
visualizations.append(img)
except Exception as e:
print(f"Visualization {i} Error: {str(e)}")
visualizations.append(generate_error_image(f"Plot {i} Error"))
return visualizations + [generate_error_image("Not Generated")]*(3-len(visualizations))
def sanitize_code(code_block, columns):
"""Clean and validate generated code"""
replacements = {
r"'y_axis'": f"'{columns[1]}'" if len(columns) > 1 else "'Value'",
r"'x_axis'": f"'{columns[0]}'",
r"data": "df",
r"plt.legend\(\)": "" # Remove empty legend calls
}
cleaned = []
for line in code_block.split('\n'):
line = line.strip()
if not line or line.startswith('`'):
continue
# Apply replacements
for pattern, replacement in replacements.items():
line = re.sub(pattern, replacement, line)
cleaned.append(line)
return '\n'.join(cleaned)
def execute_plot_code(code, df):
"""Safely execute plotting code"""
buf = io.BytesIO()
plt.figure(figsize=(16, 9), dpi=120)
plt.style.use('ggplot')
try:
exec(code, {'df': df, 'plt': plt})
plt.tight_layout()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
return Image.open(buf)
finally:
plt.close()
def generate_error_image(message):
"""Create error indication image"""
img = Image.new('RGB', (1920, 1080), color=(73, 109, 137))
return img
# Gradio interface
with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
gr.Markdown("# Professional Data Visualizer")
with gr.Row():
api_key = gr.Textbox(label="Gemini API Key", type="password")
file = gr.File(label="Upload Data File", file_types=[".csv", ".xlsx"])
instructions = gr.Textbox(label="Visualization Instructions")
submit = gr.Button("Generate Insights", variant="primary")
with gr.Row():
outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)]
submit.click(
process_file,
inputs=[api_key, file, instructions],
outputs=outputs
)
demo.launch()