File size: 3,365 Bytes
e5c238d
68780eb
a97d8ed
68780eb
 
 
 
 
e5c238d
68780eb
 
d2236a2
68780eb
 
e5c238d
68780eb
e5c238d
 
68780eb
 
 
 
 
 
 
 
 
e5c238d
68780eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5c238d
68780eb
d2236a2
68780eb
 
 
 
 
 
 
 
 
 
 
 
 
 
d2236a2
68780eb
d2236a2
68780eb
 
 
 
d2236a2
68780eb
 
 
 
 
 
d2236a2
68780eb
 
e5c238d
68780eb
 
 
 
e5c238d
68780eb
e5c238d
68780eb
 
0516a8b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import io
import base64
import tempfile
from flask import Flask, render_template, request, jsonify
from google import genai
from google.genai import types
from PIL import Image

# Configure Gemini API key using an environment variable
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))

# Initialize Gemini client
client = genai.Client()

# Initialize Flask app
app = Flask(__name__)

def save_image(image_data):
    """Save the image from a base64 string to a temporary file and return its path."""
    # image_data is expected to be in the format "data:image/png;base64,...."
    header, encoded = image_data.split(',', 1)
    image_bytes = base64.b64decode(encoded)
    image = Image.open(io.BytesIO(image_bytes))
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
    image.save(temp_file, "PNG")
    return temp_file.name

def remove_object_from_image(image_path, object_type):
    """Use Gemini API to remove a specified object from the image."""
    # Upload the image file to Gemini
    uploaded_file = client.files.upload(file=image_path)
    
    # Prepare the input parts:
    # 1. The image file
    parts = [types.Part.from_uri(file_uri=uploaded_file.uri, mime_type="image/png")]
    # 2. The Gemini magic text instructing removal
    if object_type:
        parts.append(types.Part.from_text(text=f"Remove {object_type} from the image"))
    
    contents = [types.Content(role="user", parts=parts)]
    generate_content_config = types.GenerateContentConfig(
        temperature=1,
        top_p=0.95,
        top_k=40,
        max_output_tokens=8192,
        response_modalities=["image", "text"],
        safety_settings=[types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF")],
        response_mime_type="text/plain",
    )

    result_image = None

    # Process the generation stream from Gemini
    for chunk in client.models.generate_content_stream(
        model="gemini-2.0-flash-exp-image-generation",
        contents=contents,
        config=generate_content_config,
    ):
        if chunk.candidates and chunk.candidates[0].content and chunk.candidates[0].content.parts:
            part = chunk.candidates[0].content.parts[0]
            if part.inline_data:
                # Save the generated binary image data
                file_name = "generated_output.png"
                with open(file_name, "wb") as f:
                    f.write(part.inline_data.data)
                result_image = file_name

    return result_image

@app.route('/')
def index():
    """Render the main page."""
    return render_template('index.html')

@app.route('/process', methods=['POST'])
def process_image():
    """Handle image processing via POST request."""
    data = request.get_json()
    image_data = data['image']
    object_type = data['objectType']

    # Save the uploaded image locally
    image_path = save_image(image_data)

    try:
        # Use Gemini to remove the object from the image
        result_image = remove_object_from_image(image_path, object_type)
        return jsonify({'success': True, 'resultPath': result_image})
    except Exception as e:
        return jsonify({'success': False, 'message': str(e)})

if __name__ == '__main__':
    # For local testing; in production, your hosting provider will manage the server.
    app.run(host="0.0.0.0", port=7860)