File size: 2,795 Bytes
4fc79a4
ff768e2
 
4fc79a4
 
29f87f1
4fc79a4
 
69a028d
12598b6
4fc79a4
5fcb5c4
 
 
 
 
69a028d
e5b6950
 
 
 
 
 
29f87f1
e5b6950
 
 
a888f55
5fcb5c4
a888f55
e5b6950
 
 
 
5fcb5c4
a888f55
 
 
 
 
 
 
5fcb5c4
 
 
 
 
 
a888f55
 
e5b6950
 
 
 
 
 
 
5fcb5c4
a888f55
5fcb5c4
a888f55
5fcb5c4
a888f55
 
 
5fcb5c4
 
a888f55
5fcb5c4
a888f55
 
5fcb5c4
 
a888f55
 
5fcb5c4
 
a888f55
e5b6950
5fcb5c4
a888f55
 
5fcb5c4
a888f55
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image

def process_file(api_key, file, instructions):
    genai.configure(api_key=api_key)
    model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')

    # Read file
    if file.name.endswith('.csv'):
        df = pd.read_csv(file.name)
    else:
        df = pd.read_excel(file.name)

    # Updated prompt with valid matplotlib styles
    prompt = f"""Generate 3 matplotlib codes with these requirements:
    1. Start with: plt.figure(figsize=(16,9), dpi=120)
    2. Use one of these styles: ggplot, bmh, dark_background, fast
    3. Include: title, labels, grid, legend if needed
    4. Different chart types (bar, line, scatter, etc)
    
    Data columns: {list(df.columns)}
    Sample data: {df.head(3).to_dict()}
    User instructions: {instructions or 'None'}
    
    Format exactly as:
    # Visualization 1
    plt.figure(figsize=(16,9), dpi=120)
    plt.style.use('ggplot')  # Example valid style
    [code]
    plt.tight_layout()
    """

    response = model.generate_content(prompt)
    code_blocks = response.text.split("# Visualization ")[1:4]

    visualizations = []
    for i, block in enumerate(code_blocks, 1):
        try:
            cleaned_code = '\n'.join([
                line.strip() for line in block.split('\n') 
                if line.strip() and not line.startswith('```')
            ])
            
            buf = io.BytesIO()
            plt.figure(figsize=(16, 9), dpi=120)
            
            # Execute code with safe environment
            exec_env = {
                'df': df,
                'plt': plt,
                'pd': pd
            }
            exec(cleaned_code, exec_env)
            
            plt.tight_layout()
            plt.savefig(buf, format='png', bbox_inches='tight')
            plt.close()
            
            buf.seek(0)
            visualizations.append(Image.open(buf))
        except Exception as e:
            print(f"Visualization {i} Error: {str(e)}")
            visualizations.append(None)

    return visualizations + [None]*(3-len(visualizations))

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Data Visualization Tool")
    
    with gr.Row():
        api_key = gr.Textbox(label="Gemini API Key", type="password")
        file = gr.File(label="Upload CSV/Excel", file_types=[".csv", ".xlsx"])
    
    instructions = gr.Textbox(label="Custom Instructions")
    submit = gr.Button("Generate Visualizations")
    
    with gr.Row():
        outputs = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]

    submit.click(
        process_file,
        inputs=[api_key, file, instructions],
        outputs=outputs
    )

demo.launch()