bluenevus's picture
Update app.py
6cff8d5 verified
raw
history blame
3.95 kB
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image
import ast
def process_file(api_key, file, instructions):
# Configure Gemini API with correct model
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
try:
# Read uploaded file
if file.name.endswith('.csv'):
df = pd.read_csv(file.name)
else:
df = pd.read_excel(file.name)
except Exception as e:
print(f"File Error: {str(e)}")
return [None, None, None]
# Enhanced prompt with strict coding rules
prompt = f"""Generate 3 distinct matplotlib visualization codes with these rules:
1. Use ONLY these variables: df (existing DataFrame), plt
2. No imports or additional data loading
3. Each visualization must:
- Start with: plt.figure(figsize=(16,9), dpi=120)
- Use plt.style.use('ggplot')
- Include title, axis labels, and grid
- End with plt.tight_layout()
4. Different chart types (bar, line, scatter, etc)
Dataset columns: {list(df.columns)}
Sample data: {df.head(3).to_dict()}
User instructions: {instructions or 'None'}
Format EXACTLY as:
# Visualization 1
plt.figure(figsize=(16,9), dpi=120)
plt.style.use('ggplot')
# Visualization code using df
plt.tight_layout()
"""
try:
response = model.generate_content(prompt)
code_blocks = response.text.split("# Visualization ")[1:4]
except Exception as e:
print(f"Gemini Error: {str(e)}")
return [None, None, None]
visualizations = []
for i, block in enumerate(code_blocks, 1):
buf = io.BytesIO()
try:
# Clean and validate generated code
cleaned_code = '\n'.join([
line.replace('data', 'df').split('#')[0].strip()
for line in block.split('\n')[1:]
if line.strip() and
not any(s in line.lower() for s in ['import', 'data=', 'data ='])
])
# Syntax check
ast.parse(cleaned_code)
# Execute code in controlled environment
exec_env = {'df': df, 'plt': plt}
plt.figure(figsize=(16, 9), dpi=120)
exec(cleaned_code, exec_env)
# Save HD image
plt.savefig(buf, format='png', bbox_inches='tight')
plt.close()
buf.seek(0)
visualizations.append(Image.open(buf))
except Exception as e:
print(f"Visualization {i} Error: {str(e)}")
print(f"Problematic Code:\n{cleaned_code}")
visualizations.append(None)
# Ensure exactly 3 outputs
return visualizations + [None]*(3-len(visualizations))
# Gradio interface
with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
gr.Markdown("# πŸ” Data Visualization Generator")
with gr.Row():
api_key = gr.Textbox(
label="πŸ”‘ Gemini API Key",
type="password",
placeholder="Enter your API key here"
)
file = gr.File(
label="πŸ“ Upload Dataset",
file_types=[".csv", ".xlsx"],
type="filepath"
)
instructions = gr.Textbox(
label="πŸ’‘ Custom Instructions",
placeholder="E.g.: Compare sales trends across regions..."
)
submit = gr.Button("πŸš€ Generate Visualizations", variant="primary")
with gr.Row():
outputs = [
gr.Image(
label=f"Visualization {i+1}",
width=600,
height=400
) for i in range(3)
]
submit.click(
process_file,
inputs=[api_key, file, instructions],
outputs=outputs
)
if __name__ == "__main__":
demo.launch()