File size: 4,704 Bytes
4fc79a4 12ce912 4fc79a4 12ce912 f7d95ea 12ce912 a888f55 12ce912 a888f55 12ce912 5be932a 12ce912 1ab6fac 12ce912 1ab6fac 12ce912 1ab6fac 12ce912 1ab6fac 12ce912 5be932a 12ce912 6cff8d5 12ce912 6cff8d5 12ce912 6cff8d5 12ce912 6cff8d5 12ce912 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image
import ast
import re
import traceback
def process_file(api_key, file, instructions):
# Configure Gemini API with error handling
try:
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
except Exception as e:
return [generate_error_image(f"API Config Error: {str(e)}")] * 3
# Load data with validation
try:
if file.name.endswith('.csv'):
df = pd.read_csv(file.name)
else:
df = pd.read_excel(file.name)
if df.empty:
raise ValueError("Empty dataset uploaded")
if len(df.columns) < 2:
raise ValueError("Dataset needs at least 2 columns")
except Exception as e:
return [generate_error_image(f"Data Error: {str(e)}")] * 3
# Enhanced prompt with strict formatting
prompt = f"""Generate 3 Python matplotlib codes with these rules:
1. Perfect Python 3.10 syntax with 4-space indentation
2. Complete code blocks with proper indentation
3. Use ONLY these variables: df, plt
4. Each visualization must:
- Start with: plt.figure(figsize=(16,9), dpi=120)
- Use plt.style.use('ggplot')
- Include title, axis labels, and data visualization
- End with plt.tight_layout()
Dataset Columns: {list(df.columns)}
Data Sample: {df.head(3).to_dict()}
User Instructions: {instructions or 'Show general trends'}
Format EXACTLY as:
# Visualization 1
[properly indented code]
# Visualization 2
[properly indented code]
# Visualization 3
[properly indented code]
"""
# Get and process Gemini response
try:
response = model.generate_content(prompt)
code_blocks = re.split(r'# Visualization \d+\s*', response.text)[1:4]
except Exception as e:
return [generate_error_image("API Response Error")] * 3
visualizations = []
for i, block in enumerate(code_blocks, 1):
try:
# Clean and validate code
cleaned_code = sanitize_code(block, df.columns)
img = execute_plot_code(cleaned_code, df)
visualizations.append(img)
except Exception as e:
print(f"Visualization {i} Error:\n{traceback.format_exc()}")
visualizations.append(generate_error_image(f"Error in Viz {i}"))
# Ensure exactly 3 outputs
return visualizations + [generate_error_image("Not Generated")]*(3-len(visualizations))
def sanitize_code(code_block, columns):
"""Fix indentation and syntax issues in generated code"""
lines = []
indent_level = 0
indent_size = 4
stack = []
for line in code_block.split('\n'):
line = line.rstrip()
if not line:
continue
# Remove markdown artifacts
line = re.sub(r'^```python|```$', '', line)
# Handle indentation triggers
if re.match(r'^\s*(for|if|while|with|def|class|try|except|else|elif)\b', line):
stack.append(indent_level)
indent_level += 1
elif re.match(r'^\s*(return|pass|break|continue|raise)', line):
indent_level = max(0, indent_level - 1)
# Apply current indentation
current_indent = ' ' * (indent_level * indent_size)
cleaned_line = current_indent + line.lstrip()
# Check for dedent patterns
if re.match(r'^\s*(\}|\)|]|else:|elif |except)', line):
if stack:
indent_level = stack.pop()
lines.append(cleaned_line)
cleaned_code = '\n'.join(lines)
# Validate syntax
try:
ast.parse(cleaned_code)
except SyntaxError as e:
raise ValueError(f"Syntax Error: {str(e)}")
return cleaned_code
def execute_plot_code(code, df):
"""Safely execute plotting code with resource management"""
buf = io.BytesIO()
plt.figure(figsize=(16, 9), dpi=120)
plt.style.use('ggplot')
try:
exec(code, {'df': df, 'plt': plt})
plt.tight_layout()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
return Image.open(buf)
except Exception as e:
raise RuntimeError(f"Execution Error: {str(e)}")
finally:
plt.close()
def generate_error_image(message):
"""Create error indication image with message"""
from PIL import ImageDraw, ImageFont
img = Image.new('RGB', (1920, 1080), color=(255, 255, 255))
try:
draw = ImageDraw.Draw(img)
font = ImageFont |