Spaces:
Running
Running
| import torch | |
| import numpy as np | |
| import copy | |
| def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs): | |
| """ """ | |
| pos_lists_, pos_masks_, label_lists_ = [], [], [] | |
| img_bs = batch_size | |
| ngpu = int(batch_size / img_bs) | |
| img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy() | |
| pos_lists_split, pos_masks_split, label_lists_split = [], [], [] | |
| for i in range(ngpu): | |
| pos_lists_split.append([]) | |
| pos_masks_split.append([]) | |
| label_lists_split.append([]) | |
| for i in range(img_ids.shape[0]): | |
| img_id = img_ids[i] | |
| gpu_id = int(img_id / img_bs) | |
| img_id = img_id % img_bs | |
| pos_list = pos_lists[i].copy() | |
| pos_list[:, 0] = img_id | |
| pos_lists_split[gpu_id].append(pos_list) | |
| pos_masks_split[gpu_id].append(pos_masks[i].copy()) | |
| label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i])) | |
| # repeat or delete | |
| for i in range(ngpu): | |
| vp_len = len(pos_lists_split[i]) | |
| if vp_len <= tcl_bs: | |
| for j in range(0, tcl_bs - vp_len): | |
| pos_list = pos_lists_split[i][j].copy() | |
| pos_lists_split[i].append(pos_list) | |
| pos_mask = pos_masks_split[i][j].copy() | |
| pos_masks_split[i].append(pos_mask) | |
| label_list = copy.deepcopy(label_lists_split[i][j]) | |
| label_lists_split[i].append(label_list) | |
| else: | |
| for j in range(0, vp_len - tcl_bs): | |
| c_len = len(pos_lists_split[i]) | |
| pop_id = np.random.permutation(c_len)[0] | |
| pos_lists_split[i].pop(pop_id) | |
| pos_masks_split[i].pop(pop_id) | |
| label_lists_split[i].pop(pop_id) | |
| # merge | |
| for i in range(ngpu): | |
| pos_lists_.extend(pos_lists_split[i]) | |
| pos_masks_.extend(pos_masks_split[i]) | |
| label_lists_.extend(label_lists_split[i]) | |
| return pos_lists_, pos_masks_, label_lists_ | |
| def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums, | |
| pad_num, tcl_bs): | |
| label_list = label_list.numpy() | |
| batch, _, _, _ = label_list.shape | |
| pos_list = pos_list.numpy() | |
| pos_mask = pos_mask.numpy() | |
| pos_list_t = [] | |
| pos_mask_t = [] | |
| label_list_t = [] | |
| for i in range(batch): | |
| for j in range(max_text_nums): | |
| if pos_mask[i, j].any(): | |
| pos_list_t.append(pos_list[i][j]) | |
| pos_mask_t.append(pos_mask[i][j]) | |
| label_list_t.append(label_list[i][j]) | |
| pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t, | |
| label_list_t, tcl_bs) | |
| label = [] | |
| tt = [l.tolist() for l in label_list] | |
| for i in range(tcl_bs): | |
| k = 0 | |
| for j in range(max_text_length): | |
| if tt[i][j][0] != pad_num: | |
| k += 1 | |
| else: | |
| break | |
| label.append(k) | |
| label = torch.tensor(label) | |
| label = label.long() | |
| pos_list = torch.tensor(pos_list) | |
| pos_mask = torch.tensor(pos_mask) | |
| label_list = torch.squeeze(torch.tensor(label_list), dim=2) | |
| label_list = label_list.int() | |
| return pos_list, pos_mask, label_list, label | |