File size: 3,530 Bytes
4fc79a4
12ce912
 
 
bca92aa
ec365ce
3c50a2d
4fc79a4
601022d
9a4fc1b
ec365ce
 
b06a85b
ec365ce
9a4fc1b
bca92aa
 
9a4fc1b
601022d
 
 
 
 
 
3c50a2d
601022d
 
 
 
 
 
b06a85b
601022d
b06a85b
 
 
601022d
 
3c50a2d
601022d
b06a85b
9a4fc1b
bca92aa
ec365ce
601022d
1b2886c
b06a85b
ec365ce
 
1b2886c
ec365ce
1b2886c
ec365ce
1b2886c
ec365ce
601022d
1b2886c
 
 
601022d
1b2886c
ec365ce
9a4fc1b
1b2886c
9a4fc1b
1b2886c
b06a85b
1b2886c
12ce912
b06a85b
12ce912
9a4fc1b
3c50a2d
 
 
ec365ce
3c50a2d
b06a85b
9a4fc1b
601022d
ec365ce
5be932a
23a3b49
ec365ce
601022d
6cff8d5
601022d
bca92aa
 
601022d
bca92aa
 
 
601022d
b06a85b
bca92aa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
from PIL import Image, ImageDraw
import google.generativeai as genai
import traceback

def process_file(file, instructions, api_key):
    try:
        # Initialize Gemini
        genai.configure(api_key=api_key)
        model = genai.GenerativeModel('gemini-pro')
        
        # Read uploaded file
        file_path = file.name
        df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path)
        
        # Generate visualization code using Gemini
        prompt = f"""
        Analyze the following dataset and instructions:
        
        Data columns: {list(df.columns)}
        Instructions: {instructions}
        
        Based on this, create 3 appropriate visualizations. For each visualization, provide:
        1. A title
        2. The most suitable plot type (choose from: bar, line, scatter, hist)
        3. The column to use for the x-axis
        4. The column to use for the y-axis (use None for histograms)

        Return your response as a Python list of tuples:
        [
            ("Title 1", "plot_type1", "x_column1", "y_column1"),
            ("Title 2", "plot_type2", "x_column2", "y_column2"),
            ("Title 3", "plot_type3", "x_column3", "y_column3")
        ]
        """
        
        response = model.generate_content(prompt)
        plots = eval(response.text)
        
        # Generate visualizations
        images = []
        for plot in plots:
            fig, ax = plt.subplots(figsize=(10, 6))
            title, plot_type, x, y = plot
            
            if plot_type == 'bar':
                df.plot(kind='bar', x=x, y=y, ax=ax)
            elif plot_type == 'line':
                df.plot(kind='line', x=x, y=y, ax=ax)
            elif plot_type == 'scatter':
                df.plot(kind='scatter', x=x, y=y, ax=ax)
            elif plot_type == 'hist':
                df[x].hist(ax=ax)
            
            ax.set_title(title)
            ax.set_xlabel(x)
            ax.set_ylabel(y if y else 'Frequency')
            plt.tight_layout()
            
            buf = io.BytesIO()
            plt.savefig(buf, format='png')
            buf.seek(0)
            img = Image.open(buf)
            images.append(img)
            plt.close(fig)

        return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images))

    except Exception as e:
        error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
        print(error_message)  # Print to console for debugging
        error_image = Image.new('RGB', (800, 400), (255, 255, 255))
        draw = ImageDraw.Draw(error_image)
        draw.text((10, 10), error_message, fill=(255, 0, 0))
        return [error_image] * 3

with gr.Blocks(theme=gr.themes.Default()) as demo:
    gr.Markdown("# Data Analysis Dashboard")
    
    with gr.Row():
        file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
        instructions = gr.Textbox(label="Analysis Instructions", placeholder="Describe the analysis you want...")
    
    api_key = gr.Textbox(label="Gemini API Key", type="password")
    submit = gr.Button("Generate Insights", variant="primary")
    
    output_images = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]

    submit.click(
        process_file,
        inputs=[file, instructions, api_key],
        outputs=output_images
    )

if __name__ == "__main__":
    demo.launch()