Blending_mmodel / tps_warp.py
gaur3009's picture
Update tps_warp.py
5a64740 verified
raw
history blame contribute delete
974 Bytes
import numpy as np
import cv2
from scipy.interpolate import Rbf
def apply_tps_warp(image, mask, warp_strength=5.0):
height, width = image.shape[:2]
y, x = np.where(mask > 0)
num_points = min(100, len(x))
indices = np.random.choice(len(x), num_points, replace=False)
src_points = np.array([x[indices], y[indices]]).T
dst_points = src_points + np.random.randint(-warp_strength, warp_strength, src_points.shape)
# Apply Thin Plate Spline warping
rbf_x = Rbf(src_points[:, 0], src_points[:, 1], dst_points[:, 0], function='thin_plate')
rbf_y = Rbf(src_points[:, 0], src_points[:, 1], dst_points[:, 1], function='thin_plate')
grid_x, grid_y = np.meshgrid(np.arange(width), np.arange(height))
map_x = rbf_x(grid_x, grid_y).astype(np.float32)
map_y = rbf_y(grid_x, grid_y).astype(np.float32)
warped_image = cv2.remap(image, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
return warped_image