JiantaoLin commited on
Commit
18b35a5
·
1 Parent(s): 5d625d7
Files changed (1) hide show
  1. pipeline/kiss3d_wrapper.py +3 -2
pipeline/kiss3d_wrapper.py CHANGED
@@ -69,13 +69,14 @@ def init_wrapper_from_config(config_path):
69
  # flux_lora_pth = config_['flux'].get('lora', None)
70
  flux_lora_pth = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
71
  flux_redux_pth = config_['flux'].get('redux', None)
72
-
73
  if flux_base_model_pth.endswith('safetensors'):
74
  flux_pipe = FluxImg2ImgPipeline.from_single_file(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
75
  else:
76
  flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
77
  # flux_pipe.enable_vae_slicing()
78
- flux_pipe.enable_vae_tiling()
 
79
 
80
  # flux_pipe.enable_sequential_cpu_offload()
81
  # load flux model and controlnet
 
69
  # flux_lora_pth = config_['flux'].get('lora', None)
70
  flux_lora_pth = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="rgb_normal_large.safetensors", repo_type="model", token=access_token)
71
  flux_redux_pth = config_['flux'].get('redux', None)
72
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
73
  if flux_base_model_pth.endswith('safetensors'):
74
  flux_pipe = FluxImg2ImgPipeline.from_single_file(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
75
  else:
76
  flux_pipe = FluxImg2ImgPipeline.from_pretrained(flux_base_model_pth, torch_dtype=dtype_[flux_dtype], token=access_token)
77
  # flux_pipe.enable_vae_slicing()
78
+ # flux_pipe.enable_vae_tiling()
79
+ flux_pipe.vae = taef1
80
 
81
  # flux_pipe.enable_sequential_cpu_offload()
82
  # load flux model and controlnet