Yaron Koresh commited on
Commit
7679af5
·
verified ·
1 Parent(s): 6afd65a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -4
app.py CHANGED
@@ -22,6 +22,21 @@ from diffusers import DiffusionPipeline, AnimateDiffPipeline, MotionAdapter, Eul
22
  import jax
23
  import jax.numpy as jnp
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def forest_schnell():
26
  PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
27
  return PIPE
@@ -193,7 +208,9 @@ def main():
193
  global time
194
  global last_motion
195
  global base
196
-
 
 
197
  last_motion=None
198
  fps=20
199
  time=16
@@ -208,13 +225,12 @@ def main():
208
 
209
  repo="stabilityai/sd-vae-ft-mse-original"
210
  ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
211
- save_file(load_file(hf_hub_download(repo, ckpt), device=device),"./vae")
212
  vae = "./vae"
213
 
214
  repo="ByteDance/SDXL-Lightning"
215
  ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
216
- save_file(load_file(hf_hub_download(repo, ckpt), device=device),"./unet")
217
- unet = "./unet"
218
 
219
  #repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
220
 
 
22
  import jax
23
  import jax.numpy as jnp
24
 
25
+ class MyModel(nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+ self.register_buffer('a', torch.ones(1, 1))
29
+
30
+ def forward(self, x: torch.Tensor, extend: bool):
31
+ if extend:
32
+ new_tensor = torch.randn(1, 1)
33
+ self.a = torch.cat([self.a, new_tensor], dim=0)
34
+ return x
35
+
36
+ for _ in range(10):
37
+ out = model(x, extend=True)
38
+ print(model.state_dict())
39
+
40
  def forest_schnell():
41
  PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
42
  return PIPE
 
208
  global time
209
  global last_motion
210
  global base
211
+ global model
212
+
213
+ model = MyModel()
214
  last_motion=None
215
  fps=20
216
  time=16
 
225
 
226
  repo="stabilityai/sd-vae-ft-mse-original"
227
  ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
228
+ vae = model(load_file(hf_hub_download(repo, ckpt), device=device)
229
  vae = "./vae"
230
 
231
  repo="ByteDance/SDXL-Lightning"
232
  ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
233
+ unet = model(load_file(hf_hub_download(repo, ckpt), device=device)
 
234
 
235
  #repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
236