Spaces:
Paused
Paused
ran the ruff formatter
Browse files- inference.py +12 -6
- xora/pipelines/pipeline_xora_video.py +5 -3
inference.py
CHANGED
|
@@ -240,10 +240,14 @@ def main():
|
|
| 240 |
assert height % 32 == 0, f"Height ({height}) should be divisible by 32."
|
| 241 |
assert width % 32 == 0, f"Width ({width}) should be divisible by 32."
|
| 242 |
assert (
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
# Paths for the separate mode directories
|
| 249 |
ckpt_dir = Path(args.ckpt_dir)
|
|
@@ -296,8 +300,10 @@ def main():
|
|
| 296 |
torch.manual_seed(args.seed)
|
| 297 |
if torch.cuda.is_available():
|
| 298 |
torch.cuda.manual_seed(args.seed)
|
| 299 |
-
|
| 300 |
-
generator = torch.Generator(
|
|
|
|
|
|
|
| 301 |
|
| 302 |
images = pipeline(
|
| 303 |
num_inference_steps=args.num_inference_steps,
|
|
|
|
| 240 |
assert height % 32 == 0, f"Height ({height}) should be divisible by 32."
|
| 241 |
assert width % 32 == 0, f"Width ({width}) should be divisible by 32."
|
| 242 |
assert (
|
| 243 |
+
(
|
| 244 |
+
height,
|
| 245 |
+
width,
|
| 246 |
+
args.num_frames,
|
| 247 |
+
)
|
| 248 |
+
in RECOMMENDED_RESOLUTIONS
|
| 249 |
+
or args.custom_resolution
|
| 250 |
+
), f"The selected resolution + num frames combination is not supported, results would be suboptimal. Supported (h,w,f) are: {RECOMMENDED_RESOLUTIONS}. Use --custom_resolution to enable working with this resolution."
|
| 251 |
|
| 252 |
# Paths for the separate mode directories
|
| 253 |
ckpt_dir = Path(args.ckpt_dir)
|
|
|
|
| 300 |
torch.manual_seed(args.seed)
|
| 301 |
if torch.cuda.is_available():
|
| 302 |
torch.cuda.manual_seed(args.seed)
|
| 303 |
+
|
| 304 |
+
generator = torch.Generator(
|
| 305 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
| 306 |
+
).manual_seed(args.seed)
|
| 307 |
|
| 308 |
images = pipeline(
|
| 309 |
num_inference_steps=args.num_inference_steps,
|
xora/pipelines/pipeline_xora_video.py
CHANGED
|
@@ -1010,9 +1010,11 @@ class XoraVideoPipeline(DiffusionPipeline):
|
|
| 1010 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
| 1011 |
# Choose the appropriate context manager based on `mixed_precision`
|
| 1012 |
if mixed_precision:
|
| 1013 |
-
if
|
| 1014 |
-
raise NotImplementedError(
|
| 1015 |
-
|
|
|
|
|
|
|
| 1016 |
context_manager = torch.autocast(device, dtype=torch.bfloat16)
|
| 1017 |
else:
|
| 1018 |
context_manager = nullcontext() # Dummy context manager
|
|
|
|
| 1010 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
| 1011 |
# Choose the appropriate context manager based on `mixed_precision`
|
| 1012 |
if mixed_precision:
|
| 1013 |
+
if "xla" in device.type:
|
| 1014 |
+
raise NotImplementedError(
|
| 1015 |
+
"Mixed precision is not supported yet on XLA devices."
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
context_manager = torch.autocast(device, dtype=torch.bfloat16)
|
| 1019 |
else:
|
| 1020 |
context_manager = nullcontext() # Dummy context manager
|