sm4ll-VTON-Demo / app.py
risunobushi's picture
Update to new Demo rate limits. Unlimited use on studio.yourmirror.io
52b13b2 verified
"""Gradio demo Space for Trainingless – see plan.md for details."""
from __future__ import annotations
import base64
import io
import os
import time
import uuid
from datetime import datetime
from typing import Tuple, Optional
import requests
from dotenv import load_dotenv
from PIL import Image
from collections import defaultdict, deque
import threading
import gradio as gr
from supabase import create_client, Client
from transformers import pipeline
# -----------------------------------------------------------------------------
# Environment & Supabase setup
# -----------------------------------------------------------------------------
# Load .env file *once* when running locally. The HF Spaces runtime injects the
# same names via its Secrets mechanism, so calling load_dotenv() is harmless.
load_dotenv()
SUPABASE_URL: str = os.getenv("SUPABASE_URL", "")
# Use a *secret* (server-only) key so the backend bypasses RLS.
SUPABASE_SECRET_KEY: str = os.getenv("SUPABASE_SECRET_KEY", "")
# (Optional) You can override which Edge Function gets called.
SUPABASE_FUNCTION_URL: str = os.getenv(
"SUPABASE_FUNCTION_URL", f"{SUPABASE_URL}/functions/v1/process-image"
)
# Storage bucket for uploads. Must be *public*.
UPLOAD_BUCKET = os.getenv("SUPABASE_UPLOAD_BUCKET", "images")
REQUEST_TIMEOUT = int(os.getenv("SUPABASE_FN_TIMEOUT", "240")) # seconds
# Available model workflows recognised by edge function
WORKFLOW_CHOICES = [
"eyewear",
"footwear",
"dress",
"top",
]
if not SUPABASE_URL or not SUPABASE_SECRET_KEY:
raise RuntimeError(
"SUPABASE_URL and SUPABASE_SECRET_KEY must be set in the environment."
)
# -----------------------------------------------------------------------------
# Supabase client – server-side: authenticate with secret key (bypasses RLS)
# -----------------------------------------------------------------------------
supabase: Client = create_client(SUPABASE_URL, SUPABASE_SECRET_KEY)
# Ensure the uploads bucket exists (idempotent). This requires service role *once*;
try:
buckets = supabase.storage.list_buckets() # type: ignore[attr-defined]
bucket_names = {b["name"] for b in buckets} if isinstance(buckets, list) else set()
if UPLOAD_BUCKET not in bucket_names:
# Attempt to create bucket (will fail w/ anon key – inform user to create)
try:
supabase.storage.create_bucket(
UPLOAD_BUCKET,
public=True,
)
print(f"[startup] Created bucket '{UPLOAD_BUCKET}'.")
except Exception as create_exc: # noqa: BLE001
print(f"[startup] Could not create bucket '{UPLOAD_BUCKET}': {create_exc!r}")
except Exception as exc: # noqa: BLE001
# Non-fatal. The bucket probably already exists or we don't have perms.
print(f"[startup] Bucket check/create raised {exc!r}. Continuing…")
# -----------------------------------------------------------------------------
# NSFW Filter Setup
# -----------------------------------------------------------------------------
# Initialize NSFW classifier at startup
print("[startup] Loading NSFW classifier...")
try:
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
print("[startup] NSFW classifier loaded successfully")
except Exception as exc:
print(f"[startup] Failed to load NSFW classifier: {exc!r}")
nsfw_classifier = None
# -----------------------------------------------------------------------------
# Rate Limiting
# -----------------------------------------------------------------------------
# Rate limiter: 5 requests per IP per hour
RATE_LIMIT_REQUESTS = 5
RATE_LIMIT_WINDOW = 3600 # 1 hour in seconds
request_tracker = defaultdict(deque)
def get_client_ip(request: gr.Request) -> str:
"""Extract client IP with multiple fallback methods."""
if not request:
return "no-request"
# Try multiple sources for IP address
ip_sources = [
# Direct client host
getattr(request.client, 'host', None) if hasattr(request, 'client') and request.client else None,
# Common proxy headers
request.headers.get('X-Forwarded-For', '').split(',')[0].strip() if hasattr(request, 'headers') else None,
request.headers.get('X-Real-IP', '') if hasattr(request, 'headers') else None,
request.headers.get('CF-Connecting-IP', '') if hasattr(request, 'headers') else None, # Cloudflare
request.headers.get('True-Client-IP', '') if hasattr(request, 'headers') else None,
request.headers.get('X-Client-IP', '') if hasattr(request, 'headers') else None,
]
# Return first valid IP found
for ip in ip_sources:
if ip and ip.strip() and ip != '::1' and not ip.startswith('127.'):
return ip.strip()
# Final fallbacks
if hasattr(request, 'client') and request.client:
client_host = getattr(request.client, 'host', None)
if client_host:
return client_host
# If all else fails, use a session-based identifier
session_id = getattr(request, 'session_hash', 'unknown-session')
return f"session-{session_id}"
def check_rate_limit(client_ip: str) -> bool:
"""Check if IP has exceeded rate limit. Returns True if allowed, False if blocked."""
current_time = time.time()
user_requests = request_tracker[client_ip]
# Remove requests outside the time window
while user_requests and current_time - user_requests[0] > RATE_LIMIT_WINDOW:
user_requests.popleft()
# Check if under limit
if len(user_requests) < RATE_LIMIT_REQUESTS:
user_requests.append(current_time)
return True
else:
# Log rate limit hit for monitoring
print(f"[RATE_LIMIT] IP {client_ip} exceeded limit ({len(user_requests)}/{RATE_LIMIT_REQUESTS})")
return False
def cleanup_rate_limiter():
"""Periodic cleanup to prevent memory issues."""
current_time = time.time()
ips_to_remove = []
for ip, requests in request_tracker.items():
# Remove old requests
while requests and current_time - requests[0] > RATE_LIMIT_WINDOW:
requests.popleft()
# If no recent requests, mark IP for removal
if not requests:
ips_to_remove.append(ip)
# Clean up empty entries
for ip in ips_to_remove:
del request_tracker[ip]
print(f"[RATE_LIMITER] Cleaned up {len(ips_to_remove)} inactive IPs. Active IPs: {len(request_tracker)}")
# -----------------------------------------------------------------------------
# Helper functions
# -----------------------------------------------------------------------------
def pil_to_bytes(img: Image.Image) -> bytes:
"""Convert PIL Image to PNG bytes."""
with io.BytesIO() as buffer:
img.save(buffer, format="PNG")
return buffer.getvalue()
def upload_image_to_supabase(img: Image.Image, path: str) -> str:
"""Upload image under `UPLOAD_BUCKET/path` and return **public URL**."""
data = pil_to_bytes(img)
# Overwrite if exists
supabase.storage.from_(UPLOAD_BUCKET).upload(
path,
data,
{"content-type": "image/png", "upsert": "true"}, # upsert must be string
) # type: ignore[attr-defined]
public_url = (
f"{SUPABASE_URL}/storage/v1/object/public/{UPLOAD_BUCKET}/{path}"
)
return public_url
def wait_for_job_completion(job_id: str, timeout_s: int = 600) -> Optional[str]:
"""Subscribe to the single row via Realtime. Fallback to polling every 5 s."""
# First try realtime subscription (non-blocking). If it errors, fall back.
completed_image: Optional[str] = None
did_subscribe = False
try:
# Docs: https://supabase.com/docs/reference/python/creating-channels
channel = (
supabase.channel("job_channel")
.on(
"postgres_changes",
{
"event": "UPDATE",
"schema": "public",
"table": "processing_jobs",
"filter": f"id=eq.{job_id}",
},
lambda payload: _realtime_callback(payload, job_id),
)
.subscribe()
)
did_subscribe = True
except Exception as exc: # noqa: BLE001
print(f"[wait] Realtime subscription failed – will poll: {exc!r}")
start = time.time()
while time.time() - start < timeout_s:
if _RESULT_CACHE.get(job_id):
completed_image = _RESULT_CACHE.pop(job_id)
break
if not did_subscribe or (time.time() - start) % 5 == 0:
# Poll once every ~5 s
data = (
supabase.table("processing_jobs")
.select("status,result_image_url")
.eq("id", job_id)
.single()
.execute()
)
if data.data and data.data["status"] == "completed":
completed_image = data.data.get("result_image_url")
break
time.sleep(1)
try:
if did_subscribe:
supabase.remove_channel(channel)
except Exception: # noqa: PIE786, BLE001
pass
return completed_image
_RESULT_CACHE: dict[str, str] = {}
def _realtime_callback(payload: dict, job_id: str) -> None:
new = payload.get("new", {}) # type: ignore[index]
if new.get("status") == "completed":
_RESULT_CACHE[job_id] = new.get("result_image_url")
MAX_PIXELS = 1_500_000 # 1.5 megapixels ceiling for each uploaded image
def downscale_image(img: Image.Image, max_pixels: int = MAX_PIXELS) -> Image.Image:
"""Downscale *img* proportionally so that width×height ≤ *max_pixels*.
If the image is already small enough, it is returned unchanged.
"""
w, h = img.size
if w * h <= max_pixels:
return img
scale = (max_pixels / (w * h)) ** 0.5 # uniform scaling factor
new_size = (max(1, int(w * scale)), max(1, int(h * scale)))
return img.resize(new_size, Image.LANCZOS)
def _public_storage_url(path: str) -> str:
"""Return a public (https) URL given an object *path* inside any bucket.
If *path* already looks like a full URL, it is returned unchanged.
"""
if path.startswith("http://") or path.startswith("https://"):
return path
# Ensure no leading slash.
return f"{SUPABASE_URL}/storage/v1/object/public/{path.lstrip('/')}"
def is_nsfw_content(img: Image.Image) -> bool:
"""Check if image contains explicit pornographic content using Hugging Face transformer.
Designed to allow legitimate fashion content (lingerie, swimwear) while blocking explicit porn.
"""
if nsfw_classifier is None:
print("[NSFW] Classifier not available, skipping check")
return False
try:
# Run classification
results = nsfw_classifier(img)
print(f"[NSFW] Classification results: {results}")
# Check for explicit pornographic content only
for result in results:
label = result['label'].lower()
score = result['score']
print(f"[NSFW] Label: {label}, Score: {score:.3f}")
# Only block explicit pornographic content with very high confidence
# Allow fashion content (lingerie, swimwear) by being more restrictive
if label == 'porn' and score > 0.85: # Higher threshold, only "porn" label
print(f"[NSFW] BLOCKED - Explicit pornographic content detected with {score:.3f} confidence")
return True
elif label in ['nsfw', 'explicit'] and score > 0.95: # Very high threshold for broader categories
print(f"[NSFW] BLOCKED - {label} detected with {score:.3f} confidence")
return True
print("[NSFW] Content approved (fashion/lingerie content allowed)")
return False
except Exception as exc:
print(f"[NSFW] Error during classification: {exc!r}")
# Fail open - don't block if classifier has issues
return False
# -----------------------------------------------------------------------------
# Main generate function
# -----------------------------------------------------------------------------
def fetch_image_if_url(img):
"""
If img is a string and looks like a URL, download and return as PIL.Image.
Otherwise, return as-is (assume already PIL.Image).
"""
if isinstance(img, str) and (img.startswith("http://") or img.startswith("https://")):
print(f"[FETCH] Downloading image from URL: {img}")
resp = requests.get(img, headers={"x-api-origin": "hf/demo"})
resp.raise_for_status()
from PIL import Image
return Image.open(io.BytesIO(resp.content)).convert("RGB")
return img
def generate(
base_img: Image.Image,
garment_img: Image.Image,
workflow_choice: str,
mask_img: Optional[Image.Image], # NEW: Optional mask parameter
request: gr.Request
) -> Image.Image:
base_img = fetch_image_if_url(base_img)
garment_img = fetch_image_if_url(garment_img)
if mask_img is not None:
mask_img = fetch_image_if_url(mask_img)
if base_img is None or garment_img is None:
raise gr.Error("Please provide both images.")
# Rate limiting check
client_ip = get_client_ip(request)
if not check_rate_limit(client_ip):
raise gr.Error("Rate Limit Quota Exceeded - Visit studio.yourmirror.io to sign up for unlimited use")
# NSFW content filtering - only check product image
print(f"[NSFW] Checking product image for inappropriate content...")
if is_nsfw_content(garment_img):
raise gr.Error("Product image contains inappropriate content. Please use a different image.")
# 1. Persist both images to Supabase storage
job_id = str(uuid.uuid4())
folder = f"user_uploads/gradio/{job_id}"
base_filename = f"{uuid.uuid4().hex}.png"
garment_filename = f"{uuid.uuid4().hex}.png"
base_path = f"{folder}/{base_filename}"
garment_path = f"{folder}/{garment_filename}"
base_img = downscale_image(base_img)
garment_img = downscale_image(garment_img)
base_url = upload_image_to_supabase(base_img, base_path)
garment_url = upload_image_to_supabase(garment_img, garment_path)
# Handle optional mask image (if provided by ComfyUI or future web UI)
mask_url = None
if mask_img is not None:
print(f"[MASK] Processing user-provided mask image")
mask_filename = f"{uuid.uuid4().hex}.png"
mask_path = f"{folder}/{mask_filename}"
mask_img = downscale_image(mask_img)
mask_url = upload_image_to_supabase(mask_img, mask_path)
print(f"[MASK] Uploaded mask: {mask_url}")
else:
print(f"[MASK] No mask provided - will use base image fallback")
# 2. Insert new row into processing_jobs (anon key, relies on open RLS)
token_for_row = str(uuid.uuid4())
insert_payload = {
"id": job_id,
"status": "queued",
"base_image_path": base_url,
"garment_image_path": garment_url,
"mask_image_path": mask_url if mask_url else base_url, # Track actual mask used
"access_token": token_for_row,
"created_at": datetime.utcnow().isoformat(),
}
supabase.table("processing_jobs").insert(insert_payload).execute()
# 3. Trigger edge function
workflow_choice = (workflow_choice or "eyewear").lower()
if workflow_choice not in WORKFLOW_CHOICES:
workflow_choice = "eyewear"
fn_payload = {
"baseImageUrl": base_url,
"garmentImageUrl": garment_url,
# 🎭 Smart fallback: use provided mask OR base image (much better than garment!)
"maskImageUrl": mask_url if mask_url else base_url,
"jobId": job_id,
"workflowType": workflow_choice,
}
# Log mask selection for debugging
if mask_url:
print(f"[API] Using user-provided mask: {mask_url}")
else:
print(f"[API] Using base image as mask fallback: {base_url}")
headers = {
"Content-Type": "application/json",
"apikey": SUPABASE_SECRET_KEY,
"Authorization": f"Bearer {SUPABASE_SECRET_KEY}",
"x-api-origin": "hf/demo",
}
resp = requests.post(
SUPABASE_FUNCTION_URL,
json=fn_payload,
headers=headers,
timeout=REQUEST_TIMEOUT,
)
if not resp.ok:
raise gr.Error(f"Backend error: {resp.text}")
# 4. Wait for completion via realtime (or polling fallback)
result = wait_for_job_completion(job_id)
if not result:
raise gr.Error("Timed out waiting for job to finish.")
# Result may be base64 data URI or http URL; normalise.
if result.startswith("data:image"):
header, b64 = result.split(",", 1)
img_bytes = base64.b64decode(b64)
result_img = Image.open(io.BytesIO(img_bytes)).convert("RGBA")
else:
result_url = _public_storage_url(result)
resp_img = requests.get(result_url, timeout=30, headers={"x-api-origin": "hf/demo"})
resp_img.raise_for_status()
result_img = Image.open(io.BytesIO(resp_img.content)).convert("RGBA")
return result_img
# -----------------------------------------------------------------------------
# Gradio UI
# -----------------------------------------------------------------------------
description = "Upload a person photo (Base) and a product image. Select between Eyewear, Footwear, Full-Body Garments, or Top Garments to switch between the four available models. Click 👉 **Generate** to try on a product." # noqa: E501
with gr.Blocks(title="YOURMIRROR.IO - SM4LL-VTON Demo") as demo:
# Header
gr.Markdown("# SM4LL-VTON PRE-RELEASE DEMO | YOURMIRROR.IO | Virtual Try-On")
gr.Markdown(description)
IMG_SIZE = 256
with gr.Row():
# Left column: Example images
with gr.Column(scale=1):
gr.Markdown("### base image examples")
with gr.Row():
base_example_1 = gr.Image(
value="assets/base_image-1.jpg",
interactive=False,
height=120,
width=120,
show_label=False,
)
base_example_2 = gr.Image(
value="assets/base_image-2.jpg",
interactive=False,
height=120,
width=120,
show_label=False,
)
base_example_3 = gr.Image(
value="assets/base_image-3.jpg",
interactive=False,
height=120,
width=120,
show_label=False,
)
gr.Markdown("### product examples")
with gr.Row():
product_example_1 = gr.Image(
value="assets/product_image-1.jpg",
interactive=False,
height=120,
width=120,
show_label=False,
)
product_example_2 = gr.Image(
value="assets/product_image-2.jpg",
interactive=False,
height=120,
width=120,
show_label=False,
)
product_example_3 = gr.Image(
value="assets/product_image-3.jpg",
interactive=False,
height=120,
width=120,
show_label=False,
)
# Second column: Input fields
with gr.Column(scale=1):
base_in = gr.Image(
label="Base Image",
type="pil",
height=IMG_SIZE,
width=IMG_SIZE,
)
garment_in = gr.Image(
label="Product Image",
type="pil",
height=IMG_SIZE,
width=IMG_SIZE,
)
mask_in = gr.Image(
label="Mask Image (Optional)",
type="pil",
height=IMG_SIZE,
width=IMG_SIZE,
visible=False, # Hidden from UI but available for API
)
# Third column: Result
with gr.Column(scale=2):
result_out = gr.Image(
label="Result",
height=512,
width=512,
)
# Right column: Controls
with gr.Column(scale=1):
workflow_selector = gr.Radio(
choices=[
("Eyewear", "eyewear"),
("Footwear", "footwear"),
("Full-Body Garment", "dress"),
("Top Garment", "top"),
],
value="eyewear",
label="Model",
)
generate_btn = gr.Button("Generate", variant="primary", size="lg")
# Disclaimer box
gr.Markdown("""
<div style="background-color: #495057; color: white; border-radius: 8px; padding: 15px; margin-top: 15px; font-size: 14px;">
<strong>Disclaimer:</strong>
<ul style="margin: 8px 0; padding-left: 20px;">
<li>Depending on whether the selected model is already loaded, generations take between 20 and 80 seconds</li>
<li>If the automasking process doesn't find a target, it will throw an error (e.g.: no feet in a Footwear request)</li>
<li>The Full-Body Garment model is able to generate dresses AND copy full looks, although this latter feature is highly experimental. You can provide a target full look worn by another person, and the model will treat it as a single full-body garment</li>
<li>Supported formats: JPG, JPEG, WEBP, PNG; unsupported formats: GIF, AVIF</li>
</ul>
</div>
""")
# Add spacing
gr.Markdown("<br><br>")
# Information section
with gr.Row():
with gr.Column():
gr.Markdown("""
<div style="font-size: 18px; line-height: 2.0;">
📄 <strong>Read the Technical Report here:</strong> <a href="https://sm4ll-vton.github.io/sm4llvton/" target="_blank">sm4ll-vton.github.io/sm4llvton/</a>
<br><br>
🎥 <strong>Watch the in-depth YouTube video here:</strong> <a href="https://youtu.be/5o1OjWV4gsk" target="_blank">YouTube Video Tutorial</a>
<br><br>
🚀 <strong>Sign up for APIs and SDK on YourMirror:</strong> <a href="https://yourmirror.io" target="_blank">yourmirror.io</a>
<br><br>
💬 <strong>Want to chat?</strong><br>
[email protected] | [email protected]<br>
[email protected] | [email protected]
</div>
""")
# Wire up interaction
generate_btn.click(
generate,
inputs=[base_in, garment_in, workflow_selector, mask_in],
outputs=result_out,
)
# Select handlers for example images
base_example_1.select(lambda: Image.open("assets/base_image-1.jpg"), outputs=base_in)
base_example_2.select(lambda: Image.open("assets/base_image-2.jpg"), outputs=base_in)
base_example_3.select(lambda: Image.open("assets/base_image-3.jpg"), outputs=base_in)
product_example_1.select(lambda: Image.open("assets/product_image-1.jpg"), outputs=garment_in)
product_example_2.select(lambda: Image.open("assets/product_image-2.jpg"), outputs=garment_in)
product_example_3.select(lambda: Image.open("assets/product_image-3.jpg"), outputs=garment_in)
# Periodic cleanup for rate limiter (runs every 10 minutes)
def periodic_cleanup():
cleanup_rate_limiter()
# Schedule next cleanup
threading.Timer(600.0, periodic_cleanup).start() # 10 minutes
# Start cleanup timer
threading.Timer(600.0, periodic_cleanup).start()
# Run app if executed directly (e.g. `python app.py`). HF Spaces launches via
# `python app.py` automatically if it finds `app.py` at repo root, but our file
# lives in a sub-folder, so we keep the guard.
if __name__ == "__main__":
demo.launch()