File size: 3,273 Bytes
4fc79a4
ff768e2
 
4fc79a4
 
 
ccd7bab
4fc79a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccd7bab
1e2b302
4fc79a4
1e2b302
 
 
 
 
 
 
e6ca829
7d1e58a
 
4fc79a4
1e2b302
 
ccd7bab
 
4fc79a4
 
1e2b302
4fc79a4
e6ca829
ccd7bab
 
 
 
 
 
e6ca829
 
 
 
 
 
 
 
 
1e2b302
e6ca829
1e2b302
 
e6ca829
1e2b302
e6ca829
1e2b302
4fc79a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6ca829
4fc79a4
 
 
1e2b302
4fc79a4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import google.generativeai as genai
import textwrap

def process_file(api_key, file, instructions):
    # Set up Gemini API
    genai.configure(api_key=api_key)
    model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')

    # Read the file
    if file.name.endswith('.csv'):
        df = pd.read_csv(file.name)
    else:
        df = pd.read_excel(file.name)

    # Analyze data and get visualization suggestions from Gemini
    data_description = df.describe().to_string()
    columns_info = "\n".join([f"{col}: {df[col].dtype}" for col in df.columns])
    prompt = f"""
    Given this dataset:
    Columns and types:
    {columns_info}
    
    Data summary:
    {data_description}
    
    User instructions: {instructions if instructions else 'No specific instructions provided.'}
    
    Generate Python code for 3 different visualizations using matplotlib. Each visualization should be unique and provide insights into the data. Do not include any explanations or descriptions, only the Python code for each visualization. Ensure the code is not indented.

    Format your response as:
    # Visualization 1
    # Your code here

    # Visualization 2
    # Your code here

    # Visualization 3
    # Your code here
    """

    response = model.generate_content(prompt)
    code = response.text.strip()

    print("Generated code:")
    print(code)

    visualizations = []
    for i, viz_code in enumerate(code.split("# Visualization")[1:4], 1):
        plt.figure(figsize=(10, 6))
        try:
            # Remove any leading spaces from each line
            cleaned_code = '\n'.join(line.strip() for line in viz_code.split('\n') if line.strip())
            print(f"\nVisualization {i} code:")
            print(cleaned_code)
            
            exec(cleaned_code, {'df': df, 'plt': plt})
            plt.title(f"Visualization {i}")
            
            # Save the plot to a BytesIO object
            buf = io.BytesIO()
            plt.savefig(buf, format='png')
            buf.seek(0)
            img_str = base64.b64encode(buf.getvalue()).decode()
            plt.close()
            
            visualizations.append(f"data:image/png;base64,{img_str}")
        except Exception as e:
            print(f"Error in visualization {i}: {str(e)}")
            visualizations.append(None)

    # Ensure we always return 3 visualizations
    while len(visualizations) < 3:
        visualizations.append(None)

    return visualizations

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Data Visualization with Gemini")
    api_key = gr.Textbox(label="Enter Gemini API Key", type="password")
    file = gr.File(label="Upload Excel or CSV file")
    instructions = gr.Textbox(label="Optional visualization instructions")
    submit = gr.Button("Generate Visualizations")
    
    with gr.Row():
        output1 = gr.Image(label="Visualization 1")
        output2 = gr.Image(label="Visualization 2")
        output3 = gr.Image(label="Visualization 3")

    submit.click(
        fn=process_file,
        inputs=[api_key, file, instructions],
        outputs=[output1, output2, output3],
        show_progress=True,
    )

demo.launch()