Spaces:
				
			
			
	
			
			
		No application file
		
	
	
	
			
			
	
	
	
	
		
		
		No application file
		
	| """ | |
| Export the predictions of a model for a given dataloader (e.g. ImageFolder). | |
| Use a standalone script with `python3 -m dsfm.scipts.export_predictions dir` | |
| or call from another script. | |
| """ | |
| from pathlib import Path | |
| import h5py | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from .tensor import batch_to_device | |
| def export_predictions( | |
| loader, | |
| model, | |
| output_file, | |
| as_half=False, | |
| keys="*", | |
| callback_fn=None, | |
| optional_keys=[], | |
| ): | |
| assert keys == "*" or isinstance(keys, (tuple, list)) | |
| Path(output_file).parent.mkdir(exist_ok=True, parents=True) | |
| hfile = h5py.File(str(output_file), "w") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device).eval() | |
| for data_ in tqdm(loader): | |
| data = batch_to_device(data_, device, non_blocking=True) | |
| pred = model(data) | |
| if callback_fn is not None: | |
| pred = {**callback_fn(pred, data), **pred} | |
| if keys != "*": | |
| if len(set(keys) - set(pred.keys())) > 0: | |
| raise ValueError(f"Missing key {set(keys) - set(pred.keys())}") | |
| pred = {k: v for k, v in pred.items() if k in keys + optional_keys} | |
| assert len(pred) > 0 | |
| # renormalization | |
| for k in pred.keys(): | |
| if k.startswith("keypoints"): | |
| idx = k.replace("keypoints", "") | |
| scales = 1.0 / ( | |
| data["scales"] if len(idx) == 0 else data[f"view{idx}"]["scales"] | |
| ) | |
| pred[k] = pred[k] * scales[None] | |
| if k.startswith("lines"): | |
| idx = k.replace("lines", "") | |
| scales = 1.0 / ( | |
| data["scales"] if len(idx) == 0 else data[f"view{idx}"]["scales"] | |
| ) | |
| pred[k] = pred[k] * scales[None] | |
| if k.startswith("orig_lines"): | |
| idx = k.replace("orig_lines", "") | |
| scales = 1.0 / ( | |
| data["scales"] if len(idx) == 0 else data[f"view{idx}"]["scales"] | |
| ) | |
| pred[k] = pred[k] * scales[None] | |
| pred = {k: v[0].cpu().numpy() for k, v in pred.items()} | |
| if as_half: | |
| for k in pred: | |
| dt = pred[k].dtype | |
| if (dt == np.float32) and (dt != np.float16): | |
| pred[k] = pred[k].astype(np.float16) | |
| try: | |
| name = data["name"][0] | |
| grp = hfile.create_group(name) | |
| for k, v in pred.items(): | |
| grp.create_dataset(k, data=v) | |
| except RuntimeError: | |
| continue | |
| del pred | |
| hfile.close() | |
| return output_file | |
