owiedotch commited on
Commit
d37d209
·
verified ·
1 Parent(s): b40a827

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -25
app.py CHANGED
@@ -12,6 +12,7 @@ import spaces
12
  import einops
13
  import math
14
  import random
 
15
 
16
  def download_file(url, filename):
17
  response = requests.get(url, stream=True)
@@ -48,7 +49,6 @@ def setup_environment():
48
 
49
  setup_environment()
50
 
51
- # Importing from the CCSR folder
52
  from ldm.xformers_state import disable_xformers
53
  from model.q_sampler import SpacedSampler
54
  from model.ccsr_stage1 import ControlLDM
@@ -89,28 +89,30 @@ def process(
89
  f"seed={seed}\n"
90
  f"tile_diffusion={tile_diffusion}, tile_diffusion_size={tile_diffusion_size}, tile_diffusion_stride={tile_diffusion_stride}"
91
  )
92
- if seed == -1:
93
- seed = random.randint(0, 2**32 - 1)
94
- torch.manual_seed(seed)
95
 
96
- # Resize the input image
97
  if sr_scale != 1:
98
- new_size = tuple(math.ceil(x * sr_scale) for x in control_img.size)
99
- control_img = control_img.resize(new_size, Image.BICUBIC)
 
 
100
 
101
  input_size = control_img.size
102
 
103
- # Prepare the control image
104
  if not tile_diffusion:
105
  control_img = auto_resize(control_img, 512)
106
  else:
107
  control_img = auto_resize(control_img, tile_diffusion_size)
108
 
 
109
  control_img = control_img.resize(
110
  tuple((s // 64 + 1) * 64 for s in control_img.size), Image.LANCZOS
111
  )
112
  control_img = np.array(control_img)
113
 
 
114
  control = torch.tensor(control_img[None] / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
115
  control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
116
  height, width = control.size(-2), control.size(-1)
@@ -145,10 +147,10 @@ def process(
145
 
146
  return preds
147
 
148
- def update_output_resolution(image):
149
  if image is not None:
150
  width, height = image.size
151
- return f"Current resolution: {width}x{height}. Output resolution: {int(width*sr_scale.value)}x{int(height*sr_scale.value)}"
152
  return "Upload an image to see the output resolution"
153
 
154
  block = gr.Blocks().queue()
@@ -166,24 +168,23 @@ with block:
166
 
167
  with gr.Accordion("Options", open=False):
168
  with gr.Column():
169
- num_samples = gr.Slider(label="Number Of Samples", minimum=1, maximum=12, value=1, step=1, info="Number of output images to generate.")
170
- strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01, info="Strength of the control signal.")
171
- positive_prompt = gr.Textbox(label="Positive Prompt", value="", info="Positive text prompt to guide the image generation.")
172
  negative_prompt = gr.Textbox(
173
  label="Negative Prompt",
174
- value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
175
- info="Negative text prompt to avoid undesirable features."
176
  )
177
- cfg_scale = gr.Slider(label="Classifier Free Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1, info="Scale for classifier-free guidance.")
178
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=45, step=1, info="Number of diffusion steps.")
179
- use_color_fix = gr.Checkbox(label="Use Color Correction", value=True, info="Apply color correction to the output image.")
180
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231, info="Random seed for reproducibility. Set to -1 for a random seed.")
181
- tile_diffusion = gr.Checkbox(label="Tile diffusion", value=False, info="Enable tiled diffusion for large images.")
182
- tile_diffusion_size = gr.Slider(label="Tile diffusion size", minimum=512, maximum=1024, value=512, step=256, info="Size of each tile for tiled diffusion.")
183
- tile_diffusion_stride = gr.Slider(label="Tile diffusion stride", minimum=256, maximum=512, value=256, step=128, info="Stride between tiles for tiled diffusion.")
184
 
185
  with gr.Column():
186
- result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery")
187
 
188
  inputs = [
189
  input_image,
@@ -203,8 +204,8 @@ with block:
203
  run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
204
 
205
  # Update output resolution when image is uploaded or SR scale is changed
206
- input_image.change(update_output_resolution, inputs=[input_image], outputs=[output_resolution])
207
- sr_scale.change(update_output_resolution, inputs=[input_image], outputs=[output_resolution])
208
 
209
  # Disable SR scale slider when no image is uploaded
210
  input_image.change(
 
12
  import einops
13
  import math
14
  import random
15
+ import pytorch_lightning as pl
16
 
17
  def download_file(url, filename):
18
  response = requests.get(url, stream=True)
 
49
 
50
  setup_environment()
51
 
 
52
  from ldm.xformers_state import disable_xformers
53
  from model.q_sampler import SpacedSampler
54
  from model.ccsr_stage1 import ControlLDM
 
89
  f"seed={seed}\n"
90
  f"tile_diffusion={tile_diffusion}, tile_diffusion_size={tile_diffusion_size}, tile_diffusion_stride={tile_diffusion_stride}"
91
  )
92
+ pl.seed_everything(seed)
 
 
93
 
94
+ # Resize lr
95
  if sr_scale != 1:
96
+ control_img = control_img.resize(
97
+ tuple(math.ceil(x * sr_scale) for x in control_img.size),
98
+ Image.BICUBIC
99
+ )
100
 
101
  input_size = control_img.size
102
 
103
+ # Resize the lr image
104
  if not tile_diffusion:
105
  control_img = auto_resize(control_img, 512)
106
  else:
107
  control_img = auto_resize(control_img, tile_diffusion_size)
108
 
109
+ # Resize image to be multiples of 64
110
  control_img = control_img.resize(
111
  tuple((s // 64 + 1) * 64 for s in control_img.size), Image.LANCZOS
112
  )
113
  control_img = np.array(control_img)
114
 
115
+ # Convert to tensor (NCHW, [0,1])
116
  control = torch.tensor(control_img[None] / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
117
  control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
118
  height, width = control.size(-2), control.size(-1)
 
147
 
148
  return preds
149
 
150
+ def update_output_resolution(image, scale):
151
  if image is not None:
152
  width, height = image.size
153
+ return f"Current resolution: {width}x{height}. Output resolution: {int(width*scale)}x{int(height*scale)}"
154
  return "Upload an image to see the output resolution"
155
 
156
  block = gr.Blocks().queue()
 
168
 
169
  with gr.Accordion("Options", open=False):
170
  with gr.Column():
171
+ num_samples = gr.Slider(label="Number Of Samples", minimum=1, maximum=12, value=1, step=1)
172
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
173
+ positive_prompt = gr.Textbox(label="Positive Prompt", value="")
174
  negative_prompt = gr.Textbox(
175
  label="Negative Prompt",
176
+ value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
 
177
  )
178
+ cfg_scale = gr.Slider(label="Classifier Free Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
179
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=45, step=1)
180
+ use_color_fix = gr.Checkbox(label="Use Color Correction", value=True)
181
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231)
182
+ tile_diffusion = gr.Checkbox(label="Tile diffusion", value=False)
183
+ tile_diffusion_size = gr.Slider(label="Tile diffusion size", minimum=512, maximum=1024, value=512, step=256)
184
+ tile_diffusion_stride = gr.Slider(label="Tile diffusion stride", minimum=256, maximum=512, value=256, step=128)
185
 
186
  with gr.Column():
187
+ result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery").style(grid=2, height="auto")
188
 
189
  inputs = [
190
  input_image,
 
204
  run_button.click(fn=process, inputs=inputs, outputs=[result_gallery])
205
 
206
  # Update output resolution when image is uploaded or SR scale is changed
207
+ input_image.change(update_output_resolution, inputs=[input_image, sr_scale], outputs=[output_resolution])
208
+ sr_scale.change(update_output_resolution, inputs=[input_image, sr_scale], outputs=[output_resolution])
209
 
210
  # Disable SR scale slider when no image is uploaded
211
  input_image.change(