LTT commited on
Commit
cde3140
Β·
verified Β·
1 Parent(s): fce0236

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -4
app.py CHANGED
@@ -97,7 +97,7 @@ isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(d
97
  flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(dtype=torch.bfloat16)
98
  flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model")
99
  flux_pipe.load_lora_weights(flux_lora_ckpt_path)
100
-
101
 
102
 
103
  # lrm
@@ -109,7 +109,7 @@ model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt",
109
  state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
110
  state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
111
  model.load_state_dict(state_dict, strict=True)
112
-
113
 
114
  @spaces.GPU
115
  def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
@@ -280,8 +280,6 @@ def reconstruct_3d_model(images, prompt):
280
 
281
  # Gradio ζŽ₯口函数
282
  def gradio_pipeline(prompt, seed):
283
- flux_pipe.to(device=device, dtype=torch.bfloat16)
284
- model = model.to(device)
285
  model.init_flexicubes_geometry(device, fovy=50.0)
286
  model = model.eval()
287
  # η”Ÿζˆε€šθ§†ε›Ύε›Ύεƒ
 
97
  flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(dtype=torch.bfloat16)
98
  flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model")
99
  flux_pipe.load_lora_weights(flux_lora_ckpt_path)
100
+ flux_pipe.to(device=device, dtype=torch.bfloat16)
101
 
102
 
103
  # lrm
 
109
  state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
110
  state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
111
  model.load_state_dict(state_dict, strict=True)
112
+ model = model.to(device)
113
 
114
  @spaces.GPU
115
  def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
 
280
 
281
  # Gradio ζŽ₯口函数
282
  def gradio_pipeline(prompt, seed):
 
 
283
  model.init_flexicubes_geometry(device, fovy=50.0)
284
  model = model.eval()
285
  # η”Ÿζˆε€šθ§†ε›Ύε›Ύεƒ