Athspi commited on
Commit
a36d15c
·
verified ·
1 Parent(s): 452d71e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -34
app.py CHANGED
@@ -1,38 +1,59 @@
1
- from flask import Flask, render_template, request, send_file
2
  import os
 
3
  import mimetypes
 
 
4
  from google import genai
5
  from google.genai import types
6
- from io import BytesIO
7
 
 
8
  app = Flask(__name__)
9
 
10
- # Initialize Gemini client
11
-
12
  client = genai.Client(api_key=GEMINI_API_KEY)
13
 
14
- def save_binary_file(file_name, data):
15
- """Save binary data to a file."""
16
- with open(file_name, "wb") as f:
17
- f.write(data)
18
-
19
- def generate_gemini_output(user_input, image):
 
 
 
 
 
20
  model = "gemini-2.0-flash-exp-image-generation"
21
-
22
- # Upload the image to Gemini
23
  files = []
24
- if image:
25
- uploaded_file = client.files.upload(file=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  files.append(uploaded_file)
27
 
28
- # Prepare the input content
29
  parts = []
30
  if files:
31
  parts.append(types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type))
32
-
33
- # Incorporate Gemini magic: modify the prompt to instruct removal of the user-specified element.
34
- if user_input:
35
- magic_prompt = f"Remove {user_input} from the image"
36
  parts.append(types.Part.from_text(text=magic_prompt))
37
 
38
  contents = [types.Content(role="user", parts=parts)]
@@ -52,6 +73,7 @@ def generate_gemini_output(user_input, image):
52
  result_text = None
53
  result_image = None
54
 
 
55
  for chunk in client.models.generate_content_stream(
56
  model=model,
57
  contents=contents,
@@ -59,29 +81,47 @@ def generate_gemini_output(user_input, image):
59
  ):
60
  if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
61
  continue
62
-
63
  part = chunk.candidates[0].content.parts[0]
64
-
65
  if part.inline_data:
66
- file_name = "generated_output"
67
  file_extension = mimetypes.guess_extension(part.inline_data.mime_type) or ".png"
68
- file_path = os.path.join("static", f"{file_name}{file_extension}")
69
- save_binary_file(file_path, part.inline_data.data)
70
- result_image = file_path
 
 
71
  else:
72
  result_text = part.text
73
 
74
  return result_text, result_image
75
 
76
- @app.route("/", methods=["GET", "POST"])
77
  def index():
78
- result_text = None
79
- result_image = None
80
- if request.method == "POST":
81
- user_input = request.form.get("user_input")
82
- image = request.files.get("image_input")
83
- result_text, result_image = generate_gemini_output(user_input, image)
84
- return render_template("index.html", result_text=result_text, result_image=result_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  if __name__ == "__main__":
87
- app.run(host="0.0.0.0", port=7860)
 
 
 
1
  import os
2
+ import base64
3
  import mimetypes
4
+ from flask import Flask, render_template, request, jsonify
5
+ from werkzeug.utils import secure_filename
6
  from google import genai
7
  from google.genai import types
 
8
 
9
+ # Initialize Flask app
10
  app = Flask(__name__)
11
 
12
+ # Set your Gemini API key (or set GEMINI_API_KEY in your environment)
13
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
14
  client = genai.Client(api_key=GEMINI_API_KEY)
15
 
16
+ # Create necessary directories
17
+ UPLOAD_FOLDER = 'uploads'
18
+ RESULT_FOLDER = os.path.join('static')
19
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
20
+ os.makedirs(RESULT_FOLDER, exist_ok=True)
21
+
22
+ def generate_gemini_output(object_type, image_data_url):
23
+ """
24
+ Generate output from Gemini by removing the object specified.
25
+ The image_data_url is expected to be a base64 data URL.
26
+ """
27
  model = "gemini-2.0-flash-exp-image-generation"
 
 
28
  files = []
29
+
30
+ # Decode the image data from the data URL
31
+ if image_data_url:
32
+ try:
33
+ header, encoded = image_data_url.split(',', 1)
34
+ except ValueError:
35
+ raise ValueError("Invalid image data")
36
+ binary_data = base64.b64decode(encoded)
37
+ # Determine file extension from header
38
+ if "png" in header.lower():
39
+ ext = ".png"
40
+ else:
41
+ ext = ".jpg"
42
+ temp_filename = secure_filename("temp_image" + ext)
43
+ temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename)
44
+ with open(temp_filepath, "wb") as f:
45
+ f.write(binary_data)
46
+ # Upload file to Gemini
47
+ uploaded_file = client.files.upload(file=temp_filepath)
48
  files.append(uploaded_file)
49
 
50
+ # Prepare content parts for Gemini
51
  parts = []
52
  if files:
53
  parts.append(types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type))
54
+ if object_type:
55
+ # Create Gemini magic prompt for object removal
56
+ magic_prompt = f"Remove {object_type} from the image"
 
57
  parts.append(types.Part.from_text(text=magic_prompt))
58
 
59
  contents = [types.Content(role="user", parts=parts)]
 
73
  result_text = None
74
  result_image = None
75
 
76
+ # Stream the output from Gemini API
77
  for chunk in client.models.generate_content_stream(
78
  model=model,
79
  contents=contents,
 
81
  ):
82
  if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
83
  continue
 
84
  part = chunk.candidates[0].content.parts[0]
 
85
  if part.inline_data:
 
86
  file_extension = mimetypes.guess_extension(part.inline_data.mime_type) or ".png"
87
+ output_filename = secure_filename("generated_output" + file_extension)
88
+ result_image_path = os.path.join(RESULT_FOLDER, output_filename)
89
+ with open(result_image_path, "wb") as f:
90
+ f.write(part.inline_data.data)
91
+ result_image = result_image_path # Path relative to static folder
92
  else:
93
  result_text = part.text
94
 
95
  return result_text, result_image
96
 
97
+ @app.route("/")
98
  def index():
99
+ # Render the front-end HTML (which includes inline CSS and JS)
100
+ return render_template("index.html")
101
+
102
+ @app.route("/process", methods=["POST"])
103
+ def process():
104
+ try:
105
+ # Expecting JSON with keys "image" (base64 data URL) and "objectType"
106
+ data = request.get_json(force=True)
107
+ image_data = data.get("image")
108
+ object_type = data.get("objectType", "").strip()
109
+ if not image_data or not object_type:
110
+ return jsonify({"success": False, "message": "Missing image data or object type."}), 400
111
+
112
+ # Generate Gemini output
113
+ result_text, result_image = generate_gemini_output(object_type, image_data)
114
+ if not result_image:
115
+ return jsonify({"success": False, "message": "Failed to generate image."}), 500
116
+
117
+ # Create a URL to serve the image from the static folder.
118
+ # Assuming your static folder is served at '/static'
119
+ image_url = f"/static/{os.path.basename(result_image)}"
120
+
121
+ return jsonify({"success": True, "resultPath": image_url, "resultText": result_text})
122
+ except Exception as e:
123
+ return jsonify({"success": False, "message": f"Error: {str(e)}"}), 500
124
 
125
  if __name__ == "__main__":
126
+ # Run the app; debug can be set to False in production.
127
+ app.run(host="0.0.0.0", port=int(os.getenv("PORT", 5000)), debug=True)