Clone04 commited on
Commit
81e34e0
·
verified ·
1 Parent(s): 1c3b830

Update gradio_sd3.py

Browse files
Files changed (1) hide show
  1. gradio_sd3.py +1 -1
gradio_sd3.py CHANGED
@@ -24,7 +24,7 @@ fitdit_repo = "BoyuanJiang/FitDiT"
24
  repo_path = snapshot_download(repo_id=fitdit_repo)
25
 
26
  weight_dtype = torch.bfloat16
27
- device = "cuda"
28
  transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(repo_path, "transformer_garm"), torch_dtype=weight_dtype)
29
  transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(repo_path, "transformer_vton"), torch_dtype=weight_dtype)
30
  pose_guider = PoseGuider(conditioning_embedding_channels=1536, conditioning_channels=3, block_out_channels=(32, 64, 256, 512))
 
24
  repo_path = snapshot_download(repo_id=fitdit_repo)
25
 
26
  weight_dtype = torch.bfloat16
27
+ device = "cpu"
28
  transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(repo_path, "transformer_garm"), torch_dtype=weight_dtype)
29
  transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(repo_path, "transformer_vton"), torch_dtype=weight_dtype)
30
  pose_guider = PoseGuider(conditioning_embedding_channels=1536, conditioning_channels=3, block_out_channels=(32, 64, 256, 512))