alltracker / utils /improc.py
aharley's picture
added basics
77a88de
raw
history blame
40.5 kB
import torch
import numpy as np
import utils.basic
import utils.py
from sklearn.decomposition import PCA
from matplotlib import cm
import matplotlib.pyplot as plt
import cv2
import torch.nn.functional as F
EPS = 1e-6
from skimage.color import (
rgb2lab, rgb2yuv, rgb2ycbcr, lab2rgb, yuv2rgb, ycbcr2rgb,
rgb2hsv, hsv2rgb, rgb2xyz, xyz2rgb, rgb2hed, hed2rgb)
def _convert(input_, type_):
return {
'float': input_.float(),
'double': input_.double(),
}.get(type_, input_)
def _generic_transform_sk_3d(transform, in_type='', out_type=''):
def apply_transform_individual(input_):
device = input_.device
input_ = input_.cpu()
input_ = _convert(input_, in_type)
input_ = input_.permute(1, 2, 0).detach().numpy()
transformed = transform(input_)
output = torch.from_numpy(transformed).float().permute(2, 0, 1)
output = _convert(output, out_type)
return output.to(device)
def apply_transform(input_):
to_stack = []
for image in input_:
to_stack.append(apply_transform_individual(image))
return torch.stack(to_stack)
return apply_transform
hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb)
def flow2color(flow, clip=0.0):
B, C, H, W = list(flow.size())
assert(C==2)
flow = flow[0:1].detach()
if clip==0:
clip = torch.max(torch.abs(flow)).item()
flow = torch.clamp(flow, -clip, clip)/clip
radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) # B,1,H,W
radius_clipped = torch.clamp(radius, 0.0, 1.0)
angle = torch.atan2(-flow[:, 1:2], -flow[:, 0:1]) / np.pi # B,1,H,W
hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0)
saturation = torch.ones_like(hue) * 0.75
value = radius_clipped
hsv = torch.cat([hue, saturation, value], dim=1) # B,3,H,W
flow = hsv_to_rgb(hsv)
flow = (flow*255.0).type(torch.ByteTensor)
return flow
COLORMAP_FILE = "./utils/bremm.png"
class ColorMap2d:
def __init__(self, filename=None):
self._colormap_file = filename or COLORMAP_FILE
self._img = (plt.imread(self._colormap_file)*255).astype(np.uint8)
self._height = self._img.shape[0]
self._width = self._img.shape[1]
def __call__(self, X):
assert len(X.shape) == 2
output = np.zeros((X.shape[0], 3), dtype=np.uint8)
for i in range(X.shape[0]):
x, y = X[i, :]
xp = int((self._width-1) * x)
yp = int((self._height-1) * y)
xp = np.clip(xp, 0, self._width-1)
yp = np.clip(yp, 0, self._height-1)
output[i, :] = self._img[yp, xp]
return output
def get_2d_colors(xys, H, W):
N,D = xys.shape
assert(D==2)
bremm = ColorMap2d()
xys[:,0] /= float(W-1)
xys[:,1] /= float(H-1)
colors = bremm(xys)
# print('colors', colors)
# colors = (colors[0]*255).astype(np.uint8)
# colors = (int(colors[0]),int(colors[1]),int(colors[2]))
return colors
def get_n_colors(N, sequential=False):
label_colors = []
for ii in range(N):
if sequential:
rgb = cm.winter(ii/(N-1))
rgb = (np.array(rgb) * 255).astype(np.uint8)[:3]
else:
rgb = np.zeros(3)
while np.sum(rgb) < 128: # ensure min brightness
rgb = np.random.randint(0,256,3)
label_colors.append(rgb)
return label_colors
def pca_embed(emb, keep, valid=None):
# helper function for reduce_emb
# emb is B,C,H,W
# keep is the number of principal components to keep
emb = emb + EPS
emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C
if valid:
valid = valid.cpu().detach().numpy().reshape((H*W))
emb_reduced = list()
B, H, W, C = np.shape(emb)
for img in emb:
if np.isnan(img).any():
emb_reduced.append(np.zeros([H, W, keep]))
continue
pixels_kd = np.reshape(img, (H*W, C))
if valid:
pixels_kd_pca = pixels_kd[valid]
else:
pixels_kd_pca = pixels_kd
P = PCA(keep)
P.fit(pixels_kd_pca)
if valid:
pixels3d = P.transform(pixels_kd)*valid
else:
pixels3d = P.transform(pixels_kd)
out_img = np.reshape(pixels3d, [H,W,keep]).astype(np.float32)
if np.isnan(out_img).any():
emb_reduced.append(np.zeros([H, W, keep]))
continue
emb_reduced.append(out_img)
emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32)
return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2)
def pca_embed_together(emb, keep):
# emb is B,C,H,W
# keep is the number of principal components to keep
emb = emb + EPS
emb = emb.permute(0, 2, 3, 1).cpu().detach().float().numpy() #this is B x H x W x C
B, H, W, C = np.shape(emb)
if np.isnan(emb).any():
return torch.zeros(B, keep, H, W)
pixelskd = np.reshape(emb, (B*H*W, C))
P = PCA(keep)
P.fit(pixelskd)
pixels3d = P.transform(pixelskd)
out_img = np.reshape(pixels3d, [B,H,W,keep]).astype(np.float32)
if np.isnan(out_img).any():
return torch.zeros(B, keep, H, W)
return torch.from_numpy(out_img).permute(0, 3, 1, 2)
def reduce_emb(emb, valid=None, inbound=None, together=False):
S, C, H, W = list(emb.size())
keep = 4
if together:
reduced_emb = pca_embed_together(emb, keep)
else:
reduced_emb = pca_embed(emb, keep, valid) #not im
reduced_emb = reduced_emb[:,1:]
reduced_emb = utils.basic.normalize(reduced_emb) - 0.5
if inbound is not None:
emb_inbound = emb*inbound
else:
emb_inbound = None
return reduced_emb, emb_inbound
def get_feat_pca(feat, valid=None):
B, C, D, W = list(feat.size())
pca, _ = reduce_emb(feat, valid=valid,inbound=None, together=True)
return pca
def gif_and_tile(ims, just_gif=False):
S = len(ims)
# each im is B x H x W x C
# i want a gif in the left, and the tiled frames on the right
# for the gif tool, this means making a B x S x H x W tensor
# where the leftmost part is sequential and the rest is tiled
gif = torch.stack(ims, dim=1)
if just_gif:
return gif
til = torch.cat(ims, dim=2)
til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1)
im = torch.cat([gif, til], dim=3)
return im
def preprocess_color(x):
if isinstance(x, np.ndarray):
return x.astype(np.float32) * 1./255 - 0.5
else:
return x.float() * 1./255 - 0.5
def back2color(i, blacken_zeros=False):
if blacken_zeros:
const = torch.tensor([-0.5])
i = torch.where(i==0.0, const.cuda() if i.is_cuda else const, i)
return back2color(i)
else:
return ((i+0.5)*255).type(torch.ByteTensor)
def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20, shadow=True):
rgb = vis.detach().cpu().numpy()[0]
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
color = (255, 255, 255)
# print('putting frame id', frame_id)
frame_str = utils.basic.strnum(frame_id)
text_color_bg = (0,0,0)
font = cv2.FONT_HERSHEY_SIMPLEX
text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
text_w, text_h = text_size
if shadow:
cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
cv2.putText(
rgb,
frame_str,
(left, top), # from left, from top
font,
scale, # font scale (float)
color,
1) # font thickness (int)
rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
return vis
def draw_frame_str_on_vis(vis, frame_str, scale=0.5, left=5, top=40, shadow=True):
rgb = vis.detach().cpu().numpy()[0]
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
color = (255, 255, 255)
text_color_bg = (0,0,0)
font = cv2.FONT_HERSHEY_SIMPLEX
text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)
text_w, text_h = text_size
if shadow:
cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1)
cv2.putText(
rgb,
frame_str,
(left, top), # from left, from top
font,
scale, # font scale (float)
color,
1) # font thickness (int)
rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)
vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
return vis
class Summ_writer(object):
def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False):
self.writer = writer
self.global_step = global_step
self.log_freq = log_freq
self.scalar_freq = scalar_freq
self.fps = fps
self.just_gif = just_gif
self.maxwidth = 10000
self.save_this = (self.global_step % self.log_freq == 0)
self.scalar_freq = max(scalar_freq,1)
self.save_scalar = (self.global_step % self.scalar_freq == 0)
if self.save_this:
self.save_scalar = True
def summ_gif(self, name, tensor, blacken_zeros=False):
# tensor should be in B x S x C x H x W
assert tensor.dtype in {torch.uint8,torch.float32}
shape = list(tensor.shape)
if tensor.dtype == torch.float32:
tensor = back2color(tensor, blacken_zeros=blacken_zeros)
video_to_write = tensor[0:1]
S = video_to_write.shape[1]
if S==1:
# video_to_write is 1 x 1 x C x H x W
self.writer.add_image(name, video_to_write[0,0], global_step=self.global_step)
else:
self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step)
return video_to_write
def summ_rgbs(self, name, ims, frame_ids=None, frame_strs=None, blacken_zeros=False, only_return=False):
if self.save_this:
ims = gif_and_tile(ims, just_gif=self.just_gif)
vis = ims
assert vis.dtype in {torch.uint8,torch.float32}
if vis.dtype == torch.float32:
vis = back2color(vis, blacken_zeros)
B, S, C, H, W = list(vis.shape)
if frame_ids is not None:
assert(len(frame_ids)==S)
for s in range(S):
vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
if frame_strs is not None:
assert(len(frame_strs)==S)
for s in range(S):
vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s])
if int(W) > self.maxwidth:
vis = vis[:,:,:,:self.maxwidth]
if only_return:
return vis
else:
return self.summ_gif(name, vis, blacken_zeros)
def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, frame_str=None, only_return=False, halfres=False, shadow=True):
if self.save_this:
assert ims.dtype in {torch.uint8,torch.float32}
if ims.dtype == torch.float32:
ims = back2color(ims, blacken_zeros)
#ims is B x C x H x W
vis = ims[0:1] # just the first one
B, C, H, W = list(vis.shape)
if halfres:
vis = F.interpolate(vis, scale_factor=0.5)
if frame_id is not None:
vis = draw_frame_id_on_vis(vis, frame_id, shadow=shadow)
if frame_str is not None:
vis = draw_frame_str_on_vis(vis, frame_str, shadow=shadow)
if int(W) > self.maxwidth:
vis = vis[:,:,:,:self.maxwidth]
if only_return:
return vis
else:
return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros)
def flow2color(self, flow, clip=0.0):
B, C, H, W = list(flow.size())
assert(C==2)
flow = flow[0:1].detach()
if False:
flow = flow[0].detach().cpu().permute(1,2,0).numpy() # H,W,2
if clip > 0:
clip_flow = clip
else:
clip_flow = None
im = utils.py.flow_to_image(flow, clip_flow=clip_flow, convert_to_bgr=True)
# im = utils.py.flow_to_image(flow, convert_to_bgr=True)
im = torch.from_numpy(im).permute(2,0,1).unsqueeze(0).byte() # 1,3,H,W
im = torch.flip(im, dims=[1]).clone() # BGR
# # i prefer black bkg
# white_pixels = (im == 255).all(dim=1, keepdim=True)
# im[white_pixels.expand(-1, 3, -1, -1)] = 0
return im
# flow_abs = torch.abs(flow)
# flow_mean = flow_abs.mean(dim=[1,2,3])
# flow_std = flow_abs.std(dim=[1,2,3])
if clip==0:
clip = torch.max(torch.abs(flow)).item()
# if clip:
flow = torch.clamp(flow, -clip, clip)/clip
# else:
# # # Apply some kind of normalization. Divide by the perceived maximum (mean + std*2)
# # flow_max = flow_mean + flow_std*2 + 1e-10
# # for b in range(B):
# # flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1)
# flow_max = torch.max(flow_abs[b])
# for b in range(B):
# flow[b] = flow[b].clamp(-flow_max.item(), flow_max.item()) / flow_max[b].clamp(min=1)
radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) #B x 1 x H x W
radius_clipped = torch.clamp(radius, 0.0, 1.0)
angle = torch.atan2(-flow[:, 1:2], -flow[:, 0:1]) / np.pi # B x 1 x H x W
hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0)
# hue = torch.mod(angle / (2 * np.pi) + 1.0, 1.0)
saturation = torch.ones_like(hue) * 0.75
value = radius_clipped
hsv = torch.cat([hue, saturation, value], dim=1) #B x 3 x H x W
#flow = tf.image.hsv_to_rgb(hsv)
flow = hsv_to_rgb(hsv)
flow = (flow*255.0).type(torch.ByteTensor)
# flow = torch.flip(flow, dims=[1]).clone() # BGR
return flow
def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None, frame_str=None, shadow=True):
# flow is B x C x D x W
if self.save_this:
return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id, frame_str=frame_str, shadow=shadow)
else:
return None
def summ_oneds(self, name, ims, frame_ids=None, frame_strs=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False):
if self.save_this:
if bev:
B, C, H, _, W = list(ims[0].shape)
if reduce_max:
ims = [torch.max(im, dim=3)[0] for im in ims]
else:
ims = [torch.mean(im, dim=3) for im in ims]
elif fro:
B, C, _, H, W = list(ims[0].shape)
if reduce_max:
ims = [torch.max(im, dim=2)[0] for im in ims]
else:
ims = [torch.mean(im, dim=2) for im in ims]
if len(ims) != 1: # sequence
im = gif_and_tile(ims, just_gif=self.just_gif)
else:
im = torch.stack(ims, dim=1) # single frame
B, S, C, H, W = list(im.shape)
if logvis and max_val:
max_val = np.log(max_val)
im = torch.log(torch.clamp(im, 0)+1.0)
im = torch.clamp(im, 0, max_val)
im = im/max_val
norm = False
elif max_val:
im = torch.clamp(im, 0, max_val)
im = im/max_val
norm = False
if norm:
# normalize before oned2inferno,
# so that the ranges are similar within B across S
im = utils.basic.normalize(im)
im = im.view(B*S, C, H, W)
vis = oned2inferno(im, norm=norm, do_colorize=do_colorize)
vis = vis.view(B, S, 3, H, W)
if frame_ids is not None:
assert(len(frame_ids)==S)
for s in range(S):
vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s])
if frame_strs is not None:
assert(len(frame_strs)==S)
for s in range(S):
vis[:,s] = draw_frame_str_on_vis(vis[:,s], frame_strs[s])
if W > self.maxwidth:
vis = vis[...,:self.maxwidth]
if only_return:
return vis
else:
self.summ_gif(name, vis)
def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, frame_str=None, only_return=False, shadow=True):
if self.save_this:
if bev:
B, C, H, _, W = list(im.shape)
if max_along_y:
im = torch.max(im, dim=3)[0]
else:
im = torch.mean(im, dim=3)
elif fro:
B, C, _, H, W = list(im.shape)
if max_along_y:
im = torch.max(im, dim=2)[0]
else:
im = torch.mean(im, dim=2)
else:
B, C, H, W = list(im.shape)
im = im[0:1] # just the first one
assert(C==1)
if logvis and max_val:
max_val = np.log(max_val)
im = torch.log(im)
im = torch.clamp(im, 0, max_val)
im = im/max_val
norm = False
elif max_val:
im = torch.clamp(im, 0, max_val)/max_val
norm = False
vis = oned2inferno(im, norm=norm)
if W > self.maxwidth:
vis = vis[...,:self.maxwidth]
return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, frame_str=frame_str, only_return=only_return, shadow=shadow)
def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None, frame_strs=None):
if self.save_this:
if valids is not None:
valids = torch.stack(valids, dim=1)
feats = torch.stack(feats, dim=1)
# feats leads with B x S x C
if feats.ndim==6:
# feats is B x S x C x D x H x W
if fro:
reduce_dim = 3
else:
reduce_dim = 4
if valids is None:
feats = torch.mean(feats, dim=reduce_dim)
else:
valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1)
feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim)
B, S, C, D, W = list(feats.size())
if not pca:
# feats leads with B x S x C
feats = torch.mean(torch.abs(feats), dim=2, keepdims=True)
# feats leads with B x S x 1
feats = torch.unbind(feats, dim=1)
return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
else:
__p = lambda x: utils.basic.pack_seqdim(x, B)
__u = lambda x: utils.basic.unpack_seqdim(x, B)
feats_ = __p(feats)
if valids is None:
feats_pca_ = get_feat_pca(feats_)
else:
valids_ = __p(valids)
feats_pca_ = get_feat_pca(feats_, valids)
feats_pca = __u(feats_pca_)
return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None, frame_str=None):
if self.save_this:
if feat.ndim==5: # B x C x D x H x W
if bev:
reduce_axis = 3
elif fro:
reduce_axis = 2
else:
# default to bev
reduce_axis = 3
if valid is None:
feat = torch.mean(feat, dim=reduce_axis)
else:
valid = valid.repeat(1, feat.size()[1], 1, 1, 1)
feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis)
B, C, D, W = list(feat.shape)
if not pca:
feat = torch.mean(torch.abs(feat), dim=1, keepdims=True)
# feat is B x 1 x D x W
return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
else:
feat_pca = get_feat_pca(feat, valid)
return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
def summ_scalar(self, name, value):
if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ('torch' in value.type()):
value = value.detach().cpu().numpy()
if not np.isnan(value):
if (self.log_freq == 1):
self.writer.add_scalar(name, value, global_step=self.global_step)
elif self.save_this or self.save_scalar:
self.writer.add_scalar(name, value, global_step=self.global_step)
def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap='coolwarm', vals=None, linewidth=1, max_show=1024):
# trajs is B, S, N, 2
# rgbs is B, S, C, H, W
B, S, C, H, W = rgbs.shape
B, S2, N, D = trajs.shape
assert(S==S2)
rgbs = rgbs[0] # S, C, H, W
trajs = trajs[0] # S, N, 2
if valids is None:
valids = torch.ones_like(trajs[:,:,0]) # S, N
else:
valids = valids[0]
if visibs is None:
visibs = torch.ones_like(trajs[:,:,0]) # S, N
else:
visibs = visibs[0]
if vals is not None:
vals = vals[0] # N
# print('vals', vals.shape)
if N > max_show:
inds = np.random.choice(N, max_show)
trajs = trajs[:,inds]
valids = valids[:,inds]
visibs = visibs[:,inds]
if vals is not None:
vals = vals[inds]
N = trajs.shape[1]
trajs = trajs.clamp(-16, W+16)
rgbs_color = []
for rgb in rgbs:
rgb = back2color(rgb).detach().cpu().numpy()
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
rgbs_color.append(rgb) # each element 3 x H x W
for i in range(min(N, max_show)):
if cmap=='onediff' and i==0:
cmap_ = 'spring'
elif cmap=='onediff':
cmap_ = 'winter'
else:
cmap_ = cmap
traj = trajs[:,i].long().detach().cpu().numpy() # S, 2
valid = valids[:,i].long().detach().cpu().numpy() # S
# print('traj', traj.shape)
# print('valid', valid.shape)
if vals is not None:
# val = vals[:,i].float().detach().cpu().numpy() # []
val = vals[i].float().detach().cpu().numpy() # []
# print('val', val.shape)
else:
val = None
for t in range(S):
if valid[t]:
rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj[:t+1], S=S, show_dots=show_dots, cmap=cmap_, val=val, linewidth=linewidth)
for i in range(min(N, max_show)):
if cmap=='onediff' and i==0:
cmap_ = 'spring'
elif cmap=='onediff':
cmap_ = 'winter'
else:
cmap_ = cmap
traj = trajs[:,i] # S,2
vis = visibs[:,i].round() # S
valid = valids[:,i] # S
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
rgbs = []
for rgb in rgbs_color:
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
rgbs.append(preprocess_color(rgb))
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, frame_strs=None, only_return=False, show_dots=True, cmap=None, linewidth=1, max_show=1024):
# trajs is B, S, N, 2
# rgbs is B, S, C, H, W
B, S, C, H, W = rgbs.shape
B, S2, N, D = trajs.shape
assert(S==S2)
rgbs = rgbs[0] # S, C, H, W
trajs = trajs[0] # S, N, 2
visibles = visibles[0] # S, N
if valids is None:
valids = torch.ones_like(trajs[:,:,0]) # S, N
else:
valids = valids[0]
rgbs_color = []
for rgb in rgbs:
rgb = back2color(rgb).detach().cpu().numpy()
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
rgbs_color.append(rgb) # each element 3 x H x W
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
visibles = visibles.float().detach().cpu().numpy() # S, N
valids = valids.long().detach().cpu().numpy() # S, N
for i in range(min(N, max_show)):
if cmap=='onediff' and i==0:
cmap_ = 'spring'
elif cmap=='onediff':
cmap_ = 'winter'
else:
cmap_ = cmap
traj = trajs[:,i] # S,2
vis = visibles[:,i] # S
valid = valids[:,i] # S
rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth)
for i in range(min(N, max_show)):
if cmap=='onediff' and i==0:
cmap_ = 'spring'
elif cmap=='onediff':
cmap_ = 'winter'
else:
cmap_ = cmap
traj = trajs[:,i] # S,2
vis = visibles[:,i] # S
valid = valids[:,i] # S
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth)
rgbs = []
for rgb in rgbs_color:
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
rgbs.append(preprocess_color(rgb))
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)
def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=True, show_lines=True, frame_id=None, frame_str=None, only_return=False, cmap='coolwarm', linewidth=1, max_show=1024):
# trajs is B, S, N, 2
# rgb is B, C, H, W
B, C, H, W = rgb.shape
B, S, N, D = trajs.shape
rgb = rgb[0] # S, C, H, W
trajs = trajs[0] # S, N, 2
if valids is None:
valids = torch.ones_like(trajs[:,:,0])
else:
valids = valids[0]
rgb_color = back2color(rgb).detach().cpu().numpy()
rgb_color = np.transpose(rgb_color, [1, 2, 0]) # put channels last
# using maxdist will dampen the colors for short motions
# norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0])**2, dim=1)) # N
# maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy()
maxdist = None
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
valids = valids.long().detach().cpu().numpy() # S, N
if N > max_show:
inds = np.random.choice(N, max_show)
trajs = trajs[:,inds]
valids = valids[:,inds]
N = trajs.shape[1]
for i in range(min(N, max_show)):
if cmap=='onediff' and i==0:
cmap_ = 'spring'
elif cmap=='onediff':
cmap_ = 'winter'
else:
cmap_ = cmap
traj = trajs[:,i] # S, 2
valid = valids[:,i] # S
if valid[0]==1:
traj = traj[valid>0]
rgb_color = self.draw_traj_on_image_py(
rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth)
rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0)
rgb = preprocess_color(rgb_color)
return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None):
# all inputs are numpy tensors
# rgb is 3 x H x W
# traj is S x 2
H, W, C = rgb.shape
assert(C==3)
rgb = rgb.astype(np.uint8).copy()
S1, D = traj.shape
assert(D==2)
color_map = cm.get_cmap(cmap)
S1, D = traj.shape
for s in range(S1):
if val is not None:
color = np.array(color_map(val)[:3]) * 255 # rgb
else:
if maxdist is not None:
val = (np.sqrt(np.sum((traj[s]-traj[0])**2))/maxdist).clip(0,1)
color = np.array(color_map(val)[:3]) * 255 # rgb
else:
color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
if show_lines and s<(S1-1):
cv2.line(rgb,
(int(traj[s,0]), int(traj[s,1])),
(int(traj[s+1,0]), int(traj[s+1,1])),
color,
linewidth,
cv2.LINE_AA)
if show_dots:
cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, -1)
# if maxdist is not None:
# val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1)
# color = np.array(color_map(val)[:3]) * 255 # rgb
# else:
# # draw the endpoint of traj, using the next color (which may be the last color)
# color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb
# # emphasize endpoint
# cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1)
return rgb
def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None):
# all inputs are numpy tensors
# rgbs is a list of H,W,3
# traj is S,2
H, W, C = rgbs[0].shape
assert(C==3)
rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
S1, D = traj.shape
assert(D==2)
x = int(np.clip(traj[0,0], 0, W-1))
y = int(np.clip(traj[0,1], 0, H-1))
color = rgbs[0][y,x]
color = (int(color[0]),int(color[1]),int(color[2]))
for s in range(S):
# bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb
# cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1)
cv2.polylines(rgbs[s],
[traj[:s+1]],
False,
color,
linewidth,
cv2.LINE_AA)
return rgbs
def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None):
# all inputs are numpy tensors
# rgbs is a list of 3,H,W
# xy is N,2
H, W, C = rgb.shape
assert(C==3)
rgb = rgb.astype(np.uint8).copy()
N, D = xy.shape
assert(D==2)
xy = xy.astype(np.float32)
xy[:,0] = np.clip(xy[:,0], 0, W-1)
xy[:,1] = np.clip(xy[:,1], 0, H-1)
xy = xy.astype(np.int32)
if colors is None:
colors = get_n_colors(N)
for n in range(N):
color = colors[n]
# print('color', color)
# color = (color[0]*255).astype(np.uint8)
color = (int(color[0]),int(color[1]),int(color[2]))
# x = int(np.clip(xy[0,0], 0, W-1))
# y = int(np.clip(xy[0,1], 0, H-1))
# color_ = rgbs[0][y,x]
# color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
# color_ = (int(color_[0]),int(color_[1]),int(color_[2]))
cv2.circle(rgb, (int(xy[n,0]), int(xy[n,1])), linewidth, color, 3)
# vis_color = int(np.squeeze(vis[s])*255)
# vis_color = (vis_color,vis_color,vis_color)
# cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1)
return rgb
def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None):
# all inputs are numpy tensors
# rgbs is a list of 3,H,W
# traj is S,2
H, W, C = rgbs[0].shape
assert(C==3)
rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]
S1, D = traj.shape
assert(D==2)
if cmap is None:
bremm = ColorMap2d()
traj_ = traj[0:1].astype(np.float32)
traj_[:,0] /= float(W)
traj_[:,1] /= float(H)
color = bremm(traj_)
# print('color', color)
color = (color[0]*255).astype(np.uint8)
color = (int(color[0]),int(color[1]),int(color[2]))
for s in range(S):
if cmap is not None:
color_map = cm.get_cmap(cmap)
# color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb
color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb
# color = color.astype(np.uint8)
# color = (color[0], color[1], color[2])
# print('color', color)
# import ipdb; ipdb.set_trace()
cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+2, color, -1)
vis_color = int(np.squeeze(vis[s])*255)
vis_color = (vis_color,vis_color,vis_color)
cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1)
return rgbs
def summ_pts_on_rgb(self, name, trajs, rgb, visibs=None, valids=None, frame_id=None, frame_str=None, only_return=False, show_dots=True, colors=None, cmap='coolwarm', linewidth=1, max_show=1024, already_sorted=False):
# trajs is B, S, N, 2
# rgbs is B, S, C, H, W
B, C, H, W = rgb.shape
B, S, N, D = trajs.shape
rgb = rgb[0] # C, H, W
trajs = trajs[0] # S, N, 2
if valids is None:
valids = torch.ones_like(trajs[:,:,0]) # S, N
else:
valids = valids[0]
if visibs is None:
visibs = torch.ones_like(trajs[:,:,0]) # S, N
else:
visibs = visibs[0]
trajs = trajs.clamp(-16, W+16)
if N > max_show:
inds = np.random.choice(N, max_show)
trajs = trajs[:,inds]
valids = valids[:,inds]
visibs = visibs[:,inds]
N = trajs.shape[1]
if not already_sorted:
inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0))
trajs = trajs[:,inds]
valids = valids[:,inds]
visibs = visibs[:,inds]
rgb = back2color(rgb).detach().cpu().numpy()
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
valids = valids.long().detach().cpu().numpy() # S, N
visibs = visibs.long().detach().cpu().numpy() # S, N
rgb = rgb.astype(np.uint8).copy()
for i in range(min(N, max_show)):
if cmap=='onediff' and i==0:
cmap_ = 'spring'
elif cmap=='onediff':
cmap_ = 'winter'
else:
cmap_ = cmap
traj = trajs[:,i] # S,2
valid = valids[:,i] # S
visib = visibs[:,i] # S
if colors is None:
ii = i/(1e-4+N-1.0)
color_map = cm.get_cmap(cmap)
color = np.array(color_map(ii)[:3]) * 255 # rgb
else:
color = np.array(colors[i]).astype(np.int64)
color = (int(color[0]),int(color[1]),int(color[2]))
for s in range(S):
if valid[s]:
if visib[s]:
thickness = -1
else:
thickness = 2
cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, thickness)
rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0)
rgb = preprocess_color(rgb)
return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id, frame_str=frame_str)
def summ_pts_on_rgbs(self, name, trajs, rgbs, visibs=None, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', colors=None, linewidth=1, max_show=1024, frame_strs=None):
# trajs is B, S, N, 2
# rgbs is B, S, C, H, W
B, S, C, H, W = rgbs.shape
B, S2, N, D = trajs.shape
assert(S==S2)
rgbs = rgbs[0] # S, C, H, W
trajs = trajs[0] # S, N, 2
if valids is None:
valids = torch.ones_like(trajs[:,:,0]) # S, N
else:
valids = valids[0]
if visibs is None:
visibs = torch.ones_like(trajs[:,:,0]) # S, N
else:
visibs = visibs[0]
if N > max_show:
inds = np.random.choice(N, max_show)
trajs = trajs[:,inds]
valids = valids[:,inds]
visibs = visibs[:,inds]
N = trajs.shape[1]
inds = torch.argsort(torch.mean(trajs[:,:,1], dim=0))
trajs = trajs[:,inds]
valids = valids[:,inds]
visibs = visibs[:,inds]
rgbs_color = []
for rgb in rgbs:
rgb = back2color(rgb).detach().cpu().numpy()
rgb = np.transpose(rgb, [1, 2, 0]) # put channels last
rgbs_color.append(rgb) # each element 3 x H x W
trajs = trajs.long().detach().cpu().numpy() # S, N, 2
valids = valids.long().detach().cpu().numpy() # S, N
visibs = visibs.long().detach().cpu().numpy() # S, N
rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color]
for i in range(min(N, max_show)):
traj = trajs[:,i] # S,2
valid = valids[:,i] # S
visib = visibs[:,i] # S
if colors is None:
ii = i/(1e-4+N-1.0)
color_map = cm.get_cmap(cmap)
color = np.array(color_map(ii)[:3]) * 255 # rgb
else:
color = np.array(colors[i]).astype(np.int64)
color = (int(color[0]),int(color[1]),int(color[2]))
for s in range(S):
if valid[s]:
if visib[s]:
thickness = -1
else:
thickness = 2
cv2.circle(rgbs_color[s], (int(traj[s,0]), int(traj[s,1])), int(linewidth), color, thickness)
rgbs = []
for rgb in rgbs_color:
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
rgbs.append(preprocess_color(rgb))
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids, frame_strs=frame_strs)