Spaces:
Running
Running
Yaron Koresh
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -22,6 +22,21 @@ from diffusers import DiffusionPipeline, AnimateDiffPipeline, MotionAdapter, Eul
|
|
22 |
import jax
|
23 |
import jax.numpy as jnp
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def forest_schnell():
|
26 |
PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
|
27 |
return PIPE
|
@@ -193,7 +208,9 @@ def main():
|
|
193 |
global time
|
194 |
global last_motion
|
195 |
global base
|
196 |
-
|
|
|
|
|
197 |
last_motion=None
|
198 |
fps=20
|
199 |
time=16
|
@@ -208,13 +225,12 @@ def main():
|
|
208 |
|
209 |
repo="stabilityai/sd-vae-ft-mse-original"
|
210 |
ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
|
211 |
-
|
212 |
vae = "./vae"
|
213 |
|
214 |
repo="ByteDance/SDXL-Lightning"
|
215 |
ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
|
216 |
-
|
217 |
-
unet = "./unet"
|
218 |
|
219 |
#repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
|
220 |
|
|
|
22 |
import jax
|
23 |
import jax.numpy as jnp
|
24 |
|
25 |
+
class MyModel(nn.Module):
|
26 |
+
def __init__(self):
|
27 |
+
super().__init__()
|
28 |
+
self.register_buffer('a', torch.ones(1, 1))
|
29 |
+
|
30 |
+
def forward(self, x: torch.Tensor, extend: bool):
|
31 |
+
if extend:
|
32 |
+
new_tensor = torch.randn(1, 1)
|
33 |
+
self.a = torch.cat([self.a, new_tensor], dim=0)
|
34 |
+
return x
|
35 |
+
|
36 |
+
for _ in range(10):
|
37 |
+
out = model(x, extend=True)
|
38 |
+
print(model.state_dict())
|
39 |
+
|
40 |
def forest_schnell():
|
41 |
PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
|
42 |
return PIPE
|
|
|
208 |
global time
|
209 |
global last_motion
|
210 |
global base
|
211 |
+
global model
|
212 |
+
|
213 |
+
model = MyModel()
|
214 |
last_motion=None
|
215 |
fps=20
|
216 |
time=16
|
|
|
225 |
|
226 |
repo="stabilityai/sd-vae-ft-mse-original"
|
227 |
ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
|
228 |
+
vae = model(load_file(hf_hub_download(repo, ckpt), device=device)
|
229 |
vae = "./vae"
|
230 |
|
231 |
repo="ByteDance/SDXL-Lightning"
|
232 |
ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
|
233 |
+
unet = model(load_file(hf_hub_download(repo, ckpt), device=device)
|
|
|
234 |
|
235 |
#repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
|
236 |
|