bluenevus's picture
Update app.py
eb443e3 verified
raw
history blame
3.28 kB
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image
def process_file(api_key, file, instructions):
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-latest')
# Read file with error handling
try:
if file.name.endswith('.csv'):
df = pd.read_csv(file.name)
else:
df = pd.read_excel(file.name)
except Exception as e:
return [f"File Error: {str(e)}"] * 3
# Enhanced prompt with visualization requirements
prompt = f"""Generate exactly 3 distinct matplotlib visualization codes for this dataset:
Columns: {list(df.columns)}
Data types: {dict(df.dtypes)}
Sample data: {df.head(3).to_dict()}
Requirements:
1. 1920x1080 resolution (figsize=(16,9), dpi=120)
2. Professional styling (seaborn style, grid lines)
3. Clear labels and titles
4. Diverse chart types (at least 1 must be non-basic)
User instructions: {instructions or 'No specific instructions'}
Format response strictly as:
# Visualization 1
plt.figure(figsize=(16,9), dpi=120)
[code]
plt.tight_layout()
# Visualization 2
...
"""
response = model.generate_content(prompt)
code_blocks = response.text.split("# Visualization ")[1:4] # Force 3 max
visualizations = []
for i, block in enumerate(code_blocks, 1):
try:
# HD configuration
plt.figure(figsize=(16, 9), dpi=120)
plt.style.use('seaborn')
# Clean and execute code
cleaned_code = '\n'.join([line.strip() for line in block.split('\n')[1:] if line.strip()])
exec(cleaned_code, {'df': df, 'plt': plt})
plt.title(f"Visualization {i}", fontsize=14)
plt.tight_layout()
# Save HD image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=120, bbox_inches='tight')
plt.close()
buf.seek(0)
visualizations.append(Image.open(buf))
except Exception as e:
print(f"Visualization {i} failed: {str(e)}")
visualizations.append(Image.new('RGB', (1920, 1080), color=(73, 109, 137)))
# Ensure exactly 3 outputs with fallback
while len(visualizations) < 3:
visualizations.append(Image.new('RGB', (1920, 1080), color=(73, 109, 137)))
return visualizations[:3]
# Gradio interface with HD layout
with gr.Blocks(theme=gr.themes.Default(spacing_size="xl")) as demo:
gr.Markdown("# **HD Data Visualizer** πŸ“Šβœ¨")
with gr.Row():
api_key = gr.Textbox(label="πŸ”‘ Gemini API Key", type="password")
file = gr.File(label="πŸ“ Upload Dataset", file_types=[".csv", ".xlsx"])
instructions = gr.Textbox(label="πŸ’‘ Custom Instructions (optional)", placeholder="E.g.: Focus on time series patterns...")
submit = gr.Button("πŸš€ Generate Visualizations", variant="primary")
with gr.Row():
outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)]
submit.click(
process_file,
inputs=[api_key, file, instructions],
outputs=outputs
)
demo.launch()