Spaces:
Sleeping
Sleeping
import os | |
import io | |
import base64 | |
import tempfile | |
from flask import Flask, render_template, request, jsonify | |
import google.generativeai as genai | |
from google.generativeai import types | |
from PIL import Image | |
# Configure Gemini API key using an environment variable | |
genai.configure(api_key=os.getenv("GEMINI_API_KEY")) | |
# Initialize Flask app | |
app = Flask(__name__) | |
def save_image(image_data): | |
"""Save the image from a base64 string to a temporary file and return its path.""" | |
# Expected format: "data:image/png;base64,...." | |
header, encoded = image_data.split(',', 1) | |
image_bytes = base64.b64decode(encoded) | |
image = Image.open(io.BytesIO(image_bytes)) | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
image.save(temp_file, "PNG") | |
return temp_file.name | |
def remove_object_from_image(image_path, object_type): | |
"""Use Gemini API to remove a specified object from the image.""" | |
# Upload the image file to Gemini using the module-level function. | |
uploaded_file = genai.files.upload(file=image_path) | |
# Prepare the input parts: | |
# 1. The image file. | |
parts = [types.Part.from_uri(file_uri=uploaded_file.uri, mime_type="image/png")] | |
# 2. The Gemini magic text instructing removal. | |
if object_type: | |
parts.append(types.Part.from_text(text=f"Remove {object_type} from the image")) | |
contents = [types.Content(role="user", parts=parts)] | |
generate_content_config = types.GenerateContentConfig( | |
temperature=1, | |
top_p=0.95, | |
top_k=40, | |
max_output_tokens=8192, | |
response_modalities=["image", "text"], | |
safety_settings=[types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF")], | |
response_mime_type="text/plain", | |
) | |
result_image = None | |
# Use the module-level function to generate content. | |
for chunk in genai.models.generate_content_stream( | |
model="gemini-2.0-flash-exp-image-generation", | |
contents=contents, | |
config=generate_content_config, | |
): | |
if (chunk.candidates and chunk.candidates[0].content and | |
chunk.candidates[0].content.parts): | |
part = chunk.candidates[0].content.parts[0] | |
if part.inline_data: | |
file_name = "generated_output.png" | |
with open(file_name, "wb") as f: | |
f.write(part.inline_data.data) | |
result_image = file_name | |
return result_image | |
def index(): | |
"""Render the main page.""" | |
return render_template('index.html') | |
def process_image(): | |
"""Handle image processing via POST request.""" | |
data = request.get_json() | |
image_data = data['image'] | |
object_type = data['objectType'] | |
# Save the uploaded image locally. | |
image_path = save_image(image_data) | |
try: | |
# Use Gemini to remove the object from the image. | |
result_image = remove_object_from_image(image_path, object_type) | |
return jsonify({'success': True, 'resultPath': result_image}) | |
except Exception as e: | |
return jsonify({'success': False, 'message': str(e)}) | |
if __name__ == '__main__': | |
# For local testing; in production, your hosting provider will manage the server. | |
app.run(host="0.0.0.0", port=7860) |