|
import gradio as gr |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import io |
|
import google.generativeai as genai |
|
from PIL import Image |
|
import ast |
|
|
|
def process_file(api_key, file, instructions): |
|
|
|
genai.configure(api_key=api_key) |
|
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25') |
|
|
|
try: |
|
|
|
if file.name.endswith('.csv'): |
|
df = pd.read_csv(file.name) |
|
else: |
|
df = pd.read_excel(file.name) |
|
except Exception as e: |
|
print(f"File Error: {str(e)}") |
|
return [None, None, None] |
|
|
|
|
|
prompt = f"""Generate 3 distinct matplotlib visualization codes with these rules: |
|
1. Use ONLY these variables: df (existing DataFrame), plt |
|
2. No imports or additional data loading |
|
3. Each visualization must: |
|
- Start with: plt.figure(figsize=(16,9), dpi=120) |
|
- Use plt.style.use('ggplot') |
|
- Include title, axis labels, and grid |
|
- End with plt.tight_layout() |
|
4. Different chart types (bar, line, scatter, etc) |
|
|
|
Dataset columns: {list(df.columns)} |
|
Sample data: {df.head(3).to_dict()} |
|
User instructions: {instructions or 'None'} |
|
|
|
Format EXACTLY as: |
|
# Visualization 1 |
|
plt.figure(figsize=(16,9), dpi=120) |
|
plt.style.use('ggplot') |
|
# Visualization code using df |
|
plt.tight_layout() |
|
""" |
|
|
|
try: |
|
response = model.generate_content(prompt) |
|
code_blocks = response.text.split("# Visualization ")[1:4] |
|
except Exception as e: |
|
print(f"Gemini Error: {str(e)}") |
|
return [None, None, None] |
|
|
|
visualizations = [] |
|
for i, block in enumerate(code_blocks, 1): |
|
buf = io.BytesIO() |
|
try: |
|
|
|
cleaned_code = '\n'.join([ |
|
line.replace('data', 'df').split('#')[0].strip() |
|
for line in block.split('\n')[1:] |
|
if line.strip() and |
|
not any(s in line.lower() for s in ['import', 'data=', 'data =']) |
|
]) |
|
|
|
|
|
ast.parse(cleaned_code) |
|
|
|
|
|
exec_env = {'df': df, 'plt': plt} |
|
plt.figure(figsize=(16, 9), dpi=120) |
|
exec(cleaned_code, exec_env) |
|
|
|
|
|
plt.savefig(buf, format='png', bbox_inches='tight') |
|
plt.close() |
|
buf.seek(0) |
|
visualizations.append(Image.open(buf)) |
|
except Exception as e: |
|
print(f"Visualization {i} Error: {str(e)}") |
|
print(f"Problematic Code:\n{cleaned_code}") |
|
visualizations.append(None) |
|
|
|
|
|
return visualizations + [None]*(3-len(visualizations)) |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo: |
|
gr.Markdown("# π Data Visualization Generator") |
|
|
|
with gr.Row(): |
|
api_key = gr.Textbox( |
|
label="π Gemini API Key", |
|
type="password", |
|
placeholder="Enter your API key here" |
|
) |
|
file = gr.File( |
|
label="π Upload Dataset", |
|
file_types=[".csv", ".xlsx"], |
|
type="filepath" |
|
) |
|
|
|
instructions = gr.Textbox( |
|
label="π‘ Custom Instructions", |
|
placeholder="E.g.: Compare sales trends across regions..." |
|
) |
|
|
|
submit = gr.Button("π Generate Visualizations", variant="primary") |
|
|
|
with gr.Row(): |
|
outputs = [ |
|
gr.Image( |
|
label=f"Visualization {i+1}", |
|
width=600, |
|
height=400 |
|
) for i in range(3) |
|
] |
|
|
|
submit.click( |
|
process_file, |
|
inputs=[api_key, file, instructions], |
|
outputs=outputs |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |