multimodalart HF Staff commited on
Commit
fc7c434
·
verified ·
1 Parent(s): 63dfb76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -8
app.py CHANGED
@@ -55,6 +55,32 @@ def rewrite_prompt(input_prompt):
55
 
56
  # --- 2. Preprocessor Functions ---
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def extract_canny(input_image):
59
  image = np.array(input_image)
60
  image = cv2.Canny(image, 100, 200)
@@ -104,7 +130,7 @@ anyline = AnylineDetector.from_pretrained("TheMistoAI/MistoLine", filename="MTEE
104
  print("All models loaded.")
105
 
106
  def get_control_image(input_image, control_mode):
107
- """A master function to select and run the correct preprocessor."""
108
  if control_mode == "Canny":
109
  return extract_canny(input_image)
110
  elif control_mode == "Soft Edge":
@@ -143,6 +169,8 @@ def generate(
143
  if not prompt:
144
  raise gr.Error("Please enter a prompt.")
145
 
 
 
146
  if randomize_seed:
147
  seed = random.randint(0, MAX_SEED)
148
 
@@ -151,7 +179,7 @@ def generate(
151
  print(f"Original prompt: {prompt}\nEnhanced prompt: {enhanced_prompt}")
152
  prompt = enhanced_prompt
153
 
154
- control_image = get_control_image(image, conditioning)
155
  generator = torch.Generator(device=device).manual_seed(int(seed))
156
 
157
  generated_image = pipe(
@@ -159,8 +187,8 @@ def generate(
159
  negative_prompt=negative_prompt,
160
  control_image=control_image,
161
  controlnet_conditioning_scale=controlnet_conditioning_scale,
162
- width=image.width,
163
- height=image.height,
164
  num_inference_steps=int(num_inference_steps),
165
  guidance_scale=guidance_scale,
166
  generator=generator,
@@ -229,8 +257,4 @@ with gr.Blocks(css=css, theme=gr.themes.Citrus()) as demo:
229
  )
230
 
231
  if __name__ == "__main__":
232
- if not os.path.exists("assets"):
233
- os.makedirs("assets")
234
- print("Created 'assets' directory. Please add example images for the Gradio examples to work.")
235
-
236
  demo.launch()
 
55
 
56
  # --- 2. Preprocessor Functions ---
57
 
58
+ def resize_image(input_image, max_size=1024):
59
+ """
60
+ Resizes an image so that its longest side is `max_size` pixels,
61
+ maintaining aspect ratio. The final dimensions are made divisible by 8.
62
+ """
63
+ w, h = input_image.size
64
+ aspect_ratio = w / h
65
+
66
+ if w > h:
67
+ new_w = max_size
68
+ new_h = int(new_w / aspect_ratio)
69
+ else:
70
+ new_h = max_size
71
+ new_w = int(new_h * aspect_ratio)
72
+
73
+ # Make dimensions divisible by 8
74
+ new_w = new_w - (new_w % 8)
75
+ new_h = new_h - (new_h % 8)
76
+
77
+ # Handle potential zero dimensions after rounding
78
+ if new_w == 0: new_w = 8
79
+ if new_h == 0: new_h = 8
80
+
81
+ return input_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
82
+
83
+
84
  def extract_canny(input_image):
85
  image = np.array(input_image)
86
  image = cv2.Canny(image, 100, 200)
 
130
  print("All models loaded.")
131
 
132
  def get_control_image(input_image, control_mode):
133
+ """A master function to select and run the correct preprocessor on a pre-resized image."""
134
  if control_mode == "Canny":
135
  return extract_canny(input_image)
136
  elif control_mode == "Soft Edge":
 
169
  if not prompt:
170
  raise gr.Error("Please enter a prompt.")
171
 
172
+ resized_image = resize_image(image, max_size=1024)
173
+
174
  if randomize_seed:
175
  seed = random.randint(0, MAX_SEED)
176
 
 
179
  print(f"Original prompt: {prompt}\nEnhanced prompt: {enhanced_prompt}")
180
  prompt = enhanced_prompt
181
 
182
+ control_image = get_control_image(resized_image, conditioning)
183
  generator = torch.Generator(device=device).manual_seed(int(seed))
184
 
185
  generated_image = pipe(
 
187
  negative_prompt=negative_prompt,
188
  control_image=control_image,
189
  controlnet_conditioning_scale=controlnet_conditioning_scale,
190
+ width=resized_image.width,
191
+ height=resized_image.height,
192
  num_inference_steps=int(num_inference_steps),
193
  guidance_scale=guidance_scale,
194
  generator=generator,
 
257
  )
258
 
259
  if __name__ == "__main__":
 
 
 
 
260
  demo.launch()