File size: 2,795 Bytes
4fc79a4 ff768e2 4fc79a4 29f87f1 4fc79a4 69a028d 12598b6 4fc79a4 5fcb5c4 69a028d e5b6950 29f87f1 e5b6950 a888f55 5fcb5c4 a888f55 e5b6950 5fcb5c4 a888f55 5fcb5c4 a888f55 e5b6950 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 e5b6950 5fcb5c4 a888f55 5fcb5c4 a888f55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image
def process_file(api_key, file, instructions):
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
# Read file
if file.name.endswith('.csv'):
df = pd.read_csv(file.name)
else:
df = pd.read_excel(file.name)
# Updated prompt with valid matplotlib styles
prompt = f"""Generate 3 matplotlib codes with these requirements:
1. Start with: plt.figure(figsize=(16,9), dpi=120)
2. Use one of these styles: ggplot, bmh, dark_background, fast
3. Include: title, labels, grid, legend if needed
4. Different chart types (bar, line, scatter, etc)
Data columns: {list(df.columns)}
Sample data: {df.head(3).to_dict()}
User instructions: {instructions or 'None'}
Format exactly as:
# Visualization 1
plt.figure(figsize=(16,9), dpi=120)
plt.style.use('ggplot') # Example valid style
[code]
plt.tight_layout()
"""
response = model.generate_content(prompt)
code_blocks = response.text.split("# Visualization ")[1:4]
visualizations = []
for i, block in enumerate(code_blocks, 1):
try:
cleaned_code = '\n'.join([
line.strip() for line in block.split('\n')
if line.strip() and not line.startswith('```')
])
buf = io.BytesIO()
plt.figure(figsize=(16, 9), dpi=120)
# Execute code with safe environment
exec_env = {
'df': df,
'plt': plt,
'pd': pd
}
exec(cleaned_code, exec_env)
plt.tight_layout()
plt.savefig(buf, format='png', bbox_inches='tight')
plt.close()
buf.seek(0)
visualizations.append(Image.open(buf))
except Exception as e:
print(f"Visualization {i} Error: {str(e)}")
visualizations.append(None)
return visualizations + [None]*(3-len(visualizations))
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Data Visualization Tool")
with gr.Row():
api_key = gr.Textbox(label="Gemini API Key", type="password")
file = gr.File(label="Upload CSV/Excel", file_types=[".csv", ".xlsx"])
instructions = gr.Textbox(label="Custom Instructions")
submit = gr.Button("Generate Visualizations")
with gr.Row():
outputs = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
submit.click(
process_file,
inputs=[api_key, file, instructions],
outputs=outputs
)
demo.launch() |