|
import gradio as gr |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import io |
|
import google.generativeai as genai |
|
from PIL import Image |
|
import ast |
|
import re |
|
|
|
def process_file(api_key, file, instructions): |
|
genai.configure(api_key=api_key) |
|
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25') |
|
|
|
try: |
|
df = pd.read_csv(file.name) if file.name.endswith('.csv') else pd.read_excel(file.name) |
|
if df.empty: |
|
raise ValueError("Uploaded file contains no data") |
|
except Exception as e: |
|
print(f"Data Error: {str(e)}") |
|
return [generate_error_image(str(e))]*3 |
|
|
|
|
|
prompt = f"""Generate 3 matplotlib codes with these rules: |
|
1. Use ONLY these variables: df (DataFrame), plt |
|
2. Each visualization MUST: |
|
- Plot actual data from df |
|
- Include title, axis labels, and data labels if needed |
|
- Use clear color schemes |
|
- Avoid empty plots |
|
3. Code structure: |
|
plt.figure(figsize=(16,9), dpi=120) |
|
plt.style.use('ggplot') |
|
# Plotting code using df columns: {list(df.columns)} |
|
plt.tight_layout() |
|
|
|
Sample data: {df.head(3).to_dict()} |
|
User instructions: {instructions or 'General insights'} |
|
|
|
Format EXACTLY as: |
|
# Visualization 1 |
|
[complete code] |
|
""" |
|
|
|
try: |
|
response = model.generate_content(prompt) |
|
code_blocks = re.split(r'# Visualization \d+', response.text)[1:4] |
|
except Exception as e: |
|
return [generate_error_image("API Error")]*3 |
|
|
|
visualizations = [] |
|
for i, block in enumerate(code_blocks, 1): |
|
try: |
|
|
|
cleaned_code = sanitize_code(block, df.columns) |
|
|
|
|
|
ast.parse(cleaned_code) |
|
img = execute_plot_code(cleaned_code, df) |
|
visualizations.append(img) |
|
except Exception as e: |
|
print(f"Visualization {i} Error: {str(e)}") |
|
visualizations.append(generate_error_image(f"Plot {i} Error")) |
|
|
|
return visualizations + [generate_error_image("Not Generated")]*(3-len(visualizations)) |
|
|
|
def sanitize_code(code_block, columns): |
|
"""Clean and validate generated code""" |
|
replacements = { |
|
r"'y_axis'": f"'{columns[1]}'" if len(columns) > 1 else "'Value'", |
|
r"'x_axis'": f"'{columns[0]}'", |
|
r"data": "df", |
|
r"plt.legend\(\)": "" |
|
} |
|
|
|
cleaned = [] |
|
for line in code_block.split('\n'): |
|
line = line.strip() |
|
if not line or line.startswith('`'): |
|
continue |
|
|
|
|
|
for pattern, replacement in replacements.items(): |
|
line = re.sub(pattern, replacement, line) |
|
|
|
cleaned.append(line) |
|
|
|
return '\n'.join(cleaned) |
|
|
|
def execute_plot_code(code, df): |
|
"""Safely execute plotting code""" |
|
buf = io.BytesIO() |
|
plt.figure(figsize=(16, 9), dpi=120) |
|
plt.style.use('ggplot') |
|
|
|
try: |
|
exec(code, {'df': df, 'plt': plt}) |
|
plt.tight_layout() |
|
plt.savefig(buf, format='png', bbox_inches='tight') |
|
buf.seek(0) |
|
return Image.open(buf) |
|
finally: |
|
plt.close() |
|
|
|
def generate_error_image(message): |
|
"""Create error indication image""" |
|
img = Image.new('RGB', (1920, 1080), color=(73, 109, 137)) |
|
return img |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo: |
|
gr.Markdown("# Professional Data Visualizer") |
|
|
|
with gr.Row(): |
|
api_key = gr.Textbox(label="Gemini API Key", type="password") |
|
file = gr.File(label="Upload Data File", file_types=[".csv", ".xlsx"]) |
|
|
|
instructions = gr.Textbox(label="Visualization Instructions") |
|
submit = gr.Button("Generate Insights", variant="primary") |
|
|
|
with gr.Row(): |
|
outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)] |
|
|
|
submit.click( |
|
process_file, |
|
inputs=[api_key, file, instructions], |
|
outputs=outputs |
|
) |
|
|
|
demo.launch() |