Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,160 +1,58 @@
|
|
1 |
-
|
2 |
-
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
3 |
from fastapi.responses import JSONResponse
|
4 |
import tempfile
|
5 |
from dotenv import load_dotenv
|
6 |
import os
|
7 |
-
import google.generativeai as genai
|
8 |
import json
|
9 |
-
import logging
|
10 |
|
11 |
load_dotenv()
|
12 |
-
# Configure logging
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
16 |
app = FastAPI()
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
API_KEY = os.getenv("GOOGLE_API_KEY") # Use environment variable or replace directly
|
22 |
-
|
23 |
-
if not API_KEY:
|
24 |
-
logger.error("GEMINI_API_KEY environment variable not set.")
|
25 |
-
# You might want to raise an exception or exit here in a real application
|
26 |
-
# For now, we'll let it proceed but it will fail later if the placeholder key is invalid
|
27 |
-
|
28 |
-
# Configure the Gemini client globally
|
29 |
-
try:
|
30 |
-
genai.configure(api_key=API_KEY)
|
31 |
-
logger.info("Google Gemini client configured successfully.")
|
32 |
-
except Exception as e:
|
33 |
-
logger.error(f"Failed to configure Google Gemini client: {e}")
|
34 |
-
# Handle configuration error appropriately
|
35 |
-
|
36 |
-
# Initialize the Generative Model globally
|
37 |
-
# Use a model that supports image input, like gemini-1.5-flash-latest or gemini-pro-vision
|
38 |
-
# gemini-1.5-flash is generally recommended now
|
39 |
-
try:
|
40 |
-
model = genai.GenerativeModel("gemini-2.0-flash") # Using the recommended flash model
|
41 |
-
logger.info(f"Google Gemini model '{model.model_name}' initialized.")
|
42 |
-
except Exception as e:
|
43 |
-
logger.error(f"Failed to initialize Google Gemini model: {e}")
|
44 |
-
# Handle model initialization error appropriately
|
45 |
|
46 |
-
# --- FastAPI Endpoint ---
|
47 |
@app.post("/rate-outfit/")
|
48 |
async def rate_outfit(image: UploadFile = File(...), category: str = Form(...)):
|
49 |
-
logger.info(f"Received request to rate outfit. Category: {category}, Image: {image.filename}, Content-Type: {image.content_type}")
|
50 |
-
|
51 |
if image.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
|
52 |
-
|
53 |
-
raise HTTPException(status_code=400, detail="Please upload a valid image file (jpeg, png, jpg).")
|
54 |
|
55 |
-
tmp_path = None
|
56 |
try:
|
57 |
-
# Save image to temp file safely
|
58 |
-
# Using a context manager ensures the file is closed properly
|
59 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(image.filename)[1]) as tmp:
|
60 |
-
|
61 |
-
tmp.write(content)
|
62 |
tmp_path = tmp.name
|
63 |
-
logger.info(f"Image saved temporarily to: {tmp_path}")
|
64 |
|
65 |
-
# Upload image to Gemini using the recommended function
|
66 |
-
logger.info("Uploading image to Gemini...")
|
67 |
-
# The new API uses genai.upload_file directly
|
68 |
uploaded_file = genai.upload_file(path=tmp_path, display_name=image.filename)
|
69 |
-
logger.info(f"Image uploaded successfully: {uploaded_file.name}")
|
70 |
-
|
71 |
|
72 |
-
# Define the prompt clearly
|
73 |
prompt = (
|
74 |
f"You are an AI fashion assistant. Based on the category '{category}', analyze the provided image. "
|
75 |
-
"Extract the following information and provide the response ONLY as a valid JSON object, without
|
76 |
-
"
|
77 |
-
|
78 |
-
'"Feedback": "Concise advice (1-2 sentences) on how the look could be improved or styled differently."}'
|
79 |
-
" --- IMPORTANT SAFETY CHECK: If the image contains nudity, offensive content, any religious context, political figure, or anything inappropriate for a fashion context, respond ONLY with the following JSON: "
|
80 |
-
'{"error": "Please upload an appropriate image"} --- '
|
81 |
-
"Focus on being concise and eye-catching."
|
82 |
)
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
content_parts = [prompt, uploaded_file] # Pass the UploadedFile object
|
87 |
-
|
88 |
-
logger.info("Generating content with Gemini model...")
|
89 |
-
# Generate content
|
90 |
-
response = model.generate_content(content_parts)
|
91 |
-
logger.info("Received response from Gemini.")
|
92 |
-
# logger.debug(f"Raw Gemini response text: {response.text}") # Optional: Log raw response for debugging
|
93 |
-
|
94 |
-
# Clean and parse the response
|
95 |
-
text_response = response.text.strip()
|
96 |
-
|
97 |
-
# Robust cleaning: Remove potential markdown code blocks
|
98 |
-
if text_response.startswith("```json"):
|
99 |
-
text_response = text_response[7:] # Remove ```json\n
|
100 |
-
if text_response.endswith("```"):
|
101 |
-
text_response = text_response[:-3] # Remove ```
|
102 |
-
text_response = text_response.strip() # Strip again after removing markdown
|
103 |
-
|
104 |
-
logger.info(f"Cleaned Gemini response text: {text_response}")
|
105 |
-
|
106 |
-
# Attempt to parse the cleaned JSON
|
107 |
-
try:
|
108 |
-
result = json.loads(text_response)
|
109 |
-
# Validate if the result contains expected keys or the error key
|
110 |
-
if "error" in result:
|
111 |
-
logger.warning(f"Gemini detected inappropriate image: {result['error']}")
|
112 |
-
# Return a different status code for client-side handling? (e.g., 400 Bad Request)
|
113 |
-
# raise HTTPException(status_code=400, detail=result['error'])
|
114 |
-
# Or just return the error JSON as requested by some flows:
|
115 |
-
return JSONResponse(content=result, status_code=200) # Or 400 depending on desired API behavior
|
116 |
-
elif "Tag" not in result or "Feedback" not in result:
|
117 |
-
logger.error(f"Gemini response missing expected keys 'Tag' or 'Feedback'. Got: {result}")
|
118 |
-
raise HTTPException(status_code=500, detail="AI response format error: Missing expected keys.")
|
119 |
-
|
120 |
-
logger.info(f"Successfully parsed Gemini response: {result}")
|
121 |
-
return JSONResponse(content=result)
|
122 |
-
|
123 |
-
except json.JSONDecodeError as json_err:
|
124 |
-
logger.error(f"Failed to decode JSON response from Gemini: {json_err}")
|
125 |
-
logger.error(f"Invalid JSON string received: {text_response}")
|
126 |
-
raise HTTPException(status_code=500, detail="AI response format error: Invalid JSON.")
|
127 |
-
except Exception as parse_err: # Catch other potential errors during parsing/validation
|
128 |
-
logger.error(f"Error processing Gemini response: {parse_err}")
|
129 |
-
raise HTTPException(status_code=500, detail="Error processing AI response.")
|
130 |
|
|
|
|
|
131 |
|
132 |
-
|
133 |
-
logger.warning(f"Gemini blocked the prompt or response due to safety settings: {block_err}")
|
134 |
-
# Return a generic safety message or the specific error JSON
|
135 |
-
error_response = {"error": "Request blocked due to safety policies. Please ensure the image is appropriate."}
|
136 |
-
# It's often better to return a 400 Bad Request here
|
137 |
-
return JSONResponse(content=error_response, status_code=400)
|
138 |
|
139 |
except Exception as e:
|
140 |
-
logger.error(f"
|
141 |
-
|
142 |
-
raise HTTPException(status_code=500, detail="An internal server error occurred.")
|
143 |
|
144 |
finally:
|
145 |
-
# Cleanup temp image file if it was created
|
146 |
if tmp_path and os.path.exists(tmp_path):
|
147 |
-
|
148 |
-
os.remove(tmp_path)
|
149 |
-
logger.info(f"Temporary file {tmp_path} removed.")
|
150 |
-
except OSError as e:
|
151 |
-
logger.error(f"Error removing temporary file {tmp_path}: {e}")
|
152 |
|
153 |
-
# --- To Run (if this is the main script) ---
|
154 |
if __name__ == "__main__":
|
155 |
import uvicorn
|
156 |
-
|
157 |
-
# Example (Linux/macOS): export GEMINI_API_KEY='your_actual_api_key'
|
158 |
-
# # Example (Windows CMD): set GEMINI_API_KEY=your_actual_api_key
|
159 |
-
# # Example (Windows PowerShell): $env:GEMINI_API_KEY='your_actual_api_key'
|
160 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
1 |
+
from fastapi import FastAPI, File, UploadFile, Form
|
|
|
2 |
from fastapi.responses import JSONResponse
|
3 |
import tempfile
|
4 |
from dotenv import load_dotenv
|
5 |
import os
|
6 |
+
import google.generativeai as genai
|
7 |
import json
|
8 |
+
import logging
|
9 |
|
10 |
load_dotenv()
|
|
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
14 |
app = FastAPI()
|
15 |
|
16 |
+
API_KEY = os.getenv("GOOGLE_API_KEY")
|
17 |
+
genai.configure(api_key=API_KEY)
|
18 |
+
model = genai.GenerativeModel("gemini-2.0-flash")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
|
|
20 |
@app.post("/rate-outfit/")
|
21 |
async def rate_outfit(image: UploadFile = File(...), category: str = Form(...)):
|
|
|
|
|
22 |
if image.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
|
23 |
+
return JSONResponse(content={"message": "Please review the image before uploading"})
|
|
|
24 |
|
25 |
+
tmp_path = None
|
26 |
try:
|
|
|
|
|
27 |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(image.filename)[1]) as tmp:
|
28 |
+
tmp.write(await image.read())
|
|
|
29 |
tmp_path = tmp.name
|
|
|
30 |
|
|
|
|
|
|
|
31 |
uploaded_file = genai.upload_file(path=tmp_path, display_name=image.filename)
|
|
|
|
|
32 |
|
|
|
33 |
prompt = (
|
34 |
f"You are an AI fashion assistant. Based on the category '{category}', analyze the provided image. "
|
35 |
+
"Extract the following information and provide the response ONLY as a valid JSON object, without explanations. "
|
36 |
+
"Schema: {\"Tag\": \"catchy phrase\", \"Feedback\": \"concise advice\"}. "
|
37 |
+
"If inappropriate, respond ONLY with {\"error\": \"Please upload an appropriate image\"}."
|
|
|
|
|
|
|
|
|
38 |
)
|
39 |
|
40 |
+
response = model.generate_content([prompt, uploaded_file])
|
41 |
+
result = json.loads(response.text.strip().replace('```json', '').replace('```', '').strip())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
if "error" in result:
|
44 |
+
return JSONResponse(content={"message": "Please review the image before uploading"})
|
45 |
|
46 |
+
return JSONResponse(content=result)
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
except Exception as e:
|
49 |
+
logger.error(f"Error: {e}")
|
50 |
+
return JSONResponse(content={"message": "Please review the image before uploading"})
|
|
|
51 |
|
52 |
finally:
|
|
|
53 |
if tmp_path and os.path.exists(tmp_path):
|
54 |
+
os.remove(tmp_path)
|
|
|
|
|
|
|
|
|
55 |
|
|
|
56 |
if __name__ == "__main__":
|
57 |
import uvicorn
|
58 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
|
|
|
|