File size: 4,453 Bytes
4fc79a4
ff768e2
 
4fc79a4
 
 
e6ca829
4fc79a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6ca829
4fc79a4
 
 
 
 
e6ca829
 
 
 
4fc79a4
 
 
 
 
 
 
7d1e58a
 
4fc79a4
 
 
 
 
 
 
e6ca829
 
 
 
 
 
 
4fc79a4
 
 
e6ca829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fc79a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6ca829
 
 
 
 
4fc79a4
 
 
 
e6ca829
 
 
4fc79a4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import google.generativeai as genai
import traceback

def process_file(api_key, file, instructions):
    # Set up Gemini API
    genai.configure(api_key=api_key)
    model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')

    # Read the file
    if file.name.endswith('.csv'):
        df = pd.read_csv(file.name)
    else:
        df = pd.read_excel(file.name)

    # Analyze data and get visualization suggestions from Gemini
    data_description = df.describe().to_string()
    columns_info = "\n".join([f"{col}: {df[col].dtype}" for col in df.columns])
    prompt = f"""
    Given this dataset:
    Columns and types:
    {columns_info}
    
    Data summary:
    {data_description}
    
    User instructions: {instructions if instructions else 'No specific instructions provided.'}
    
    Suggest 3 ways to visualize this data. For each visualization:
    1. Describe the visualization type and what it will show.
    2. Provide Python code using matplotlib to create the visualization. Ensure the code is complete and executable.
    3. Explain why this visualization is useful for understanding the data.
    
    Format your response as:
    Visualization 1:
    Description: ...
    Code:
    ```python
    # Your code here
    ```
    Explanation: ...
    
    Visualization 2:
    ...
    
    Visualization 3:
    ...
    """

    response = model.generate_content(prompt)
    suggestions = response.text.split("Visualization")

    visualizations = []
    for i, suggestion in enumerate(suggestions[1:4], 1):  # Process only the first 3 visualizations
        parts = suggestion.split("Code:")
        description = parts[0].strip()
        code_parts = parts[1].split("Explanation:")
        code = code_parts[0].strip()
        explanation = code_parts[1].strip() if len(code_parts) > 1 else "No explanation provided."
        
        # Extract code from markdown code block if present
        if "```python" in code and "```" in code:
            code = code.split("```python")[1].split("```")[0].strip()
        
        # Execute the code
        plt.figure(figsize=(10, 6))
        try:
            exec(code)
            plt.title(f"Visualization {i}")
            
            # Save the plot to a BytesIO object
            buf = io.BytesIO()
            plt.savefig(buf, format='png')
            buf.seek(0)
            img_str = base64.b64encode(buf.getvalue()).decode()
            plt.close()
            
            visualizations.append((f"data:image/png;base64,{img_str}", description, code, explanation))
        except Exception as e:
            error_message = f"Error in visualization {i}: {str(e)}\n{traceback.format_exc()}"
            visualizations.append((None, description, code, error_message))

    # Ensure we always return 3 visualizations, even if there were errors
    while len(visualizations) < 3:
        visualizations.append((None, "Error", "No code generated", "An error occurred"))

    return visualizations

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Data Visualization with Gemini")
    api_key = gr.Textbox(label="Enter Gemini API Key", type="password")
    file = gr.File(label="Upload Excel or CSV file")
    instructions = gr.Textbox(label="Optional visualization instructions")
    submit = gr.Button("Generate Visualizations")
    
    with gr.Row():
        output1 = gr.Image(label="Visualization 1")
        output2 = gr.Image(label="Visualization 2")
        output3 = gr.Image(label="Visualization 3")
    
    with gr.Row():
        desc1 = gr.Textbox(label="Description 1")
        desc2 = gr.Textbox(label="Description 2")
        desc3 = gr.Textbox(label="Description 3")
    
    with gr.Row():
        code1 = gr.Code(language="python", label="Code 1")
        code2 = gr.Code(language="python", label="Code 2")
        code3 = gr.Code(language="python", label="Code 3")

    with gr.Row():
        expl1 = gr.Textbox(label="Explanation/Error 1")
        expl2 = gr.Textbox(label="Explanation/Error 2")
        expl3 = gr.Textbox(label="Explanation/Error 3")

    submit.click(
        fn=process_file,
        inputs=[api_key, file, instructions],
        outputs=[
            output1, desc1, code1, expl1,
            output2, desc2, code2, expl2,
            output3, desc3, code3, expl3
        ],
        show_progress=True,
    )

demo.launch()