bluenevus commited on
Commit
218ce96
Β·
verified Β·
1 Parent(s): 6cff8d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -66
app.py CHANGED
@@ -5,116 +5,98 @@ import io
5
  import google.generativeai as genai
6
  from PIL import Image
7
  import ast
 
8
 
9
  def process_file(api_key, file, instructions):
10
- # Configure Gemini API with correct model
11
  genai.configure(api_key=api_key)
12
  model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
13
 
14
  try:
15
- # Read uploaded file
16
- if file.name.endswith('.csv'):
17
- df = pd.read_csv(file.name)
18
- else:
19
- df = pd.read_excel(file.name)
20
  except Exception as e:
21
  print(f"File Error: {str(e)}")
22
- return [None, None, None]
23
 
24
- # Enhanced prompt with strict coding rules
25
- prompt = f"""Generate 3 distinct matplotlib visualization codes with these rules:
26
- 1. Use ONLY these variables: df (existing DataFrame), plt
27
- 2. No imports or additional data loading
28
- 3. Each visualization must:
29
- - Start with: plt.figure(figsize=(16,9), dpi=120)
30
- - Use plt.style.use('ggplot')
31
- - Include title, axis labels, and grid
32
- - End with plt.tight_layout()
33
- 4. Different chart types (bar, line, scatter, etc)
34
 
35
- Dataset columns: {list(df.columns)}
36
- Sample data: {df.head(3).to_dict()}
37
- User instructions: {instructions or 'None'}
38
 
39
  Format EXACTLY as:
40
  # Visualization 1
41
- plt.figure(figsize=(16,9), dpi=120)
42
- plt.style.use('ggplot')
43
- # Visualization code using df
44
- plt.tight_layout()
 
 
 
45
  """
46
 
47
  try:
48
  response = model.generate_content(prompt)
49
- code_blocks = response.text.split("# Visualization ")[1:4]
50
  except Exception as e:
51
  print(f"Gemini Error: {str(e)}")
52
- return [None, None, None]
53
 
54
  visualizations = []
55
  for i, block in enumerate(code_blocks, 1):
56
  buf = io.BytesIO()
57
  try:
58
- # Clean and validate generated code
59
- cleaned_code = '\n'.join([
60
- line.replace('data', 'df').split('#')[0].strip()
61
- for line in block.split('\n')[1:]
62
- if line.strip() and
63
- not any(s in line.lower() for s in ['import', 'data=', 'data ='])
64
- ])
 
 
 
 
 
65
 
66
- # Syntax check
67
  ast.parse(cleaned_code)
68
 
69
- # Execute code in controlled environment
70
  exec_env = {'df': df, 'plt': plt}
71
  plt.figure(figsize=(16, 9), dpi=120)
72
  exec(cleaned_code, exec_env)
73
 
74
- # Save HD image
75
  plt.savefig(buf, format='png', bbox_inches='tight')
76
  plt.close()
77
- buf.seek(0)
78
  visualizations.append(Image.open(buf))
79
  except Exception as e:
80
  print(f"Visualization {i} Error: {str(e)}")
81
- print(f"Problematic Code:\n{cleaned_code}")
82
  visualizations.append(None)
83
 
84
- # Ensure exactly 3 outputs
85
  return visualizations + [None]*(3-len(visualizations))
86
 
87
  # Gradio interface
88
- with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
89
- gr.Markdown("# πŸ” Data Visualization Generator")
90
 
91
  with gr.Row():
92
- api_key = gr.Textbox(
93
- label="πŸ”‘ Gemini API Key",
94
- type="password",
95
- placeholder="Enter your API key here"
96
- )
97
- file = gr.File(
98
- label="πŸ“ Upload Dataset",
99
- file_types=[".csv", ".xlsx"],
100
- type="filepath"
101
- )
102
-
103
- instructions = gr.Textbox(
104
- label="πŸ’‘ Custom Instructions",
105
- placeholder="E.g.: Compare sales trends across regions..."
106
- )
107
 
108
- submit = gr.Button("πŸš€ Generate Visualizations", variant="primary")
 
109
 
110
  with gr.Row():
111
- outputs = [
112
- gr.Image(
113
- label=f"Visualization {i+1}",
114
- width=600,
115
- height=400
116
- ) for i in range(3)
117
- ]
118
 
119
  submit.click(
120
  process_file,
@@ -122,5 +104,4 @@ with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
122
  outputs=outputs
123
  )
124
 
125
- if __name__ == "__main__":
126
- demo.launch()
 
5
  import google.generativeai as genai
6
  from PIL import Image
7
  import ast
8
+ import re
9
 
10
  def process_file(api_key, file, instructions):
 
11
  genai.configure(api_key=api_key)
12
  model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
13
 
14
  try:
15
+ df = pd.read_csv(file.name) if file.name.endswith('.csv') else pd.read_excel(file.name)
 
 
 
 
16
  except Exception as e:
17
  print(f"File Error: {str(e)}")
18
+ return [None]*3
19
 
20
+ # Enhanced prompt with strict syntax rules
21
+ prompt = f"""Generate 3 Python code blocks for matplotlib visualizations with:
22
+ 1. Perfect Python syntax
23
+ 2. No markdown or incomplete lines
24
+ 3. Each block must start with:
25
+ plt.figure(figsize=(16,9), dpi=120)
26
+ plt.style.use('ggplot')
27
+ 4. Use ONLY df and plt variables
28
+ 5. End with plt.tight_layout()
 
29
 
30
+ Columns: {list(df.columns)}
31
+ Sample: {df.head(3).to_dict()}
32
+ Instructions: {instructions or 'None'}
33
 
34
  Format EXACTLY as:
35
  # Visualization 1
36
+ [code]
37
+
38
+ # Visualization 2
39
+ [code]
40
+
41
+ # Visualization 3
42
+ [code]
43
  """
44
 
45
  try:
46
  response = model.generate_content(prompt)
47
+ code_blocks = re.split(r'# Visualization \d+', response.text)[1:4]
48
  except Exception as e:
49
  print(f"Gemini Error: {str(e)}")
50
+ return [None]*3
51
 
52
  visualizations = []
53
  for i, block in enumerate(code_blocks, 1):
54
  buf = io.BytesIO()
55
  try:
56
+ # Advanced code cleaning
57
+ cleaned_code = '\n'.join(
58
+ line.strip().replace("'", "").replace('"', '') # Remove stray quotes
59
+ for line in block.split('\n')
60
+ if line.strip() and
61
+ not any(c in line for c in ['`', '```', 'Annotation']) and
62
+ re.match(r'^[a-zA-Z0-9_().=, <>:]+$', line) # Basic syntax validation
63
+ )
64
+
65
+ # Add missing parentheses check
66
+ cleaned_code = re.sub(r'plt.style.use\([\'"]ggplot$',
67
+ 'plt.style.use("ggplot")', cleaned_code)
68
 
69
+ # Syntax validation
70
  ast.parse(cleaned_code)
71
 
72
+ # Execute code
73
  exec_env = {'df': df, 'plt': plt}
74
  plt.figure(figsize=(16, 9), dpi=120)
75
  exec(cleaned_code, exec_env)
76
 
 
77
  plt.savefig(buf, format='png', bbox_inches='tight')
78
  plt.close()
 
79
  visualizations.append(Image.open(buf))
80
  except Exception as e:
81
  print(f"Visualization {i} Error: {str(e)}")
82
+ print(f"Cleaned Code:\n{cleaned_code}")
83
  visualizations.append(None)
84
 
 
85
  return visualizations + [None]*(3-len(visualizations))
86
 
87
  # Gradio interface
88
+ with gr.Blocks() as demo:
89
+ gr.Markdown("# Professional Data Visualizer")
90
 
91
  with gr.Row():
92
+ api_key = gr.Textbox(label="Gemini API Key", type="password")
93
+ file = gr.File(label="Upload Data File", file_types=[".csv", ".xlsx"])
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ instructions = gr.Textbox(label="Visualization Instructions")
96
+ submit = gr.Button("Generate", variant="primary")
97
 
98
  with gr.Row():
99
+ outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)]
 
 
 
 
 
 
100
 
101
  submit.click(
102
  process_file,
 
104
  outputs=outputs
105
  )
106
 
107
+ demo.launch()