Yaron Koresh commited on
Commit
711f84f
·
verified ·
1 Parent(s): 0b7510d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -22,16 +22,14 @@ from diffusers import DiffusionPipeline, AnimateDiffPipeline, MotionAdapter, Eul
22
  import jax
23
  import jax.numpy as jnp
24
 
25
- class MyModel(nn.Module):
26
  def __init__(self):
27
  super().__init__()
28
- self.register_buffer('a', torch.ones(1, 1))
29
 
30
- def forward(self, x, extend=False):
31
- if extend:
32
- new_tensor = torch.randn(1, 1)
33
- self.a = torch.cat([self.a, new_tensor], dim=0)
34
- return x
35
 
36
  def forest_schnell():
37
  PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
@@ -204,9 +202,7 @@ def main():
204
  global time
205
  global last_motion
206
  global base
207
- global model
208
 
209
- model = MyModel()
210
  last_motion=None
211
  fps=20
212
  time=16
@@ -221,11 +217,13 @@ def main():
221
 
222
  repo="stabilityai/sd-vae-ft-mse-original"
223
  ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
224
- vae = model(load_file(hf_hub_download(repo, ckpt), device=device))
 
225
 
226
  repo="ByteDance/SDXL-Lightning"
227
  ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
228
- unet = model(load_file(hf_hub_download(repo, ckpt), device=device))
 
229
 
230
  #repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
231
 
 
22
  import jax
23
  import jax.numpy as jnp
24
 
25
+ class Model(nn.Module):
26
  def __init__(self):
27
  super().__init__()
28
+ self.register_buffer('buffer', torch.ones(1, 1))
29
 
30
+ def forward(self, x):
31
+ new_tensor = torch.randn(1, 1)
32
+ self.buffer = torch.cat([self.buffer, new_tensor], dim=0)
 
 
33
 
34
  def forest_schnell():
35
  PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
 
202
  global time
203
  global last_motion
204
  global base
 
205
 
 
206
  last_motion=None
207
  fps=20
208
  time=16
 
217
 
218
  repo="stabilityai/sd-vae-ft-mse-original"
219
  ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
220
+ vae = Model()
221
+ vae(load_file(hf_hub_download(repo, ckpt), device=device))
222
 
223
  repo="ByteDance/SDXL-Lightning"
224
  ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
225
+ unet = Model()
226
+ unet(load_file(hf_hub_download(repo, ckpt), device=device))
227
 
228
  #repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
229