Update app.py
Browse files
app.py
CHANGED
@@ -21,21 +21,18 @@ torch.backends.cudnn.deterministic = True
|
|
21 |
|
22 |
# Initialize the models
|
23 |
base_model = "runwayml/stable-diffusion-v1-5"
|
24 |
-
dtype = torch.float16
|
25 |
|
26 |
# Load the custom UNet
|
27 |
unet = UNet2DConditionModelEx.from_pretrained(
|
28 |
base_model,
|
29 |
-
subfolder="unet"
|
30 |
-
torch_dtype=dtype
|
31 |
)
|
32 |
|
33 |
unet = unet.add_extra_conditions("ow-gbi-control-lora")
|
34 |
|
35 |
pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
|
36 |
base_model,
|
37 |
-
unet=unet
|
38 |
-
torch_dtype=dtype
|
39 |
)
|
40 |
|
41 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
@@ -45,9 +42,6 @@ pipe.load_lora_weights(
|
|
45 |
weight_name="40kHalf.safetensors"
|
46 |
)
|
47 |
|
48 |
-
# Move to CPU initially
|
49 |
-
pipe = pipe.to("cpu")
|
50 |
-
|
51 |
def get_random_condition_image():
|
52 |
conditions_dir = Path("conditions")
|
53 |
if conditions_dir.exists():
|
@@ -63,16 +57,16 @@ def get_canny_image(image, low_threshold=100, high_threshold=200):
|
|
63 |
elif isinstance(image, str):
|
64 |
image = np.array(Image.open(image))
|
65 |
|
66 |
-
if len(image.shape) == 2:
|
67 |
image = np.stack([image] * 3, axis=-1)
|
68 |
-
elif image.shape[2] == 4:
|
69 |
image = image[..., :3]
|
70 |
|
71 |
canny_image = cv2.Canny(image, low_threshold, high_threshold)
|
72 |
canny_image = np.stack([canny_image] * 3, axis=-1)
|
73 |
return Image.fromarray(canny_image)
|
74 |
|
75 |
-
@spaces.GPU(duration=300)
|
76 |
def generate_image(input_image, prompt, negative_prompt, guidance_scale, steps, low_threshold, high_threshold, seed, progress=gr.Progress()):
|
77 |
if input_image is None:
|
78 |
raise gr.Error("Please provide an input image!")
|
@@ -89,35 +83,22 @@ def generate_image(input_image, prompt, negative_prompt, guidance_scale, steps,
|
|
89 |
progress(0.1, desc="Processing input image...")
|
90 |
canny_image = get_canny_image(input_image, low_threshold, high_threshold)
|
91 |
|
92 |
-
progress(0.2, desc="Moving model to device...")
|
93 |
-
# Move pipeline to GPU
|
94 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
95 |
-
pipe.to(device)
|
96 |
-
|
97 |
progress(0.3, desc="Generating image...")
|
98 |
with torch.no_grad():
|
99 |
image = pipe(
|
100 |
prompt=prompt,
|
101 |
negative_prompt=negative_prompt,
|
102 |
-
num_inference_steps=steps,
|
103 |
-
guidance_scale=guidance_scale,
|
104 |
image=canny_image,
|
105 |
extra_condition_scale=1.0,
|
106 |
generator=generator
|
107 |
).images[0]
|
108 |
|
109 |
-
progress(0.9, desc="Moving model back to CPU...")
|
110 |
-
# Move back to CPU to free up GPU memory
|
111 |
-
pipe.to("cpu")
|
112 |
-
torch.cuda.empty_cache()
|
113 |
-
|
114 |
progress(1.0, desc="Done!")
|
115 |
return canny_image, image
|
116 |
|
117 |
except Exception as e:
|
118 |
-
# Move back to CPU in case of error
|
119 |
-
pipe.to("cpu")
|
120 |
-
torch.cuda.empty_cache()
|
121 |
raise gr.Error(f"An error occurred: {str(e)}")
|
122 |
|
123 |
def random_image_click():
|
@@ -204,13 +185,11 @@ with gr.Blocks() as demo:
|
|
204 |
cache_examples=True
|
205 |
)
|
206 |
|
207 |
-
# Handle the random image button
|
208 |
random_image_btn.click(
|
209 |
fn=random_image_click,
|
210 |
outputs=input_image
|
211 |
)
|
212 |
|
213 |
-
# Handle the generate button
|
214 |
generate.click(
|
215 |
fn=generate_image,
|
216 |
inputs=[
|
@@ -226,5 +205,5 @@ with gr.Blocks() as demo:
|
|
226 |
outputs=[canny_output, result]
|
227 |
)
|
228 |
|
229 |
-
demo.queue()
|
230 |
-
demo.launch()
|
|
|
21 |
|
22 |
# Initialize the models
|
23 |
base_model = "runwayml/stable-diffusion-v1-5"
|
|
|
24 |
|
25 |
# Load the custom UNet
|
26 |
unet = UNet2DConditionModelEx.from_pretrained(
|
27 |
base_model,
|
28 |
+
subfolder="unet"
|
|
|
29 |
)
|
30 |
|
31 |
unet = unet.add_extra_conditions("ow-gbi-control-lora")
|
32 |
|
33 |
pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
|
34 |
base_model,
|
35 |
+
unet=unet
|
|
|
36 |
)
|
37 |
|
38 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
|
|
42 |
weight_name="40kHalf.safetensors"
|
43 |
)
|
44 |
|
|
|
|
|
|
|
45 |
def get_random_condition_image():
|
46 |
conditions_dir = Path("conditions")
|
47 |
if conditions_dir.exists():
|
|
|
57 |
elif isinstance(image, str):
|
58 |
image = np.array(Image.open(image))
|
59 |
|
60 |
+
if len(image.shape) == 2:
|
61 |
image = np.stack([image] * 3, axis=-1)
|
62 |
+
elif image.shape[2] == 4:
|
63 |
image = image[..., :3]
|
64 |
|
65 |
canny_image = cv2.Canny(image, low_threshold, high_threshold)
|
66 |
canny_image = np.stack([canny_image] * 3, axis=-1)
|
67 |
return Image.fromarray(canny_image)
|
68 |
|
69 |
+
@spaces.GPU(duration=300)
|
70 |
def generate_image(input_image, prompt, negative_prompt, guidance_scale, steps, low_threshold, high_threshold, seed, progress=gr.Progress()):
|
71 |
if input_image is None:
|
72 |
raise gr.Error("Please provide an input image!")
|
|
|
83 |
progress(0.1, desc="Processing input image...")
|
84 |
canny_image = get_canny_image(input_image, low_threshold, high_threshold)
|
85 |
|
|
|
|
|
|
|
|
|
|
|
86 |
progress(0.3, desc="Generating image...")
|
87 |
with torch.no_grad():
|
88 |
image = pipe(
|
89 |
prompt=prompt,
|
90 |
negative_prompt=negative_prompt,
|
91 |
+
num_inference_steps=int(steps), # Convert to int
|
92 |
+
guidance_scale=float(guidance_scale), # Convert to float
|
93 |
image=canny_image,
|
94 |
extra_condition_scale=1.0,
|
95 |
generator=generator
|
96 |
).images[0]
|
97 |
|
|
|
|
|
|
|
|
|
|
|
98 |
progress(1.0, desc="Done!")
|
99 |
return canny_image, image
|
100 |
|
101 |
except Exception as e:
|
|
|
|
|
|
|
102 |
raise gr.Error(f"An error occurred: {str(e)}")
|
103 |
|
104 |
def random_image_click():
|
|
|
185 |
cache_examples=True
|
186 |
)
|
187 |
|
|
|
188 |
random_image_btn.click(
|
189 |
fn=random_image_click,
|
190 |
outputs=input_image
|
191 |
)
|
192 |
|
|
|
193 |
generate.click(
|
194 |
fn=generate_image,
|
195 |
inputs=[
|
|
|
205 |
outputs=[canny_output, result]
|
206 |
)
|
207 |
|
208 |
+
demo.queue()
|
209 |
+
demo.launch(share=True)
|