NSFW-detection / app.py
seawolf2357's picture
Update app.py
356e2fe verified
raw
history blame
7.73 kB
import os
import spaces
import gradio as gr
import numpy as np
from PIL import Image
import random
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
import torch
from transformers import pipeline as transformers_pipeline
import re
# ------------------------------------------------------------
# DEVICE SETUP
# ------------------------------------------------------------
# Prefer GPU when the Space provides it, otherwise CPU
# `@spaces.GPU` takes care of binding the call itself, but we still
# need a device handle for manual ops.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------------------------------------------------------
# STABLE DIFFUSION XL PIPELINE
# ------------------------------------------------------------
pipe = StableDiffusionXLPipeline.from_pretrained(
"votepurchase/waiNSFWIllustrious_v120",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(device)
# Force important sub-modules to fp16 for VRAM efficiency (GPU) or
# reduce RAM (CPU). The model itself already sits in fp16, we just
# mirror that for sub-components explicitly to avoid silent fp32
# promotions that eat memory on ZeroGPU.
for sub in (pipe.text_encoder, pipe.text_encoder_2, pipe.vae, pipe.unet):
sub.to(torch.float16)
# ------------------------------------------------------------
# LIGHTWEIGHT KOR→ENG TRANSLATOR (CPU-ONLY)
# ------------------------------------------------------------
# * Hugging Face Spaces occasionally trips over the full MarianMT
# weights on ZeroGPU, resulting in the _untyped_storage_new_register
# error you just saw. We wrap initialisation in try/except and fall
# back to an identity function if the model cannot be loaded.
# * If you need translation and have a custom HF token, set the env
# HF_API_TOKEN so the smaller *small100* model can be pulled.
#
translator = None # default stub → "no translator"
try:
# First try the 60 MB Marian model.
translator = transformers_pipeline(
"translation",
model="Helsinki-NLP/opus-mt-ko-en",
device=-1, # force CPU so CUDA never initialises in the main proc
)
except Exception as marian_err:
print("[WARN] MarianMT load failed →", marian_err)
# Second chance: use compact multilingual SMaLL-100 (≈35 MB).
try:
translator = transformers_pipeline(
"translation",
model="alirezamsh/small100",
src_lang="ko_Kore",
tgt_lang="en_XX",
device=-1,
)
except Exception as small_err:
print("[WARN] SMaLL-100 load failed →", small_err)
# Final fallback: identity – no translation, but the app still runs.
translator = None
korean_regex = re.compile(r"[\uac00-\ud7af]+")
def maybe_translate(text: str) -> str:
"""Translate Korean → English if Korean chars present and translator ready."""
if translator is not None and korean_regex.search(text):
try:
out = translator(text, max_length=256, clean_up_tokenization_spaces=True)
return out[0]["translation_text"]
except Exception as e:
print("[WARN] Translation failed at runtime →", e)
return text
# ------------------------------------------------------------
# SDXL INFERENCE WRAPPER
# ------------------------------------------------------------
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1216
@spaces.GPU
def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
prompt = maybe_translate(prompt)
negative_prompt = maybe_translate(negative_prompt)
if len(prompt.split()) > 60:
print("[WARN] Prompt >60 words — CLIP may truncate it.")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
try:
output_image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return output_image
except RuntimeError as e:
print(f"[ERROR] Diffusion failed → {e}")
return Image.new("RGB", (width, height), color=(0, 0, 0))
# ------------------------------------------------------------
# UI LAYOUT + THEME (Pastel Lavender Background)
# ------------------------------------------------------------
css = """
body {background: #f2f1f7; color: #222; font-family: 'Noto Sans', sans-serif;}
#col-container {margin: 0 auto; max-width: 640px;}
.gr-button {background: #7fbdf6; color: #fff; border-radius: 8px;}
#prompt-box textarea {font-size: 1.1rem; height: 3rem; background: #fff; color: #222;}
"""
author_note = (
"**ℹ️ Automatic translation** — Korean prompts are translated to English "
"only if translation weights could be loaded. If not, Korean input will be "
"sent unchanged."
)
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
## 🖌️ Stable Diffusion XL Playground
{author_note}
"""
)
with gr.Column(elem_id="col-container"):
with gr.Row():
prompt = gr.Text(
label="Prompt",
elem_id="prompt-box",
show_label=False,
max_lines=1,
placeholder="Enter your prompt (Korean or English, 60 words max)",
)
run_button = gr.Button("Generate", scale=0)
result = gr.Image(label="", show_label=False)
examples = gr.Examples(
examples=[
["아름답고 섹시한 여자가 속옷을 입고 유혹하는 포즈, 관능적, 4K"],
["Seductive anime woman lounging in a dimly lit bar, adult anime style, ultra-detail"],
["a girl in a school uniform having her skirt pulled up by a boy, and then being fucked"],
["Moody mature anime scene of two lovers fuck under neon rain, sensual atmosphere"],
["Moody mature anime scene of two lovers kissing under neon rain, sensual atmosphere"],
],
inputs=[prompt],
)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
value="text, talk bubble, low quality, watermark, signature",
)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
with gr.Row():
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=20.0, step=0.1, value=7)
num_inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=28, step=1, value=28)
run_button.click(
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result],
)
demo.queue().launch()