import os import warnings import cv2 import numpy as np import torch from einops import rearrange from PIL import Image from custom_controlnet_aux.util import resize_image_with_pad,common_input_validate, custom_hf_download, UNIMATCH_MODEL_NAME from .utils.flow_viz import save_vis_flow_tofile, flow_to_image from .unimatch.unimatch import UniMatch import torch.nn.functional as F from argparse import Namespace def inference_flow(model, image1, #np array of HWC image2, padding_factor=8, inference_size=None, attn_type='swin', attn_splits_list=None, corr_radius_list=None, prop_radius_list=None, num_reg_refine=1, pred_bidir_flow=False, pred_bwd_flow=False, fwd_bwd_consistency_check=False, device="cpu", **kwargs ): fixed_inference_size = inference_size transpose_img = False image1 = torch.from_numpy(image1).permute(2, 0, 1).float().unsqueeze(0).to(device) image2 = torch.from_numpy(image2).permute(2, 0, 1).float().unsqueeze(0).to(device) # the model is trained with size: width > height if image1.size(-2) > image1.size(-1): image1 = torch.transpose(image1, -2, -1) image2 = torch.transpose(image2, -2, -1) transpose_img = True nearest_size = [int(np.ceil(image1.size(-2) / padding_factor)) * padding_factor, int(np.ceil(image1.size(-1) / padding_factor)) * padding_factor] # resize to nearest size or specified size inference_size = nearest_size if fixed_inference_size is None else fixed_inference_size assert isinstance(inference_size, list) or isinstance(inference_size, tuple) ori_size = image1.shape[-2:] # resize before inference if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]: image1 = F.interpolate(image1, size=inference_size, mode='bilinear', align_corners=True) image2 = F.interpolate(image2, size=inference_size, mode='bilinear', align_corners=True) if pred_bwd_flow: image1, image2 = image2, image1 results_dict = model(image1, image2, attn_type=attn_type, attn_splits_list=attn_splits_list, corr_radius_list=corr_radius_list, prop_radius_list=prop_radius_list, num_reg_refine=num_reg_refine, task='flow', pred_bidir_flow=pred_bidir_flow, ) flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] # resize back if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]: flow_pr = F.interpolate(flow_pr, size=ori_size, mode='bilinear', align_corners=True) flow_pr[:, 0] = flow_pr[:, 0] * ori_size[-1] / inference_size[-1] flow_pr[:, 1] = flow_pr[:, 1] * ori_size[-2] / inference_size[-2] if transpose_img: flow_pr = torch.transpose(flow_pr, -2, -1) flow = flow_pr[0].permute(1, 2, 0).cpu().numpy() # [H, W, 2] vis_image = flow_to_image(flow) # also predict backward flow if pred_bidir_flow: assert flow_pr.size(0) == 2 # [2, H, W, 2] flow_bwd = flow_pr[1].permute(1, 2, 0).cpu().numpy() # [H, W, 2] vis_image = flow_to_image(flow_bwd) flow = flow_bwd return flow, vis_image MODEL_CONFIGS = { "gmflow-scale1": Namespace( num_scales=1, upsample_factor=8, attn_type="swin", feature_channels=128, num_head=1, ffn_dim_expansion=4, num_transformer_layers=6, attn_splits_list=[2], corr_radius_list=[-1], prop_radius_list=[-1], reg_refine=False, num_reg_refine=1 ), "gmflow-scale2": Namespace( num_scales=2, upsample_factor=4, padding_factor=32, attn_type="swin", feature_channels=128, num_head=1, ffn_dim_expansion=4, num_transformer_layers=6, attn_splits_list=[2, 8], corr_radius_list=[-1, 4], prop_radius_list=[-1, 1], reg_refine=False, num_reg_refine=1 ), "gmflow-scale2-regrefine6": Namespace( num_scales=2, upsample_factor=4, padding_factor=32, attn_type="swin", feature_channels=128, num_head=1, ffn_dim_expansion=4, num_transformer_layers=6, attn_splits_list=[2, 8], corr_radius_list=[-1, 4], prop_radius_list=[-1, 1], reg_refine=True, num_reg_refine=6 ) } class UnimatchDetector: def __init__(self, unimatch, config_args): self.unimatch = unimatch self.config_args = config_args self.device = "cpu" @classmethod def from_pretrained(cls, pretrained_model_or_path=UNIMATCH_MODEL_NAME, filename="gmflow-scale2-regrefine6-mixdata.pth"): model_path = custom_hf_download(pretrained_model_or_path, filename) config_args = None for key in list(MODEL_CONFIGS.keys())[::-1]: if key in filename: config_args = MODEL_CONFIGS[key] break assert config_args, f"Couldn't find hardcoded Unimatch config for {filename}" model = UniMatch(feature_channels=config_args.feature_channels, num_scales=config_args.num_scales, upsample_factor=config_args.upsample_factor, num_head=config_args.num_head, ffn_dim_expansion=config_args.ffn_dim_expansion, num_transformer_layers=config_args.num_transformer_layers, reg_refine=config_args.reg_refine, task='flow') sd = torch.load(model_path, map_location="cpu") model.load_state_dict(sd['model']) return cls(model, config_args) def to(self, device): self.unimatch.to(device) self.device = device return self def __call__(self, image1, image2, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", pred_bwd_flow=False, pred_bidir_flow=False, **kwargs): assert image1.shape == image2.shape, f"[Unimatch] image1 and image2 must have the same size, got {image1.shape} and {image2.shape}" image1, output_type = common_input_validate(image1, output_type, **kwargs) #image1, remove_pad = resize_image_with_pad(image1, detect_resolution, upscale_method) image2, output_type = common_input_validate(image2, output_type, **kwargs) #image2, remove_pad = resize_image_with_pad(image2, detect_resolution, upscale_method) with torch.no_grad(): flow, vis_image = inference_flow(self.unimatch, image1, image2, device=self.device, pred_bwd_flow=pred_bwd_flow, pred_bidir_flow=pred_bidir_flow, **vars(self.config_args)) if output_type == "pil": vis_image = Image.fromarray(vis_image) return flow, vis_image