File size: 4,108 Bytes
4fc79a4
ff768e2
 
4fc79a4
 
29f87f1
8d1334d
218ce96
4fc79a4
 
69a028d
12598b6
4fc79a4
f7d95ea
218ce96
5be932a
 
f7d95ea
5be932a
 
69a028d
5be932a
 
 
 
 
 
 
 
 
218ce96
 
5be932a
 
29f87f1
5be932a
 
a888f55
f7d95ea
a888f55
5be932a
5fcb5c4
a888f55
6cff8d5
 
218ce96
6cff8d5
5be932a
a888f55
 
 
 
5be932a
 
218ce96
5be932a
f7d95ea
5be932a
 
a888f55
5fcb5c4
5be932a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a888f55
5be932a
 
 
 
a888f55
6cff8d5
5be932a
218ce96
6cff8d5
 
218ce96
 
6cff8d5
218ce96
5be932a
6cff8d5
 
218ce96
6cff8d5
 
 
 
 
 
 
218ce96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image
import ast
import re

def process_file(api_key, file, instructions):
    genai.configure(api_key=api_key)
    model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')

    try:
        df = pd.read_csv(file.name) if file.name.endswith('.csv') else pd.read_excel(file.name)
        if df.empty:
            raise ValueError("Uploaded file contains no data")
    except Exception as e:
        print(f"Data Error: {str(e)}")
        return [generate_error_image(str(e))]*3

    # Enhanced prompt with strict plotting requirements
    prompt = f"""Generate 3 matplotlib codes with these rules:
    1. Use ONLY these variables: df (DataFrame), plt
    2. Each visualization MUST:
       - Plot actual data from df
       - Include title, axis labels, and data labels if needed
       - Use clear color schemes
       - Avoid empty plots
    3. Code structure:
       plt.figure(figsize=(16,9), dpi=120)
       plt.style.use('ggplot')
       # Plotting code using df columns: {list(df.columns)}
       plt.tight_layout()
    
    Sample data: {df.head(3).to_dict()}
    User instructions: {instructions or 'General insights'}
    
    Format EXACTLY as:
    # Visualization 1
    [complete code]
    """

    try:
        response = model.generate_content(prompt)
        code_blocks = re.split(r'# Visualization \d+', response.text)[1:4]
    except Exception as e:
        return [generate_error_image("API Error")]*3

    visualizations = []
    for i, block in enumerate(code_blocks, 1):
        try:
            # Advanced code sanitization
            cleaned_code = sanitize_code(block, df.columns)
            
            # Validate and execute
            ast.parse(cleaned_code)
            img = execute_plot_code(cleaned_code, df)
            visualizations.append(img)
        except Exception as e:
            print(f"Visualization {i} Error: {str(e)}")
            visualizations.append(generate_error_image(f"Plot {i} Error"))

    return visualizations + [generate_error_image("Not Generated")]*(3-len(visualizations))

def sanitize_code(code_block, columns):
    """Clean and validate generated code"""
    replacements = {
        r"'y_axis'": f"'{columns[1]}'" if len(columns) > 1 else "'Value'",
        r"'x_axis'": f"'{columns[0]}'",
        r"data": "df",
        r"plt.legend\(\)": ""  # Remove empty legend calls
    }
    
    cleaned = []
    for line in code_block.split('\n'):
        line = line.strip()
        if not line or line.startswith('`'):
            continue
            
        # Apply replacements
        for pattern, replacement in replacements.items():
            line = re.sub(pattern, replacement, line)
            
        cleaned.append(line)
    
    return '\n'.join(cleaned)

def execute_plot_code(code, df):
    """Safely execute plotting code"""
    buf = io.BytesIO()
    plt.figure(figsize=(16, 9), dpi=120)
    plt.style.use('ggplot')
    
    try:
        exec(code, {'df': df, 'plt': plt})
        plt.tight_layout()
        plt.savefig(buf, format='png', bbox_inches='tight')
        buf.seek(0)
        return Image.open(buf)
    finally:
        plt.close()

def generate_error_image(message):
    """Create error indication image"""
    img = Image.new('RGB', (1920, 1080), color=(73, 109, 137))
    return img

# Gradio interface
with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
    gr.Markdown("# Professional Data Visualizer")
    
    with gr.Row():
        api_key = gr.Textbox(label="Gemini API Key", type="password")
        file = gr.File(label="Upload Data File", file_types=[".csv", ".xlsx"])
    
    instructions = gr.Textbox(label="Visualization Instructions")
    submit = gr.Button("Generate Insights", variant="primary")
    
    with gr.Row():
        outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)]

    submit.click(
        process_file,
        inputs=[api_key, file, instructions],
        outputs=outputs
    )

demo.launch()