Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,38 +1,59 @@
|
|
1 |
-
from flask import Flask, render_template, request, send_file
|
2 |
import os
|
|
|
3 |
import mimetypes
|
|
|
|
|
4 |
from google import genai
|
5 |
from google.genai import types
|
6 |
-
from io import BytesIO
|
7 |
|
|
|
8 |
app = Flask(__name__)
|
9 |
|
10 |
-
#
|
11 |
-
|
12 |
client = genai.Client(api_key=GEMINI_API_KEY)
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
20 |
model = "gemini-2.0-flash-exp-image-generation"
|
21 |
-
|
22 |
-
# Upload the image to Gemini
|
23 |
files = []
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
files.append(uploaded_file)
|
27 |
|
28 |
-
# Prepare
|
29 |
parts = []
|
30 |
if files:
|
31 |
parts.append(types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type))
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
magic_prompt = f"Remove {user_input} from the image"
|
36 |
parts.append(types.Part.from_text(text=magic_prompt))
|
37 |
|
38 |
contents = [types.Content(role="user", parts=parts)]
|
@@ -52,6 +73,7 @@ def generate_gemini_output(user_input, image):
|
|
52 |
result_text = None
|
53 |
result_image = None
|
54 |
|
|
|
55 |
for chunk in client.models.generate_content_stream(
|
56 |
model=model,
|
57 |
contents=contents,
|
@@ -59,29 +81,47 @@ def generate_gemini_output(user_input, image):
|
|
59 |
):
|
60 |
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
|
61 |
continue
|
62 |
-
|
63 |
part = chunk.candidates[0].content.parts[0]
|
64 |
-
|
65 |
if part.inline_data:
|
66 |
-
file_name = "generated_output"
|
67 |
file_extension = mimetypes.guess_extension(part.inline_data.mime_type) or ".png"
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
71 |
else:
|
72 |
result_text = part.text
|
73 |
|
74 |
return result_text, result_image
|
75 |
|
76 |
-
@app.route("/"
|
77 |
def index():
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
if __name__ == "__main__":
|
87 |
-
app.
|
|
|
|
|
|
1 |
import os
|
2 |
+
import base64
|
3 |
import mimetypes
|
4 |
+
from flask import Flask, render_template, request, jsonify
|
5 |
+
from werkzeug.utils import secure_filename
|
6 |
from google import genai
|
7 |
from google.genai import types
|
|
|
8 |
|
9 |
+
# Initialize Flask app
|
10 |
app = Flask(__name__)
|
11 |
|
12 |
+
# Set your Gemini API key (or set GEMINI_API_KEY in your environment)
|
13 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
14 |
client = genai.Client(api_key=GEMINI_API_KEY)
|
15 |
|
16 |
+
# Create necessary directories
|
17 |
+
UPLOAD_FOLDER = 'uploads'
|
18 |
+
RESULT_FOLDER = os.path.join('static')
|
19 |
+
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
20 |
+
os.makedirs(RESULT_FOLDER, exist_ok=True)
|
21 |
+
|
22 |
+
def generate_gemini_output(object_type, image_data_url):
|
23 |
+
"""
|
24 |
+
Generate output from Gemini by removing the object specified.
|
25 |
+
The image_data_url is expected to be a base64 data URL.
|
26 |
+
"""
|
27 |
model = "gemini-2.0-flash-exp-image-generation"
|
|
|
|
|
28 |
files = []
|
29 |
+
|
30 |
+
# Decode the image data from the data URL
|
31 |
+
if image_data_url:
|
32 |
+
try:
|
33 |
+
header, encoded = image_data_url.split(',', 1)
|
34 |
+
except ValueError:
|
35 |
+
raise ValueError("Invalid image data")
|
36 |
+
binary_data = base64.b64decode(encoded)
|
37 |
+
# Determine file extension from header
|
38 |
+
if "png" in header.lower():
|
39 |
+
ext = ".png"
|
40 |
+
else:
|
41 |
+
ext = ".jpg"
|
42 |
+
temp_filename = secure_filename("temp_image" + ext)
|
43 |
+
temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename)
|
44 |
+
with open(temp_filepath, "wb") as f:
|
45 |
+
f.write(binary_data)
|
46 |
+
# Upload file to Gemini
|
47 |
+
uploaded_file = client.files.upload(file=temp_filepath)
|
48 |
files.append(uploaded_file)
|
49 |
|
50 |
+
# Prepare content parts for Gemini
|
51 |
parts = []
|
52 |
if files:
|
53 |
parts.append(types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type))
|
54 |
+
if object_type:
|
55 |
+
# Create Gemini magic prompt for object removal
|
56 |
+
magic_prompt = f"Remove {object_type} from the image"
|
|
|
57 |
parts.append(types.Part.from_text(text=magic_prompt))
|
58 |
|
59 |
contents = [types.Content(role="user", parts=parts)]
|
|
|
73 |
result_text = None
|
74 |
result_image = None
|
75 |
|
76 |
+
# Stream the output from Gemini API
|
77 |
for chunk in client.models.generate_content_stream(
|
78 |
model=model,
|
79 |
contents=contents,
|
|
|
81 |
):
|
82 |
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
|
83 |
continue
|
|
|
84 |
part = chunk.candidates[0].content.parts[0]
|
|
|
85 |
if part.inline_data:
|
|
|
86 |
file_extension = mimetypes.guess_extension(part.inline_data.mime_type) or ".png"
|
87 |
+
output_filename = secure_filename("generated_output" + file_extension)
|
88 |
+
result_image_path = os.path.join(RESULT_FOLDER, output_filename)
|
89 |
+
with open(result_image_path, "wb") as f:
|
90 |
+
f.write(part.inline_data.data)
|
91 |
+
result_image = result_image_path # Path relative to static folder
|
92 |
else:
|
93 |
result_text = part.text
|
94 |
|
95 |
return result_text, result_image
|
96 |
|
97 |
+
@app.route("/")
|
98 |
def index():
|
99 |
+
# Render the front-end HTML (which includes inline CSS and JS)
|
100 |
+
return render_template("index.html")
|
101 |
+
|
102 |
+
@app.route("/process", methods=["POST"])
|
103 |
+
def process():
|
104 |
+
try:
|
105 |
+
# Expecting JSON with keys "image" (base64 data URL) and "objectType"
|
106 |
+
data = request.get_json(force=True)
|
107 |
+
image_data = data.get("image")
|
108 |
+
object_type = data.get("objectType", "").strip()
|
109 |
+
if not image_data or not object_type:
|
110 |
+
return jsonify({"success": False, "message": "Missing image data or object type."}), 400
|
111 |
+
|
112 |
+
# Generate Gemini output
|
113 |
+
result_text, result_image = generate_gemini_output(object_type, image_data)
|
114 |
+
if not result_image:
|
115 |
+
return jsonify({"success": False, "message": "Failed to generate image."}), 500
|
116 |
+
|
117 |
+
# Create a URL to serve the image from the static folder.
|
118 |
+
# Assuming your static folder is served at '/static'
|
119 |
+
image_url = f"/static/{os.path.basename(result_image)}"
|
120 |
+
|
121 |
+
return jsonify({"success": True, "resultPath": image_url, "resultText": result_text})
|
122 |
+
except Exception as e:
|
123 |
+
return jsonify({"success": False, "message": f"Error: {str(e)}"}), 500
|
124 |
|
125 |
if __name__ == "__main__":
|
126 |
+
# Run the app; debug can be set to False in production.
|
127 |
+
app.run(host="0.0.0.0", port=int(os.getenv("PORT", 5000)), debug=True)
|