Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import numpy as np | |
| import os | |
| EPS = 1e-6 | |
| def sub2ind(height, width, y, x): | |
| return y*width + x | |
| def ind2sub(height, width, ind): | |
| y = ind // width | |
| x = ind % width | |
| return y, x | |
| def get_lr_str(lr): | |
| lrn = "%.1e" % lr # e.g., 5.0e-04 | |
| lrn = lrn[0] + lrn[3:5] + lrn[-1] # e.g., 5e-4 | |
| return lrn | |
| def strnum(x): | |
| s = '%g' % x | |
| if '.' in s: | |
| if x < 1.0: | |
| s = s[s.index('.'):] | |
| s = s[:min(len(s),4)] | |
| return s | |
| def assert_same_shape(t1, t2): | |
| for (x, y) in zip(list(t1.shape), list(t2.shape)): | |
| assert(x==y) | |
| def mkdir(path): | |
| if not os.path.exists(path): | |
| os.makedirs(path) | |
| def print_stats(name, tensor): | |
| shape = tensor.shape | |
| tensor = tensor.detach().cpu().numpy() | |
| print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape) | |
| def normalize_single(d): | |
| # d is a whatever shape torch tensor | |
| dmin = torch.min(d) | |
| dmax = torch.max(d) | |
| d = (d-dmin)/(EPS+(dmax-dmin)) | |
| return d | |
| def normalize(d): | |
| # d is B x whatever. normalize within each element of the batch | |
| out = torch.zeros(d.size(), dtype=d.dtype, device=d.device) | |
| B = list(d.size())[0] | |
| for b in list(range(B)): | |
| out[b] = normalize_single(d[b]) | |
| return out | |
| def meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False): | |
| # returns a meshgrid sized B x Y x X | |
| grid_y = torch.linspace(0.0, Y-1, Y, device=torch.device(device)) | |
| grid_y = torch.reshape(grid_y, [1, Y, 1]) | |
| grid_y = grid_y.repeat(B, 1, X) | |
| grid_x = torch.linspace(0.0, X-1, X, device=torch.device(device)) | |
| grid_x = torch.reshape(grid_x, [1, 1, X]) | |
| grid_x = grid_x.repeat(B, Y, 1) | |
| if norm: | |
| grid_y, grid_x = normalize_grid2d( | |
| grid_y, grid_x, Y, X) | |
| if stack: | |
| # note we stack in xy order | |
| # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) | |
| if on_chans: | |
| grid = torch.stack([grid_x, grid_y], dim=1) | |
| else: | |
| grid = torch.stack([grid_x, grid_y], dim=-1) | |
| return grid | |
| else: | |
| return grid_y, grid_x | |
| def gridcloud2d(B, Y, X, norm=False, device='cuda'): | |
| # we want to sample for each location in the grid | |
| grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device) | |
| x = torch.reshape(grid_x, [B, -1]) | |
| y = torch.reshape(grid_y, [B, -1]) | |
| # these are B x N | |
| xy = torch.stack([x, y], dim=2) | |
| # this is B x N x 2 | |
| return xy | |
| def reduce_masked_mean(x, mask, dim=None, keepdim=False, broadcast=False): | |
| # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting | |
| # returns shape-1 | |
| # axis can be a list of axes | |
| if not broadcast: | |
| for (a,b) in zip(x.size(), mask.size()): | |
| if not a==b: | |
| print('some shape mismatch:', x.shape, mask.shape) | |
| assert(a==b) # some shape mismatch! | |
| # assert(x.size() == mask.size()) | |
| prod = x*mask | |
| if dim is None: | |
| numer = torch.sum(prod) | |
| denom = EPS+torch.sum(mask) | |
| else: | |
| numer = torch.sum(prod, dim=dim, keepdim=keepdim) | |
| denom = EPS+torch.sum(mask, dim=dim, keepdim=keepdim) | |
| mean = numer/denom | |
| return mean | |
| def reduce_masked_median(x, mask, keep_batch=False): | |
| # x and mask are the same shape | |
| assert(x.size() == mask.size()) | |
| device = x.device | |
| B = list(x.shape)[0] | |
| x = x.detach().cpu().numpy() | |
| mask = mask.detach().cpu().numpy() | |
| if keep_batch: | |
| x = np.reshape(x, [B, -1]) | |
| mask = np.reshape(mask, [B, -1]) | |
| meds = np.zeros([B], np.float32) | |
| for b in list(range(B)): | |
| xb = x[b] | |
| mb = mask[b] | |
| if np.sum(mb) > 0: | |
| xb = xb[mb > 0] | |
| meds[b] = np.median(xb) | |
| else: | |
| meds[b] = np.nan | |
| meds = torch.from_numpy(meds).to(device) | |
| return meds.float() | |
| else: | |
| x = np.reshape(x, [-1]) | |
| mask = np.reshape(mask, [-1]) | |
| if np.sum(mask) > 0: | |
| x = x[mask > 0] | |
| med = np.median(x) | |
| else: | |
| med = np.nan | |
| med = np.array([med], np.float32) | |
| med = torch.from_numpy(med).to(device) | |
| return med.float() | |