Athspi commited on
Commit
7b62da2
Β·
verified Β·
1 Parent(s): 5a42ff8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -42
app.py CHANGED
@@ -35,51 +35,146 @@ def upload_image(image_data_url):
35
 
36
  return client.files.upload(file=temp_filepath)
37
 
38
- def is_prohibited_request(uploaded_file, object_type):
39
- """Check if request matches prohibited removal cases"""
40
- object_type = object_type.lower()
41
-
42
- # Prohibited cases
43
- prohibited_requests = [
44
- "remove sunglasses" in object_type and "table" not in object_type, # ❌ when worn
45
- "remove phone" in object_type and "hand" in object_type, # ❌ when in hand
46
- "remove eyes" in object_type # ❌ remove eyes
 
 
 
 
47
  ]
48
-
49
- # Allowed cases
50
- allowed_requests = [
51
- "remove sunglasses" in object_type and "table" in object_type, # βœ… when on table
52
- "remove car" in object_type, # βœ… remove car
53
- "remove background" in object_type # βœ… remove background
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  ]
55
-
56
- # Check for person/animal removal
57
- person_animal_check = "remove person" in object_type or "remove animal" in object_type or \
58
- "remove dog" in object_type or "remove cat" in object_type
59
-
60
- return any(prohibited_requests) or person_animal_check
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  def generate_gemini_output(object_type, uploaded_file):
63
- """Generate image using gemini-2.0-flash-exp-image-generation"""
64
  model = "gemini-2.0-flash-exp-image-generation"
65
  parts = [
66
  types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type),
67
  types.Part.from_text(text=f"Remove {object_type} from the image")
68
  ]
69
-
70
  contents = [types.Content(role="user", parts=parts)]
71
-
72
  generate_content_config = types.GenerateContentConfig(
73
  temperature=1,
74
  top_p=0.95,
75
  top_k=40,
76
  max_output_tokens=8192,
77
  response_modalities=["image", "text"],
78
- safety_settings=[
79
- types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF"),
80
- ],
81
  )
82
-
83
  result_image = None
84
  for chunk in client.models.generate_content_stream(
85
  model=model,
@@ -95,7 +190,6 @@ def generate_gemini_output(object_type, uploaded_file):
95
  with open(result_image_path, "wb") as f:
96
  f.write(part.inline_data.data)
97
  result_image = result_image_path
98
-
99
  return result_image
100
 
101
  @app.route("/")
@@ -107,27 +201,49 @@ def process():
107
  try:
108
  data = request.get_json(force=True)
109
  image_data = data.get("image")
110
- object_type = data.get("objectType", "").strip().lower()
111
 
112
  if not image_data or not object_type:
113
  return jsonify({"success": False, "message": "Missing required data"}), 400
114
 
115
- # Upload image once
116
  uploaded_file = upload_image(image_data)
117
-
118
- # Check for prohibited requests
119
- if is_prohibited_request(uploaded_file, object_type):
120
- error_message = "Sorry, I can't assist with this request."
121
- if "person" in object_type or "animal" in object_type or "cat" in object_type or "dog" in object_type:
122
- error_message = "Sorry, I can't assist with removing people or animals."
123
- return jsonify({"success": False, "message": error_message}), 400
124
-
125
- # Generate output if allowed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  result_image = generate_gemini_output(object_type, uploaded_file)
127
-
128
  if not result_image:
129
  return jsonify({"success": False, "message": "Failed to generate image"}), 500
130
-
131
  return jsonify({
132
  "success": True,
133
  "resultPath": f"/static/{os.path.basename(result_image)}"
 
35
 
36
  return client.files.upload(file=temp_filepath)
37
 
38
+ def normalize_object_type(object_type):
39
+ """Normalize object type by removing action verbs"""
40
+ action_verbs = {'remove', 'delete', 'erase', 'eliminate'}
41
+ words = object_type.lower().split()
42
+ filtered_words = [word for word in words if word not in action_verbs]
43
+ return ' '.join(filtered_words) if filtered_words else object_type.lower()
44
+
45
+ def check_if_person_animal(uploaded_file, object_type):
46
+ """Check if the object to remove is a person or animal"""
47
+ model = "gemini-2.0-flash-lite"
48
+ parts = [
49
+ types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type),
50
+ types.Part.from_text(text=f"Remove {object_type}")
51
  ]
52
+ contents = [types.Content(role="user", parts=parts)]
53
+ generate_content_config = types.GenerateContentConfig(
54
+ system_instruction=[
55
+ types.Part.from_text(text="""Determine if the user wants to remove a person or animal.
56
+ Respond ONLY with 'Yes' or 'No'. Examples:
57
+ - Remove person β†’ Yes
58
+ - Remove dog β†’ Yes
59
+ - Remove sunglasses β†’ No""")
60
+ ],
61
+ temperature=0.0,
62
+ max_output_tokens=1,
63
+ )
64
+ try:
65
+ response = client.models.generate_content(
66
+ model=model,
67
+ contents=contents,
68
+ config=generate_content_config
69
+ )
70
+ if response.candidates and response.candidates[0].content.parts:
71
+ return response.candidates[0].content.parts[0].text.strip().lower() == "yes"
72
+ return False
73
+ except Exception as e:
74
+ print(f"Error checking person/animal: {str(e)}")
75
+ return False
76
+
77
+ def check_other_entities(uploaded_file, object_type):
78
+ """Check if image contains other people/animals"""
79
+ model = "gemini-2.0-flash-lite"
80
+ parts = [
81
+ types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type),
82
+ types.Part.from_text(text=f"Remove {object_type}")
83
  ]
84
+ contents = [types.Content(role="user", parts=parts)]
85
+ generate_content_config = types.GenerateContentConfig(
86
+ system_instruction=[
87
+ types.Part.from_text(text=f"""Analyze this image. Are there any other people or animals
88
+ besides the {object_type}? Respond ONLY with 'Yes' or 'No'.""")
89
+ ],
90
+ temperature=0.0,
91
+ max_output_tokens=1,
92
+ )
93
+ try:
94
+ response = client.models.generate_content(
95
+ model=model,
96
+ contents=contents,
97
+ config=generate_content_config
98
+ )
99
+ if response.candidates and response.candidates[0].content.parts:
100
+ return response.candidates[0].content.parts[0].text.strip().lower() == "yes"
101
+ return False
102
+ except Exception as e:
103
+ print(f"Error checking other entities: {str(e)}")
104
+ return False
105
+
106
+ def check_sunglasses_state(uploaded_file):
107
+ """Check if sunglasses are being worn"""
108
+ model = "gemini-2.0-flash-lite"
109
+ parts = [
110
+ types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type),
111
+ types.Part.from_text(text="Are sunglasses being worn in this image?")
112
+ ]
113
+ contents = [types.Content(role="user", parts=parts)]
114
+ generate_content_config = types.GenerateContentConfig(
115
+ system_instruction=[
116
+ types.Part.from_text(text="Respond ONLY with 'Yes' or 'No'.")
117
+ ],
118
+ temperature=0.0,
119
+ max_output_tokens=1,
120
+ )
121
+ try:
122
+ response = client.models.generate_content(
123
+ model=model,
124
+ contents=contents,
125
+ config=generate_content_config
126
+ )
127
+ if response.candidates and response.candidates[0].content.parts:
128
+ return response.candidates[0].content.parts[0].text.strip().lower() == "yes"
129
+ return False
130
+ except Exception as e:
131
+ print(f"Error checking sunglasses: {str(e)}")
132
+ return False
133
+
134
+ def check_phone_state(uploaded_file):
135
+ """Check if phone is in hand"""
136
+ model = "gemini-2.0-flash-lite"
137
+ parts = [
138
+ types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type),
139
+ types.Part.from_text(text="Is a phone being held in hand?")
140
+ ]
141
+ contents = [types.Content(role="user", parts=parts)]
142
+ generate_content_config = types.GenerateContentConfig(
143
+ system_instruction=[
144
+ types.Part.from_text(text="Respond ONLY with 'Yes' or 'No'.")
145
+ ],
146
+ temperature=0.0,
147
+ max_output_tokens=1,
148
+ )
149
+ try:
150
+ response = client.models.generate_content(
151
+ model=model,
152
+ contents=contents,
153
+ config=generate_content_config
154
+ )
155
+ if response.candidates and response.candidates[0].content.parts:
156
+ return response.candidates[0].content.parts[0].text.strip().lower() == "yes"
157
+ return False
158
+ except Exception as e:
159
+ print(f"Error checking phone: {str(e)}")
160
+ return False
161
 
162
  def generate_gemini_output(object_type, uploaded_file):
163
+ """Generate image using Gemini"""
164
  model = "gemini-2.0-flash-exp-image-generation"
165
  parts = [
166
  types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type),
167
  types.Part.from_text(text=f"Remove {object_type} from the image")
168
  ]
 
169
  contents = [types.Content(role="user", parts=parts)]
 
170
  generate_content_config = types.GenerateContentConfig(
171
  temperature=1,
172
  top_p=0.95,
173
  top_k=40,
174
  max_output_tokens=8192,
175
  response_modalities=["image", "text"],
176
+ safety_settings=[types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF")],
 
 
177
  )
 
178
  result_image = None
179
  for chunk in client.models.generate_content_stream(
180
  model=model,
 
190
  with open(result_image_path, "wb") as f:
191
  f.write(part.inline_data.data)
192
  result_image = result_image_path
 
193
  return result_image
194
 
195
  @app.route("/")
 
201
  try:
202
  data = request.get_json(force=True)
203
  image_data = data.get("image")
204
+ object_type = data.get("objectType", "").strip()
205
 
206
  if not image_data or not object_type:
207
  return jsonify({"success": False, "message": "Missing required data"}), 400
208
 
 
209
  uploaded_file = upload_image(image_data)
210
+ normalized_object = normalize_object_type(object_type)
211
+
212
+ # Prohibited categories check
213
+ if normalized_object == 'eyes':
214
+ return jsonify({
215
+ "success": False,
216
+ "message": "Sorry, I can't assist with removing eyes."
217
+ }), 400
218
+
219
+ # State checks
220
+ if normalized_object == 'sunglasses':
221
+ if check_sunglasses_state(uploaded_file):
222
+ return jsonify({
223
+ "success": False,
224
+ "message": "Can't remove sunglasses while being worn."
225
+ }), 400
226
+
227
+ if normalized_object == 'phone':
228
+ if check_phone_state(uploaded_file):
229
+ return jsonify({
230
+ "success": False,
231
+ "message": "Can't remove phones while being held."
232
+ }), 400
233
+
234
+ # Person/animal checks
235
+ if check_if_person_animal(uploaded_file, normalized_object):
236
+ if check_other_entities(uploaded_file, normalized_object):
237
+ return jsonify({
238
+ "success": False,
239
+ "message": "Can't remove people/animals when others are present."
240
+ }), 400
241
+
242
+ # Generate output
243
  result_image = generate_gemini_output(object_type, uploaded_file)
 
244
  if not result_image:
245
  return jsonify({"success": False, "message": "Failed to generate image"}), 500
246
+
247
  return jsonify({
248
  "success": True,
249
  "resultPath": f"/static/{os.path.basename(result_image)}"