|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import cv2 |
|
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler |
|
from model import UNet2DConditionModelEx |
|
from pipeline import StableDiffusionControlLoraV3Pipeline |
|
from PIL import Image |
|
import os |
|
from huggingface_hub import login |
|
import spaces |
|
import random |
|
from pathlib import Path |
|
|
|
|
|
login(token=os.environ.get("HF_TOKEN")) |
|
|
|
|
|
torch.manual_seed(42) |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
base_model = "runwayml/stable-diffusion-v1-5" |
|
dtype = torch.float16 |
|
|
|
|
|
unet = UNet2DConditionModelEx.from_pretrained( |
|
base_model, |
|
subfolder="unet", |
|
torch_dtype=dtype |
|
) |
|
|
|
unet = unet.add_extra_conditions("ow-gbi-control-lora") |
|
|
|
pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained( |
|
base_model, |
|
unet=unet, |
|
torch_dtype=dtype |
|
) |
|
|
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
|
pipe.load_lora_weights( |
|
"models", |
|
weight_name="40kHalf.safetensors" |
|
) |
|
|
|
def get_random_condition_image(): |
|
conditions_dir = Path("conditions") |
|
if conditions_dir.exists(): |
|
image_files = list(conditions_dir.glob("*.[jp][pn][g]")) |
|
if image_files: |
|
random_image = random.choice(image_files) |
|
return str(random_image) |
|
return None |
|
|
|
def get_canny_image(image, low_threshold=100, high_threshold=200): |
|
if isinstance(image, Image.Image): |
|
image = np.array(image) |
|
|
|
if image.shape[2] == 4: |
|
image = image[..., :3] |
|
|
|
canny_image = cv2.Canny(image, low_threshold, high_threshold) |
|
canny_image = np.stack([canny_image] * 3, axis=-1) |
|
return Image.fromarray(canny_image) |
|
|
|
@spaces.GPU(duration=120) |
|
def generate_image(input_image, prompt, negative_prompt, guidance_scale, steps, low_threshold, high_threshold, seed): |
|
if seed is not None and seed != "": |
|
try: |
|
generator = torch.Generator().manual_seed(int(seed)) |
|
except ValueError: |
|
generator = torch.Generator() |
|
else: |
|
generator = torch.Generator() |
|
|
|
canny_image = get_canny_image(input_image, low_threshold, high_threshold) |
|
|
|
with torch.no_grad(): |
|
image = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=steps, |
|
guidance_scale=guidance_scale, |
|
image=canny_image, |
|
extra_condition_scale=1.0, |
|
generator=generator |
|
).images[0] |
|
|
|
return canny_image, image |
|
|
|
def random_image_click(): |
|
image_path = get_random_condition_image() |
|
if image_path: |
|
return Image.open(image_path) |
|
return None |
|
|
|
|
|
examples = [ |
|
[ |
|
"conditions/example1.jpg", |
|
"a futuristic cyberpunk city", |
|
"blurry, bad quality", |
|
7.5, |
|
50, |
|
100, |
|
200, |
|
42 |
|
], |
|
[ |
|
"conditions/example2.jpg", |
|
"a serene mountain landscape", |
|
"dark, gloomy", |
|
7.0, |
|
40, |
|
120, |
|
180, |
|
123 |
|
] |
|
] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Control LoRA v3 Demo |
|
⚠️ Warning: This is a demo of Control LoRA v3. Please be aware that generation can take several minutes. |
|
The model uses edge detection to guide the image generation process. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(label="Input Image", type="numpy") |
|
random_image_btn = gr.Button("Load Random Reference Image") |
|
|
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
placeholder="Enter your prompt here... (e.g., 'a futuristic cyberpunk city')" |
|
) |
|
negative_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
placeholder="Enter things you don't want to see... (e.g., 'blurry, bad quality')" |
|
) |
|
with gr.Row(): |
|
low_threshold = gr.Slider(minimum=1, maximum=255, value=100, label="Canny Low Threshold") |
|
high_threshold = gr.Slider(minimum=1, maximum=255, value=200, label="Canny High Threshold") |
|
guidance_scale = gr.Slider(minimum=1, maximum=20, value=7.5, label="Guidance Scale") |
|
steps = gr.Slider(minimum=1, maximum=100, value=50, label="Steps") |
|
seed = gr.Textbox(label="Seed (empty for random)", placeholder="Enter a number for reproducible results") |
|
generate = gr.Button("Generate") |
|
|
|
with gr.Column(): |
|
canny_output = gr.Image(label="Canny Edge Detection") |
|
result = gr.Image(label="Generated Image") |
|
|
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[ |
|
input_image, |
|
prompt, |
|
negative_prompt, |
|
guidance_scale, |
|
steps, |
|
low_threshold, |
|
high_threshold, |
|
seed |
|
], |
|
outputs=[canny_output, result], |
|
fn=generate_image, |
|
cache_examples=True |
|
) |
|
|
|
|
|
random_image_btn.click( |
|
fn=random_image_click, |
|
outputs=input_image |
|
) |
|
|
|
|
|
generate.click( |
|
fn=generate_image, |
|
inputs=[ |
|
input_image, |
|
prompt, |
|
negative_prompt, |
|
guidance_scale, |
|
steps, |
|
low_threshold, |
|
high_threshold, |
|
seed |
|
], |
|
outputs=[canny_output, result] |
|
) |
|
|
|
demo.launch() |