File size: 3,458 Bytes
e29a7a0
 
13e41c1
 
84ab83e
7ad3690
 
13e41c1
 
 
 
 
 
 
 
 
84ab83e
fb70850
a84d2fb
 
 
 
 
 
13e41c1
 
 
fb70850
a84d2fb
 
e29a7a0
 
7ad3690
13e41c1
 
 
 
 
 
a84d2fb
 
13e41c1
fb70850
84ab83e
 
 
fb70850
 
13e41c1
 
fb70850
e29a7a0
13e41c1
 
 
 
 
 
 
 
 
 
 
 
 
 
e29a7a0
 
 
 
 
 
7ad3690
 
e29a7a0
13e41c1
 
e29a7a0
 
 
 
 
13e41c1
e29a7a0
 
13e41c1
 
 
e29a7a0
 
 
 
 
13e41c1
e29a7a0
 
 
 
13e41c1
 
 
 
 
 
 
 
 
 
e29a7a0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import torch
import numpy as np
import cv2
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
from model import UNet2DConditionModelEx
from pipeline import StableDiffusionControlLoraV3Pipeline 
from PIL import Image
import os
from huggingface_hub import login

# Login using the token
login(token=os.environ.get("HF_TOKEN"))

# Initialize the models
base_model = "runwayml/stable-diffusion-v1-5"
dtype = torch.float16  # A100 works better with float16

try:
    # Check if CUDA is available
    device = "cuda" if torch.cuda.is_available() else "cpu"
except:
    device = "cpu"

# Load the custom UNet
unet = UNet2DConditionModelEx.from_pretrained(
    base_model,
    subfolder="unet",
    torch_dtype=dtype,
    device_map="auto"  # Let the model handle device placement
)

# Add conditioning with ow-gbi-control-lora
unet = unet.add_extra_conditions("ow-gbi-control-lora")

# Create the pipeline with custom UNet
pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
    base_model, 
    unet=unet,
    torch_dtype=dtype,
    device_map="auto"  # Let the model handle device placement
)

# Use a faster scheduler
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

# Load the ControlLoRA weights
pipe.load_lora_weights(
    "models",
    weight_name="40kHalf.safetensors"
)

def get_canny_image(image, low_threshold=100, high_threshold=200):
    if isinstance(image, Image.Image):
        image = np.array(image)
    
    if image.shape[2] == 4:
        image = image[..., :3]
    
    canny_image = cv2.Canny(image, low_threshold, high_threshold)
    canny_image = np.stack([canny_image] * 3, axis=-1)
    return Image.fromarray(canny_image)

def generate_image(input_image, prompt, negative_prompt, guidance_scale, steps, low_threshold, high_threshold):
    canny_image = get_canny_image(input_image, low_threshold, high_threshold)
    
    with torch.no_grad():
        image = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=steps,
            guidance_scale=guidance_scale,
            image=canny_image,
            extra_condition_scale=1.0
        ).images[0]
    
    return canny_image, image

# Create the Gradio interface
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Input Image", type="numpy")
            prompt = gr.Textbox(label="Prompt")
            negative_prompt = gr.Textbox(label="Negative Prompt")
            with gr.Row():
                low_threshold = gr.Slider(minimum=1, maximum=255, value=100, label="Canny Low Threshold")
                high_threshold = gr.Slider(minimum=1, maximum=255, value=200, label="Canny High Threshold")
            guidance_scale = gr.Slider(minimum=1, maximum=20, value=7.5, label="Guidance Scale")
            steps = gr.Slider(minimum=1, maximum=100, value=50, label="Steps")
            generate = gr.Button("Generate")
        
        with gr.Column():
            canny_output = gr.Image(label="Canny Edge Detection")
            result = gr.Image(label="Generated Image")
    
    generate.click(
        fn=generate_image,
        inputs=[
            input_image,
            prompt,
            negative_prompt,
            guidance_scale,
            steps,
            low_threshold,
            high_threshold
        ],
        outputs=[canny_output, result]
    )

demo.launch()