Athspi commited on
Commit
9479bea
·
verified ·
1 Parent(s): a36d15c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -16
app.py CHANGED
@@ -9,8 +9,9 @@ from google.genai import types
9
  # Initialize Flask app
10
  app = Flask(__name__)
11
 
12
- # Set your Gemini API key (or set GEMINI_API_KEY in your environment)
13
- GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
 
14
  client = genai.Client(api_key=GEMINI_API_KEY)
15
 
16
  # Create necessary directories
@@ -21,8 +22,8 @@ os.makedirs(RESULT_FOLDER, exist_ok=True)
21
 
22
  def generate_gemini_output(object_type, image_data_url):
23
  """
24
- Generate output from Gemini by removing the object specified.
25
- The image_data_url is expected to be a base64 data URL.
26
  """
27
  model = "gemini-2.0-flash-exp-image-generation"
28
  files = []
@@ -35,10 +36,7 @@ def generate_gemini_output(object_type, image_data_url):
35
  raise ValueError("Invalid image data")
36
  binary_data = base64.b64decode(encoded)
37
  # Determine file extension from header
38
- if "png" in header.lower():
39
- ext = ".png"
40
- else:
41
- ext = ".jpg"
42
  temp_filename = secure_filename("temp_image" + ext)
43
  temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename)
44
  with open(temp_filepath, "wb") as f:
@@ -52,7 +50,7 @@ def generate_gemini_output(object_type, image_data_url):
52
  if files:
53
  parts.append(types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type))
54
  if object_type:
55
- # Create Gemini magic prompt for object removal
56
  magic_prompt = f"Remove {object_type} from the image"
57
  parts.append(types.Part.from_text(text=magic_prompt))
58
 
@@ -73,7 +71,7 @@ def generate_gemini_output(object_type, image_data_url):
73
  result_text = None
74
  result_image = None
75
 
76
- # Stream the output from Gemini API
77
  for chunk in client.models.generate_content_stream(
78
  model=model,
79
  contents=contents,
@@ -96,26 +94,25 @@ def generate_gemini_output(object_type, image_data_url):
96
 
97
  @app.route("/")
98
  def index():
99
- # Render the front-end HTML (which includes inline CSS and JS)
100
  return render_template("index.html")
101
 
102
  @app.route("/process", methods=["POST"])
103
  def process():
104
  try:
105
- # Expecting JSON with keys "image" (base64 data URL) and "objectType"
106
  data = request.get_json(force=True)
107
  image_data = data.get("image")
108
  object_type = data.get("objectType", "").strip()
109
  if not image_data or not object_type:
110
  return jsonify({"success": False, "message": "Missing image data or object type."}), 400
111
 
112
- # Generate Gemini output
113
  result_text, result_image = generate_gemini_output(object_type, image_data)
114
  if not result_image:
115
  return jsonify({"success": False, "message": "Failed to generate image."}), 500
116
 
117
  # Create a URL to serve the image from the static folder.
118
- # Assuming your static folder is served at '/static'
119
  image_url = f"/static/{os.path.basename(result_image)}"
120
 
121
  return jsonify({"success": True, "resultPath": image_url, "resultText": result_text})
@@ -123,5 +120,5 @@ def process():
123
  return jsonify({"success": False, "message": f"Error: {str(e)}"}), 500
124
 
125
  if __name__ == "__main__":
126
- # Run the app; debug can be set to False in production.
127
- app.run(host="0.0.0.0", port=int(os.getenv("PORT", 5000)), debug=True)
 
9
  # Initialize Flask app
10
  app = Flask(__name__)
11
 
12
+ # Set your Gemini API key via Hugging Face Spaces environment variables.
13
+ # Do not include a default fallback; the environment must supply GEMINI_API_KEY.
14
+ GEMINI_API_KEY = os.environ["GEMINI_API_KEY"]
15
  client = genai.Client(api_key=GEMINI_API_KEY)
16
 
17
  # Create necessary directories
 
22
 
23
  def generate_gemini_output(object_type, image_data_url):
24
  """
25
+ Generate output from Gemini by removing the specified object.
26
+ Expects the image_data_url to be a base64 data URL.
27
  """
28
  model = "gemini-2.0-flash-exp-image-generation"
29
  files = []
 
36
  raise ValueError("Invalid image data")
37
  binary_data = base64.b64decode(encoded)
38
  # Determine file extension from header
39
+ ext = ".png" if "png" in header.lower() else ".jpg"
 
 
 
40
  temp_filename = secure_filename("temp_image" + ext)
41
  temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename)
42
  with open(temp_filepath, "wb") as f:
 
50
  if files:
51
  parts.append(types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type))
52
  if object_type:
53
+ # Gemini magic prompt: instruct the model to remove the specified object
54
  magic_prompt = f"Remove {object_type} from the image"
55
  parts.append(types.Part.from_text(text=magic_prompt))
56
 
 
71
  result_text = None
72
  result_image = None
73
 
74
+ # Stream output from Gemini API
75
  for chunk in client.models.generate_content_stream(
76
  model=model,
77
  contents=contents,
 
94
 
95
  @app.route("/")
96
  def index():
97
+ # Render the front-end HTML (which contains complete HTML/CSS/JS inline)
98
  return render_template("index.html")
99
 
100
  @app.route("/process", methods=["POST"])
101
  def process():
102
  try:
103
+ # Expect JSON with keys "image" (base64 data URL) and "objectType"
104
  data = request.get_json(force=True)
105
  image_data = data.get("image")
106
  object_type = data.get("objectType", "").strip()
107
  if not image_data or not object_type:
108
  return jsonify({"success": False, "message": "Missing image data or object type."}), 400
109
 
110
+ # Generate output using Gemini
111
  result_text, result_image = generate_gemini_output(object_type, image_data)
112
  if not result_image:
113
  return jsonify({"success": False, "message": "Failed to generate image."}), 500
114
 
115
  # Create a URL to serve the image from the static folder.
 
116
  image_url = f"/static/{os.path.basename(result_image)}"
117
 
118
  return jsonify({"success": True, "resultPath": image_url, "resultText": result_text})
 
120
  return jsonify({"success": False, "message": f"Error: {str(e)}"}), 500
121
 
122
  if __name__ == "__main__":
123
+ # Run the app on port 5000 or the port provided by the environment (for Hugging Face Spaces)
124
+ app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 5000)))