File size: 4,704 Bytes
4fc79a4
12ce912
 
 
 
 
 
 
 
4fc79a4
12ce912
 
f7d95ea
12ce912
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a888f55
12ce912
 
 
 
a888f55
12ce912
 
5be932a
12ce912
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ab6fac
12ce912
 
 
 
 
 
1ab6fac
12ce912
 
 
1ab6fac
12ce912
 
 
 
1ab6fac
12ce912
5be932a
12ce912
6cff8d5
12ce912
 
 
 
 
6cff8d5
12ce912
 
 
 
 
 
 
6cff8d5
12ce912
 
 
 
 
 
 
 
 
 
6cff8d5
12ce912
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image
import ast
import re
import traceback

def process_file(api_key, file, instructions):
    # Configure Gemini API with error handling
    try:
        genai.configure(api_key=api_key)
        model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
    except Exception as e:
        return [generate_error_image(f"API Config Error: {str(e)}")] * 3

    # Load data with validation
    try:
        if file.name.endswith('.csv'):
            df = pd.read_csv(file.name)
        else:
            df = pd.read_excel(file.name)
            
        if df.empty:
            raise ValueError("Empty dataset uploaded")
        if len(df.columns) < 2:
            raise ValueError("Dataset needs at least 2 columns")
    except Exception as e:
        return [generate_error_image(f"Data Error: {str(e)}")] * 3

    # Enhanced prompt with strict formatting
    prompt = f"""Generate 3 Python matplotlib codes with these rules:
    1. Perfect Python 3.10 syntax with 4-space indentation
    2. Complete code blocks with proper indentation
    3. Use ONLY these variables: df, plt
    4. Each visualization must:
       - Start with: plt.figure(figsize=(16,9), dpi=120)
       - Use plt.style.use('ggplot')
       - Include title, axis labels, and data visualization
       - End with plt.tight_layout()
    
    Dataset Columns: {list(df.columns)}
    Data Sample: {df.head(3).to_dict()}
    User Instructions: {instructions or 'Show general trends'}
    
    Format EXACTLY as:
    # Visualization 1
    [properly indented code]
    
    # Visualization 2
    [properly indented code]
    
    # Visualization 3
    [properly indented code]
    """

    # Get and process Gemini response
    try:
        response = model.generate_content(prompt)
        code_blocks = re.split(r'# Visualization \d+\s*', response.text)[1:4]
    except Exception as e:
        return [generate_error_image("API Response Error")] * 3

    visualizations = []
    for i, block in enumerate(code_blocks, 1):
        try:
            # Clean and validate code
            cleaned_code = sanitize_code(block, df.columns)
            img = execute_plot_code(cleaned_code, df)
            visualizations.append(img)
        except Exception as e:
            print(f"Visualization {i} Error:\n{traceback.format_exc()}")
            visualizations.append(generate_error_image(f"Error in Viz {i}"))

    # Ensure exactly 3 outputs
    return visualizations + [generate_error_image("Not Generated")]*(3-len(visualizations))

def sanitize_code(code_block, columns):
    """Fix indentation and syntax issues in generated code"""
    lines = []
    indent_level = 0
    indent_size = 4
    stack = []

    for line in code_block.split('\n'):
        line = line.rstrip()
        if not line:
            continue

        # Remove markdown artifacts
        line = re.sub(r'^```python|```$', '', line)
        
        # Handle indentation triggers
        if re.match(r'^\s*(for|if|while|with|def|class|try|except|else|elif)\b', line):
            stack.append(indent_level)
            indent_level += 1
        elif re.match(r'^\s*(return|pass|break|continue|raise)', line):
            indent_level = max(0, indent_level - 1)
        
        # Apply current indentation
        current_indent = ' ' * (indent_level * indent_size)
        cleaned_line = current_indent + line.lstrip()
        
        # Check for dedent patterns
        if re.match(r'^\s*(\}|\)|]|else:|elif |except)', line):
            if stack:
                indent_level = stack.pop()
        
        lines.append(cleaned_line)
    
    cleaned_code = '\n'.join(lines)
    
    # Validate syntax
    try:
        ast.parse(cleaned_code)
    except SyntaxError as e:
        raise ValueError(f"Syntax Error: {str(e)}")
    
    return cleaned_code

def execute_plot_code(code, df):
    """Safely execute plotting code with resource management"""
    buf = io.BytesIO()
    plt.figure(figsize=(16, 9), dpi=120)
    plt.style.use('ggplot')
    
    try:
        exec(code, {'df': df, 'plt': plt})
        plt.tight_layout()
        plt.savefig(buf, format='png', bbox_inches='tight')
        buf.seek(0)
        return Image.open(buf)
    except Exception as e:
        raise RuntimeError(f"Execution Error: {str(e)}")
    finally:
        plt.close()

def generate_error_image(message):
    """Create error indication image with message"""
    from PIL import ImageDraw, ImageFont
    
    img = Image.new('RGB', (1920, 1080), color=(255, 255, 255))
    try:
        draw = ImageDraw.Draw(img)
        font = ImageFont