Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| import pdb | |
| class SpatialDataset(Dataset): | |
| def __init__(self, tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors = False ): | |
| self.tokenizer = tokenizer | |
| self.max_token_len = max_token_len | |
| self.distance_norm_factor = distance_norm_factor | |
| self.sep_between_neighbors = sep_between_neighbors | |
| def parse_spatial_context(self, pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill, pivot_dist_fill = 0): | |
| sep_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.sep_token) | |
| cls_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.cls_token) | |
| mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) | |
| pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) | |
| max_token_len = self.max_token_len | |
| # process pivot | |
| pivot_name_tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(pivot_name)) | |
| pivot_token_len = len(pivot_name_tokens) | |
| pivot_lng = pivot_pos[0] | |
| pivot_lat = pivot_pos[1] | |
| # prepare entity mask | |
| entity_mask_arr = [] | |
| rand_entity = np.random.uniform(size = len(neighbor_name_list) + 1) # random number for masking entities including neighbors and pivot | |
| # True for mask, False for unmask | |
| # check if pivot entity needs to be masked out, 15% prob. to be masked out | |
| if rand_entity[0] < 0.15: | |
| entity_mask_arr.extend([True] * pivot_token_len) | |
| else: | |
| entity_mask_arr.extend([False] * pivot_token_len) | |
| # process neighbors | |
| neighbor_token_list = [] | |
| neighbor_lng_list = [] | |
| neighbor_lat_list = [] | |
| # add separator between pivot and neighbor tokens | |
| # a trick to avoid adding separator token after the class name (for class name encoding of margin-ranking loss) | |
| if self.sep_between_neighbors and pivot_dist_fill==0: | |
| neighbor_lng_list.append(spatial_dist_fill) | |
| neighbor_lat_list.append(spatial_dist_fill) | |
| neighbor_token_list.append(sep_token_id) | |
| for neighbor_name, neighbor_geometry, rnd in zip(neighbor_name_list, neighbor_geometry_list, rand_entity[1:]): | |
| if not neighbor_name[0].isalpha(): | |
| # only consider neighbors starting with letters | |
| continue | |
| neighbor_token = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(neighbor_name)) | |
| neighbor_token_len = len(neighbor_token) | |
| # compute the relative distance from neighbor to pivot, | |
| # normalize the relative distance by distance_norm_factor | |
| # apply the calculated distance for all the subtokens of the neighbor | |
| # neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len) | |
| # neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len) | |
| if 'coordinates' in neighbor_geometry: # to handle different json dict structures | |
| neighbor_lng_list.extend([(neighbor_geometry['coordinates'][0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len) | |
| neighbor_lat_list.extend([(neighbor_geometry['coordinates'][1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len) | |
| neighbor_token_list.extend(neighbor_token) | |
| else: | |
| neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len) | |
| neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len) | |
| neighbor_token_list.extend(neighbor_token) | |
| if self.sep_between_neighbors: | |
| neighbor_lng_list.append(spatial_dist_fill) | |
| neighbor_lat_list.append(spatial_dist_fill) | |
| neighbor_token_list.append(sep_token_id) | |
| entity_mask_arr.extend([False]) | |
| if rnd < 0.15: | |
| #True: mask out, False: Keey original token | |
| entity_mask_arr.extend([True] * neighbor_token_len) | |
| else: | |
| entity_mask_arr.extend([False] * neighbor_token_len) | |
| pseudo_sentence = pivot_name_tokens + neighbor_token_list | |
| dist_lng_list = [pivot_dist_fill] * pivot_token_len + neighbor_lng_list | |
| dist_lat_list = [pivot_dist_fill] * pivot_token_len + neighbor_lat_list | |
| #including cls and sep | |
| sent_len = len(pseudo_sentence) | |
| max_token_len_middle = max_token_len -2 # 2 for CLS and SEP token | |
| # padding and truncation | |
| if sent_len > max_token_len_middle : | |
| pseudo_sentence = [cls_token_id] + pseudo_sentence[:max_token_len_middle] + [sep_token_id] | |
| dist_lat_list = [spatial_dist_fill] + dist_lat_list[:max_token_len_middle]+ [spatial_dist_fill] | |
| dist_lng_list = [spatial_dist_fill] + dist_lng_list[:max_token_len_middle]+ [spatial_dist_fill] | |
| attention_mask = [False] + [1] * max_token_len_middle + [False] # make sure SEP and CLS are not attented to | |
| else: | |
| pad_len = max_token_len_middle - sent_len | |
| assert pad_len >= 0 | |
| pseudo_sentence = [cls_token_id] + pseudo_sentence + [sep_token_id] + [pad_token_id] * pad_len | |
| dist_lat_list = [spatial_dist_fill] + dist_lat_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len | |
| dist_lng_list = [spatial_dist_fill] + dist_lng_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len | |
| attention_mask = [False] + [1] * sent_len + [0] * pad_len + [False] | |
| norm_lng_list = np.array(dist_lng_list) # / 0.0001 | |
| norm_lat_list = np.array(dist_lat_list) # / 0.0001 | |
| # mask entity in the pseudo sentence | |
| entity_mask_indices = np.where(entity_mask_arr)[0] | |
| masked_entity_input = [mask_token_id if i in entity_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)] | |
| # mask token in the pseudo sentence | |
| rand_token = np.random.uniform(size = len(pseudo_sentence)) | |
| # do not mask out cls and sep token. True: masked tokens False: Keey original token | |
| token_mask_arr = (rand_token <0.15) & (np.array(pseudo_sentence) != cls_token_id) & (np.array(pseudo_sentence) != sep_token_id) & (np.array(pseudo_sentence) != pad_token_id) | |
| token_mask_indices = np.where(token_mask_arr)[0] | |
| masked_token_input = [mask_token_id if i in token_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)] | |
| # yield masked_token with 50% prob, masked_entity with 50% prob | |
| if np.random.rand() > 0.5: | |
| masked_input = torch.tensor(masked_entity_input) | |
| else: | |
| masked_input = torch.tensor(masked_token_input) | |
| train_data = {} | |
| train_data['pivot_name'] = pivot_name | |
| train_data['pivot_token_len'] = pivot_token_len | |
| train_data['masked_input'] = masked_input | |
| train_data['sent_position_ids'] = torch.tensor(np.arange(0, len(pseudo_sentence))) | |
| train_data['attention_mask'] = torch.tensor(attention_mask) | |
| train_data['norm_lng_list'] = torch.tensor(norm_lng_list).to(torch.float32) | |
| train_data['norm_lat_list'] = torch.tensor(norm_lat_list).to(torch.float32) | |
| train_data['pseudo_sentence'] = torch.tensor(pseudo_sentence) | |
| return train_data | |
| def __len__(self): | |
| return NotImplementedError | |
| def __getitem__(self, index): | |
| raise NotImplementedError | |