Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -197,27 +197,30 @@ def load_model(model_path, device):
|
|
197 |
checkpoint = torch.load(model_path, map_location=device)
|
198 |
|
199 |
if 'model_state_dict' in checkpoint:
|
200 |
-
#
|
201 |
state_dict = {
|
202 |
-
k[6:]: v for k, v in checkpoint['model_state_dict'].items()
|
203 |
-
if k.startswith('model.')
|
204 |
}
|
205 |
|
206 |
-
# Load
|
207 |
-
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
# Reinitialize diffusion model with loaded UNet
|
212 |
diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
|
213 |
-
else:
|
214 |
-
# Handle case where it's not a training checkpoint
|
215 |
-
diffusion_model.load_state_dict({
|
216 |
-
k: v for k, v in checkpoint.items()
|
217 |
-
if not k.startswith('alpha_bars')
|
218 |
-
})
|
219 |
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
else:
|
222 |
print(f"Weights file not found at {model_path}")
|
223 |
print("Using randomly initialized weights")
|
|
|
197 |
checkpoint = torch.load(model_path, map_location=device)
|
198 |
|
199 |
if 'model_state_dict' in checkpoint:
|
200 |
+
# Handle training checkpoint format
|
201 |
state_dict = {
|
202 |
+
k[6:]: v for k, v in checkpoint['model_state_dict'].items()
|
203 |
+
if k.startswith('model.')
|
204 |
}
|
205 |
|
206 |
+
# Load UNet weights
|
207 |
+
unet_model.load_state_dict(state_dict, strict=False)
|
208 |
|
209 |
+
# Initialize diffusion model with loaded UNet
|
|
|
|
|
210 |
diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
+
print(f"Loaded UNet weights from {model_path}")
|
213 |
+
else:
|
214 |
+
# Handle direct model weights format
|
215 |
+
try:
|
216 |
+
# First try loading full DiffusionModel
|
217 |
+
diffusion_model.load_state_dict(checkpoint)
|
218 |
+
print(f"Loaded full DiffusionModel from {model_path}")
|
219 |
+
except RuntimeError:
|
220 |
+
# If that fails, load just the UNet weights
|
221 |
+
unet_model.load_state_dict(checkpoint, strict=False)
|
222 |
+
diffusion_model = DiffusionModel(unet_model, timesteps=TIMESTEPS).to(device)
|
223 |
+
print(f"Loaded UNet weights only from {model_path}")
|
224 |
else:
|
225 |
print(f"Weights file not found at {model_path}")
|
226 |
print("Using randomly initialized weights")
|