TextureAnything / app.py
MrPio's picture
Disable xFormers
8c9e7fc
raw
history blame
2.47 kB
import hashlib
import io
import torch
from pathlib import Path
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler
from PIL import Image, ImageOps
import gradio as gr
# ---- Model loading ----
CACHE_DIR = "./cache"
CNET_MODEL = "MrPio/Texture-Anything_CNet-SD15"
SD_MODEL = "stable-diffusion-v1-5/stable-diffusion-v1-5"
controlnet = ControlNetModel.from_pretrained(
CNET_MODEL, cache_dir=CACHE_DIR, torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
SD_MODEL,
controlnet=controlnet,
cache_dir=CACHE_DIR,
torch_dtype=torch.float16,
safety_checker=None,
)
# speed & memory optimizations
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# pipe.enable_xformers_memory_efficient_attention() # if xformers installed
# pipe.enable_model_cpu_offload()
def pil2hash(image: Image.Image) -> str:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
image_bytes = buffer.getvalue()
return hashlib.sha256(image_bytes).hexdigest()
def caption2hash(caption: str) -> str:
return hashlib.sha256(caption.encode()).hexdigest()
# ---- Inference function ----
def infer(caption: str, condition_image: Image.Image, steps: int = 20, seed: int = 0, invert: bool = False):
img = condition_image.convert("RGB")
if invert:
img = ImageOps.invert(img)
cache_file = Path(f"inferences/{pil2hash(img)}_{caption2hash(caption)}.png")
if cache_file.exists():
return Image.open(cache_file)
generator = torch.manual_seed(seed)
output = pipe(prompt=caption, image=img, num_inference_steps=steps, generator=generator).images[0]
output.save(cache_file)
return output
# ---- Gradio UI + API ----
with gr.Blocks() as demo:
gr.Markdown("## ControlNet + Stable Diffusion 1.5")
with gr.Row():
txt = gr.Textbox(label="Prompt", placeholder="Describe the texture...")
cond = gr.Image(type="pil", label="Condition Image")
with gr.Row():
steps = gr.Slider(1, 50, value=20, label="Inference Steps")
seed = gr.Number(value=0, label="Seed (0 for random)")
inv = gr.Checkbox(label="Invert UV colors?")
btn = gr.Button("Generate")
out = gr.Image(label="Output")
btn.click(fn=infer, inputs=[txt, cond, steps, seed, inv], outputs=out)
# enable the standard gradio REST API (/run/predict)
demo.launch(server_name="0.0.0.0", server_port=7860)