Athspi commited on
Commit
5a42ff8
Β·
verified Β·
1 Parent(s): 2269b2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -106
app.py CHANGED
@@ -19,111 +19,84 @@ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
19
  os.makedirs(RESULT_FOLDER, exist_ok=True)
20
 
21
  def upload_image(image_data_url):
22
- """Handle base64 image upload and Gemini file upload"""
23
  try:
24
  header, encoded = image_data_url.split(',', 1)
25
- binary_data = base64.b64decode(encoded)
26
- ext = ".png" if "png" in header.lower() else ".jpg"
27
- temp_filename = secure_filename(f"temp_{os.urandom(8).hex()}{ext}")
28
- temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename)
29
-
30
- with open(temp_filepath, "wb") as f:
31
- f.write(binary_data)
32
-
33
- uploaded_file = client.files.upload(file=temp_filepath)
34
- os.remove(temp_filepath) # Clean up temporary file
35
- return uploaded_file
36
 
37
- except Exception as e:
38
- raise ValueError(f"Image processing error: {str(e)}")
 
 
 
 
 
 
 
39
 
40
  def is_prohibited_request(uploaded_file, object_type):
41
- """Check if request involves people/animals or their belongings"""
42
- model = "gemini-2.0-flash-lite"
43
- parts = [
44
- types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type),
45
- types.Part.from_text(text=f"Remove {object_type}")
 
 
 
46
  ]
47
 
48
- contents = [types.Content(role="user", parts=parts)]
 
 
 
 
 
49
 
50
- generate_content_config = types.GenerateContentConfig(
51
- system_instruction=[types.Part.from_text(text="""Analyze image and request to detect:
52
- 1. Direct removal of people/animals
53
- 2. Removal of items attached to/worn by people/animals
54
- 3. Removal of body parts or personal belongings
55
-
56
- Prohibited examples:
57
- - Person, dog, cat
58
- - Sunglasses on face, mask, hat
59
- - Phone in hand, watch on wrist
60
- - Eyes, hands, hair
61
-
62
- Allowed examples:
63
- - Background, car, tree
64
- - Sunglasses on table
65
- - Phone on desk
66
-
67
- Respond ONLY with 'Yes' or 'No'""")],
68
- temperature=0.0,
69
- max_output_tokens=1,
70
- )
71
 
72
- try:
73
- response = client.models.generate_content(
74
- model=model,
75
- contents=contents,
76
- config=generate_content_config
77
- )
78
- if response.candidates and response.candidates[0].content.parts:
79
- return response.candidates[0].content.parts[0].text.strip().lower() == "yes"
80
- return True # Default to safe mode if uncertain
81
- except Exception as e:
82
- print(f"Safety check failed: {str(e)}")
83
- return True # Block if check fails
84
 
85
- def generate_modified_image(uploaded_file, object_type):
86
- """Generate image with object removed using experimental model"""
87
  model = "gemini-2.0-flash-exp-image-generation"
88
  parts = [
89
  types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type),
90
- types.Part.from_text(text=f"Completely remove {object_type} from the image without leaving traces")
91
  ]
92
 
93
  contents = [types.Content(role="user", parts=parts)]
94
 
95
  generate_content_config = types.GenerateContentConfig(
96
- temperature=0.5,
97
- top_p=0.9,
98
- max_output_tokens=1024,
99
- response_modalities=["image"],
 
100
  safety_settings=[
101
- types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
102
- types.SafetySetting(category="HARM_CATEGORY_VIOLENCE", threshold="BLOCK_NONE")
103
- ]
104
  )
105
 
106
- try:
107
- for chunk in client.models.generate_content_stream(
108
- model=model,
109
- contents=contents,
110
- config=generate_content_config,
111
- ):
112
- if chunk.candidates and chunk.candidates[0].content.parts:
113
- part = chunk.candidates[0].content.parts[0]
114
- if part.inline_data:
115
- ext = mimetypes.guess_extension(part.inline_data.mime_type) or ".png"
116
- output_filename = secure_filename(f"result_{os.urandom(4).hex()}{ext}")
117
- output_path = os.path.join(RESULT_FOLDER, output_filename)
118
-
119
- with open(output_path, "wb") as f:
120
- f.write(part.inline_data.data)
121
-
122
- return output_path
123
- return None
124
- except Exception as e:
125
- print(f"Image generation failed: {str(e)}")
126
- return None
127
 
128
  @app.route("/")
129
  def index():
@@ -132,40 +105,36 @@ def index():
132
  @app.route("/process", methods=["POST"])
133
  def process():
134
  try:
135
- data = request.get_json()
136
- if not data or "image" not in data or "objectType" not in data:
137
- return jsonify({"success": False, "message": "Invalid request format"}), 400
138
-
139
- image_data = data["image"]
140
- object_type = data["objectType"].strip().lower()
141
 
142
- if not object_type:
143
- return jsonify({"success": False, "message": "Please specify an object to remove"}), 400
144
 
145
- # Process image upload
146
  uploaded_file = upload_image(image_data)
147
 
148
- # Safety check
149
  if is_prohibited_request(uploaded_file, object_type):
150
- return jsonify({
151
- "success": False,
152
- "message": "Cannot remove people, animals, or personal items"
153
- }), 403
 
 
 
154
 
155
- # Generate modified image
156
- result_path = generate_modified_image(uploaded_file, object_type)
157
- if not result_path:
158
  return jsonify({"success": False, "message": "Failed to generate image"}), 500
159
 
160
  return jsonify({
161
  "success": True,
162
- "resultUrl": f"/static/{os.path.basename(result_path)}"
163
  })
164
-
165
- except ValueError as e:
166
- return jsonify({"success": False, "message": str(e)}), 400
167
  except Exception as e:
168
- return jsonify({"success": False, "message": "Internal server error"}), 500
169
 
170
  if __name__ == "__main__":
171
  app.run(host="0.0.0.0", port=7860)
 
19
  os.makedirs(RESULT_FOLDER, exist_ok=True)
20
 
21
  def upload_image(image_data_url):
22
+ """Helper function to upload image to Gemini"""
23
  try:
24
  header, encoded = image_data_url.split(',', 1)
25
+ except ValueError:
26
+ raise ValueError("Invalid image data")
 
 
 
 
 
 
 
 
 
27
 
28
+ binary_data = base64.b64decode(encoded)
29
+ ext = ".png" if "png" in header.lower() else ".jpg"
30
+ temp_filename = secure_filename("temp_image" + ext)
31
+ temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename)
32
+
33
+ with open(temp_filepath, "wb") as f:
34
+ f.write(binary_data)
35
+
36
+ return client.files.upload(file=temp_filepath)
37
 
38
  def is_prohibited_request(uploaded_file, object_type):
39
+ """Check if request matches prohibited removal cases"""
40
+ object_type = object_type.lower()
41
+
42
+ # Prohibited cases
43
+ prohibited_requests = [
44
+ "remove sunglasses" in object_type and "table" not in object_type, # ❌ when worn
45
+ "remove phone" in object_type and "hand" in object_type, # ❌ when in hand
46
+ "remove eyes" in object_type # ❌ remove eyes
47
  ]
48
 
49
+ # Allowed cases
50
+ allowed_requests = [
51
+ "remove sunglasses" in object_type and "table" in object_type, # βœ… when on table
52
+ "remove car" in object_type, # βœ… remove car
53
+ "remove background" in object_type # βœ… remove background
54
+ ]
55
 
56
+ # Check for person/animal removal
57
+ person_animal_check = "remove person" in object_type or "remove animal" in object_type or \
58
+ "remove dog" in object_type or "remove cat" in object_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ return any(prohibited_requests) or person_animal_check
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ def generate_gemini_output(object_type, uploaded_file):
63
+ """Generate image using gemini-2.0-flash-exp-image-generation"""
64
  model = "gemini-2.0-flash-exp-image-generation"
65
  parts = [
66
  types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type),
67
+ types.Part.from_text(text=f"Remove {object_type} from the image")
68
  ]
69
 
70
  contents = [types.Content(role="user", parts=parts)]
71
 
72
  generate_content_config = types.GenerateContentConfig(
73
+ temperature=1,
74
+ top_p=0.95,
75
+ top_k=40,
76
+ max_output_tokens=8192,
77
+ response_modalities=["image", "text"],
78
  safety_settings=[
79
+ types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF"),
80
+ ],
 
81
  )
82
 
83
+ result_image = None
84
+ for chunk in client.models.generate_content_stream(
85
+ model=model,
86
+ contents=contents,
87
+ config=generate_content_config,
88
+ ):
89
+ if chunk.candidates and chunk.candidates[0].content.parts:
90
+ part = chunk.candidates[0].content.parts[0]
91
+ if part.inline_data:
92
+ file_extension = mimetypes.guess_extension(part.inline_data.mime_type) or ".png"
93
+ output_filename = secure_filename("generated_output" + file_extension)
94
+ result_image_path = os.path.join(RESULT_FOLDER, output_filename)
95
+ with open(result_image_path, "wb") as f:
96
+ f.write(part.inline_data.data)
97
+ result_image = result_image_path
98
+
99
+ return result_image
 
 
 
 
100
 
101
  @app.route("/")
102
  def index():
 
105
  @app.route("/process", methods=["POST"])
106
  def process():
107
  try:
108
+ data = request.get_json(force=True)
109
+ image_data = data.get("image")
110
+ object_type = data.get("objectType", "").strip().lower()
 
 
 
111
 
112
+ if not image_data or not object_type:
113
+ return jsonify({"success": False, "message": "Missing required data"}), 400
114
 
115
+ # Upload image once
116
  uploaded_file = upload_image(image_data)
117
 
118
+ # Check for prohibited requests
119
  if is_prohibited_request(uploaded_file, object_type):
120
+ error_message = "Sorry, I can't assist with this request."
121
+ if "person" in object_type or "animal" in object_type or "cat" in object_type or "dog" in object_type:
122
+ error_message = "Sorry, I can't assist with removing people or animals."
123
+ return jsonify({"success": False, "message": error_message}), 400
124
+
125
+ # Generate output if allowed
126
+ result_image = generate_gemini_output(object_type, uploaded_file)
127
 
128
+ if not result_image:
 
 
129
  return jsonify({"success": False, "message": "Failed to generate image"}), 500
130
 
131
  return jsonify({
132
  "success": True,
133
+ "resultPath": f"/static/{os.path.basename(result_image)}"
134
  })
135
+
 
 
136
  except Exception as e:
137
+ return jsonify({"success": False, "message": f"Error: {str(e)}"}), 500
138
 
139
  if __name__ == "__main__":
140
  app.run(host="0.0.0.0", port=7860)