Spaces:
Running
Running
David Krajewski
commited on
Commit
·
76c8ea5
1
Parent(s):
07f3992
Moved init models
Browse files
app.py
CHANGED
@@ -226,23 +226,12 @@ class Drag:
|
|
226 |
self.device = device
|
227 |
|
228 |
ckpts_dir = "./ckpts/"
|
229 |
-
svd_ckpt = "./ckpts/stable-video-diffusion-img2vid-xt-1-1"
|
230 |
-
mofa_ckpt = "./ckpts/controlnet/ckpts/controlnet"
|
231 |
|
232 |
self.device = 'cuda'
|
233 |
self.weight_dtype = torch.float16
|
234 |
|
235 |
download_models(ckpts_dir)
|
236 |
|
237 |
-
print("Contents of ckpts_dir:", os.listdir(f"{ckpts_dir}/controlnet"))
|
238 |
-
|
239 |
-
self.pipeline, self.cmp = init_models(
|
240 |
-
svd_ckpt,
|
241 |
-
mofa_ckpt,
|
242 |
-
weight_dtype=self.weight_dtype,
|
243 |
-
device=self.device
|
244 |
-
)
|
245 |
-
|
246 |
self.height = height
|
247 |
self.width = width
|
248 |
self.model_length = model_length
|
@@ -500,6 +489,16 @@ class Drag:
|
|
500 |
|
501 |
@spaces.GPU(enable_queue=True, duration=240)
|
502 |
def run(self, first_frame_path, tracking_points, inference_batch_size, motion_brush_mask, motion_brush_viz, ctrl_scale):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
self.pipeline = self.pipeline.to("cuda:0")
|
504 |
self.cmp = self.cmp.to("cuda:0")
|
505 |
|
|
|
226 |
self.device = device
|
227 |
|
228 |
ckpts_dir = "./ckpts/"
|
|
|
|
|
229 |
|
230 |
self.device = 'cuda'
|
231 |
self.weight_dtype = torch.float16
|
232 |
|
233 |
download_models(ckpts_dir)
|
234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
self.height = height
|
236 |
self.width = width
|
237 |
self.model_length = model_length
|
|
|
489 |
|
490 |
@spaces.GPU(enable_queue=True, duration=240)
|
491 |
def run(self, first_frame_path, tracking_points, inference_batch_size, motion_brush_mask, motion_brush_viz, ctrl_scale):
|
492 |
+
svd_ckpt = "./ckpts/stable-video-diffusion-img2vid-xt-1-1"
|
493 |
+
mofa_ckpt = "./ckpts/controlnet/ckpts/controlnet"
|
494 |
+
|
495 |
+
self.pipeline, self.cmp = init_models(
|
496 |
+
svd_ckpt,
|
497 |
+
mofa_ckpt,
|
498 |
+
weight_dtype=self.weight_dtype,
|
499 |
+
device=self.device
|
500 |
+
)
|
501 |
+
|
502 |
self.pipeline = self.pipeline.to("cuda:0")
|
503 |
self.cmp = self.cmp.to("cuda:0")
|
504 |
|