naonauno commited on
Commit
097fdcb
·
verified ·
1 Parent(s): e14967b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -11
app.py CHANGED
@@ -9,13 +9,19 @@ from PIL import Image
9
  import os
10
  from huggingface_hub import login
11
  import spaces
 
 
12
 
13
  # Login using the token
14
  login(token=os.environ.get("HF_TOKEN"))
15
 
 
 
 
 
16
  # Initialize the models
17
  base_model = "runwayml/stable-diffusion-v1-5"
18
- dtype = torch.float16 # A100 works better with float16
19
 
20
  # Load the custom UNet
21
  unet = UNet2DConditionModelEx.from_pretrained(
@@ -24,25 +30,30 @@ unet = UNet2DConditionModelEx.from_pretrained(
24
  torch_dtype=dtype
25
  )
26
 
27
- # Add conditioning with ow-gbi-control-lora
28
  unet = unet.add_extra_conditions("ow-gbi-control-lora")
29
 
30
- # Create the pipeline with custom UNet
31
  pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
32
  base_model,
33
  unet=unet,
34
  torch_dtype=dtype
35
  )
36
 
37
- # Use a faster scheduler
38
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
39
 
40
- # Load the ControlLoRA weights
41
  pipe.load_lora_weights(
42
  "models",
43
  weight_name="40kHalf.safetensors"
44
  )
45
 
 
 
 
 
 
 
 
 
 
46
  def get_canny_image(image, low_threshold=100, high_threshold=200):
47
  if isinstance(image, Image.Image):
48
  image = np.array(image)
@@ -54,8 +65,16 @@ def get_canny_image(image, low_threshold=100, high_threshold=200):
54
  canny_image = np.stack([canny_image] * 3, axis=-1)
55
  return Image.fromarray(canny_image)
56
 
57
- @spaces.GPU(duration=120) # Set GPU allocation duration to 120 seconds
58
- def generate_image(input_image, prompt, negative_prompt, guidance_scale, steps, low_threshold, high_threshold):
 
 
 
 
 
 
 
 
59
  canny_image = get_canny_image(input_image, low_threshold, high_threshold)
60
 
61
  with torch.no_grad():
@@ -65,29 +84,102 @@ def generate_image(input_image, prompt, negative_prompt, guidance_scale, steps,
65
  num_inference_steps=steps,
66
  guidance_scale=guidance_scale,
67
  image=canny_image,
68
- extra_condition_scale=1.0
 
69
  ).images[0]
70
 
71
  return canny_image, image
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # Create the Gradio interface
74
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
75
  with gr.Row():
76
  with gr.Column():
77
  input_image = gr.Image(label="Input Image", type="numpy")
78
- prompt = gr.Textbox(label="Prompt")
79
- negative_prompt = gr.Textbox(label="Negative Prompt")
 
 
 
 
 
 
 
 
80
  with gr.Row():
81
  low_threshold = gr.Slider(minimum=1, maximum=255, value=100, label="Canny Low Threshold")
82
  high_threshold = gr.Slider(minimum=1, maximum=255, value=200, label="Canny High Threshold")
83
  guidance_scale = gr.Slider(minimum=1, maximum=20, value=7.5, label="Guidance Scale")
84
  steps = gr.Slider(minimum=1, maximum=100, value=50, label="Steps")
 
85
  generate = gr.Button("Generate")
86
 
87
  with gr.Column():
88
  canny_output = gr.Image(label="Canny Edge Detection")
89
  result = gr.Image(label="Generated Image")
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  generate.click(
92
  fn=generate_image,
93
  inputs=[
@@ -97,7 +189,8 @@ with gr.Blocks() as demo:
97
  guidance_scale,
98
  steps,
99
  low_threshold,
100
- high_threshold
 
101
  ],
102
  outputs=[canny_output, result]
103
  )
 
9
  import os
10
  from huggingface_hub import login
11
  import spaces
12
+ import random
13
+ from pathlib import Path
14
 
15
  # Login using the token
16
  login(token=os.environ.get("HF_TOKEN"))
17
 
18
+ # For deterministic generation
19
+ torch.manual_seed(42)
20
+ torch.backends.cudnn.deterministic = True
21
+
22
  # Initialize the models
23
  base_model = "runwayml/stable-diffusion-v1-5"
24
+ dtype = torch.float16
25
 
26
  # Load the custom UNet
27
  unet = UNet2DConditionModelEx.from_pretrained(
 
30
  torch_dtype=dtype
31
  )
32
 
 
33
  unet = unet.add_extra_conditions("ow-gbi-control-lora")
34
 
 
35
  pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
36
  base_model,
37
  unet=unet,
38
  torch_dtype=dtype
39
  )
40
 
 
41
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
42
 
 
43
  pipe.load_lora_weights(
44
  "models",
45
  weight_name="40kHalf.safetensors"
46
  )
47
 
48
+ def get_random_condition_image():
49
+ conditions_dir = Path("conditions")
50
+ if conditions_dir.exists():
51
+ image_files = list(conditions_dir.glob("*.[jp][pn][g]")) # matches .jpg, .png, .jpeg
52
+ if image_files:
53
+ random_image = random.choice(image_files)
54
+ return str(random_image)
55
+ return None
56
+
57
  def get_canny_image(image, low_threshold=100, high_threshold=200):
58
  if isinstance(image, Image.Image):
59
  image = np.array(image)
 
65
  canny_image = np.stack([canny_image] * 3, axis=-1)
66
  return Image.fromarray(canny_image)
67
 
68
+ @spaces.GPU(duration=120)
69
+ def generate_image(input_image, prompt, negative_prompt, guidance_scale, steps, low_threshold, high_threshold, seed):
70
+ if seed is not None and seed != "":
71
+ try:
72
+ generator = torch.Generator().manual_seed(int(seed))
73
+ except ValueError:
74
+ generator = torch.Generator()
75
+ else:
76
+ generator = torch.Generator()
77
+
78
  canny_image = get_canny_image(input_image, low_threshold, high_threshold)
79
 
80
  with torch.no_grad():
 
84
  num_inference_steps=steps,
85
  guidance_scale=guidance_scale,
86
  image=canny_image,
87
+ extra_condition_scale=1.0,
88
+ generator=generator
89
  ).images[0]
90
 
91
  return canny_image, image
92
 
93
+ def random_image_click():
94
+ image_path = get_random_condition_image()
95
+ if image_path:
96
+ return Image.open(image_path)
97
+ return None
98
+
99
+ # Example data
100
+ examples = [
101
+ [
102
+ "conditions/example1.jpg", # Replace with actual paths
103
+ "a futuristic cyberpunk city",
104
+ "blurry, bad quality",
105
+ 7.5,
106
+ 50,
107
+ 100,
108
+ 200,
109
+ 42
110
+ ],
111
+ [
112
+ "conditions/example2.jpg", # Replace with actual paths
113
+ "a serene mountain landscape",
114
+ "dark, gloomy",
115
+ 7.0,
116
+ 40,
117
+ 120,
118
+ 180,
119
+ 123
120
+ ]
121
+ ]
122
+
123
  # Create the Gradio interface
124
  with gr.Blocks() as demo:
125
+ gr.Markdown(
126
+ """
127
+ # Control LoRA v3 Demo
128
+ ⚠️ Warning: This is a demo of Control LoRA v3. Please be aware that generation can take several minutes.
129
+ The model uses edge detection to guide the image generation process.
130
+ """
131
+ )
132
+
133
  with gr.Row():
134
  with gr.Column():
135
  input_image = gr.Image(label="Input Image", type="numpy")
136
+ random_image_btn = gr.Button("Load Random Reference Image")
137
+
138
+ prompt = gr.Textbox(
139
+ label="Prompt",
140
+ placeholder="Enter your prompt here... (e.g., 'a futuristic cyberpunk city')"
141
+ )
142
+ negative_prompt = gr.Textbox(
143
+ label="Negative Prompt",
144
+ placeholder="Enter things you don't want to see... (e.g., 'blurry, bad quality')"
145
+ )
146
  with gr.Row():
147
  low_threshold = gr.Slider(minimum=1, maximum=255, value=100, label="Canny Low Threshold")
148
  high_threshold = gr.Slider(minimum=1, maximum=255, value=200, label="Canny High Threshold")
149
  guidance_scale = gr.Slider(minimum=1, maximum=20, value=7.5, label="Guidance Scale")
150
  steps = gr.Slider(minimum=1, maximum=100, value=50, label="Steps")
151
+ seed = gr.Textbox(label="Seed (empty for random)", placeholder="Enter a number for reproducible results")
152
  generate = gr.Button("Generate")
153
 
154
  with gr.Column():
155
  canny_output = gr.Image(label="Canny Edge Detection")
156
  result = gr.Image(label="Generated Image")
157
 
158
+ # Set up example gallery
159
+ gr.Examples(
160
+ examples=examples,
161
+ inputs=[
162
+ input_image,
163
+ prompt,
164
+ negative_prompt,
165
+ guidance_scale,
166
+ steps,
167
+ low_threshold,
168
+ high_threshold,
169
+ seed
170
+ ],
171
+ outputs=[canny_output, result],
172
+ fn=generate_image,
173
+ cache_examples=True
174
+ )
175
+
176
+ # Handle the random image button
177
+ random_image_btn.click(
178
+ fn=random_image_click,
179
+ outputs=input_image
180
+ )
181
+
182
+ # Handle the generate button
183
  generate.click(
184
  fn=generate_image,
185
  inputs=[
 
189
  guidance_scale,
190
  steps,
191
  low_threshold,
192
+ high_threshold,
193
+ seed
194
  ],
195
  outputs=[canny_output, result]
196
  )