"""Gradio demo Space for Trainingless – see plan.md for details.""" from __future__ import annotations import base64 import io import os import time import uuid from datetime import datetime from typing import Tuple, Optional import requests from dotenv import load_dotenv from PIL import Image from collections import defaultdict, deque import threading import gradio as gr from supabase import create_client, Client from transformers import pipeline # ----------------------------------------------------------------------------- # Environment & Supabase setup # ----------------------------------------------------------------------------- # Load .env file *once* when running locally. The HF Spaces runtime injects the # same names via its Secrets mechanism, so calling load_dotenv() is harmless. load_dotenv() SUPABASE_URL: str = os.getenv("SUPABASE_URL", "") # Use a *secret* (server-only) key so the backend bypasses RLS. SUPABASE_SECRET_KEY: str = os.getenv("SUPABASE_SECRET_KEY", "") # (Optional) You can override which Edge Function gets called. SUPABASE_FUNCTION_URL: str = os.getenv( "SUPABASE_FUNCTION_URL", f"{SUPABASE_URL}/functions/v1/process-image" ) # Storage bucket for uploads. Must be *public*. UPLOAD_BUCKET = os.getenv("SUPABASE_UPLOAD_BUCKET", "images") REQUEST_TIMEOUT = int(os.getenv("SUPABASE_FN_TIMEOUT", "240")) # seconds # Available model workflows recognised by edge function WORKFLOW_CHOICES = [ "eyewear", "footwear", "dress", "top", ] if not SUPABASE_URL or not SUPABASE_SECRET_KEY: raise RuntimeError( "SUPABASE_URL and SUPABASE_SECRET_KEY must be set in the environment." ) # ----------------------------------------------------------------------------- # Supabase client – server-side: authenticate with secret key (bypasses RLS) # ----------------------------------------------------------------------------- supabase: Client = create_client(SUPABASE_URL, SUPABASE_SECRET_KEY) # Ensure the uploads bucket exists (idempotent). This requires service role *once*; try: buckets = supabase.storage.list_buckets() # type: ignore[attr-defined] bucket_names = {b["name"] for b in buckets} if isinstance(buckets, list) else set() if UPLOAD_BUCKET not in bucket_names: # Attempt to create bucket (will fail w/ anon key – inform user to create) try: supabase.storage.create_bucket( UPLOAD_BUCKET, public=True, ) print(f"[startup] Created bucket '{UPLOAD_BUCKET}'.") except Exception as create_exc: # noqa: BLE001 print(f"[startup] Could not create bucket '{UPLOAD_BUCKET}': {create_exc!r}") except Exception as exc: # noqa: BLE001 # Non-fatal. The bucket probably already exists or we don't have perms. print(f"[startup] Bucket check/create raised {exc!r}. Continuing…") # ----------------------------------------------------------------------------- # NSFW Filter Setup # ----------------------------------------------------------------------------- # Initialize NSFW classifier at startup print("[startup] Loading NSFW classifier...") try: nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection") print("[startup] NSFW classifier loaded successfully") except Exception as exc: print(f"[startup] Failed to load NSFW classifier: {exc!r}") nsfw_classifier = None # ----------------------------------------------------------------------------- # Rate Limiting # ----------------------------------------------------------------------------- # Rate limiter: 5 requests per IP per hour RATE_LIMIT_REQUESTS = 5 RATE_LIMIT_WINDOW = 3600 # 1 hour in seconds request_tracker = defaultdict(deque) def get_client_ip(request: gr.Request) -> str: """Extract client IP with multiple fallback methods.""" if not request: return "no-request" # Try multiple sources for IP address ip_sources = [ # Direct client host getattr(request.client, 'host', None) if hasattr(request, 'client') and request.client else None, # Common proxy headers request.headers.get('X-Forwarded-For', '').split(',')[0].strip() if hasattr(request, 'headers') else None, request.headers.get('X-Real-IP', '') if hasattr(request, 'headers') else None, request.headers.get('CF-Connecting-IP', '') if hasattr(request, 'headers') else None, # Cloudflare request.headers.get('True-Client-IP', '') if hasattr(request, 'headers') else None, request.headers.get('X-Client-IP', '') if hasattr(request, 'headers') else None, ] # Return first valid IP found for ip in ip_sources: if ip and ip.strip() and ip != '::1' and not ip.startswith('127.'): return ip.strip() # Final fallbacks if hasattr(request, 'client') and request.client: client_host = getattr(request.client, 'host', None) if client_host: return client_host # If all else fails, use a session-based identifier session_id = getattr(request, 'session_hash', 'unknown-session') return f"session-{session_id}" def check_rate_limit(client_ip: str) -> bool: """Check if IP has exceeded rate limit. Returns True if allowed, False if blocked.""" current_time = time.time() user_requests = request_tracker[client_ip] # Remove requests outside the time window while user_requests and current_time - user_requests[0] > RATE_LIMIT_WINDOW: user_requests.popleft() # Check if under limit if len(user_requests) < RATE_LIMIT_REQUESTS: user_requests.append(current_time) return True else: # Log rate limit hit for monitoring print(f"[RATE_LIMIT] IP {client_ip} exceeded limit ({len(user_requests)}/{RATE_LIMIT_REQUESTS})") return False def cleanup_rate_limiter(): """Periodic cleanup to prevent memory issues.""" current_time = time.time() ips_to_remove = [] for ip, requests in request_tracker.items(): # Remove old requests while requests and current_time - requests[0] > RATE_LIMIT_WINDOW: requests.popleft() # If no recent requests, mark IP for removal if not requests: ips_to_remove.append(ip) # Clean up empty entries for ip in ips_to_remove: del request_tracker[ip] print(f"[RATE_LIMITER] Cleaned up {len(ips_to_remove)} inactive IPs. Active IPs: {len(request_tracker)}") # ----------------------------------------------------------------------------- # Helper functions # ----------------------------------------------------------------------------- def pil_to_bytes(img: Image.Image) -> bytes: """Convert PIL Image to PNG bytes.""" with io.BytesIO() as buffer: img.save(buffer, format="PNG") return buffer.getvalue() def upload_image_to_supabase(img: Image.Image, path: str) -> str: """Upload image under `UPLOAD_BUCKET/path` and return **public URL**.""" data = pil_to_bytes(img) # Overwrite if exists supabase.storage.from_(UPLOAD_BUCKET).upload( path, data, {"content-type": "image/png", "upsert": "true"}, # upsert must be string ) # type: ignore[attr-defined] public_url = ( f"{SUPABASE_URL}/storage/v1/object/public/{UPLOAD_BUCKET}/{path}" ) return public_url def wait_for_job_completion(job_id: str, timeout_s: int = 600) -> Optional[str]: """Subscribe to the single row via Realtime. Fallback to polling every 5 s.""" # First try realtime subscription (non-blocking). If it errors, fall back. completed_image: Optional[str] = None did_subscribe = False try: # Docs: https://supabase.com/docs/reference/python/creating-channels channel = ( supabase.channel("job_channel") .on( "postgres_changes", { "event": "UPDATE", "schema": "public", "table": "processing_jobs", "filter": f"id=eq.{job_id}", }, lambda payload: _realtime_callback(payload, job_id), ) .subscribe() ) did_subscribe = True except Exception as exc: # noqa: BLE001 print(f"[wait] Realtime subscription failed – will poll: {exc!r}") start = time.time() while time.time() - start < timeout_s: if _RESULT_CACHE.get(job_id): completed_image = _RESULT_CACHE.pop(job_id) break if not did_subscribe or (time.time() - start) % 5 == 0: # Poll once every ~5 s data = ( supabase.table("processing_jobs") .select("status,result_image_url") .eq("id", job_id) .single() .execute() ) if data.data and data.data["status"] == "completed": completed_image = data.data.get("result_image_url") break time.sleep(1) try: if did_subscribe: supabase.remove_channel(channel) except Exception: # noqa: PIE786, BLE001 pass return completed_image _RESULT_CACHE: dict[str, str] = {} def _realtime_callback(payload: dict, job_id: str) -> None: new = payload.get("new", {}) # type: ignore[index] if new.get("status") == "completed": _RESULT_CACHE[job_id] = new.get("result_image_url") MAX_PIXELS = 1_500_000 # 1.5 megapixels ceiling for each uploaded image def downscale_image(img: Image.Image, max_pixels: int = MAX_PIXELS) -> Image.Image: """Downscale *img* proportionally so that width×height ≤ *max_pixels*. If the image is already small enough, it is returned unchanged. """ w, h = img.size if w * h <= max_pixels: return img scale = (max_pixels / (w * h)) ** 0.5 # uniform scaling factor new_size = (max(1, int(w * scale)), max(1, int(h * scale))) return img.resize(new_size, Image.LANCZOS) def _public_storage_url(path: str) -> str: """Return a public (https) URL given an object *path* inside any bucket. If *path* already looks like a full URL, it is returned unchanged. """ if path.startswith("http://") or path.startswith("https://"): return path # Ensure no leading slash. return f"{SUPABASE_URL}/storage/v1/object/public/{path.lstrip('/')}" def is_nsfw_content(img: Image.Image) -> bool: """Check if image contains explicit pornographic content using Hugging Face transformer. Designed to allow legitimate fashion content (lingerie, swimwear) while blocking explicit porn. """ if nsfw_classifier is None: print("[NSFW] Classifier not available, skipping check") return False try: # Run classification results = nsfw_classifier(img) print(f"[NSFW] Classification results: {results}") # Check for explicit pornographic content only for result in results: label = result['label'].lower() score = result['score'] print(f"[NSFW] Label: {label}, Score: {score:.3f}") # Only block explicit pornographic content with very high confidence # Allow fashion content (lingerie, swimwear) by being more restrictive if label == 'porn' and score > 0.85: # Higher threshold, only "porn" label print(f"[NSFW] BLOCKED - Explicit pornographic content detected with {score:.3f} confidence") return True elif label in ['nsfw', 'explicit'] and score > 0.95: # Very high threshold for broader categories print(f"[NSFW] BLOCKED - {label} detected with {score:.3f} confidence") return True print("[NSFW] Content approved (fashion/lingerie content allowed)") return False except Exception as exc: print(f"[NSFW] Error during classification: {exc!r}") # Fail open - don't block if classifier has issues return False # ----------------------------------------------------------------------------- # Main generate function # ----------------------------------------------------------------------------- def fetch_image_if_url(img): """ If img is a string and looks like a URL, download and return as PIL.Image. Otherwise, return as-is (assume already PIL.Image). """ if isinstance(img, str) and (img.startswith("http://") or img.startswith("https://")): print(f"[FETCH] Downloading image from URL: {img}") resp = requests.get(img, headers={"x-api-origin": "hf/demo"}) resp.raise_for_status() from PIL import Image return Image.open(io.BytesIO(resp.content)).convert("RGB") return img def generate( base_img: Image.Image, garment_img: Image.Image, workflow_choice: str, mask_img: Optional[Image.Image], # NEW: Optional mask parameter request: gr.Request ) -> Image.Image: base_img = fetch_image_if_url(base_img) garment_img = fetch_image_if_url(garment_img) if mask_img is not None: mask_img = fetch_image_if_url(mask_img) if base_img is None or garment_img is None: raise gr.Error("Please provide both images.") # Rate limiting check client_ip = get_client_ip(request) if not check_rate_limit(client_ip): raise gr.Error("Rate Limit Quota Exceeded - Visit studio.yourmirror.io to sign up for unlimited use") # NSFW content filtering - only check product image print(f"[NSFW] Checking product image for inappropriate content...") if is_nsfw_content(garment_img): raise gr.Error("Product image contains inappropriate content. Please use a different image.") # 1. Persist both images to Supabase storage job_id = str(uuid.uuid4()) folder = f"user_uploads/gradio/{job_id}" base_filename = f"{uuid.uuid4().hex}.png" garment_filename = f"{uuid.uuid4().hex}.png" base_path = f"{folder}/{base_filename}" garment_path = f"{folder}/{garment_filename}" base_img = downscale_image(base_img) garment_img = downscale_image(garment_img) base_url = upload_image_to_supabase(base_img, base_path) garment_url = upload_image_to_supabase(garment_img, garment_path) # Handle optional mask image (if provided by ComfyUI or future web UI) mask_url = None if mask_img is not None: print(f"[MASK] Processing user-provided mask image") mask_filename = f"{uuid.uuid4().hex}.png" mask_path = f"{folder}/{mask_filename}" mask_img = downscale_image(mask_img) mask_url = upload_image_to_supabase(mask_img, mask_path) print(f"[MASK] Uploaded mask: {mask_url}") else: print(f"[MASK] No mask provided - will use base image fallback") # 2. Insert new row into processing_jobs (anon key, relies on open RLS) token_for_row = str(uuid.uuid4()) insert_payload = { "id": job_id, "status": "queued", "base_image_path": base_url, "garment_image_path": garment_url, "mask_image_path": mask_url if mask_url else base_url, # Track actual mask used "access_token": token_for_row, "created_at": datetime.utcnow().isoformat(), } supabase.table("processing_jobs").insert(insert_payload).execute() # 3. Trigger edge function workflow_choice = (workflow_choice or "eyewear").lower() if workflow_choice not in WORKFLOW_CHOICES: workflow_choice = "eyewear" fn_payload = { "baseImageUrl": base_url, "garmentImageUrl": garment_url, # 🎭 Smart fallback: use provided mask OR base image (much better than garment!) "maskImageUrl": mask_url if mask_url else base_url, "jobId": job_id, "workflowType": workflow_choice, } # Log mask selection for debugging if mask_url: print(f"[API] Using user-provided mask: {mask_url}") else: print(f"[API] Using base image as mask fallback: {base_url}") headers = { "Content-Type": "application/json", "apikey": SUPABASE_SECRET_KEY, "Authorization": f"Bearer {SUPABASE_SECRET_KEY}", "x-api-origin": "hf/demo", } resp = requests.post( SUPABASE_FUNCTION_URL, json=fn_payload, headers=headers, timeout=REQUEST_TIMEOUT, ) if not resp.ok: raise gr.Error(f"Backend error: {resp.text}") # 4. Wait for completion via realtime (or polling fallback) result = wait_for_job_completion(job_id) if not result: raise gr.Error("Timed out waiting for job to finish.") # Result may be base64 data URI or http URL; normalise. if result.startswith("data:image"): header, b64 = result.split(",", 1) img_bytes = base64.b64decode(b64) result_img = Image.open(io.BytesIO(img_bytes)).convert("RGBA") else: result_url = _public_storage_url(result) resp_img = requests.get(result_url, timeout=30, headers={"x-api-origin": "hf/demo"}) resp_img.raise_for_status() result_img = Image.open(io.BytesIO(resp_img.content)).convert("RGBA") return result_img # ----------------------------------------------------------------------------- # Gradio UI # ----------------------------------------------------------------------------- description = "Upload a person photo (Base) and a product image. Select between Eyewear, Footwear, Full-Body Garments, or Top Garments to switch between the four available models. Click 👉 **Generate** to try on a product." # noqa: E501 with gr.Blocks(title="YOURMIRROR.IO - SM4LL-VTON Demo") as demo: # Header gr.Markdown("# SM4LL-VTON PRE-RELEASE DEMO | YOURMIRROR.IO | Virtual Try-On") gr.Markdown(description) IMG_SIZE = 256 with gr.Row(): # Left column: Example images with gr.Column(scale=1): gr.Markdown("### base image examples") with gr.Row(): base_example_1 = gr.Image( value="assets/base_image-1.jpg", interactive=False, height=120, width=120, show_label=False, ) base_example_2 = gr.Image( value="assets/base_image-2.jpg", interactive=False, height=120, width=120, show_label=False, ) base_example_3 = gr.Image( value="assets/base_image-3.jpg", interactive=False, height=120, width=120, show_label=False, ) gr.Markdown("### product examples") with gr.Row(): product_example_1 = gr.Image( value="assets/product_image-1.jpg", interactive=False, height=120, width=120, show_label=False, ) product_example_2 = gr.Image( value="assets/product_image-2.jpg", interactive=False, height=120, width=120, show_label=False, ) product_example_3 = gr.Image( value="assets/product_image-3.jpg", interactive=False, height=120, width=120, show_label=False, ) # Second column: Input fields with gr.Column(scale=1): base_in = gr.Image( label="Base Image", type="pil", height=IMG_SIZE, width=IMG_SIZE, ) garment_in = gr.Image( label="Product Image", type="pil", height=IMG_SIZE, width=IMG_SIZE, ) mask_in = gr.Image( label="Mask Image (Optional)", type="pil", height=IMG_SIZE, width=IMG_SIZE, visible=False, # Hidden from UI but available for API ) # Third column: Result with gr.Column(scale=2): result_out = gr.Image( label="Result", height=512, width=512, ) # Right column: Controls with gr.Column(scale=1): workflow_selector = gr.Radio( choices=[ ("Eyewear", "eyewear"), ("Footwear", "footwear"), ("Full-Body Garment", "dress"), ("Top Garment", "top"), ], value="eyewear", label="Model", ) generate_btn = gr.Button("Generate", variant="primary", size="lg") # Disclaimer box gr.Markdown("""
Disclaimer:
""") # Add spacing gr.Markdown("

") # Information section with gr.Row(): with gr.Column(): gr.Markdown("""
📄 Read the Technical Report here: sm4ll-vton.github.io/sm4llvton/

🎥 Watch the in-depth YouTube video here: YouTube Video Tutorial

🚀 Sign up for APIs and SDK on YourMirror: yourmirror.io

💬 Want to chat?
andrea@andreabaioni.com | andrea@yourmirror.io
a@puliatti.com | alex@yourmirror.io
""") # Wire up interaction generate_btn.click( generate, inputs=[base_in, garment_in, workflow_selector, mask_in], outputs=result_out, ) # Select handlers for example images base_example_1.select(lambda: Image.open("assets/base_image-1.jpg"), outputs=base_in) base_example_2.select(lambda: Image.open("assets/base_image-2.jpg"), outputs=base_in) base_example_3.select(lambda: Image.open("assets/base_image-3.jpg"), outputs=base_in) product_example_1.select(lambda: Image.open("assets/product_image-1.jpg"), outputs=garment_in) product_example_2.select(lambda: Image.open("assets/product_image-2.jpg"), outputs=garment_in) product_example_3.select(lambda: Image.open("assets/product_image-3.jpg"), outputs=garment_in) # Periodic cleanup for rate limiter (runs every 10 minutes) def periodic_cleanup(): cleanup_rate_limiter() # Schedule next cleanup threading.Timer(600.0, periodic_cleanup).start() # 10 minutes # Start cleanup timer threading.Timer(600.0, periodic_cleanup).start() # Run app if executed directly (e.g. `python app.py`). HF Spaces launches via # `python app.py` automatically if it finds `app.py` at repo root, but our file # lives in a sub-folder, so we keep the guard. if __name__ == "__main__": demo.launch()