Gttg / app.py
Athspi's picture
Update app.py
d529638 verified
raw
history blame
3.31 kB
import os
import io
import base64
import tempfile
from flask import Flask, render_template, request, jsonify
import google.generativeai as genai
from google.generativeai import types
from PIL import Image
# Configure Gemini API key using an environment variable
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
# Initialize Flask app
app = Flask(__name__)
def save_image(image_data):
"""Save the image from a base64 string to a temporary file and return its path."""
# Expected format: "data:image/png;base64,...."
header, encoded = image_data.split(',', 1)
image_bytes = base64.b64decode(encoded)
image = Image.open(io.BytesIO(image_bytes))
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
image.save(temp_file, "PNG")
return temp_file.name
def remove_object_from_image(image_path, object_type):
"""Use Gemini API to remove a specified object from the image."""
# Upload the image file to Gemini using the module-level function.
uploaded_file = genai.files.upload(file=image_path)
# Prepare the input parts:
# 1. The image file.
parts = [types.Part.from_uri(file_uri=uploaded_file.uri, mime_type="image/png")]
# 2. The Gemini magic text instructing removal.
if object_type:
parts.append(types.Part.from_text(text=f"Remove {object_type} from the image"))
contents = [types.Content(role="user", parts=parts)]
generate_content_config = types.GenerateContentConfig(
temperature=1,
top_p=0.95,
top_k=40,
max_output_tokens=8192,
response_modalities=["image", "text"],
safety_settings=[types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF")],
response_mime_type="text/plain",
)
result_image = None
# Use the module-level function to generate content.
for chunk in genai.models.generate_content_stream(
model="gemini-2.0-flash-exp-image-generation",
contents=contents,
config=generate_content_config,
):
if (chunk.candidates and chunk.candidates[0].content and
chunk.candidates[0].content.parts):
part = chunk.candidates[0].content.parts[0]
if part.inline_data:
file_name = "generated_output.png"
with open(file_name, "wb") as f:
f.write(part.inline_data.data)
result_image = file_name
return result_image
@app.route('/')
def index():
"""Render the main page."""
return render_template('index.html')
@app.route('/process', methods=['POST'])
def process_image():
"""Handle image processing via POST request."""
data = request.get_json()
image_data = data['image']
object_type = data['objectType']
# Save the uploaded image locally.
image_path = save_image(image_data)
try:
# Use Gemini to remove the object from the image.
result_image = remove_object_from_image(image_path, object_type)
return jsonify({'success': True, 'resultPath': result_image})
except Exception as e:
return jsonify({'success': False, 'message': str(e)})
if __name__ == '__main__':
# For local testing; in production, your hosting provider will manage the server.
app.run(host="0.0.0.0", port=7860)