bluenevus commited on
Commit
7d1e58a
·
verified ·
1 Parent(s): 56d19f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -55
app.py CHANGED
@@ -1,59 +1,133 @@
1
- import gradio as gr
 
 
 
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
- import io
5
- import base64
6
- import google.generativeai as genai
7
-
8
- def process_file(api_key, file, instructions):
9
- # Set up Gemini API
10
- genai.configure(api_key=api_key)
11
- model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
12
-
13
- # Read the file
14
- if file.name.endswith('.csv'):
15
- df = pd.read_csv(file.name)
16
- else:
17
- df = pd.read_excel(file.name)
18
-
19
- # Analyze data and get visualization suggestions from Gemini
20
- data_description = df.describe().to_string()
21
- prompt = f"Given this data: {data_description}\n"
22
- if instructions:
23
- prompt += f"And these instructions: {instructions}\n"
24
- prompt += "Suggest 3 ways to visualize this data."
25
-
26
- response = model.generate_content(prompt)
27
- suggestions = response.text.split('\n')
28
-
29
- # Generate visualizations
30
- visualizations = []
31
- for i, suggestion in enumerate(suggestions[:3]):
32
- plt.figure(figsize=(10, 6))
33
- plt.title(f"Visualization {i+1}")
34
- plt.text(0.5, 0.5, suggestion, ha='center', va='center', wrap=True)
35
- buf = io.BytesIO()
36
- plt.savefig(buf, format='png')
37
- buf.seek(0)
38
- img_str = base64.b64encode(buf.getvalue()).decode()
39
- visualizations.append(f"data:image/png;base64,{img_str}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  plt.close()
 
 
 
 
 
 
 
 
41
 
42
- return visualizations
43
-
44
- # Gradio interface
45
- with gr.Blocks() as demo:
46
- gr.Markdown("Data Visualization with Gemini")
47
- api_key = gr.Textbox(label="Enter Gemini API Key", type="password")
48
- file = gr.File(label="Upload Excel or CSV file")
49
- instructions = gr.Textbox(label="Optional visualization instructions")
50
- submit = gr.Button("Generate Visualizations")
51
- outputs = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
52
-
53
- submit.click(
54
- fn=process_file,
55
- inputs=[api_key, file, instructions],
56
- outputs=outputs
57
- )
58
-
59
- demo.launch(share=True)
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import HTMLResponse
4
+ from fastapi.staticfiles import StaticFiles
5
  import pandas as pd
6
  import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ import os
9
+ import logging
10
+ from huggingface_hub import InferenceClient
11
+ from dotenv import load_dotenv
12
+ import hashlib
13
+
14
+ # Set up logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Load environment variables
19
+ load_dotenv()
20
+
21
+ app = FastAPI()
22
+
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ app.mount("/static", StaticFiles(directory="static"), name="static")
32
+
33
+ API_TOKEN = os.getenv("HF_TOKEN")
34
+ if not API_TOKEN:
35
+ raise ValueError("HF_TOKEN environment variable not set.")
36
+
37
+ MODEL_NAME = "gemini-2.5-pro-preview-03-25"
38
+ client = InferenceClient(model=MODEL_NAME, token=API_TOKEN)
39
+
40
+ UPLOAD_DIR = "uploads"
41
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
42
+
43
+ IMAGES_DIR = os.path.join("static", "images")
44
+ os.makedirs(IMAGES_DIR, exist_ok=True)
45
+
46
+ @app.post("/upload/")
47
+ async def upload_file(file: UploadFile = File(...)):
48
+ if not file.filename.endswith((".xlsx", ".csv")):
49
+ raise HTTPException(status_code=400, detail="File must be an Excel (.xlsx) or CSV file")
50
+
51
+ file_path = os.path.join(UPLOAD_DIR, file.filename)
52
+ with open(file_path, "wb") as buffer:
53
+ buffer.write(await file.read())
54
+
55
+ logger.info(f"File uploaded: {file.filename}")
56
+ return {"filename": file.filename}
57
+
58
+ @app.post("/generate-visualization/")
59
+ async def generate_visualization(prompt: str = Form(...), filename: str = Form(...)):
60
+ file_path = os.path.join(UPLOAD_DIR, filename)
61
+
62
+ if not os.path.exists(file_path):
63
+ raise HTTPException(status_code=404, detail="File not found on server.")
64
+
65
+ try:
66
+ if filename.endswith('.csv'):
67
+ df = pd.read_csv(file_path)
68
+ else:
69
+ df = pd.read_excel(file_path)
70
+ if df.empty:
71
+ raise ValueError("File is empty.")
72
+ except Exception as e:
73
+ raise HTTPException(status_code=400, detail=f"Error reading file: {str(e)}")
74
+
75
+ input_text = f"""
76
+ Given the DataFrame 'df' with columns {', '.join(df.columns)} and preview:
77
+ {df.head().to_string()}
78
+ Write Python code to: {prompt}
79
+ - Use ONLY 'df' (no external data loading).
80
+ - Use pandas (pd), matplotlib.pyplot (plt), or seaborn (sns).
81
+ - Include axis labels and a title.
82
+ - Output ONLY executable code (no comments, functions, print, or triple quotes).
83
+ """
84
+
85
+ try:
86
+ generated_code = client.text_generation(input_text, max_new_tokens=500)
87
+ logger.info(f"Generated code:\n{generated_code}")
88
+ except Exception as e:
89
+ raise HTTPException(status_code=500, detail=f"Error querying model: {str(e)}")
90
+
91
+ if not generated_code.strip():
92
+ raise HTTPException(status_code=500, detail="No code generated by the AI model.")
93
+
94
+ generated_code = generated_code.strip()
95
+ if generated_code.startswith('"""') or generated_code.startswith("'''"):
96
+ generated_code = generated_code.split('"""')[1] if '"""' in generated_code else generated_code.split("'''")[1]
97
+ if generated_code.endswith('"""') or generated_code.endswith("'''"):
98
+ generated_code = generated_code.rsplit('"""')[0] if '"""' in generated_code else generated_code.rsplit("'''")[0]
99
+ generated_code = generated_code.strip()
100
+
101
+ lines = generated_code.splitlines()
102
+ executable_code = "\n".join(
103
+ line.strip() for line in lines
104
+ if line.strip() and not line.strip().startswith(('#', 'def', 'class', '"""', "'''"))
105
+ and not any(kw in line for kw in ["pd.read_csv", "pd.read_excel", "http", "raise", "print"])
106
+ ).strip()
107
+
108
+ executable_code = executable_code.replace("plt.show()", "").strip()
109
+
110
+ logger.info(f"Executable code:\n{executable_code}")
111
+
112
+ plot_hash = hashlib.md5(f"{filename}_{prompt}".encode()).hexdigest()[:8]
113
+ plot_filename = f"plot_{plot_hash}.png"
114
+ plot_path = os.path.join(IMAGES_DIR, plot_filename)
115
+
116
+ try:
117
+ exec_globals = {"pd": pd, "plt": plt, "sns": sns, "df": df}
118
+ exec(executable_code, exec_globals)
119
+ plt.savefig(plot_path, bbox_inches="tight")
120
  plt.close()
121
+ except Exception as e:
122
+ logger.error(f"Error executing code:\n{executable_code}\nException: {str(e)}")
123
+ raise HTTPException(status_code=500, detail=f"Error executing code: {str(e)}")
124
+
125
+ if not os.path.exists(plot_path):
126
+ raise HTTPException(status_code=500, detail="Plot file was not created.")
127
+
128
+ return {"plot_url": f"/static/images/{plot_filename}", "generated_code": generated_code}
129
 
130
+ @app.get("/")
131
+ async def serve_frontend():
132
+ with open("static/index.html", "r") as f:
133
+ return HTMLResponse(content=f.read())