Spaces:
Sleeping
Sleeping
import os | |
import base64 | |
import mimetypes | |
from flask import Flask, render_template, request, jsonify | |
from werkzeug.utils import secure_filename | |
from google import genai | |
from google.genai import types | |
app = Flask(__name__) | |
# Initialize Gemini client | |
GEMINI_API_KEY = os.environ["GEMINI_API_KEY"] | |
client = genai.Client(api_key=GEMINI_API_KEY) | |
# Configure upload folders | |
UPLOAD_FOLDER = 'uploads' | |
RESULT_FOLDER = os.path.join('static') | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
os.makedirs(RESULT_FOLDER, exist_ok=True) | |
def upload_image(image_data_url): | |
"""Handle base64 image upload and Gemini file upload""" | |
try: | |
header, encoded = image_data_url.split(',', 1) | |
binary_data = base64.b64decode(encoded) | |
ext = ".png" if "png" in header.lower() else ".jpg" | |
temp_filename = secure_filename(f"temp_{os.urandom(8).hex()}{ext}") | |
temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename) | |
with open(temp_filepath, "wb") as f: | |
f.write(binary_data) | |
uploaded_file = client.files.upload(file=temp_filepath) | |
os.remove(temp_filepath) # Clean up temporary file | |
return uploaded_file | |
except Exception as e: | |
raise ValueError(f"Image processing error: {str(e)}") | |
def is_prohibited_request(uploaded_file, object_type): | |
"""Check if request involves people/animals or their belongings""" | |
model = "gemini-2.0-flash-lite" | |
parts = [ | |
types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type), | |
types.Part.from_text(text=f"Remove {object_type}") | |
] | |
contents = [types.Content(role="user", parts=parts)] | |
generate_content_config = types.GenerateContentConfig( | |
system_instruction=[types.Part.from_text(text="""Analyze image and request to detect: | |
1. Direct removal of people/animals | |
2. Removal of items attached to/worn by people/animals | |
3. Removal of body parts or personal belongings | |
Prohibited examples: | |
- Person, dog, cat | |
- Sunglasses on face, mask, hat | |
- Phone in hand, watch on wrist | |
- Eyes, hands, hair | |
Allowed examples: | |
- Background, car, tree | |
- Sunglasses on table | |
- Phone on desk | |
Respond ONLY with 'Yes' or 'No'""")], | |
temperature=0.0, | |
max_output_tokens=1, | |
) | |
try: | |
response = client.models.generate_content( | |
model=model, | |
contents=contents, | |
config=generate_content_config | |
) | |
if response.candidates and response.candidates[0].content.parts: | |
return response.candidates[0].content.parts[0].text.strip().lower() == "yes" | |
return True # Default to safe mode if uncertain | |
except Exception as e: | |
print(f"Safety check failed: {str(e)}") | |
return True # Block if check fails | |
def generate_modified_image(uploaded_file, object_type): | |
"""Generate image with object removed using experimental model""" | |
model = "gemini-2.0-flash-exp-image-generation" | |
parts = [ | |
types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type), | |
types.Part.from_text(text=f"Completely remove {object_type} from the image without leaving traces") | |
] | |
contents = [types.Content(role="user", parts=parts)] | |
generate_content_config = types.GenerateContentConfig( | |
temperature=0.5, | |
top_p=0.9, | |
max_output_tokens=1024, | |
response_modalities=["image"], | |
safety_settings=[ | |
types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"), | |
types.SafetySetting(category="HARM_CATEGORY_VIOLENCE", threshold="BLOCK_NONE") | |
] | |
) | |
try: | |
for chunk in client.models.generate_content_stream( | |
model=model, | |
contents=contents, | |
config=generate_content_config, | |
): | |
if chunk.candidates and chunk.candidates[0].content.parts: | |
part = chunk.candidates[0].content.parts[0] | |
if part.inline_data: | |
ext = mimetypes.guess_extension(part.inline_data.mime_type) or ".png" | |
output_filename = secure_filename(f"result_{os.urandom(4).hex()}{ext}") | |
output_path = os.path.join(RESULT_FOLDER, output_filename) | |
with open(output_path, "wb") as f: | |
f.write(part.inline_data.data) | |
return output_path | |
return None | |
except Exception as e: | |
print(f"Image generation failed: {str(e)}") | |
return None | |
def index(): | |
return render_template("index.html") | |
def process(): | |
try: | |
data = request.get_json() | |
if not data or "image" not in data or "objectType" not in data: | |
return jsonify({"success": False, "message": "Invalid request format"}), 400 | |
image_data = data["image"] | |
object_type = data["objectType"].strip().lower() | |
if not object_type: | |
return jsonify({"success": False, "message": "Please specify an object to remove"}), 400 | |
# Process image upload | |
uploaded_file = upload_image(image_data) | |
# Safety check | |
if is_prohibited_request(uploaded_file, object_type): | |
return jsonify({ | |
"success": False, | |
"message": "Cannot remove people, animals, or personal items" | |
}), 403 | |
# Generate modified image | |
result_path = generate_modified_image(uploaded_file, object_type) | |
if not result_path: | |
return jsonify({"success": False, "message": "Failed to generate image"}), 500 | |
return jsonify({ | |
"success": True, | |
"resultUrl": f"/static/{os.path.basename(result_path)}" | |
}) | |
except ValueError as e: | |
return jsonify({"success": False, "message": str(e)}), 400 | |
except Exception as e: | |
return jsonify({"success": False, "message": "Internal server error"}), 500 | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860,) |