JiantaoLin commited on
Commit
c5daa2d
·
1 Parent(s): ebe241c
Files changed (1) hide show
  1. pipeline/kiss3d_wrapper.py +24 -28
pipeline/kiss3d_wrapper.py CHANGED
@@ -78,7 +78,7 @@ def init_wrapper_from_config(config_path):
78
  flux_pipe.vae.enable_tiling()
79
 
80
  # load flux model and controlnet
81
- if flux_controlnet_pth is not None:
82
  flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
83
  flux_pipe = convert_flux_pipeline(flux_pipe, FluxControlNetImg2ImgPipeline, controlnet=[flux_controlnet])
84
 
@@ -90,7 +90,7 @@ def init_wrapper_from_config(config_path):
90
 
91
  # load redux model
92
  flux_redux_pipe = None
93
- if flux_redux_pth is not None:
94
  flux_redux_pipe = FluxPriorReduxPipeline.from_pretrained(flux_redux_pth, torch_dtype=torch.bfloat16, token=access_token)
95
  flux_redux_pipe.text_encoder = flux_pipe.text_encoder
96
  flux_redux_pipe.text_encoder_2 = flux_pipe.text_encoder_2
@@ -101,41 +101,37 @@ def init_wrapper_from_config(config_path):
101
 
102
  # logger.warning(f"GPU memory allocated after load flux model on {flux_device}: {torch.cuda.memory_allocated(device=flux_device) / 1024**3} GB")
103
 
104
- # TODO: load pulid model
105
-
106
  # init multiview model
107
- logger.info('==> Loading multiview diffusion model ...')
108
- multiview_device = config_['multiview'].get('device', 'cpu')
109
- multiview_pipeline = DiffusionPipeline.from_pretrained(
110
- config_['multiview']['base_model'],
111
- custom_pipeline=config_['multiview']['custom_pipeline'],
112
- torch_dtype=torch.float16,
113
- )
114
- multiview_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
115
- multiview_pipeline.scheduler.config, timestep_spacing='trailing'
116
- )
117
 
118
- # unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen_19w.ckpt", repo_type="model", token=access_token)
119
- unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen.ckpt", repo_type="model", token=access_token)
120
- if unet_ckpt_path is not None:
121
- state_dict = torch.load(unet_ckpt_path, map_location='cpu')
122
- # state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
123
- multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
124
 
125
  # multiview_pipeline.to(multiview_device)
126
  # logger.warning(f"GPU memory allocated after load multiview model on {multiview_device}: {torch.cuda.memory_allocated(device=multiview_device) / 1024**3} GB")
127
- # multiview_pipeline = None
128
 
129
 
130
  # load caption model
131
- logger.info('==> Loading caption model ...')
132
- caption_device = config_['caption'].get('device', 'cpu')
133
- caption_model = AutoModelForCausalLM.from_pretrained(config_['caption']['base_model'], \
134
- torch_dtype=torch.bfloat16, trust_remote_code=True)
135
- caption_processor = AutoProcessor.from_pretrained(config_['caption']['base_model'], trust_remote_code=True)
136
  # logger.warning(f"GPU memory allocated after load caption model on {caption_device}: {torch.cuda.memory_allocated(device=caption_device) / 1024**3} GB")
137
- # caption_processor = None
138
- # caption_model = None
139
 
140
  # load reconstruction model
141
  logger.info('==> Loading reconstruction model ...')
 
78
  flux_pipe.vae.enable_tiling()
79
 
80
  # load flux model and controlnet
81
+ if flux_controlnet_pth is not None and False:
82
  flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth, torch_dtype=torch.bfloat16)
83
  flux_pipe = convert_flux_pipeline(flux_pipe, FluxControlNetImg2ImgPipeline, controlnet=[flux_controlnet])
84
 
 
90
 
91
  # load redux model
92
  flux_redux_pipe = None
93
+ if flux_redux_pth is not None and False:
94
  flux_redux_pipe = FluxPriorReduxPipeline.from_pretrained(flux_redux_pth, torch_dtype=torch.bfloat16, token=access_token)
95
  flux_redux_pipe.text_encoder = flux_pipe.text_encoder
96
  flux_redux_pipe.text_encoder_2 = flux_pipe.text_encoder_2
 
101
 
102
  # logger.warning(f"GPU memory allocated after load flux model on {flux_device}: {torch.cuda.memory_allocated(device=flux_device) / 1024**3} GB")
103
 
 
 
104
  # init multiview model
105
+ # logger.info('==> Loading multiview diffusion model ...')
106
+ # multiview_device = config_['multiview'].get('device', 'cpu')
107
+ # multiview_pipeline = DiffusionPipeline.from_pretrained(
108
+ # config_['multiview']['base_model'],
109
+ # custom_pipeline=config_['multiview']['custom_pipeline'],
110
+ # torch_dtype=torch.float16,
111
+ # )
112
+ # multiview_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
113
+ # multiview_pipeline.scheduler.config, timestep_spacing='trailing'
114
+ # )
115
 
116
+ # unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen.ckpt", repo_type="model", token=access_token)
117
+ # if unet_ckpt_path is not None:
118
+ # state_dict = torch.load(unet_ckpt_path, map_location='cpu')
119
+ # multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
 
 
120
 
121
  # multiview_pipeline.to(multiview_device)
122
  # logger.warning(f"GPU memory allocated after load multiview model on {multiview_device}: {torch.cuda.memory_allocated(device=multiview_device) / 1024**3} GB")
123
+ multiview_pipeline = None
124
 
125
 
126
  # load caption model
127
+ # logger.info('==> Loading caption model ...')
128
+ # caption_device = config_['caption'].get('device', 'cpu')
129
+ # caption_model = AutoModelForCausalLM.from_pretrained(config_['caption']['base_model'], \
130
+ # torch_dtype=torch.bfloat16, trust_remote_code=True)
131
+ # caption_processor = AutoProcessor.from_pretrained(config_['caption']['base_model'], trust_remote_code=True)
132
  # logger.warning(f"GPU memory allocated after load caption model on {caption_device}: {torch.cuda.memory_allocated(device=caption_device) / 1024**3} GB")
133
+ caption_processor = None
134
+ caption_model = None
135
 
136
  # load reconstruction model
137
  logger.info('==> Loading reconstruction model ...')