Spaces:
Sleeping
Sleeping
Yaron Koresh
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -22,16 +22,14 @@ from diffusers import DiffusionPipeline, AnimateDiffPipeline, MotionAdapter, Eul
|
|
22 |
import jax
|
23 |
import jax.numpy as jnp
|
24 |
|
25 |
-
class
|
26 |
def __init__(self):
|
27 |
super().__init__()
|
28 |
-
self.register_buffer('
|
29 |
|
30 |
-
def forward(self, x
|
31 |
-
|
32 |
-
|
33 |
-
self.a = torch.cat([self.a, new_tensor], dim=0)
|
34 |
-
return x
|
35 |
|
36 |
def forest_schnell():
|
37 |
PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
|
@@ -204,9 +202,7 @@ def main():
|
|
204 |
global time
|
205 |
global last_motion
|
206 |
global base
|
207 |
-
global model
|
208 |
|
209 |
-
model = MyModel()
|
210 |
last_motion=None
|
211 |
fps=20
|
212 |
time=16
|
@@ -221,11 +217,13 @@ def main():
|
|
221 |
|
222 |
repo="stabilityai/sd-vae-ft-mse-original"
|
223 |
ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
|
224 |
-
vae =
|
|
|
225 |
|
226 |
repo="ByteDance/SDXL-Lightning"
|
227 |
ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
|
228 |
-
unet =
|
|
|
229 |
|
230 |
#repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
|
231 |
|
|
|
22 |
import jax
|
23 |
import jax.numpy as jnp
|
24 |
|
25 |
+
class Model(nn.Module):
|
26 |
def __init__(self):
|
27 |
super().__init__()
|
28 |
+
self.register_buffer('buffer', torch.ones(1, 1))
|
29 |
|
30 |
+
def forward(self, x):
|
31 |
+
new_tensor = torch.randn(1, 1)
|
32 |
+
self.buffer = torch.cat([self.buffer, new_tensor], dim=0)
|
|
|
|
|
33 |
|
34 |
def forest_schnell():
|
35 |
PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
|
|
|
202 |
global time
|
203 |
global last_motion
|
204 |
global base
|
|
|
205 |
|
|
|
206 |
last_motion=None
|
207 |
fps=20
|
208 |
time=16
|
|
|
217 |
|
218 |
repo="stabilityai/sd-vae-ft-mse-original"
|
219 |
ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
|
220 |
+
vae = Model()
|
221 |
+
vae(load_file(hf_hub_download(repo, ckpt), device=device))
|
222 |
|
223 |
repo="ByteDance/SDXL-Lightning"
|
224 |
ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
|
225 |
+
unet = Model()
|
226 |
+
unet(load_file(hf_hub_download(repo, ckpt), device=device))
|
227 |
|
228 |
#repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
|
229 |
|