Spaces:
Running
Running
"""Gradio demo Space for Trainingless – see plan.md for details.""" | |
from __future__ import annotations | |
import base64 | |
import io | |
import os | |
import time | |
import uuid | |
from datetime import datetime | |
from typing import Tuple, Optional | |
import requests | |
from dotenv import load_dotenv | |
from PIL import Image | |
from collections import defaultdict, deque | |
import threading | |
import gradio as gr | |
from supabase import create_client, Client | |
from transformers import pipeline | |
# ----------------------------------------------------------------------------- | |
# Environment & Supabase setup | |
# ----------------------------------------------------------------------------- | |
# Load .env file *once* when running locally. The HF Spaces runtime injects the | |
# same names via its Secrets mechanism, so calling load_dotenv() is harmless. | |
load_dotenv() | |
SUPABASE_URL: str = os.getenv("SUPABASE_URL", "") | |
# Use a *secret* (server-only) key so the backend bypasses RLS. | |
SUPABASE_SECRET_KEY: str = os.getenv("SUPABASE_SECRET_KEY", "") | |
# (Optional) You can override which Edge Function gets called. | |
SUPABASE_FUNCTION_URL: str = os.getenv( | |
"SUPABASE_FUNCTION_URL", f"{SUPABASE_URL}/functions/v1/process-image" | |
) | |
# Storage bucket for uploads. Must be *public*. | |
UPLOAD_BUCKET = os.getenv("SUPABASE_UPLOAD_BUCKET", "images") | |
REQUEST_TIMEOUT = int(os.getenv("SUPABASE_FN_TIMEOUT", "240")) # seconds | |
# Available model workflows recognised by edge function | |
WORKFLOW_CHOICES = [ | |
"eyewear", | |
"footwear", | |
"dress", | |
"top", | |
] | |
if not SUPABASE_URL or not SUPABASE_SECRET_KEY: | |
raise RuntimeError( | |
"SUPABASE_URL and SUPABASE_SECRET_KEY must be set in the environment." | |
) | |
# ----------------------------------------------------------------------------- | |
# Supabase client – server-side: authenticate with secret key (bypasses RLS) | |
# ----------------------------------------------------------------------------- | |
supabase: Client = create_client(SUPABASE_URL, SUPABASE_SECRET_KEY) | |
# Ensure the uploads bucket exists (idempotent). This requires service role *once*; | |
try: | |
buckets = supabase.storage.list_buckets() # type: ignore[attr-defined] | |
bucket_names = {b["name"] for b in buckets} if isinstance(buckets, list) else set() | |
if UPLOAD_BUCKET not in bucket_names: | |
# Attempt to create bucket (will fail w/ anon key – inform user to create) | |
try: | |
supabase.storage.create_bucket( | |
UPLOAD_BUCKET, | |
public=True, | |
) | |
print(f"[startup] Created bucket '{UPLOAD_BUCKET}'.") | |
except Exception as create_exc: # noqa: BLE001 | |
print(f"[startup] Could not create bucket '{UPLOAD_BUCKET}': {create_exc!r}") | |
except Exception as exc: # noqa: BLE001 | |
# Non-fatal. The bucket probably already exists or we don't have perms. | |
print(f"[startup] Bucket check/create raised {exc!r}. Continuing…") | |
# ----------------------------------------------------------------------------- | |
# NSFW Filter Setup | |
# ----------------------------------------------------------------------------- | |
# Initialize NSFW classifier at startup | |
print("[startup] Loading NSFW classifier...") | |
try: | |
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection") | |
print("[startup] NSFW classifier loaded successfully") | |
except Exception as exc: | |
print(f"[startup] Failed to load NSFW classifier: {exc!r}") | |
nsfw_classifier = None | |
# ----------------------------------------------------------------------------- | |
# Rate Limiting | |
# ----------------------------------------------------------------------------- | |
# Rate limiter: 5 requests per IP per hour | |
RATE_LIMIT_REQUESTS = 5 | |
RATE_LIMIT_WINDOW = 3600 # 1 hour in seconds | |
request_tracker = defaultdict(deque) | |
def get_client_ip(request: gr.Request) -> str: | |
"""Extract client IP with multiple fallback methods.""" | |
if not request: | |
return "no-request" | |
# Try multiple sources for IP address | |
ip_sources = [ | |
# Direct client host | |
getattr(request.client, 'host', None) if hasattr(request, 'client') and request.client else None, | |
# Common proxy headers | |
request.headers.get('X-Forwarded-For', '').split(',')[0].strip() if hasattr(request, 'headers') else None, | |
request.headers.get('X-Real-IP', '') if hasattr(request, 'headers') else None, | |
request.headers.get('CF-Connecting-IP', '') if hasattr(request, 'headers') else None, # Cloudflare | |
request.headers.get('True-Client-IP', '') if hasattr(request, 'headers') else None, | |
request.headers.get('X-Client-IP', '') if hasattr(request, 'headers') else None, | |
] | |
# Return first valid IP found | |
for ip in ip_sources: | |
if ip and ip.strip() and ip != '::1' and not ip.startswith('127.'): | |
return ip.strip() | |
# Final fallbacks | |
if hasattr(request, 'client') and request.client: | |
client_host = getattr(request.client, 'host', None) | |
if client_host: | |
return client_host | |
# If all else fails, use a session-based identifier | |
session_id = getattr(request, 'session_hash', 'unknown-session') | |
return f"session-{session_id}" | |
def check_rate_limit(client_ip: str) -> bool: | |
"""Check if IP has exceeded rate limit. Returns True if allowed, False if blocked.""" | |
current_time = time.time() | |
user_requests = request_tracker[client_ip] | |
# Remove requests outside the time window | |
while user_requests and current_time - user_requests[0] > RATE_LIMIT_WINDOW: | |
user_requests.popleft() | |
# Check if under limit | |
if len(user_requests) < RATE_LIMIT_REQUESTS: | |
user_requests.append(current_time) | |
return True | |
else: | |
# Log rate limit hit for monitoring | |
print(f"[RATE_LIMIT] IP {client_ip} exceeded limit ({len(user_requests)}/{RATE_LIMIT_REQUESTS})") | |
return False | |
def cleanup_rate_limiter(): | |
"""Periodic cleanup to prevent memory issues.""" | |
current_time = time.time() | |
ips_to_remove = [] | |
for ip, requests in request_tracker.items(): | |
# Remove old requests | |
while requests and current_time - requests[0] > RATE_LIMIT_WINDOW: | |
requests.popleft() | |
# If no recent requests, mark IP for removal | |
if not requests: | |
ips_to_remove.append(ip) | |
# Clean up empty entries | |
for ip in ips_to_remove: | |
del request_tracker[ip] | |
print(f"[RATE_LIMITER] Cleaned up {len(ips_to_remove)} inactive IPs. Active IPs: {len(request_tracker)}") | |
# ----------------------------------------------------------------------------- | |
# Helper functions | |
# ----------------------------------------------------------------------------- | |
def pil_to_bytes(img: Image.Image) -> bytes: | |
"""Convert PIL Image to PNG bytes.""" | |
with io.BytesIO() as buffer: | |
img.save(buffer, format="PNG") | |
return buffer.getvalue() | |
def upload_image_to_supabase(img: Image.Image, path: str) -> str: | |
"""Upload image under `UPLOAD_BUCKET/path` and return **public URL**.""" | |
data = pil_to_bytes(img) | |
# Overwrite if exists | |
supabase.storage.from_(UPLOAD_BUCKET).upload( | |
path, | |
data, | |
{"content-type": "image/png", "upsert": "true"}, # upsert must be string | |
) # type: ignore[attr-defined] | |
public_url = ( | |
f"{SUPABASE_URL}/storage/v1/object/public/{UPLOAD_BUCKET}/{path}" | |
) | |
return public_url | |
def wait_for_job_completion(job_id: str, timeout_s: int = 600) -> Optional[str]: | |
"""Subscribe to the single row via Realtime. Fallback to polling every 5 s.""" | |
# First try realtime subscription (non-blocking). If it errors, fall back. | |
completed_image: Optional[str] = None | |
did_subscribe = False | |
try: | |
# Docs: https://supabase.com/docs/reference/python/creating-channels | |
channel = ( | |
supabase.channel("job_channel") | |
.on( | |
"postgres_changes", | |
{ | |
"event": "UPDATE", | |
"schema": "public", | |
"table": "processing_jobs", | |
"filter": f"id=eq.{job_id}", | |
}, | |
lambda payload: _realtime_callback(payload, job_id), | |
) | |
.subscribe() | |
) | |
did_subscribe = True | |
except Exception as exc: # noqa: BLE001 | |
print(f"[wait] Realtime subscription failed – will poll: {exc!r}") | |
start = time.time() | |
while time.time() - start < timeout_s: | |
if _RESULT_CACHE.get(job_id): | |
completed_image = _RESULT_CACHE.pop(job_id) | |
break | |
if not did_subscribe or (time.time() - start) % 5 == 0: | |
# Poll once every ~5 s | |
data = ( | |
supabase.table("processing_jobs") | |
.select("status,result_image_url") | |
.eq("id", job_id) | |
.single() | |
.execute() | |
) | |
if data.data and data.data["status"] == "completed": | |
completed_image = data.data.get("result_image_url") | |
break | |
time.sleep(1) | |
try: | |
if did_subscribe: | |
supabase.remove_channel(channel) | |
except Exception: # noqa: PIE786, BLE001 | |
pass | |
return completed_image | |
_RESULT_CACHE: dict[str, str] = {} | |
def _realtime_callback(payload: dict, job_id: str) -> None: | |
new = payload.get("new", {}) # type: ignore[index] | |
if new.get("status") == "completed": | |
_RESULT_CACHE[job_id] = new.get("result_image_url") | |
MAX_PIXELS = 1_500_000 # 1.5 megapixels ceiling for each uploaded image | |
def downscale_image(img: Image.Image, max_pixels: int = MAX_PIXELS) -> Image.Image: | |
"""Downscale *img* proportionally so that width×height ≤ *max_pixels*. | |
If the image is already small enough, it is returned unchanged. | |
""" | |
w, h = img.size | |
if w * h <= max_pixels: | |
return img | |
scale = (max_pixels / (w * h)) ** 0.5 # uniform scaling factor | |
new_size = (max(1, int(w * scale)), max(1, int(h * scale))) | |
return img.resize(new_size, Image.LANCZOS) | |
def _public_storage_url(path: str) -> str: | |
"""Return a public (https) URL given an object *path* inside any bucket. | |
If *path* already looks like a full URL, it is returned unchanged. | |
""" | |
if path.startswith("http://") or path.startswith("https://"): | |
return path | |
# Ensure no leading slash. | |
return f"{SUPABASE_URL}/storage/v1/object/public/{path.lstrip('/')}" | |
def is_nsfw_content(img: Image.Image) -> bool: | |
"""Check if image contains explicit pornographic content using Hugging Face transformer. | |
Designed to allow legitimate fashion content (lingerie, swimwear) while blocking explicit porn. | |
""" | |
if nsfw_classifier is None: | |
print("[NSFW] Classifier not available, skipping check") | |
return False | |
try: | |
# Run classification | |
results = nsfw_classifier(img) | |
print(f"[NSFW] Classification results: {results}") | |
# Check for explicit pornographic content only | |
for result in results: | |
label = result['label'].lower() | |
score = result['score'] | |
print(f"[NSFW] Label: {label}, Score: {score:.3f}") | |
# Only block explicit pornographic content with very high confidence | |
# Allow fashion content (lingerie, swimwear) by being more restrictive | |
if label == 'porn' and score > 0.85: # Higher threshold, only "porn" label | |
print(f"[NSFW] BLOCKED - Explicit pornographic content detected with {score:.3f} confidence") | |
return True | |
elif label in ['nsfw', 'explicit'] and score > 0.95: # Very high threshold for broader categories | |
print(f"[NSFW] BLOCKED - {label} detected with {score:.3f} confidence") | |
return True | |
print("[NSFW] Content approved (fashion/lingerie content allowed)") | |
return False | |
except Exception as exc: | |
print(f"[NSFW] Error during classification: {exc!r}") | |
# Fail open - don't block if classifier has issues | |
return False | |
# ----------------------------------------------------------------------------- | |
# Main generate function | |
# ----------------------------------------------------------------------------- | |
def fetch_image_if_url(img): | |
""" | |
If img is a string and looks like a URL, download and return as PIL.Image. | |
Otherwise, return as-is (assume already PIL.Image). | |
""" | |
if isinstance(img, str) and (img.startswith("http://") or img.startswith("https://")): | |
print(f"[FETCH] Downloading image from URL: {img}") | |
resp = requests.get(img, headers={"x-api-origin": "hf/demo"}) | |
resp.raise_for_status() | |
from PIL import Image | |
return Image.open(io.BytesIO(resp.content)).convert("RGB") | |
return img | |
def generate( | |
base_img: Image.Image, | |
garment_img: Image.Image, | |
workflow_choice: str, | |
mask_img: Optional[Image.Image], # NEW: Optional mask parameter | |
request: gr.Request | |
) -> Image.Image: | |
base_img = fetch_image_if_url(base_img) | |
garment_img = fetch_image_if_url(garment_img) | |
if mask_img is not None: | |
mask_img = fetch_image_if_url(mask_img) | |
if base_img is None or garment_img is None: | |
raise gr.Error("Please provide both images.") | |
# Rate limiting check | |
client_ip = get_client_ip(request) | |
if not check_rate_limit(client_ip): | |
raise gr.Error("Rate Limit Quota Exceeded - Visit studio.yourmirror.io to sign up for unlimited use") | |
# NSFW content filtering - only check product image | |
print(f"[NSFW] Checking product image for inappropriate content...") | |
if is_nsfw_content(garment_img): | |
raise gr.Error("Product image contains inappropriate content. Please use a different image.") | |
# 1. Persist both images to Supabase storage | |
job_id = str(uuid.uuid4()) | |
folder = f"user_uploads/gradio/{job_id}" | |
base_filename = f"{uuid.uuid4().hex}.png" | |
garment_filename = f"{uuid.uuid4().hex}.png" | |
base_path = f"{folder}/{base_filename}" | |
garment_path = f"{folder}/{garment_filename}" | |
base_img = downscale_image(base_img) | |
garment_img = downscale_image(garment_img) | |
base_url = upload_image_to_supabase(base_img, base_path) | |
garment_url = upload_image_to_supabase(garment_img, garment_path) | |
# Handle optional mask image (if provided by ComfyUI or future web UI) | |
mask_url = None | |
if mask_img is not None: | |
print(f"[MASK] Processing user-provided mask image") | |
mask_filename = f"{uuid.uuid4().hex}.png" | |
mask_path = f"{folder}/{mask_filename}" | |
mask_img = downscale_image(mask_img) | |
mask_url = upload_image_to_supabase(mask_img, mask_path) | |
print(f"[MASK] Uploaded mask: {mask_url}") | |
else: | |
print(f"[MASK] No mask provided - will use base image fallback") | |
# 2. Insert new row into processing_jobs (anon key, relies on open RLS) | |
token_for_row = str(uuid.uuid4()) | |
insert_payload = { | |
"id": job_id, | |
"status": "queued", | |
"base_image_path": base_url, | |
"garment_image_path": garment_url, | |
"mask_image_path": mask_url if mask_url else base_url, # Track actual mask used | |
"access_token": token_for_row, | |
"created_at": datetime.utcnow().isoformat(), | |
} | |
supabase.table("processing_jobs").insert(insert_payload).execute() | |
# 3. Trigger edge function | |
workflow_choice = (workflow_choice or "eyewear").lower() | |
if workflow_choice not in WORKFLOW_CHOICES: | |
workflow_choice = "eyewear" | |
fn_payload = { | |
"baseImageUrl": base_url, | |
"garmentImageUrl": garment_url, | |
# 🎭 Smart fallback: use provided mask OR base image (much better than garment!) | |
"maskImageUrl": mask_url if mask_url else base_url, | |
"jobId": job_id, | |
"workflowType": workflow_choice, | |
} | |
# Log mask selection for debugging | |
if mask_url: | |
print(f"[API] Using user-provided mask: {mask_url}") | |
else: | |
print(f"[API] Using base image as mask fallback: {base_url}") | |
headers = { | |
"Content-Type": "application/json", | |
"apikey": SUPABASE_SECRET_KEY, | |
"Authorization": f"Bearer {SUPABASE_SECRET_KEY}", | |
"x-api-origin": "hf/demo", | |
} | |
resp = requests.post( | |
SUPABASE_FUNCTION_URL, | |
json=fn_payload, | |
headers=headers, | |
timeout=REQUEST_TIMEOUT, | |
) | |
if not resp.ok: | |
raise gr.Error(f"Backend error: {resp.text}") | |
# 4. Wait for completion via realtime (or polling fallback) | |
result = wait_for_job_completion(job_id) | |
if not result: | |
raise gr.Error("Timed out waiting for job to finish.") | |
# Result may be base64 data URI or http URL; normalise. | |
if result.startswith("data:image"): | |
header, b64 = result.split(",", 1) | |
img_bytes = base64.b64decode(b64) | |
result_img = Image.open(io.BytesIO(img_bytes)).convert("RGBA") | |
else: | |
result_url = _public_storage_url(result) | |
resp_img = requests.get(result_url, timeout=30, headers={"x-api-origin": "hf/demo"}) | |
resp_img.raise_for_status() | |
result_img = Image.open(io.BytesIO(resp_img.content)).convert("RGBA") | |
return result_img | |
# ----------------------------------------------------------------------------- | |
# Gradio UI | |
# ----------------------------------------------------------------------------- | |
description = "Upload a person photo (Base) and a product image. Select between Eyewear, Footwear, Full-Body Garments, or Top Garments to switch between the four available models. Click 👉 **Generate** to try on a product." # noqa: E501 | |
with gr.Blocks(title="YOURMIRROR.IO - SM4LL-VTON Demo") as demo: | |
# Header | |
gr.Markdown("# SM4LL-VTON PRE-RELEASE DEMO | YOURMIRROR.IO | Virtual Try-On") | |
gr.Markdown(description) | |
IMG_SIZE = 256 | |
with gr.Row(): | |
# Left column: Example images | |
with gr.Column(scale=1): | |
gr.Markdown("### base image examples") | |
with gr.Row(): | |
base_example_1 = gr.Image( | |
value="assets/base_image-1.jpg", | |
interactive=False, | |
height=120, | |
width=120, | |
show_label=False, | |
) | |
base_example_2 = gr.Image( | |
value="assets/base_image-2.jpg", | |
interactive=False, | |
height=120, | |
width=120, | |
show_label=False, | |
) | |
base_example_3 = gr.Image( | |
value="assets/base_image-3.jpg", | |
interactive=False, | |
height=120, | |
width=120, | |
show_label=False, | |
) | |
gr.Markdown("### product examples") | |
with gr.Row(): | |
product_example_1 = gr.Image( | |
value="assets/product_image-1.jpg", | |
interactive=False, | |
height=120, | |
width=120, | |
show_label=False, | |
) | |
product_example_2 = gr.Image( | |
value="assets/product_image-2.jpg", | |
interactive=False, | |
height=120, | |
width=120, | |
show_label=False, | |
) | |
product_example_3 = gr.Image( | |
value="assets/product_image-3.jpg", | |
interactive=False, | |
height=120, | |
width=120, | |
show_label=False, | |
) | |
# Second column: Input fields | |
with gr.Column(scale=1): | |
base_in = gr.Image( | |
label="Base Image", | |
type="pil", | |
height=IMG_SIZE, | |
width=IMG_SIZE, | |
) | |
garment_in = gr.Image( | |
label="Product Image", | |
type="pil", | |
height=IMG_SIZE, | |
width=IMG_SIZE, | |
) | |
mask_in = gr.Image( | |
label="Mask Image (Optional)", | |
type="pil", | |
height=IMG_SIZE, | |
width=IMG_SIZE, | |
visible=False, # Hidden from UI but available for API | |
) | |
# Third column: Result | |
with gr.Column(scale=2): | |
result_out = gr.Image( | |
label="Result", | |
height=512, | |
width=512, | |
) | |
# Right column: Controls | |
with gr.Column(scale=1): | |
workflow_selector = gr.Radio( | |
choices=[ | |
("Eyewear", "eyewear"), | |
("Footwear", "footwear"), | |
("Full-Body Garment", "dress"), | |
("Top Garment", "top"), | |
], | |
value="eyewear", | |
label="Model", | |
) | |
generate_btn = gr.Button("Generate", variant="primary", size="lg") | |
# Disclaimer box | |
gr.Markdown(""" | |
<div style="background-color: #495057; color: white; border-radius: 8px; padding: 15px; margin-top: 15px; font-size: 14px;"> | |
<strong>Disclaimer:</strong> | |
<ul style="margin: 8px 0; padding-left: 20px;"> | |
<li>Depending on whether the selected model is already loaded, generations take between 20 and 80 seconds</li> | |
<li>If the automasking process doesn't find a target, it will throw an error (e.g.: no feet in a Footwear request)</li> | |
<li>The Full-Body Garment model is able to generate dresses AND copy full looks, although this latter feature is highly experimental. You can provide a target full look worn by another person, and the model will treat it as a single full-body garment</li> | |
<li>Supported formats: JPG, JPEG, WEBP, PNG; unsupported formats: GIF, AVIF</li> | |
</ul> | |
</div> | |
""") | |
# Add spacing | |
gr.Markdown("<br><br>") | |
# Information section | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(""" | |
<div style="font-size: 18px; line-height: 2.0;"> | |
📄 <strong>Read the Technical Report here:</strong> <a href="https://sm4ll-vton.github.io/sm4llvton/" target="_blank">sm4ll-vton.github.io/sm4llvton/</a> | |
<br><br> | |
🎥 <strong>Watch the in-depth YouTube video here:</strong> <a href="https://youtu.be/5o1OjWV4gsk" target="_blank">YouTube Video Tutorial</a> | |
<br><br> | |
🚀 <strong>Sign up for APIs and SDK on YourMirror:</strong> <a href="https://yourmirror.io" target="_blank">yourmirror.io</a> | |
<br><br> | |
💬 <strong>Want to chat?</strong><br> | |
[email protected] | [email protected]<br> | |
[email protected] | [email protected] | |
</div> | |
""") | |
# Wire up interaction | |
generate_btn.click( | |
generate, | |
inputs=[base_in, garment_in, workflow_selector, mask_in], | |
outputs=result_out, | |
) | |
# Select handlers for example images | |
base_example_1.select(lambda: Image.open("assets/base_image-1.jpg"), outputs=base_in) | |
base_example_2.select(lambda: Image.open("assets/base_image-2.jpg"), outputs=base_in) | |
base_example_3.select(lambda: Image.open("assets/base_image-3.jpg"), outputs=base_in) | |
product_example_1.select(lambda: Image.open("assets/product_image-1.jpg"), outputs=garment_in) | |
product_example_2.select(lambda: Image.open("assets/product_image-2.jpg"), outputs=garment_in) | |
product_example_3.select(lambda: Image.open("assets/product_image-3.jpg"), outputs=garment_in) | |
# Periodic cleanup for rate limiter (runs every 10 minutes) | |
def periodic_cleanup(): | |
cleanup_rate_limiter() | |
# Schedule next cleanup | |
threading.Timer(600.0, periodic_cleanup).start() # 10 minutes | |
# Start cleanup timer | |
threading.Timer(600.0, periodic_cleanup).start() | |
# Run app if executed directly (e.g. `python app.py`). HF Spaces launches via | |
# `python app.py` automatically if it finds `app.py` at repo root, but our file | |
# lives in a sub-folder, so we keep the guard. | |
if __name__ == "__main__": | |
demo.launch() |