Spaces:
Sleeping
Sleeping
import os | |
import base64 | |
import mimetypes | |
from flask import Flask, render_template, request, jsonify | |
from werkzeug.utils import secure_filename | |
from google import genai | |
from google.genai import types | |
# Initialize Flask app | |
app = Flask(__name__) | |
# Read the Gemini API key from environment variables (set in Hugging Face Spaces) | |
GEMINI_API_KEY = os.environ["GEMINI_API_KEY"] | |
client = genai.Client(api_key=GEMINI_API_KEY) | |
# Create necessary directories | |
UPLOAD_FOLDER = 'uploads' | |
RESULT_FOLDER = os.path.join('static') | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
os.makedirs(RESULT_FOLDER, exist_ok=True) | |
def generate_gemini_output(object_type, image_data_url): | |
""" | |
Generate output from Gemini by removing the specified object. | |
Expects the image_data_url to be a base64 data URL. | |
""" | |
model = "gemini-2.0-flash-lite" # Use the lite model for text-based responses | |
files = [] | |
# Decode the image data from the data URL | |
if image_data_url: | |
try: | |
header, encoded = image_data_url.split(',', 1) | |
except ValueError: | |
raise ValueError("Invalid image data") | |
binary_data = base64.b64decode(encoded) | |
# Determine file extension from header | |
ext = ".png" if "png" in header.lower() else ".jpg" | |
temp_filename = secure_filename("temp_image" + ext) | |
temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename) | |
with open(temp_filepath, "wb") as f: | |
f.write(binary_data) | |
# Upload file to Gemini | |
uploaded_file = client.files.upload(file=temp_filepath) | |
files.append(uploaded_file) | |
# Prepare content parts for Gemini | |
parts = [] | |
if files: | |
parts.append(types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type)) | |
if object_type: | |
# Gemini magic prompt: instruct the model to remove the specified object | |
magic_prompt = f"Remove {object_type} from the image" | |
parts.append(types.Part.from_text(text=magic_prompt)) | |
contents = [types.Content(role="user", parts=parts)] | |
generate_content_config = types.GenerateContentConfig( | |
temperature=1, | |
top_p=0.95, | |
top_k=40, | |
max_output_tokens=8192, | |
response_mime_type="text/plain", | |
system_instruction=[ | |
types.Part.from_text(text="""Your AI finds user requests about removing objects from images. | |
If the user asks to remove a person or animal, respond with 'No'."""), | |
], | |
) | |
result_text = None | |
# Stream output from Gemini API | |
for chunk in client.models.generate_content_stream( | |
model=model, | |
contents=contents, | |
config=generate_content_config, | |
): | |
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts: | |
continue | |
part = chunk.candidates[0].content.parts[0] | |
if part.text: | |
result_text = part.text | |
# If the response is "No", switch to the image generation model | |
if result_text and "no" in result_text.lower(): | |
model = "gemini-2.0-flash-exp-image-generation" | |
generate_content_config.response_modalities = ["image", "text"] | |
for chunk in client.models.generate_content_stream( | |
model=model, | |
contents=contents, | |
config=generate_content_config, | |
): | |
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts: | |
continue | |
part = chunk.candidates[0].content.parts[0] | |
if part.inline_data: | |
file_extension = mimetypes.guess_extension(part.inline_data.mime_type) or ".png" | |
output_filename = secure_filename("generated_output" + file_extension) | |
result_image_path = os.path.join(RESULT_FOLDER, output_filename) | |
with open(result_image_path, "wb") as f: | |
f.write(part.inline_data.data) | |
result_image = result_image_path # Path relative to static folder | |
return result_text, result_image | |
return result_text, None | |
def index(): | |
# Render the front-end HTML (which contains complete HTML/CSS/JS inline) | |
return render_template("index.html") | |
def process(): | |
try: | |
# Expect JSON with keys "image" (base64 data URL) and "objectType" | |
data = request.get_json(force=True) | |
image_data = data.get("image") | |
object_type = data.get("objectType", "").strip() | |
if not image_data or not object_type: | |
return jsonify({"success": False, "message": "Missing image data or object type."}), 400 | |
# Generate output using Gemini | |
result_text, result_image = generate_gemini_output(object_type, image_data) | |
if not result_image: | |
return jsonify({"success": False, "message": result_text or "Failed to generate image."}), 500 | |
# Create a URL to serve the image from the static folder. | |
image_url = f"/static/{os.path.basename(result_image)}" | |
return jsonify({"success": True, "resultPath": image_url, "resultText": result_text}) | |
except Exception as e: | |
return jsonify({"success": False, "message": f"Error: {str(e)}"}), 500 | |
if __name__ == "__main__": | |
# Run the app on port 5000 or the port provided by the environment (for Hugging Face Spaces) | |
app.run(host="0.0.0.0", port=7860) |