dkebudi commited on
Commit
3c525b4
·
verified ·
1 Parent(s): effc90e

printing training data's pixel values with """print(f"{batch['pixel_values'].shape=}")"""

Browse files
train_dreambooth_lora_sdxl_advanced.py CHANGED
@@ -1893,6 +1893,7 @@ def main(args):
1893
  latents_cache = []
1894
  for batch in tqdm(train_dataloader, desc="Caching latents"):
1895
  with torch.no_grad():
 
1896
  batch["pixel_values"] = batch["pixel_values"].to(
1897
  accelerator.device, non_blocking=True, dtype=torch.float32
1898
  )
 
1893
  latents_cache = []
1894
  for batch in tqdm(train_dataloader, desc="Caching latents"):
1895
  with torch.no_grad():
1896
+ print(f"{batch['pixel_values'].shape=}")
1897
  batch["pixel_values"] = batch["pixel_values"].to(
1898
  accelerator.device, non_blocking=True, dtype=torch.float32
1899
  )