File size: 3,279 Bytes
4fc79a4
ff768e2
 
4fc79a4
 
29f87f1
4fc79a4
 
 
eb443e3
4fc79a4
eb443e3
 
 
 
 
 
 
 
4fc79a4
eb443e3
 
 
 
 
4fc79a4
eb443e3
 
 
 
 
4fc79a4
eb443e3
4fc79a4
eb443e3
1e2b302
eb443e3
29f87f1
eb443e3
 
1e2b302
eb443e3
7d1e58a
 
4fc79a4
eb443e3
1e2b302
4fc79a4
eb443e3
e6ca829
eb443e3
 
 
e6ca829
eb443e3
 
 
 
 
 
 
e6ca829
eb443e3
e6ca829
29f87f1
 
e6ca829
eb443e3
 
 
 
 
 
e6ca829
eb443e3
4fc79a4
eb443e3
 
 
29f87f1
 
eb443e3
 
29f87f1
eb443e3
 
4fc79a4
 
eb443e3
e6ca829
4fc79a4
29f87f1
4fc79a4
29f87f1
4fc79a4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image

def process_file(api_key, file, instructions):
    genai.configure(api_key=api_key)
    model = genai.GenerativeModel('gemini-2.5-pro-latest')

    # Read file with error handling
    try:
        if file.name.endswith('.csv'):
            df = pd.read_csv(file.name)
        else:
            df = pd.read_excel(file.name)
    except Exception as e:
        return [f"File Error: {str(e)}"] * 3

    # Enhanced prompt with visualization requirements
    prompt = f"""Generate exactly 3 distinct matplotlib visualization codes for this dataset:
    Columns: {list(df.columns)}
    Data types: {dict(df.dtypes)}
    Sample data: {df.head(3).to_dict()}
    
    Requirements:
    1. 1920x1080 resolution (figsize=(16,9), dpi=120)
    2. Professional styling (seaborn style, grid lines)
    3. Clear labels and titles
    4. Diverse chart types (at least 1 must be non-basic)
    
    User instructions: {instructions or 'No specific instructions'}
    
    Format response strictly as:
    # Visualization 1
    plt.figure(figsize=(16,9), dpi=120)
    [code]
    plt.tight_layout()
    
    # Visualization 2
    ...
    """

    response = model.generate_content(prompt)
    code_blocks = response.text.split("# Visualization ")[1:4]  # Force 3 max

    visualizations = []
    for i, block in enumerate(code_blocks, 1):
        try:
            # HD configuration
            plt.figure(figsize=(16, 9), dpi=120)
            plt.style.use('seaborn')
            
            # Clean and execute code
            cleaned_code = '\n'.join([line.strip() for line in block.split('\n')[1:] if line.strip()])
            exec(cleaned_code, {'df': df, 'plt': plt})
            plt.title(f"Visualization {i}", fontsize=14)
            plt.tight_layout()

            # Save HD image
            buf = io.BytesIO()
            plt.savefig(buf, format='png', dpi=120, bbox_inches='tight')
            plt.close()
            buf.seek(0)
            visualizations.append(Image.open(buf))
        except Exception as e:
            print(f"Visualization {i} failed: {str(e)}")
            visualizations.append(Image.new('RGB', (1920, 1080), color=(73, 109, 137)))

    # Ensure exactly 3 outputs with fallback
    while len(visualizations) < 3:
        visualizations.append(Image.new('RGB', (1920, 1080), color=(73, 109, 137)))

    return visualizations[:3]

# Gradio interface with HD layout
with gr.Blocks(theme=gr.themes.Default(spacing_size="xl")) as demo:
    gr.Markdown("# **HD Data Visualizer**  πŸ“Šβœ¨")
    
    with gr.Row():
        api_key = gr.Textbox(label="πŸ”‘ Gemini API Key", type="password")
        file = gr.File(label="πŸ“ Upload Dataset", file_types=[".csv", ".xlsx"])
    
    instructions = gr.Textbox(label="πŸ’‘ Custom Instructions (optional)", placeholder="E.g.: Focus on time series patterns...")
    submit = gr.Button("πŸš€ Generate Visualizations", variant="primary")
    
    with gr.Row():
        outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)]

    submit.click(
        process_file,
        inputs=[api_key, file, instructions],
        outputs=outputs
    )

demo.launch()