Yaron Koresh commited on
Commit
8e3d0f0
·
verified ·
1 Parent(s): 21d5371

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -31,6 +31,12 @@ class Model(nn.Module):
31
  new_tensor = torch.randn(1, 1)
32
  self.buffer = torch.cat([self.buffer, new_tensor], dim=0)
33
 
 
 
 
 
 
 
34
  def forest_schnell():
35
  PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
36
  return PIPE
@@ -217,15 +223,11 @@ def main():
217
 
218
  repo="stabilityai/sd-vae-ft-mse-original"
219
  ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
220
- vae = Model()
221
- vae(load_file(hf_hub_download(repo, ckpt), device=device))
222
- vae = ModelMixin(vae)
223
 
224
  repo="ByteDance/SDXL-Lightning"
225
  ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
226
- unet = Model()
227
- unet(load_file(hf_hub_download(repo, ckpt), device=device))
228
- unet = ModelMixin(unet)
229
 
230
  #repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
231
 
 
31
  new_tensor = torch.randn(1, 1)
32
  self.buffer = torch.cat([self.buffer, new_tensor], dim=0)
33
 
34
+ def dict2model(dict,model=Model()):
35
+ model(dict)
36
+ mix = ModelMixin()
37
+ mix(model)
38
+ return mix
39
+
40
  def forest_schnell():
41
  PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
42
  return PIPE
 
223
 
224
  repo="stabilityai/sd-vae-ft-mse-original"
225
  ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
226
+ vae = dict2model(load_file(hf_hub_download(repo, ckpt), device=device))
 
 
227
 
228
  repo="ByteDance/SDXL-Lightning"
229
  ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
230
+ unet = dict2model(load_file(hf_hub_download(repo, ckpt), device=device))
 
 
231
 
232
  #repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
233