Spaces:
Runtime error
Runtime error
printing training data's pixel values with """print(f"{batch['pixel_values'].shape=}")"""
Browse files
train_dreambooth_lora_sdxl_advanced.py
CHANGED
@@ -1893,6 +1893,7 @@ def main(args):
|
|
1893 |
latents_cache = []
|
1894 |
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
1895 |
with torch.no_grad():
|
|
|
1896 |
batch["pixel_values"] = batch["pixel_values"].to(
|
1897 |
accelerator.device, non_blocking=True, dtype=torch.float32
|
1898 |
)
|
|
|
1893 |
latents_cache = []
|
1894 |
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
1895 |
with torch.no_grad():
|
1896 |
+
print(f"{batch['pixel_values'].shape=}")
|
1897 |
batch["pixel_values"] = batch["pixel_values"].to(
|
1898 |
accelerator.device, non_blocking=True, dtype=torch.float32
|
1899 |
)
|