Spaces:
Sleeping
Sleeping
File size: 6,154 Bytes
605bf7b b0a339e 8e6ca2b a36d15c 8e6ca2b e5c238d b0a339e 9479bea 8e6ca2b b0a339e a36d15c b0a339e 19af451 b0a339e 19af451 b0a339e 19af451 b0a339e 19af451 b0a339e 605bf7b b0a339e 19af451 b0a339e 605bf7b b0a339e 605bf7b b0a339e 605bf7b b0a339e 19af451 605bf7b 19af451 b0a339e 19af451 b0a339e 19af451 b0a339e 68780eb b0a339e 19af451 605bf7b 19af451 68780eb b0a339e 19af451 8e6ca2b a36d15c 68780eb a36d15c 19af451 b0a339e 19af451 b0a339e 19af451 b0a339e 19af451 b0a339e 19af451 b0a339e 19af451 b0a339e 19af451 b0a339e 19af451 a36d15c 19af451 8e6ca2b 19af451 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import os
import base64
import mimetypes
from flask import Flask, render_template, request, jsonify
from werkzeug.utils import secure_filename
from google import genai
from google.genai import types
app = Flask(__name__)
# Initialize Gemini client
GEMINI_API_KEY = os.environ["GEMINI_API_KEY"]
client = genai.Client(api_key=GEMINI_API_KEY)
# Configure upload folders
UPLOAD_FOLDER = 'uploads'
RESULT_FOLDER = os.path.join('static')
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULT_FOLDER, exist_ok=True)
def upload_image(image_data_url):
"""Handle base64 image upload and Gemini file upload"""
try:
header, encoded = image_data_url.split(',', 1)
binary_data = base64.b64decode(encoded)
ext = ".png" if "png" in header.lower() else ".jpg"
temp_filename = secure_filename(f"temp_{os.urandom(8).hex()}{ext}")
temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename)
with open(temp_filepath, "wb") as f:
f.write(binary_data)
uploaded_file = client.files.upload(file=temp_filepath)
os.remove(temp_filepath) # Clean up temporary file
return uploaded_file
except Exception as e:
raise ValueError(f"Image processing error: {str(e)}")
def is_prohibited_request(uploaded_file, object_type):
"""Check if request involves people/animals or their belongings"""
model = "gemini-2.0-flash-lite"
parts = [
types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type),
types.Part.from_text(text=f"Remove {object_type}")
]
contents = [types.Content(role="user", parts=parts)]
generate_content_config = types.GenerateContentConfig(
system_instruction=[types.Part.from_text(text="""Analyze image and request to detect:
1. Direct removal of people/animals
2. Removal of items attached to/worn by people/animals
3. Removal of body parts or personal belongings
Prohibited examples:
- Person, dog, cat
- Sunglasses on face, mask, hat
- Phone in hand, watch on wrist
- Eyes, hands, hair
Allowed examples:
- Background, car, tree
- Sunglasses on table
- Phone on desk
Respond ONLY with 'Yes' or 'No'""")],
temperature=0.0,
max_output_tokens=1,
)
try:
response = client.models.generate_content(
model=model,
contents=contents,
config=generate_content_config
)
if response.candidates and response.candidates[0].content.parts:
return response.candidates[0].content.parts[0].text.strip().lower() == "yes"
return True # Default to safe mode if uncertain
except Exception as e:
print(f"Safety check failed: {str(e)}")
return True # Block if check fails
def generate_modified_image(uploaded_file, object_type):
"""Generate image with object removed using experimental model"""
model = "gemini-2.0-flash-exp-image-generation"
parts = [
types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type),
types.Part.from_text(text=f"Completely remove {object_type} from the image without leaving traces")
]
contents = [types.Content(role="user", parts=parts)]
generate_content_config = types.GenerateContentConfig(
temperature=0.5,
top_p=0.9,
max_output_tokens=1024,
response_modalities=["image"],
safety_settings=[
types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
types.SafetySetting(category="HARM_CATEGORY_VIOLENCE", threshold="BLOCK_NONE")
]
)
try:
for chunk in client.models.generate_content_stream(
model=model,
contents=contents,
config=generate_content_config,
):
if chunk.candidates and chunk.candidates[0].content.parts:
part = chunk.candidates[0].content.parts[0]
if part.inline_data:
ext = mimetypes.guess_extension(part.inline_data.mime_type) or ".png"
output_filename = secure_filename(f"result_{os.urandom(4).hex()}{ext}")
output_path = os.path.join(RESULT_FOLDER, output_filename)
with open(output_path, "wb") as f:
f.write(part.inline_data.data)
return output_path
return None
except Exception as e:
print(f"Image generation failed: {str(e)}")
return None
@app.route("/")
def index():
return render_template("index.html")
@app.route("/process", methods=["POST"])
def process():
try:
data = request.get_json()
if not data or "image" not in data or "objectType" not in data:
return jsonify({"success": False, "message": "Invalid request format"}), 400
image_data = data["image"]
object_type = data["objectType"].strip().lower()
if not object_type:
return jsonify({"success": False, "message": "Please specify an object to remove"}), 400
# Process image upload
uploaded_file = upload_image(image_data)
# Safety check
if is_prohibited_request(uploaded_file, object_type):
return jsonify({
"success": False,
"message": "Cannot remove people, animals, or personal items"
}), 403
# Generate modified image
result_path = generate_modified_image(uploaded_file, object_type)
if not result_path:
return jsonify({"success": False, "message": "Failed to generate image"}), 500
return jsonify({
"success": True,
"resultUrl": f"/static/{os.path.basename(result_path)}"
})
except ValueError as e:
return jsonify({"success": False, "message": str(e)}), 400
except Exception as e:
return jsonify({"success": False, "message": "Internal server error"}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860,) |