Spaces:
Paused
Paused
| import os | |
| from pathlib import Path | |
| import hydra | |
| import torch | |
| import yaml | |
| from omegaconf import OmegaConf | |
| from torch import nn | |
| from saicinpainting.training.trainers import load_checkpoint | |
| from saicinpainting.utils import register_debug_signal_handlers | |
| class JITWrapper(nn.Module): | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model | |
| def forward(self, image, mask): | |
| batch = { | |
| "image": image, | |
| "mask": mask | |
| } | |
| out = self.model(batch) | |
| return out["inpainted"] | |
| def main(predict_config: OmegaConf): | |
| register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log | |
| train_config_path = os.path.join(predict_config.model.path, "config.yaml") | |
| with open(train_config_path, "r") as f: | |
| train_config = OmegaConf.create(yaml.safe_load(f)) | |
| train_config.training_model.predict_only = True | |
| train_config.visualizer.kind = "noop" | |
| checkpoint_path = os.path.join( | |
| predict_config.model.path, "models", predict_config.model.checkpoint | |
| ) | |
| model = load_checkpoint( | |
| train_config, checkpoint_path, strict=False, map_location="cpu" | |
| ) | |
| model.eval() | |
| jit_model_wrapper = JITWrapper(model) | |
| image = torch.rand(1, 3, 120, 120) | |
| mask = torch.rand(1, 1, 120, 120) | |
| output = jit_model_wrapper(image, mask) | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| image = image.to(device) | |
| mask = mask.to(device) | |
| traced_model = torch.jit.trace(jit_model_wrapper, (image, mask), strict=False).to(device) | |
| save_path = Path(predict_config.save_path) | |
| save_path.parent.mkdir(parents=True, exist_ok=True) | |
| print(f"Saving big-lama.pt model to {save_path}") | |
| traced_model.save(save_path) | |
| print(f"Checking jit model output...") | |
| jit_model = torch.jit.load(str(save_path)) | |
| jit_output = jit_model(image, mask) | |
| diff = (output - jit_output).abs().sum() | |
| print(f"diff: {diff}") | |
| if __name__ == "__main__": | |
| main() | |