Spaces:
Running
Running
File size: 4,824 Bytes
789c2b2 b0b5f60 789c2b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
Hugging Face's logo
Hugging Face
Models
Datasets
Spaces
Community
Docs
Enterprise
Pricing
Spaces:
AItool
/
RIFE_Interpolation
like
0
Logs
App
Files
Community
Settings
RIFE_Interpolation
/
inference_img.py
AItool's picture
AItool
Create inference_img.py
b0b5f60
verified
3 days ago
raw
Copy download link
history
blame
edit
delete
4.48 kB
import os
import cv2
import torch
import argparse
from torch.nn import functional as F
import warnings
OUTPUT_PATH = "/home/user/app/output/"
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
parser.add_argument('--img', dest='img', nargs=2, required=True)
parser.add_argument('--exp', default=2, type=int)
parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range')
parser.add_argument('--rthreshold', default=0.02, type=float, help='returns image when actual ratio falls in given range threshold')
parser.add_argument('--rmaxcycles', default=8, type=int, help='limit max number of bisectional cycles')
parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
args = parser.parse_args()
try:
from train_log.RIFE_HDv3 import Model
model = Model()
model.load_model(args.modelDir, -1)
print("Loaded RIFE_HDv3 model.")
print("Checkpoint reached RIFE!")
except:
from train_log.IFNet_HDv3 import Model
model = Model()
model.load_model(args.modelDir, -1)
print("Loaded IFNet_HDv3 model.")
print("Checkpoint reached IFNet!")
model.eval()
model.device()
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0)
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0)
else:
img0 = cv2.imread(args.img[0], cv2.IMREAD_UNCHANGED)
img1 = cv2.imread(args.img[1], cv2.IMREAD_UNCHANGED)
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
n, c, h, w = img0.shape
ph = ((h - 1) // 32 + 1) * 32
pw = ((w - 1) // 32 + 1) * 32
padding = (0, pw - w, 0, ph - h)
img0 = F.pad(img0, padding)
img1 = F.pad(img1, padding)
if args.ratio:
img_list = [img0]
img0_ratio = 0.0
img1_ratio = 1.0
if args.ratio <= img0_ratio + args.rthreshold / 2:
middle = img0
elif args.ratio >= img1_ratio - args.rthreshold / 2:
middle = img1
else:
tmp_img0 = img0
tmp_img1 = img1
for inference_cycle in range(args.rmaxcycles):
middle = model.inference(tmp_img0, tmp_img1)
middle_ratio = (img0_ratio + img1_ratio) / 2
if args.ratio - (args.rthreshold / 2) <= middle_ratio <= args.ratio + (args.rthreshold / 2):
break
if args.ratio > middle_ratio:
tmp_img0 = middle
img0_ratio = middle_ratio
else:
tmp_img1 = middle
img1_ratio = middle_ratio
img_list.append(middle)
img_list.append(img1)
else:
img_list = [img0, img1]
for i in range(args.exp):
tmp = []
for j in range(len(img_list) - 1):
mid = model.inference(img_list[j], img_list[j + 1])
tmp.append(img_list[j])
tmp.append(mid)
tmp.append(img1)
img_list = tmp
if not os.path.exists('output'):
os.mkdir('output')
print("Checkpoint reached! output folder ok")
for i in range(len(img_list)):
filename_exr = os.path.join(OUTPUT_PATH, f"img{i}.exr")
filename_png = os.path.join(OUTPUT_PATH, f"img{i}.png")
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
cv2.imwrite(filename_exr, (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
success = cv2.imwrite(filename_png, (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
print(f"Saving to {filename_png} β success: {success}")
print("Saving to:", os.path.abspath(filename_png))
else:
success = cv2.imwrite(filename_png, (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])
print(f"Saving to {filename_png} β success: {success}")
print("Saving to:", os.path.abspath(filename_png))
print("Checkpoint reached!")
|