File size: 3,507 Bytes
4fc79a4 12ce912 72c5969 bca92aa ec365ce 3c50a2d 4fc79a4 601022d 9a4fc1b ec365ce 422964b ec365ce 9a4fc1b bca92aa 9a4fc1b 72c5969 601022d 72c5969 3c50a2d 72c5969 9a4fc1b bca92aa ec365ce 72c5969 1b2886c b06a85b ec365ce 1b2886c ec365ce 1b2886c ec365ce 1b2886c ec365ce 601022d 1b2886c 601022d 1b2886c ec365ce 9a4fc1b 1b2886c 9a4fc1b 1b2886c b06a85b 1b2886c 12ce912 b06a85b 12ce912 9a4fc1b 3c50a2d ec365ce 3c50a2d b06a85b 9a4fc1b 601022d ec365ce 5be932a 23a3b49 ec365ce 601022d 6cff8d5 601022d bca92aa 601022d bca92aa 601022d b06a85b bca92aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import ast
from PIL import Image, ImageDraw
import google.generativeai as genai
import traceback
def process_file(file, instructions, api_key):
try:
# Initialize Gemini
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
# Read uploaded file
file_path = file.name
df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path)
# Generate visualization code
response = model.generate_content(f"""
Create 3 matplotlib visualization codes based on: {instructions}
Data columns: {list(df.columns)}
Return Python code as: [('title','plot_type','x','y'), ...]
Allowed plot_types: bar, line, scatter, hist
Use only DataFrame 'df' and these exact variable names.
""")
# Extract code block safely
code_block = response.text
if '```python' in code_block:
code_block = code_block.split('```python')[1].split('```')[0].strip()
elif '```' in code_block:
code_block = code_block.split('```')[1].strip()
print("Generated code block:")
print(code_block)
plots = ast.literal_eval(code_block)
# Generate visualizations
images = []
for plot in plots[:3]: # Ensure max 3 plots
fig, ax = plt.subplots(figsize=(10, 6))
title, plot_type, x, y = plot
if plot_type == 'bar':
df.plot(kind='bar', x=x, y=y, ax=ax)
elif plot_type == 'line':
df.plot(kind='line', x=x, y=y, ax=ax)
elif plot_type == 'scatter':
df.plot(kind='scatter', x=x, y=y, ax=ax)
elif plot_type == 'hist':
df[x].hist(ax=ax)
ax.set_title(title)
ax.set_xlabel(x)
ax.set_ylabel(y if y else 'Frequency')
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
img = Image.open(buf)
images.append(img)
plt.close(fig)
return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images))
except Exception as e:
error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_message) # Print to console for debugging
error_image = Image.new('RGB', (800, 400), (255, 255, 255))
draw = ImageDraw.Draw(error_image)
draw.text((10, 10), error_message, fill=(255, 0, 0))
return [error_image] * 3
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown("# Data Analysis Dashboard")
with gr.Row():
file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
instructions = gr.Textbox(label="Analysis Instructions", placeholder="Describe the analysis you want...")
api_key = gr.Textbox(label="Gemini API Key", type="password")
submit = gr.Button("Generate Insights", variant="primary")
output_images = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
submit.click(
process_file,
inputs=[file, instructions, api_key],
outputs=output_images
)
if __name__ == "__main__":
demo.launch() |