Spaces:
Sleeping
Sleeping
from flask import Flask, render_template, request, send_file | |
import os | |
import mimetypes | |
from google import genai | |
from google.genai import types | |
from io import BytesIO | |
app = Flask(__name__) | |
# Initialize Gemini client | |
client = genai.Client(api_key=GEMINI_API_KEY) | |
def save_binary_file(file_name, data): | |
"""Save binary data to a file.""" | |
with open(file_name, "wb") as f: | |
f.write(data) | |
def generate_gemini_output(user_input, image): | |
model = "gemini-2.0-flash-exp-image-generation" | |
# Upload the image to Gemini | |
files = [] | |
if image: | |
uploaded_file = client.files.upload(file=image) | |
files.append(uploaded_file) | |
# Prepare the input content | |
parts = [] | |
if files: | |
parts.append(types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type)) | |
# Incorporate Gemini magic: modify the prompt to instruct removal of the user-specified element. | |
if user_input: | |
magic_prompt = f"Remove {user_input} from the image" | |
parts.append(types.Part.from_text(text=magic_prompt)) | |
contents = [types.Content(role="user", parts=parts)] | |
generate_content_config = types.GenerateContentConfig( | |
temperature=1, | |
top_p=0.95, | |
top_k=40, | |
max_output_tokens=8192, | |
response_modalities=["image", "text"], | |
safety_settings=[ | |
types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF"), | |
], | |
response_mime_type="text/plain", | |
) | |
result_text = None | |
result_image = None | |
for chunk in client.models.generate_content_stream( | |
model=model, | |
contents=contents, | |
config=generate_content_config, | |
): | |
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts: | |
continue | |
part = chunk.candidates[0].content.parts[0] | |
if part.inline_data: | |
file_name = "generated_output" | |
file_extension = mimetypes.guess_extension(part.inline_data.mime_type) or ".png" | |
file_path = os.path.join("static", f"{file_name}{file_extension}") | |
save_binary_file(file_path, part.inline_data.data) | |
result_image = file_path | |
else: | |
result_text = part.text | |
return result_text, result_image | |
def index(): | |
result_text = None | |
result_image = None | |
if request.method == "POST": | |
user_input = request.form.get("user_input") | |
image = request.files.get("image_input") | |
result_text, result_image = generate_gemini_output(user_input, image) | |
return render_template("index.html", result_text=result_text, result_image=result_image) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) |