File size: 6,154 Bytes
605bf7b
b0a339e
8e6ca2b
a36d15c
 
8e6ca2b
 
 
e5c238d
 
b0a339e
9479bea
8e6ca2b
 
b0a339e
a36d15c
 
 
 
 
b0a339e
19af451
b0a339e
 
19af451
 
 
 
 
 
 
 
 
 
 
b0a339e
19af451
 
b0a339e
 
19af451
b0a339e
 
 
 
605bf7b
b0a339e
 
 
 
19af451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0a339e
 
605bf7b
b0a339e
605bf7b
b0a339e
 
 
 
605bf7b
b0a339e
 
19af451
605bf7b
19af451
 
b0a339e
19af451
 
b0a339e
 
 
19af451
b0a339e
 
68780eb
b0a339e
 
19af451
 
 
 
605bf7b
19af451
 
 
68780eb
b0a339e
19af451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e6ca2b
a36d15c
68780eb
a36d15c
 
 
 
 
19af451
 
 
 
 
 
b0a339e
19af451
 
b0a339e
19af451
b0a339e
 
19af451
b0a339e
 
 
19af451
 
b0a339e
19af451
 
 
b0a339e
 
 
 
19af451
b0a339e
19af451
 
 
a36d15c
19af451
8e6ca2b
 
19af451
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import base64
import mimetypes
from flask import Flask, render_template, request, jsonify
from werkzeug.utils import secure_filename
from google import genai
from google.genai import types

app = Flask(__name__)

# Initialize Gemini client
GEMINI_API_KEY = os.environ["GEMINI_API_KEY"]
client = genai.Client(api_key=GEMINI_API_KEY)

# Configure upload folders
UPLOAD_FOLDER = 'uploads'
RESULT_FOLDER = os.path.join('static')
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULT_FOLDER, exist_ok=True)

def upload_image(image_data_url):
    """Handle base64 image upload and Gemini file upload"""
    try:
        header, encoded = image_data_url.split(',', 1)
        binary_data = base64.b64decode(encoded)
        ext = ".png" if "png" in header.lower() else ".jpg"
        temp_filename = secure_filename(f"temp_{os.urandom(8).hex()}{ext}")
        temp_filepath = os.path.join(UPLOAD_FOLDER, temp_filename)
        
        with open(temp_filepath, "wb") as f:
            f.write(binary_data)
        
        uploaded_file = client.files.upload(file=temp_filepath)
        os.remove(temp_filepath)  # Clean up temporary file
        return uploaded_file
    
    except Exception as e:
        raise ValueError(f"Image processing error: {str(e)}")

def is_prohibited_request(uploaded_file, object_type):
    """Check if request involves people/animals or their belongings"""
    model = "gemini-2.0-flash-lite"
    parts = [
        types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type),
        types.Part.from_text(text=f"Remove {object_type}")
    ]
    
    contents = [types.Content(role="user", parts=parts)]
    
    generate_content_config = types.GenerateContentConfig(
        system_instruction=[types.Part.from_text(text="""Analyze image and request to detect:
1. Direct removal of people/animals
2. Removal of items attached to/worn by people/animals
3. Removal of body parts or personal belongings

Prohibited examples:
- Person, dog, cat
- Sunglasses on face, mask, hat
- Phone in hand, watch on wrist
- Eyes, hands, hair

Allowed examples:
- Background, car, tree
- Sunglasses on table
- Phone on desk

Respond ONLY with 'Yes' or 'No'""")],
        temperature=0.0,
        max_output_tokens=1,
    )
    
    try:
        response = client.models.generate_content(
            model=model,
            contents=contents,
            config=generate_content_config
        )
        if response.candidates and response.candidates[0].content.parts:
            return response.candidates[0].content.parts[0].text.strip().lower() == "yes"
        return True  # Default to safe mode if uncertain
    except Exception as e:
        print(f"Safety check failed: {str(e)}")
        return True  # Block if check fails

def generate_modified_image(uploaded_file, object_type):
    """Generate image with object removed using experimental model"""
    model = "gemini-2.0-flash-exp-image-generation"
    parts = [
        types.Part.from_uri(file_uri=uploaded_file.uri, mime_type=uploaded_file.mime_type),
        types.Part.from_text(text=f"Completely remove {object_type} from the image without leaving traces")
    ]
    
    contents = [types.Content(role="user", parts=parts)]
    
    generate_content_config = types.GenerateContentConfig(
        temperature=0.5,
        top_p=0.9,
        max_output_tokens=1024,
        response_modalities=["image"],
        safety_settings=[
            types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
            types.SafetySetting(category="HARM_CATEGORY_VIOLENCE", threshold="BLOCK_NONE")
        ]
    )
    
    try:
        for chunk in client.models.generate_content_stream(
            model=model,
            contents=contents,
            config=generate_content_config,
        ):
            if chunk.candidates and chunk.candidates[0].content.parts:
                part = chunk.candidates[0].content.parts[0]
                if part.inline_data:
                    ext = mimetypes.guess_extension(part.inline_data.mime_type) or ".png"
                    output_filename = secure_filename(f"result_{os.urandom(4).hex()}{ext}")
                    output_path = os.path.join(RESULT_FOLDER, output_filename)
                    
                    with open(output_path, "wb") as f:
                        f.write(part.inline_data.data)
                    
                    return output_path
        return None
    except Exception as e:
        print(f"Image generation failed: {str(e)}")
        return None

@app.route("/")
def index():
    return render_template("index.html")

@app.route("/process", methods=["POST"])
def process():
    try:
        data = request.get_json()
        if not data or "image" not in data or "objectType" not in data:
            return jsonify({"success": False, "message": "Invalid request format"}), 400
            
        image_data = data["image"]
        object_type = data["objectType"].strip().lower()
        
        if not object_type:
            return jsonify({"success": False, "message": "Please specify an object to remove"}), 400
        
        # Process image upload
        uploaded_file = upload_image(image_data)
        
        # Safety check
        if is_prohibited_request(uploaded_file, object_type):
            return jsonify({
                "success": False,
                "message": "Cannot remove people, animals, or personal items"
            }), 403
        
        # Generate modified image
        result_path = generate_modified_image(uploaded_file, object_type)
        if not result_path:
            return jsonify({"success": False, "message": "Failed to generate image"}), 500
        
        return jsonify({
            "success": True,
            "resultUrl": f"/static/{os.path.basename(result_path)}"
        })
    
    except ValueError as e:
        return jsonify({"success": False, "message": str(e)}), 400
    except Exception as e:
        return jsonify({"success": False, "message": "Internal server error"}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860,)