Update gradio_sd3.py
Browse files- gradio_sd3.py +1 -1
gradio_sd3.py
CHANGED
@@ -23,7 +23,7 @@ example_path = os.path.join(os.path.dirname(__file__), 'examples')
|
|
23 |
fitdit_repo = "BoyuanJiang/FitDiT"
|
24 |
repo_path = snapshot_download(repo_id=fitdit_repo)
|
25 |
|
26 |
-
weight_dtype = torch.
|
27 |
device = "cpu"
|
28 |
transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(repo_path, "transformer_garm"), torch_dtype=weight_dtype)
|
29 |
transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(repo_path, "transformer_vton"), torch_dtype=weight_dtype)
|
|
|
23 |
fitdit_repo = "BoyuanJiang/FitDiT"
|
24 |
repo_path = snapshot_download(repo_id=fitdit_repo)
|
25 |
|
26 |
+
weight_dtype = torch.bfloat16
|
27 |
device = "cpu"
|
28 |
transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(repo_path, "transformer_garm"), torch_dtype=weight_dtype)
|
29 |
transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(repo_path, "transformer_vton"), torch_dtype=weight_dtype)
|