File size: 3,947 Bytes
4fc79a4
ff768e2
 
4fc79a4
 
29f87f1
8d1334d
4fc79a4
 
6cff8d5
69a028d
12598b6
4fc79a4
f7d95ea
6cff8d5
f7d95ea
 
 
 
 
6cff8d5
 
69a028d
6cff8d5
 
 
 
 
 
 
 
 
 
29f87f1
6cff8d5
 
e5b6950
a888f55
f7d95ea
a888f55
6cff8d5
 
 
 
5fcb5c4
a888f55
6cff8d5
 
 
 
 
 
a888f55
 
 
f7d95ea
a888f55
6cff8d5
 
8d1334d
 
 
6cff8d5
 
5fcb5c4
6cff8d5
f7d95ea
a888f55
6cff8d5
8d1334d
f7d95ea
8d1334d
5fcb5c4
6cff8d5
5fcb5c4
a888f55
 
 
 
5fcb5c4
6cff8d5
5fcb5c4
a888f55
6cff8d5
5fcb5c4
a888f55
6cff8d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image
import ast

def process_file(api_key, file, instructions):
    # Configure Gemini API with correct model
    genai.configure(api_key=api_key)
    model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')

    try:
        # Read uploaded file
        if file.name.endswith('.csv'):
            df = pd.read_csv(file.name)
        else:
            df = pd.read_excel(file.name)
    except Exception as e:
        print(f"File Error: {str(e)}")
        return [None, None, None]

    # Enhanced prompt with strict coding rules
    prompt = f"""Generate 3 distinct matplotlib visualization codes with these rules:
    1. Use ONLY these variables: df (existing DataFrame), plt
    2. No imports or additional data loading
    3. Each visualization must:
       - Start with: plt.figure(figsize=(16,9), dpi=120)
       - Use plt.style.use('ggplot')
       - Include title, axis labels, and grid
       - End with plt.tight_layout()
    4. Different chart types (bar, line, scatter, etc)
    
    Dataset columns: {list(df.columns)}
    Sample data: {df.head(3).to_dict()}
    User instructions: {instructions or 'None'}
    
    Format EXACTLY as:
    # Visualization 1
    plt.figure(figsize=(16,9), dpi=120)
    plt.style.use('ggplot')
    # Visualization code using df
    plt.tight_layout()
    """

    try:
        response = model.generate_content(prompt)
        code_blocks = response.text.split("# Visualization ")[1:4]
    except Exception as e:
        print(f"Gemini Error: {str(e)}")
        return [None, None, None]

    visualizations = []
    for i, block in enumerate(code_blocks, 1):
        buf = io.BytesIO()
        try:
            # Clean and validate generated code
            cleaned_code = '\n'.join([
                line.replace('data', 'df').split('#')[0].strip()
                for line in block.split('\n')[1:]
                if line.strip() and
                not any(s in line.lower() for s in ['import', 'data=', 'data ='])
            ])
            
            # Syntax check
            ast.parse(cleaned_code)
            
            # Execute code in controlled environment
            exec_env = {'df': df, 'plt': plt}
            plt.figure(figsize=(16, 9), dpi=120)
            exec(cleaned_code, exec_env)
            
            # Save HD image
            plt.savefig(buf, format='png', bbox_inches='tight')
            plt.close()
            buf.seek(0)
            visualizations.append(Image.open(buf))
        except Exception as e:
            print(f"Visualization {i} Error: {str(e)}")
            print(f"Problematic Code:\n{cleaned_code}")
            visualizations.append(None)

    # Ensure exactly 3 outputs
    return visualizations + [None]*(3-len(visualizations))

# Gradio interface
with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
    gr.Markdown("# πŸ” Data Visualization Generator")
    
    with gr.Row():
        api_key = gr.Textbox(
            label="πŸ”‘ Gemini API Key",
            type="password",
            placeholder="Enter your API key here"
        )
        file = gr.File(
            label="πŸ“ Upload Dataset",
            file_types=[".csv", ".xlsx"],
            type="filepath"
        )
    
    instructions = gr.Textbox(
        label="πŸ’‘ Custom Instructions",
        placeholder="E.g.: Compare sales trends across regions..."
    )
    
    submit = gr.Button("πŸš€ Generate Visualizations", variant="primary")
    
    with gr.Row():
        outputs = [
            gr.Image(
                label=f"Visualization {i+1}",
                width=600,
                height=400
            ) for i in range(3)
        ]

    submit.click(
        process_file,
        inputs=[api_key, file, instructions],
        outputs=outputs
    )

if __name__ == "__main__":
    demo.launch()