RohitGandikota commited on
Commit
74c6f31
·
verified ·
1 Parent(s): 87efe0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -55
app.py CHANGED
@@ -89,31 +89,31 @@ MAX_SEED = np.iinfo(np.int32).max
89
  MAX_IMAGE_SIZE = 1024
90
 
91
 
92
- base_model_id = "black-forest-labs/FLUX.1-schnell"
93
- max_sequence_length = 256
94
- flux_pipe = FluxPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
95
- flux_pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch_dtype)
96
- flux_pipe = flux_pipe.to(device).to(torch_dtype)
97
- # pipe.enable_sequential_cpu_offload()
98
- transformer = flux_pipe.transformer
99
 
100
- ## Change these parameters based on how you trained your sliderspace sliders
101
- train_method = 'flux-attn'
102
- rank = 1
103
- alpha =1
104
 
105
- flux_networks = {}
106
- modules = DEFAULT_TARGET_REPLACE
107
- modules += UNET_TARGET_REPLACE_MODULE_CONV
108
- for i in range(1):
109
- flux_networks[i] = LoRANetwork(
110
- transformer,
111
- rank=int(rank),
112
- multiplier=1.0,
113
- alpha=int(alpha),
114
- train_method=train_method,
115
- fast_init=True,
116
- ).to(device, dtype=torch_dtype)
117
 
118
 
119
  def update_sliderspace_choices(model_choice):
@@ -175,38 +175,38 @@ def infer(
175
  height=height,
176
  generator=generator,
177
  ).images[0]
178
- else:
179
- sliderspace_path = f"flux_sliderspace_weights/{slider_space}/slider_{int(discovered_directions.split(' ')[-1])-1}.pt"
180
- for net in flux_networks:
181
- flux_networks[net].load_state_dict(torch.load(sliderspace_path))
182
- flux_networks[net].set_lora_slider(-1*slider_scale)
183
- with flux_networks[0]:
184
- pass
185
-
186
- # original image
187
- generator = torch.Generator().manual_seed(seed)
188
- image = flux_pipe(
189
- prompt=prompt,
190
- guidance_scale=guidance_scale,
191
- num_inference_steps=num_inference_steps,
192
- width=width,
193
- height=height,
194
- generator=generator,
195
- max_sequence_length = 256,
196
- ).images[0]
197
 
198
- # edited image
199
- generator = torch.Generator().manual_seed(seed)
200
- with flux_networks[0]:
201
- slider_image = flux_pipe(
202
- prompt=prompt,
203
- guidance_scale=guidance_scale,
204
- num_inference_steps=num_inference_steps,
205
- width=width,
206
- height=height,
207
- generator=generator,
208
- max_sequence_length = 256,
209
- ).images[0]
210
 
211
  return image, slider_image, seed
212
 
@@ -280,7 +280,7 @@ with gr.Blocks(css=css) as demo:
280
 
281
  # Add model selection dropdown
282
  model_choice = gr.Dropdown(
283
- choices=["SDXL-DMD", "FLUX-Schnell"],
284
  label="Model",
285
  value="SDXL-DMD"
286
  )
 
89
  MAX_IMAGE_SIZE = 1024
90
 
91
 
92
+ # base_model_id = "black-forest-labs/FLUX.1-schnell"
93
+ # max_sequence_length = 256
94
+ # flux_pipe = FluxPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
95
+ # flux_pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch_dtype)
96
+ # flux_pipe = flux_pipe.to(device).to(torch_dtype)
97
+ # # pipe.enable_sequential_cpu_offload()
98
+ # transformer = flux_pipe.transformer
99
 
100
+ # ## Change these parameters based on how you trained your sliderspace sliders
101
+ # train_method = 'flux-attn'
102
+ # rank = 1
103
+ # alpha =1
104
 
105
+ # flux_networks = {}
106
+ # modules = DEFAULT_TARGET_REPLACE
107
+ # modules += UNET_TARGET_REPLACE_MODULE_CONV
108
+ # for i in range(1):
109
+ # flux_networks[i] = LoRANetwork(
110
+ # transformer,
111
+ # rank=int(rank),
112
+ # multiplier=1.0,
113
+ # alpha=int(alpha),
114
+ # train_method=train_method,
115
+ # fast_init=True,
116
+ # ).to(device, dtype=torch_dtype)
117
 
118
 
119
  def update_sliderspace_choices(model_choice):
 
175
  height=height,
176
  generator=generator,
177
  ).images[0]
178
+ # else:
179
+ # sliderspace_path = f"flux_sliderspace_weights/{slider_space}/slider_{int(discovered_directions.split(' ')[-1])-1}.pt"
180
+ # for net in flux_networks:
181
+ # flux_networks[net].load_state_dict(torch.load(sliderspace_path))
182
+ # flux_networks[net].set_lora_slider(-1*slider_scale)
183
+ # with flux_networks[0]:
184
+ # pass
185
+
186
+ # # original image
187
+ # generator = torch.Generator().manual_seed(seed)
188
+ # image = flux_pipe(
189
+ # prompt=prompt,
190
+ # guidance_scale=guidance_scale,
191
+ # num_inference_steps=num_inference_steps,
192
+ # width=width,
193
+ # height=height,
194
+ # generator=generator,
195
+ # max_sequence_length = 256,
196
+ # ).images[0]
197
 
198
+ # # edited image
199
+ # generator = torch.Generator().manual_seed(seed)
200
+ # with flux_networks[0]:
201
+ # slider_image = flux_pipe(
202
+ # prompt=prompt,
203
+ # guidance_scale=guidance_scale,
204
+ # num_inference_steps=num_inference_steps,
205
+ # width=width,
206
+ # height=height,
207
+ # generator=generator,
208
+ # max_sequence_length = 256,
209
+ # ).images[0]
210
 
211
  return image, slider_image, seed
212
 
 
280
 
281
  # Add model selection dropdown
282
  model_choice = gr.Dropdown(
283
+ choices=["SDXL-DMD"],
284
  label="Model",
285
  value="SDXL-DMD"
286
  )