Athspi commited on
Commit
605bf7b
·
verified ·
1 Parent(s): c90c576

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -55
app.py CHANGED
@@ -1,5 +1,5 @@
1
- import os
2
  import base64
 
3
  import mimetypes
4
  from flask import Flask, render_template, request, jsonify
5
  from werkzeug.utils import secure_filename
@@ -9,7 +9,8 @@ from google.genai import types
9
  # Initialize Flask app
10
  app = Flask(__name__)
11
 
12
- # Read the Gemini API key from environment variables (set in Hugging Face Spaces)
 
13
  GEMINI_API_KEY = os.environ["GEMINI_API_KEY"]
14
  client = genai.Client(api_key=GEMINI_API_KEY)
15
 
@@ -19,76 +20,119 @@ RESULT_FOLDER = os.path.join('static')
19
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
20
  os.makedirs(RESULT_FOLDER, exist_ok=True)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def generate_gemini_output(object_type, image_data_url):
23
  """
24
- Generate output from Gemini by removing the specified object.
25
  Expects the image_data_url to be a base64 data URL.
26
  """
27
- model = "gemini-2.0-flash-lite" # Use the lite model for text-based responses
 
 
 
 
 
28
  files = []
29
 
30
- # Decode the image data from the data URL
31
  if image_data_url:
32
  try:
33
  header, encoded = image_data_url.split(',', 1)
34
- except ValueError:
35
- raise ValueError("Invalid image data")
36
- binary_data = base64.b64decode(encoded)
37
- # Determine file extension from header
38
- ext = ".png" if "png" in header.lower() else ".jpg"
39
- temp_filename = secure_filename("temp_image" + ext)
40
- temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename)
41
- with open(temp_filepath, "wb") as f:
42
- f.write(binary_data)
43
- # Upload file to Gemini
44
- uploaded_file = client.files.upload(file=temp_filepath)
45
- files.append(uploaded_file)
46
-
47
- # Prepare content parts for Gemini
 
 
 
 
 
 
48
  parts = []
49
  if files:
50
  parts.append(types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type))
51
  if object_type:
52
- # Gemini magic prompt: instruct the model to remove the specified object
53
  magic_prompt = f"Remove {object_type} from the image"
54
  parts.append(types.Part.from_text(text=magic_prompt))
55
 
56
  contents = [types.Content(role="user", parts=parts)]
57
 
58
- generate_content_config = types.GenerateContentConfig(
59
  temperature=1,
60
  top_p=0.95,
61
  top_k=40,
62
  max_output_tokens=8192,
63
- response_mime_type="text/plain",
64
- system_instruction=[
65
- types.Part.from_text(text="""Your AI finds user requests about removing objects from images.
66
- If the user asks to remove a person or animal, respond with 'No'."""),
67
  ],
68
  )
69
 
70
  result_text = None
 
71
 
72
- # Stream output from Gemini API
73
- for chunk in client.models.generate_content_stream(
74
- model=model,
75
- contents=contents,
76
- config=generate_content_config,
77
- ):
78
- if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
79
- continue
80
- part = chunk.candidates[0].content.parts[0]
81
- if part.text:
82
- result_text = part.text
83
-
84
- # If the response is "No", switch to the image generation model
85
- if result_text and "no" in result_text.lower():
86
- model = "gemini-2.0-flash-exp-image-generation"
87
- generate_content_config.response_modalities = ["image", "text"]
88
  for chunk in client.models.generate_content_stream(
89
- model=model,
90
  contents=contents,
91
- config=generate_content_config,
92
  ):
93
  if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
94
  continue
@@ -99,38 +143,42 @@ If the user asks to remove a person or animal, respond with 'No'."""),
99
  result_image_path = os.path.join(RESULT_FOLDER, output_filename)
100
  with open(result_image_path, "wb") as f:
101
  f.write(part.inline_data.data)
102
- result_image = result_image_path # Path relative to static folder
103
- return result_text, result_image
 
 
 
 
 
104
 
105
- return result_text, None
106
 
107
  @app.route("/")
108
  def index():
109
- # Render the front-end HTML (which contains complete HTML/CSS/JS inline)
110
  return render_template("index.html")
111
 
112
  @app.route("/process", methods=["POST"])
113
  def process():
114
  try:
115
- # Expect JSON with keys "image" (base64 data URL) and "objectType"
116
  data = request.get_json(force=True)
117
  image_data = data.get("image")
118
  object_type = data.get("objectType", "").strip()
119
  if not image_data or not object_type:
120
  return jsonify({"success": False, "message": "Missing image data or object type."}), 400
121
 
122
- # Generate output using Gemini
123
  result_text, result_image = generate_gemini_output(object_type, image_data)
124
- if not result_image:
125
- return jsonify({"success": False, "message": result_text or "Failed to generate image."}), 500
126
 
127
- # Create a URL to serve the image from the static folder.
128
- image_url = f"/static/{os.path.basename(result_image)}"
129
 
130
- return jsonify({"success": True, "resultPath": image_url, "resultText": result_text})
 
 
 
 
131
  except Exception as e:
132
  return jsonify({"success": False, "message": f"Error: {str(e)}"}), 500
133
 
134
  if __name__ == "__main__":
135
- # Run the app on port 5000 or the port provided by the environment (for Hugging Face Spaces)
136
  app.run(host="0.0.0.0", port=7860)
 
 
1
  import base64
2
+ import os
3
  import mimetypes
4
  from flask import Flask, render_template, request, jsonify
5
  from werkzeug.utils import secure_filename
 
9
  # Initialize Flask app
10
  app = Flask(__name__)
11
 
12
+ # Set your Gemini API key via Hugging Face Spaces environment variables.
13
+ # Do not include a default fallback; the environment must supply GEMINI_API_KEY.
14
  GEMINI_API_KEY = os.environ["GEMINI_API_KEY"]
15
  client = genai.Client(api_key=GEMINI_API_KEY)
16
 
 
20
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
21
  os.makedirs(RESULT_FOLDER, exist_ok=True)
22
 
23
+ def analyze_object_removal_request(object_type):
24
+ """
25
+ Analyzes the object removal request using gemini-2.0-flash-lite to check if it's about people or animals.
26
+ Returns True if it's a person/animal removal, False otherwise.
27
+ """
28
+ model_text_check = "gemini-2.0-flash-lite"
29
+ contents_text_check = [
30
+ types.Content(
31
+ role="user",
32
+ parts=[
33
+ types.Part.from_text(text=f"Is '{object_type}' a person or animal? Answer yes or no."),
34
+ ],
35
+ ),
36
+ ]
37
+ generate_content_config_text_check = types.GenerateContentConfig(
38
+ temperature=0.1, # Lower temperature for more deterministic yes/no answers
39
+ top_p=0.95,
40
+ top_k=40,
41
+ max_output_tokens=256, # Limit output tokens for quick analysis
42
+ response_mime_type="text/plain",
43
+ system_instruction=[
44
+ types.Part.from_text(text="""You are a helpful AI assistant. Determine if the user's object removal request is about a person or animal. Respond with only 'yes' or 'no'."""),
45
+ ],
46
+ )
47
+
48
+ try:
49
+ response_text_check = client.models.generate_content(
50
+ model=model_text_check,
51
+ contents=contents_text_check,
52
+ config=generate_content_config_text_check,
53
+ )
54
+ if response_text_check.text:
55
+ lower_text = response_text_check.text.strip().lower()
56
+ if "yes" in lower_text:
57
+ return True # It's likely a person or animal
58
+ elif "no" in lower_text:
59
+ return False # It's likely not a person or animal
60
+ else:
61
+ # If the response is unclear, err on the side of caution (treat as person/animal)
62
+ print(f"Warning: Unclear text analysis response: '{response_text_check.text}'. Treating as potential person/animal removal.")
63
+ return True # Be conservative
64
+ else:
65
+ print("Warning: No text response from text analysis model.")
66
+ return True # Be conservative if no response
67
+ except Exception as e:
68
+ print(f"Error during text analysis: {e}")
69
+ return True # Be conservative on error
70
+
71
  def generate_gemini_output(object_type, image_data_url):
72
  """
73
+ Generate output from Gemini by removing the specified object, with initial text analysis.
74
  Expects the image_data_url to be a base64 data URL.
75
  """
76
+
77
+ # Analyze the object type using gemini-2.0-flash-lite
78
+ if analyze_object_removal_request(object_type):
79
+ return "Sorry, I can't assist with removing people or animals.", None # Text result, no image
80
+
81
+ model_image_gen = "gemini-2.0-flash-exp-image-generation" # Switch to image generation model if not person/animal
82
  files = []
83
 
84
+ # Decode the image data from the data URL (same as before)
85
  if image_data_url:
86
  try:
87
  header, encoded = image_data_url.split(',', 1)
88
+ binary_data = base64.b64decode(encoded)
89
+ mime_type = header.split(':')[1].split(';')[0]
90
+ ext = mimetypes.guess_extension(mime_type) or ".png"
91
+ if ext not in ['.jpg', '.jpeg', '.png']:
92
+ raise ValueError("Invalid image format. Only JPG, JPEG, and PNG are supported.")
93
+
94
+ temp_filename = secure_filename("temp_image" + ext)
95
+ temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename)
96
+ with open(temp_filepath, "wb") as f:
97
+ f.write(binary_data)
98
+ uploaded_file = client.files.upload(file=temp_filepath)
99
+ files.append(uploaded_file)
100
+ os.remove(temp_filepath)
101
+
102
+ except (ValueError, base64.binascii.Error) as e:
103
+ raise ValueError(f"Invalid image data: {str(e)}") from e
104
+ except Exception as e:
105
+ raise ValueError(f"Error processing image: {str(e)}") from e
106
+
107
+ # Prepare content parts for Gemini (same as before)
108
  parts = []
109
  if files:
110
  parts.append(types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type))
111
  if object_type:
 
112
  magic_prompt = f"Remove {object_type} from the image"
113
  parts.append(types.Part.from_text(text=magic_prompt))
114
 
115
  contents = [types.Content(role="user", parts=parts)]
116
 
117
+ generate_content_config_image_gen = types.GenerateContentConfig( # Config for image generation model
118
  temperature=1,
119
  top_p=0.95,
120
  top_k=40,
121
  max_output_tokens=8192,
122
+ response_modalities=["image", "text"],
123
+ safety_settings=[
124
+ types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF"),
 
125
  ],
126
  )
127
 
128
  result_text = None
129
+ result_image = None
130
 
131
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  for chunk in client.models.generate_content_stream(
133
+ model=model_image_gen, # Use image generation model here
134
  contents=contents,
135
+ config=generate_content_config_image_gen,
136
  ):
137
  if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
138
  continue
 
143
  result_image_path = os.path.join(RESULT_FOLDER, output_filename)
144
  with open(result_image_path, "wb") as f:
145
  f.write(part.inline_data.data)
146
+ result_image = result_image_path
147
+ else:
148
+ result_text = part.text
149
+ except genai.APIError as e:
150
+ raise RuntimeError(f"Gemini API Error: {str(e)}") from e
151
+ except Exception as e:
152
+ raise RuntimeError(f"An unexpected error occurred during Gemini processing: {str(e)}") from e
153
 
154
+ return result_text, result_image # May return text error or image path/None
155
 
156
  @app.route("/")
157
  def index():
 
158
  return render_template("index.html")
159
 
160
  @app.route("/process", methods=["POST"])
161
  def process():
162
  try:
 
163
  data = request.get_json(force=True)
164
  image_data = data.get("image")
165
  object_type = data.get("objectType", "").strip()
166
  if not image_data or not object_type:
167
  return jsonify({"success": False, "message": "Missing image data or object type."}), 400
168
 
169
+ # Generate output using Gemini (now with text analysis first)
170
  result_text, result_image = generate_gemini_output(object_type, image_data)
 
 
171
 
172
+ if result_text and not result_image: # Text result means error or text response
173
+ return jsonify({"success": False, "message": result_text}), 400 # Send back text error
174
 
175
+ if not result_image: # Still check for image failure if no text error
176
+ return jsonify({"success": False, "message": "Failed to generate image. The object may be too large or complex, or the image may not be suitable."}), 500
177
+
178
+ image_url = f"/static/{os.path.basename(result_image)}"
179
+ return jsonify({"success": True, "resultPath": image_url, "resultText": result_text}) # resultText might be None or text from image model
180
  except Exception as e:
181
  return jsonify({"success": False, "message": f"Error: {str(e)}"}), 500
182
 
183
  if __name__ == "__main__":
 
184
  app.run(host="0.0.0.0", port=7860)