Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,8 +9,9 @@ from google.genai import types
|
|
9 |
# Initialize Flask app
|
10 |
app = Flask(__name__)
|
11 |
|
12 |
-
# Set your Gemini API key
|
13 |
-
|
|
|
14 |
client = genai.Client(api_key=GEMINI_API_KEY)
|
15 |
|
16 |
# Create necessary directories
|
@@ -21,8 +22,8 @@ os.makedirs(RESULT_FOLDER, exist_ok=True)
|
|
21 |
|
22 |
def generate_gemini_output(object_type, image_data_url):
|
23 |
"""
|
24 |
-
Generate output from Gemini by removing the object
|
25 |
-
|
26 |
"""
|
27 |
model = "gemini-2.0-flash-exp-image-generation"
|
28 |
files = []
|
@@ -35,10 +36,7 @@ def generate_gemini_output(object_type, image_data_url):
|
|
35 |
raise ValueError("Invalid image data")
|
36 |
binary_data = base64.b64decode(encoded)
|
37 |
# Determine file extension from header
|
38 |
-
if "png" in header.lower()
|
39 |
-
ext = ".png"
|
40 |
-
else:
|
41 |
-
ext = ".jpg"
|
42 |
temp_filename = secure_filename("temp_image" + ext)
|
43 |
temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename)
|
44 |
with open(temp_filepath, "wb") as f:
|
@@ -52,7 +50,7 @@ def generate_gemini_output(object_type, image_data_url):
|
|
52 |
if files:
|
53 |
parts.append(types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type))
|
54 |
if object_type:
|
55 |
-
#
|
56 |
magic_prompt = f"Remove {object_type} from the image"
|
57 |
parts.append(types.Part.from_text(text=magic_prompt))
|
58 |
|
@@ -73,7 +71,7 @@ def generate_gemini_output(object_type, image_data_url):
|
|
73 |
result_text = None
|
74 |
result_image = None
|
75 |
|
76 |
-
# Stream
|
77 |
for chunk in client.models.generate_content_stream(
|
78 |
model=model,
|
79 |
contents=contents,
|
@@ -96,26 +94,25 @@ def generate_gemini_output(object_type, image_data_url):
|
|
96 |
|
97 |
@app.route("/")
|
98 |
def index():
|
99 |
-
# Render the front-end HTML (which
|
100 |
return render_template("index.html")
|
101 |
|
102 |
@app.route("/process", methods=["POST"])
|
103 |
def process():
|
104 |
try:
|
105 |
-
#
|
106 |
data = request.get_json(force=True)
|
107 |
image_data = data.get("image")
|
108 |
object_type = data.get("objectType", "").strip()
|
109 |
if not image_data or not object_type:
|
110 |
return jsonify({"success": False, "message": "Missing image data or object type."}), 400
|
111 |
|
112 |
-
# Generate Gemini
|
113 |
result_text, result_image = generate_gemini_output(object_type, image_data)
|
114 |
if not result_image:
|
115 |
return jsonify({"success": False, "message": "Failed to generate image."}), 500
|
116 |
|
117 |
# Create a URL to serve the image from the static folder.
|
118 |
-
# Assuming your static folder is served at '/static'
|
119 |
image_url = f"/static/{os.path.basename(result_image)}"
|
120 |
|
121 |
return jsonify({"success": True, "resultPath": image_url, "resultText": result_text})
|
@@ -123,5 +120,5 @@ def process():
|
|
123 |
return jsonify({"success": False, "message": f"Error: {str(e)}"}), 500
|
124 |
|
125 |
if __name__ == "__main__":
|
126 |
-
# Run the app
|
127 |
-
app.run(host="0.0.0.0", port=int(os.
|
|
|
9 |
# Initialize Flask app
|
10 |
app = Flask(__name__)
|
11 |
|
12 |
+
# Set your Gemini API key via Hugging Face Spaces environment variables.
|
13 |
+
# Do not include a default fallback; the environment must supply GEMINI_API_KEY.
|
14 |
+
GEMINI_API_KEY = os.environ["GEMINI_API_KEY"]
|
15 |
client = genai.Client(api_key=GEMINI_API_KEY)
|
16 |
|
17 |
# Create necessary directories
|
|
|
22 |
|
23 |
def generate_gemini_output(object_type, image_data_url):
|
24 |
"""
|
25 |
+
Generate output from Gemini by removing the specified object.
|
26 |
+
Expects the image_data_url to be a base64 data URL.
|
27 |
"""
|
28 |
model = "gemini-2.0-flash-exp-image-generation"
|
29 |
files = []
|
|
|
36 |
raise ValueError("Invalid image data")
|
37 |
binary_data = base64.b64decode(encoded)
|
38 |
# Determine file extension from header
|
39 |
+
ext = ".png" if "png" in header.lower() else ".jpg"
|
|
|
|
|
|
|
40 |
temp_filename = secure_filename("temp_image" + ext)
|
41 |
temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename)
|
42 |
with open(temp_filepath, "wb") as f:
|
|
|
50 |
if files:
|
51 |
parts.append(types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type))
|
52 |
if object_type:
|
53 |
+
# Gemini magic prompt: instruct the model to remove the specified object
|
54 |
magic_prompt = f"Remove {object_type} from the image"
|
55 |
parts.append(types.Part.from_text(text=magic_prompt))
|
56 |
|
|
|
71 |
result_text = None
|
72 |
result_image = None
|
73 |
|
74 |
+
# Stream output from Gemini API
|
75 |
for chunk in client.models.generate_content_stream(
|
76 |
model=model,
|
77 |
contents=contents,
|
|
|
94 |
|
95 |
@app.route("/")
|
96 |
def index():
|
97 |
+
# Render the front-end HTML (which contains complete HTML/CSS/JS inline)
|
98 |
return render_template("index.html")
|
99 |
|
100 |
@app.route("/process", methods=["POST"])
|
101 |
def process():
|
102 |
try:
|
103 |
+
# Expect JSON with keys "image" (base64 data URL) and "objectType"
|
104 |
data = request.get_json(force=True)
|
105 |
image_data = data.get("image")
|
106 |
object_type = data.get("objectType", "").strip()
|
107 |
if not image_data or not object_type:
|
108 |
return jsonify({"success": False, "message": "Missing image data or object type."}), 400
|
109 |
|
110 |
+
# Generate output using Gemini
|
111 |
result_text, result_image = generate_gemini_output(object_type, image_data)
|
112 |
if not result_image:
|
113 |
return jsonify({"success": False, "message": "Failed to generate image."}), 500
|
114 |
|
115 |
# Create a URL to serve the image from the static folder.
|
|
|
116 |
image_url = f"/static/{os.path.basename(result_image)}"
|
117 |
|
118 |
return jsonify({"success": True, "resultPath": image_url, "resultText": result_text})
|
|
|
120 |
return jsonify({"success": False, "message": f"Error: {str(e)}"}), 500
|
121 |
|
122 |
if __name__ == "__main__":
|
123 |
+
# Run the app on port 5000 or the port provided by the environment (for Hugging Face Spaces)
|
124 |
+
app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 5000)))
|