tejani commited on
Commit
d1730e5
·
verified ·
1 Parent(s): d9e1f28

Update gradio_sd3.py

Browse files
Files changed (1) hide show
  1. gradio_sd3.py +1 -1
gradio_sd3.py CHANGED
@@ -23,7 +23,7 @@ example_path = os.path.join(os.path.dirname(__file__), 'examples')
23
  fitdit_repo = "BoyuanJiang/FitDiT"
24
  repo_path = snapshot_download(repo_id=fitdit_repo)
25
 
26
- weight_dtype = torch.bfloat4
27
  device = "cpu"
28
  transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(repo_path, "transformer_garm"), torch_dtype=weight_dtype)
29
  transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(repo_path, "transformer_vton"), torch_dtype=weight_dtype)
 
23
  fitdit_repo = "BoyuanJiang/FitDiT"
24
  repo_path = snapshot_download(repo_id=fitdit_repo)
25
 
26
+ weight_dtype = torch.bfloat16
27
  device = "cpu"
28
  transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(repo_path, "transformer_garm"), torch_dtype=weight_dtype)
29
  transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(repo_path, "transformer_vton"), torch_dtype=weight_dtype)