Gttg / app.py
Athspi's picture
Update app.py
fea0355 verified
raw
history blame
4.75 kB
import os
import base64
import mimetypes
from flask import Flask, render_template, request, jsonify
from werkzeug.utils import secure_filename
from google import genai
from google.genai import types
# Initialize Flask app
app = Flask(__name__)
# Set your Gemini API key via Hugging Face Spaces environment variables.
# Do not include a default fallback; the environment must supply GEMINI_API_KEY.
GEMINI_API_KEY = os.environ["GEMINI_API_KEY"]
client = genai.Client(api_key=GEMINI_API_KEY)
# Create necessary directories
UPLOAD_FOLDER = 'uploads'
RESULT_FOLDER = os.path.join('static')
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULT_FOLDER, exist_ok=True)
def generate_gemini_output(object_type, image_data_url):
"""
Generate output from Gemini by removing the specified object.
Expects the image_data_url to be a base64 data URL.
"""
model = "gemini-2.0-flash-exp-image-generation"
files = []
# Decode the image data from the data URL
if image_data_url:
try:
header, encoded = image_data_url.split(',', 1)
except ValueError:
raise ValueError("Invalid image data")
binary_data = base64.b64decode(encoded)
# Determine file extension from header
ext = ".png" if "png" in header.lower() else ".jpg"
temp_filename = secure_filename("temp_image" + ext)
temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename)
with open(temp_filepath, "wb") as f:
f.write(binary_data)
# Upload file to Gemini
uploaded_file = client.files.upload(file=temp_filepath)
files.append(uploaded_file)
# Prepare content parts for Gemini
parts = []
if files:
parts.append(types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type))
if object_type:
# Gemini magic prompt: instruct the model to remove the specified object
magic_prompt = f"Remove {object_type} from the image"
parts.append(types.Part.from_text(text=magic_prompt))
contents = [types.Content(role="user", parts=parts)]
generate_content_config = types.GenerateContentConfig(
temperature=1,
top_p=0.95,
top_k=40,
max_output_tokens=8192,
response_modalities=["image", "text"],
safety_settings=[
types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF"),
],
response_mime_type="text/plain",
)
result_text = None
result_image = None
# Stream output from Gemini API
for chunk in client.models.generate_content_stream(
model=model,
contents=contents,
config=generate_content_config,
):
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
continue
part = chunk.candidates[0].content.parts[0]
if part.inline_data:
file_extension = mimetypes.guess_extension(part.inline_data.mime_type) or ".png"
output_filename = secure_filename("generated_output" + file_extension)
result_image_path = os.path.join(RESULT_FOLDER, output_filename)
with open(result_image_path, "wb") as f:
f.write(part.inline_data.data)
result_image = result_image_path # Path relative to static folder
else:
result_text = part.text
return result_text, result_image
@app.route("/")
def index():
# Render the front-end HTML (which contains complete HTML/CSS/JS inline)
return render_template("index.html")
@app.route("/process", methods=["POST"])
def process():
try:
# Expect JSON with keys "image" (base64 data URL) and "objectType"
data = request.get_json(force=True)
image_data = data.get("image")
object_type = data.get("objectType", "").strip()
if not image_data or not object_type:
return jsonify({"success": False, "message": "Missing image data or object type."}), 400
# Generate output using Gemini
result_text, result_image = generate_gemini_output(object_type, image_data)
if not result_image:
return jsonify({"success": False, "message": "Failed to generate image."}), 500
# Create a URL to serve the image from the static folder.
image_url = f"/static/{os.path.basename(result_image)}"
return jsonify({"success": True, "resultPath": image_url, "resultText": result_text})
except Exception as e:
return jsonify({"success": False, "message": f"Error: {str(e)}"}), 500
if __name__ == "__main__":
# Run the app on port 5000 or the port provided by the environment (for Hugging Face Spaces)
app.run(host="0.0.0.0", port=7860)