|
import gradio as gr |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import io |
|
import google.generativeai as genai |
|
from PIL import Image |
|
import ast |
|
import re |
|
import traceback |
|
|
|
def process_file(api_key, file, instructions): |
|
|
|
try: |
|
genai.configure(api_key=api_key) |
|
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25') |
|
except Exception as e: |
|
return [generate_error_image(f"API Config Error: {str(e)}")] * 3 |
|
|
|
|
|
try: |
|
if file.name.endswith('.csv'): |
|
df = pd.read_csv(file.name) |
|
else: |
|
df = pd.read_excel(file.name) |
|
|
|
if df.empty: |
|
raise ValueError("Empty dataset uploaded") |
|
if len(df.columns) < 2: |
|
raise ValueError("Dataset needs at least 2 columns") |
|
except Exception as e: |
|
return [generate_error_image(f"Data Error: {str(e)}")] * 3 |
|
|
|
|
|
prompt = f"""Generate 3 Python matplotlib codes with these rules: |
|
1. Perfect Python 3.10 syntax with 4-space indentation |
|
2. Complete code blocks with proper indentation |
|
3. Use ONLY these variables: df, plt |
|
4. Each visualization must: |
|
- Start with: plt.figure(figsize=(16,9), dpi=120) |
|
- Use plt.style.use('ggplot') |
|
- Include title, axis labels, and data visualization |
|
- End with plt.tight_layout() |
|
|
|
Dataset Columns: {list(df.columns)} |
|
Data Sample: {df.head(3).to_dict()} |
|
User Instructions: {instructions or 'Show general trends'} |
|
|
|
Format EXACTLY as: |
|
# Visualization 1 |
|
[properly indented code] |
|
|
|
# Visualization 2 |
|
[properly indented code] |
|
|
|
# Visualization 3 |
|
[properly indented code] |
|
""" |
|
|
|
|
|
try: |
|
response = model.generate_content(prompt) |
|
code_blocks = re.split(r'# Visualization \d+\s*', response.text)[1:4] |
|
except Exception as e: |
|
return [generate_error_image("API Response Error")] * 3 |
|
|
|
visualizations = [] |
|
for i, block in enumerate(code_blocks, 1): |
|
try: |
|
|
|
cleaned_code = sanitize_code(block, df.columns) |
|
img = execute_plot_code(cleaned_code, df) |
|
visualizations.append(img) |
|
except Exception as e: |
|
print(f"Visualization {i} Error:\n{traceback.format_exc()}") |
|
visualizations.append(generate_error_image(f"Error in Viz {i}")) |
|
|
|
|
|
return visualizations + [generate_error_image("Not Generated")]*(3-len(visualizations)) |
|
|
|
def sanitize_code(code_block, columns): |
|
"""Fix indentation and syntax issues in generated code""" |
|
lines = [] |
|
indent_level = 0 |
|
indent_size = 4 |
|
stack = [] |
|
|
|
for line in code_block.split('\n'): |
|
line = line.rstrip() |
|
if not line: |
|
continue |
|
|
|
|
|
line = re.sub(r'^```python|```$', '', line) |
|
|
|
|
|
if re.match(r'^\s*(for|if|while|with|def|class|try|except|else|elif)\b', line): |
|
stack.append(indent_level) |
|
indent_level += 1 |
|
elif re.match(r'^\s*(return|pass|break|continue|raise)', line): |
|
indent_level = max(0, indent_level - 1) |
|
|
|
|
|
current_indent = ' ' * (indent_level * indent_size) |
|
cleaned_line = current_indent + line.lstrip() |
|
|
|
|
|
if re.match(r'^\s*(\}|\)|]|else:|elif |except)', line): |
|
if stack: |
|
indent_level = stack.pop() |
|
|
|
lines.append(cleaned_line) |
|
|
|
cleaned_code = '\n'.join(lines) |
|
|
|
|
|
try: |
|
ast.parse(cleaned_code) |
|
except SyntaxError as e: |
|
raise ValueError(f"Syntax Error: {str(e)}") |
|
|
|
return cleaned_code |
|
|
|
def execute_plot_code(code, df): |
|
"""Safely execute plotting code with resource management""" |
|
buf = io.BytesIO() |
|
plt.figure(figsize=(16, 9), dpi=120) |
|
plt.style.use('ggplot') |
|
|
|
try: |
|
exec(code, {'df': df, 'plt': plt}) |
|
plt.tight_layout() |
|
plt.savefig(buf, format='png', bbox_inches='tight') |
|
buf.seek(0) |
|
return Image.open(buf) |
|
except Exception as e: |
|
raise RuntimeError(f"Execution Error: {str(e)}") |
|
finally: |
|
plt.close() |
|
|
|
def generate_error_image(message): |
|
"""Create error indication image with message""" |
|
from PIL import ImageDraw, ImageFont |
|
|
|
img = Image.new('RGB', (1920, 1080), color=(255, 255, 255)) |
|
try: |
|
draw = ImageDraw.Draw(img) |
|
font = ImageFont |