Vedansh-7 commited on
Commit
10ba1e7
·
verified ·
1 Parent(s): 190a6d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -197,27 +197,30 @@ def load_model(model_path, device):
197
  checkpoint = torch.load(model_path, map_location=device)
198
 
199
  if 'model_state_dict' in checkpoint:
200
- # Filter out DiffusionModel-specific keys
201
  state_dict = {
202
- k[6:]: v for k, v in checkpoint['model_state_dict'].items()
203
- if k.startswith('model.') and not k.startswith('model.alpha_bars')
204
  }
205
 
206
- # Load into UNet only
207
- missing, unexpected = unet_model.load_state_dict(state_dict, strict=False)
208
 
209
- print(f"Loaded UNet weights. Missing keys: {missing}. Unexpected keys: {unexpected}")
210
-
211
- # Reinitialize diffusion model with loaded UNet
212
  diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
213
- else:
214
- # Handle case where it's not a training checkpoint
215
- diffusion_model.load_state_dict({
216
- k: v for k, v in checkpoint.items()
217
- if not k.startswith('alpha_bars')
218
- })
219
 
220
- print(f"Model successfully loaded from {model_path}")
 
 
 
 
 
 
 
 
 
 
 
221
  else:
222
  print(f"Weights file not found at {model_path}")
223
  print("Using randomly initialized weights")
 
197
  checkpoint = torch.load(model_path, map_location=device)
198
 
199
  if 'model_state_dict' in checkpoint:
200
+ # Handle training checkpoint format
201
  state_dict = {
202
+ k[6:]: v for k, v in checkpoint['model_state_dict'].items()
203
+ if k.startswith('model.')
204
  }
205
 
206
+ # Load UNet weights
207
+ unet_model.load_state_dict(state_dict, strict=False)
208
 
209
+ # Initialize diffusion model with loaded UNet
 
 
210
  diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
 
 
 
 
 
 
211
 
212
+ print(f"Loaded UNet weights from {model_path}")
213
+ else:
214
+ # Handle direct model weights format
215
+ try:
216
+ # First try loading full DiffusionModel
217
+ diffusion_model.load_state_dict(checkpoint)
218
+ print(f"Loaded full DiffusionModel from {model_path}")
219
+ except RuntimeError:
220
+ # If that fails, load just the UNet weights
221
+ unet_model.load_state_dict(checkpoint, strict=False)
222
+ diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
223
+ print(f"Loaded UNet weights only from {model_path}")
224
  else:
225
  print(f"Weights file not found at {model_path}")
226
  print("Using randomly initialized weights")