Spaces:
Running
Running
| import copy | |
| import os | |
| import cv2 | |
| import glob | |
| import logging | |
| import argparse | |
| import numpy as np | |
| from tqdm import tqdm | |
| from alike import ALike, configs | |
| class ImageLoader(object): | |
| def __init__(self, filepath: str): | |
| self.N = 3000 | |
| if filepath.startswith("camera"): | |
| camera = int(filepath[6:]) | |
| self.cap = cv2.VideoCapture(camera) | |
| if not self.cap.isOpened(): | |
| raise IOError(f"Can't open camera {camera}!") | |
| logging.info(f"Opened camera {camera}") | |
| self.mode = "camera" | |
| elif os.path.exists(filepath): | |
| if os.path.isfile(filepath): | |
| self.cap = cv2.VideoCapture(filepath) | |
| if not self.cap.isOpened(): | |
| raise IOError(f"Can't open video {filepath}!") | |
| rate = self.cap.get(cv2.CAP_PROP_FPS) | |
| self.N = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1 | |
| duration = self.N / rate | |
| logging.info(f"Opened video {filepath}") | |
| logging.info(f"Frames: {self.N}, FPS: {rate}, Duration: {duration}s") | |
| self.mode = "video" | |
| else: | |
| self.images = ( | |
| glob.glob(os.path.join(filepath, "*.png")) | |
| + glob.glob(os.path.join(filepath, "*.jpg")) | |
| + glob.glob(os.path.join(filepath, "*.ppm")) | |
| ) | |
| self.images.sort() | |
| self.N = len(self.images) | |
| logging.info(f"Loading {self.N} images") | |
| self.mode = "images" | |
| else: | |
| raise IOError( | |
| "Error filepath (camerax/path of images/path of videos): ", filepath | |
| ) | |
| def __getitem__(self, item): | |
| if self.mode == "camera" or self.mode == "video": | |
| if item > self.N: | |
| return None | |
| ret, img = self.cap.read() | |
| if not ret: | |
| raise "Can't read image from camera" | |
| if self.mode == "video": | |
| self.cap.set(cv2.CAP_PROP_POS_FRAMES, item) | |
| elif self.mode == "images": | |
| filename = self.images[item] | |
| img = cv2.imread(filename) | |
| if img is None: | |
| raise Exception("Error reading image %s" % filename) | |
| return img | |
| def __len__(self): | |
| return self.N | |
| class SimpleTracker(object): | |
| def __init__(self): | |
| self.pts_prev = None | |
| self.desc_prev = None | |
| def update(self, img, pts, desc): | |
| N_matches = 0 | |
| if self.pts_prev is None: | |
| self.pts_prev = pts | |
| self.desc_prev = desc | |
| out = copy.deepcopy(img) | |
| for pt1 in pts: | |
| p1 = (int(round(pt1[0])), int(round(pt1[1]))) | |
| cv2.circle(out, p1, 1, (0, 0, 255), -1, lineType=16) | |
| else: | |
| matches = self.mnn_mather(self.desc_prev, desc) | |
| mpts1, mpts2 = self.pts_prev[matches[:, 0]], pts[matches[:, 1]] | |
| N_matches = len(matches) | |
| out = copy.deepcopy(img) | |
| for pt1, pt2 in zip(mpts1, mpts2): | |
| p1 = (int(round(pt1[0])), int(round(pt1[1]))) | |
| p2 = (int(round(pt2[0])), int(round(pt2[1]))) | |
| cv2.line(out, p1, p2, (0, 255, 0), lineType=16) | |
| cv2.circle(out, p2, 1, (0, 0, 255), -1, lineType=16) | |
| self.pts_prev = pts | |
| self.desc_prev = desc | |
| return out, N_matches | |
| def mnn_mather(self, desc1, desc2): | |
| sim = desc1 @ desc2.transpose() | |
| sim[sim < 0.9] = 0 | |
| nn12 = np.argmax(sim, axis=1) | |
| nn21 = np.argmax(sim, axis=0) | |
| ids1 = np.arange(0, sim.shape[0]) | |
| mask = ids1 == nn21[nn12] | |
| matches = np.stack([ids1[mask], nn12[mask]]) | |
| return matches.transpose() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="ALike Demo.") | |
| parser.add_argument( | |
| "input", | |
| type=str, | |
| default="", | |
| help='Image directory or movie file or "camera0" (for webcam0).', | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| choices=["alike-t", "alike-s", "alike-n", "alike-l"], | |
| default="alike-t", | |
| help="The model configuration", | |
| ) | |
| parser.add_argument( | |
| "--device", type=str, default="cuda", help="Running device (default: cuda)." | |
| ) | |
| parser.add_argument( | |
| "--top_k", | |
| type=int, | |
| default=-1, | |
| help="Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)", | |
| ) | |
| parser.add_argument( | |
| "--scores_th", | |
| type=float, | |
| default=0.2, | |
| help="Detector score threshold (default: 0.2).", | |
| ) | |
| parser.add_argument( | |
| "--n_limit", | |
| type=int, | |
| default=5000, | |
| help="Maximum number of keypoints to be detected (default: 5000).", | |
| ) | |
| parser.add_argument( | |
| "--no_display", | |
| action="store_true", | |
| help="Do not display images to screen. Useful if running remotely (default: False).", | |
| ) | |
| parser.add_argument( | |
| "--no_sub_pixel", | |
| action="store_true", | |
| help="Do not detect sub-pixel keypoints (default: False).", | |
| ) | |
| args = parser.parse_args() | |
| logging.basicConfig(level=logging.INFO) | |
| image_loader = ImageLoader(args.input) | |
| model = ALike( | |
| **configs[args.model], | |
| device=args.device, | |
| top_k=args.top_k, | |
| scores_th=args.scores_th, | |
| n_limit=args.n_limit, | |
| ) | |
| tracker = SimpleTracker() | |
| if not args.no_display: | |
| logging.info("Press 'q' to stop!") | |
| cv2.namedWindow(args.model) | |
| runtime = [] | |
| progress_bar = tqdm(image_loader) | |
| for img in progress_bar: | |
| if img is None: | |
| break | |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| pred = model(img_rgb, sub_pixel=not args.no_sub_pixel) | |
| kpts = pred["keypoints"] | |
| desc = pred["descriptors"] | |
| runtime.append(pred["time"]) | |
| out, N_matches = tracker.update(img, kpts, desc) | |
| ave_fps = (1.0 / np.stack(runtime)).mean() | |
| status = f"Fps:{ave_fps:.1f}, Keypoints/Matches: {len(kpts)}/{N_matches}" | |
| progress_bar.set_description(status) | |
| if not args.no_display: | |
| cv2.setWindowTitle(args.model, args.model + ": " + status) | |
| cv2.imshow(args.model, out) | |
| if cv2.waitKey(1) == ord("q"): | |
| break | |
| logging.info("Finished!") | |
| if not args.no_display: | |
| logging.info("Press any key to exit!") | |
| cv2.waitKey() | |