Gttg / app.py
Athspi's picture
Update app.py
49012e3 verified
raw
history blame
2.78 kB
from flask import Flask, render_template, request, send_file
import os
import mimetypes
from google import genai
from google.genai import types
from io import BytesIO
app = Flask(__name__)
# Initialize Gemini client
client = genai.Client(api_key=GEMINI_API_KEY)
def save_binary_file(file_name, data):
"""Save binary data to a file."""
with open(file_name, "wb") as f:
f.write(data)
def generate_gemini_output(user_input, image):
model = "gemini-2.0-flash-exp-image-generation"
# Upload the image to Gemini
files = []
if image:
uploaded_file = client.files.upload(file=image)
files.append(uploaded_file)
# Prepare the input content
parts = []
if files:
parts.append(types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type))
# Incorporate Gemini magic: modify the prompt to instruct removal of the user-specified element.
if user_input:
magic_prompt = f"Remove {user_input} from the image"
parts.append(types.Part.from_text(text=magic_prompt))
contents = [types.Content(role="user", parts=parts)]
generate_content_config = types.GenerateContentConfig(
temperature=1,
top_p=0.95,
top_k=40,
max_output_tokens=8192,
response_modalities=["image", "text"],
safety_settings=[
types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF"),
],
response_mime_type="text/plain",
)
result_text = None
result_image = None
for chunk in client.models.generate_content_stream(
model=model,
contents=contents,
config=generate_content_config,
):
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
continue
part = chunk.candidates[0].content.parts[0]
if part.inline_data:
file_name = "generated_output"
file_extension = mimetypes.guess_extension(part.inline_data.mime_type) or ".png"
file_path = os.path.join("static", f"{file_name}{file_extension}")
save_binary_file(file_path, part.inline_data.data)
result_image = file_path
else:
result_text = part.text
return result_text, result_image
@app.route("/", methods=["GET", "POST"])
def index():
result_text = None
result_image = None
if request.method == "POST":
user_input = request.form.get("user_input")
image = request.files.get("image_input")
result_text, result_image = generate_gemini_output(user_input, image)
return render_template("index.html", result_text=result_text, result_image=result_image)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)