bluenevus commited on
Commit
eb443e3
Β·
verified Β·
1 Parent(s): 9a6f839

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -46
app.py CHANGED
@@ -6,80 +6,87 @@ import google.generativeai as genai
6
  from PIL import Image
7
 
8
  def process_file(api_key, file, instructions):
9
- # Configure Gemini API
10
  genai.configure(api_key=api_key)
11
- model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
12
 
13
- # Read uploaded file
14
- if file.name.endswith('.csv'):
15
- df = pd.read_csv(file.name)
16
- else:
17
- df = pd.read_excel(file.name)
 
 
 
18
 
19
- # Generate analysis prompt
20
- data_description = df.describe().to_string()
21
- columns_info = "\n".join([f"{col}: {df[col].dtype}" for col in df.columns])
22
- prompt = f"""
23
- Given this dataset:
24
- Columns and types:
25
- {columns_info}
26
 
27
- Data summary:
28
- {data_description}
 
 
 
29
 
30
- User instructions: {instructions if instructions else 'No specific instructions provided.'}
31
 
32
- Generate Python code for 3 different matplotlib visualizations. Only provide executable code without explanations.
33
- Format as:
34
  # Visualization 1
 
35
  [code]
 
 
36
  # Visualization 2
37
- [code]
38
- # Visualization 3
39
- [code]
40
  """
41
 
42
- # Get Gemini response
43
  response = model.generate_content(prompt)
44
- code = response.text.strip()
45
 
46
- # Process visualizations
47
  visualizations = []
48
- for i, viz_code in enumerate(code.split("# Visualization")[1:4], 1):
49
- plt.figure(figsize=(10, 6))
50
  try:
51
- # Clean and execute generated code
52
- cleaned_code = '\n'.join(line.strip() for line in viz_code.split('\n') if line.strip())
53
- exec(cleaned_code, {'df': df, 'plt': plt})
54
- plt.title(f"Visualization {i}")
55
 
56
- # Convert plot to Gradio-compatible format
 
 
 
 
 
 
57
  buf = io.BytesIO()
58
- plt.savefig(buf, format='png', bbox_inches='tight')
59
  plt.close()
60
  buf.seek(0)
61
  visualizations.append(Image.open(buf))
62
  except Exception as e:
63
- print(f"Visualization {i} error: {str(e)}")
64
- visualizations.append(None)
 
 
 
 
65
 
66
- # Ensure 3 outputs for Gradio
67
- return visualizations + [None]*(3-len(visualizations))
68
 
69
- # Gradio interface
70
- with gr.Blocks() as demo:
71
- gr.Markdown("""# AI-Powered Data Visualizer
72
- Upload your data file and get automatic visualizations powered by Gemini""")
73
 
74
  with gr.Row():
75
- api_key = gr.Textbox(label="Gemini API Key", type="password")
76
- instructions = gr.Textbox(label="Custom Instructions (optional)")
77
 
78
- file = gr.File(label="Dataset (CSV/Excel)", file_types=[".csv", ".xlsx"])
79
- submit = gr.Button("Generate Insights", variant="primary")
80
 
81
  with gr.Row():
82
- outputs = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
83
 
84
  submit.click(
85
  process_file,
 
6
  from PIL import Image
7
 
8
  def process_file(api_key, file, instructions):
 
9
  genai.configure(api_key=api_key)
10
+ model = genai.GenerativeModel('gemini-2.5-pro-latest')
11
 
12
+ # Read file with error handling
13
+ try:
14
+ if file.name.endswith('.csv'):
15
+ df = pd.read_csv(file.name)
16
+ else:
17
+ df = pd.read_excel(file.name)
18
+ except Exception as e:
19
+ return [f"File Error: {str(e)}"] * 3
20
 
21
+ # Enhanced prompt with visualization requirements
22
+ prompt = f"""Generate exactly 3 distinct matplotlib visualization codes for this dataset:
23
+ Columns: {list(df.columns)}
24
+ Data types: {dict(df.dtypes)}
25
+ Sample data: {df.head(3).to_dict()}
 
 
26
 
27
+ Requirements:
28
+ 1. 1920x1080 resolution (figsize=(16,9), dpi=120)
29
+ 2. Professional styling (seaborn style, grid lines)
30
+ 3. Clear labels and titles
31
+ 4. Diverse chart types (at least 1 must be non-basic)
32
 
33
+ User instructions: {instructions or 'No specific instructions'}
34
 
35
+ Format response strictly as:
 
36
  # Visualization 1
37
+ plt.figure(figsize=(16,9), dpi=120)
38
  [code]
39
+ plt.tight_layout()
40
+
41
  # Visualization 2
42
+ ...
 
 
43
  """
44
 
 
45
  response = model.generate_content(prompt)
46
+ code_blocks = response.text.split("# Visualization ")[1:4] # Force 3 max
47
 
 
48
  visualizations = []
49
+ for i, block in enumerate(code_blocks, 1):
 
50
  try:
51
+ # HD configuration
52
+ plt.figure(figsize=(16, 9), dpi=120)
53
+ plt.style.use('seaborn')
 
54
 
55
+ # Clean and execute code
56
+ cleaned_code = '\n'.join([line.strip() for line in block.split('\n')[1:] if line.strip()])
57
+ exec(cleaned_code, {'df': df, 'plt': plt})
58
+ plt.title(f"Visualization {i}", fontsize=14)
59
+ plt.tight_layout()
60
+
61
+ # Save HD image
62
  buf = io.BytesIO()
63
+ plt.savefig(buf, format='png', dpi=120, bbox_inches='tight')
64
  plt.close()
65
  buf.seek(0)
66
  visualizations.append(Image.open(buf))
67
  except Exception as e:
68
+ print(f"Visualization {i} failed: {str(e)}")
69
+ visualizations.append(Image.new('RGB', (1920, 1080), color=(73, 109, 137)))
70
+
71
+ # Ensure exactly 3 outputs with fallback
72
+ while len(visualizations) < 3:
73
+ visualizations.append(Image.new('RGB', (1920, 1080), color=(73, 109, 137)))
74
 
75
+ return visualizations[:3]
 
76
 
77
+ # Gradio interface with HD layout
78
+ with gr.Blocks(theme=gr.themes.Default(spacing_size="xl")) as demo:
79
+ gr.Markdown("# **HD Data Visualizer** πŸ“Šβœ¨")
 
80
 
81
  with gr.Row():
82
+ api_key = gr.Textbox(label="πŸ”‘ Gemini API Key", type="password")
83
+ file = gr.File(label="πŸ“ Upload Dataset", file_types=[".csv", ".xlsx"])
84
 
85
+ instructions = gr.Textbox(label="πŸ’‘ Custom Instructions (optional)", placeholder="E.g.: Focus on time series patterns...")
86
+ submit = gr.Button("πŸš€ Generate Visualizations", variant="primary")
87
 
88
  with gr.Row():
89
+ outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)]
90
 
91
  submit.click(
92
  process_file,