Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import sys | |
| from pathlib import Path | |
| import torch | |
| from zipfile import ZipFile | |
| import os | |
| import sklearn | |
| import gdown | |
| from ..utils.base_model import BaseModel | |
| sys.path.append( | |
| str(Path(__file__).parent / "../../third_party/deep-image-retrieval") | |
| ) | |
| os.environ["DB_ROOT"] = "" # required by dirtorch | |
| from dirtorch.utils import common # noqa: E402 | |
| from dirtorch.extract_features import load_model # noqa: E402 | |
| # The DIR model checkpoints (pickle files) include sklearn.decomposition.pca, | |
| # which has been deprecated in sklearn v0.24 | |
| # and must be explicitly imported with `from sklearn.decomposition import PCA`. | |
| # This is a hacky workaround to maintain forward compatibility. | |
| sys.modules["sklearn.decomposition.pca"] = sklearn.decomposition._pca | |
| class DIR(BaseModel): | |
| default_conf = { | |
| "model_name": "Resnet-101-AP-GeM", | |
| "whiten_name": "Landmarks_clean", | |
| "whiten_params": { | |
| "whitenp": 0.25, | |
| "whitenv": None, | |
| "whitenm": 1.0, | |
| }, | |
| "pooling": "gem", | |
| "gemp": 3, | |
| } | |
| required_inputs = ["image"] | |
| dir_models = { | |
| "Resnet-101-AP-GeM": "https://docs.google.com/uc?export=download&id=1UWJGDuHtzaQdFhSMojoYVQjmCXhIwVvy", | |
| } | |
| def _init(self, conf): | |
| checkpoint = Path( | |
| torch.hub.get_dir(), "dirtorch", conf["model_name"] + ".pt" | |
| ) | |
| if not checkpoint.exists(): | |
| checkpoint.parent.mkdir(exist_ok=True, parents=True) | |
| link = self.dir_models[conf["model_name"]] | |
| gdown.download(str(link), str(checkpoint) + ".zip", quiet=False) | |
| zf = ZipFile(str(checkpoint) + ".zip", "r") | |
| zf.extractall(checkpoint.parent) | |
| zf.close() | |
| os.remove(str(checkpoint) + ".zip") | |
| self.net = load_model(checkpoint, False) # first load on CPU | |
| if conf["whiten_name"]: | |
| assert conf["whiten_name"] in self.net.pca | |
| def _forward(self, data): | |
| image = data["image"] | |
| assert image.shape[1] == 3 | |
| mean = self.net.preprocess["mean"] | |
| std = self.net.preprocess["std"] | |
| image = image - image.new_tensor(mean)[:, None, None] | |
| image = image / image.new_tensor(std)[:, None, None] | |
| desc = self.net(image) | |
| desc = desc.unsqueeze(0) # batch dimension | |
| if self.conf["whiten_name"]: | |
| pca = self.net.pca[self.conf["whiten_name"]] | |
| desc = common.whiten_features( | |
| desc.cpu().numpy(), pca, **self.conf["whiten_params"] | |
| ) | |
| desc = torch.from_numpy(desc) | |
| return { | |
| "global_descriptor": desc, | |
| } | |
 
			
