Spaces:
Build error
Build error
| import comet.utils.utils as utils | |
| import comet.src.data.utils as data_utils | |
| import comet.src.data.config as cfg | |
| import pandas | |
| import json | |
| import random | |
| import math | |
| import torch | |
| from tqdm import tqdm | |
| def map_name(name): | |
| if name == "train": | |
| return "trn" | |
| elif name == "test": | |
| return "tst" | |
| else: | |
| return "dev" | |
| class DataLoader(object): | |
| def __init__(self, opt): | |
| self.data = {} | |
| self.data["train"] = {} | |
| self.data["dev"] = {} | |
| self.data["test"] = {} | |
| self.sequences = {} | |
| self.sequences["train"] = {} | |
| self.sequences["dev"] = {} | |
| self.sequences["test"] = {} | |
| self.masks = {} | |
| self.masks["train"] = {} | |
| self.masks["dev"] = {} | |
| self.masks["test"] = {} | |
| self.offsets = {} | |
| self.offsets["train"] = {} | |
| self.offsets["dev"] = {} | |
| self.offsets["test"] = {} | |
| def offset_summary(self, split): | |
| return self.offsets[split]["total"] | |
| def do_take_partial_dataset(data_opts): | |
| if data_opts.get("kr", None) is None: | |
| return False | |
| if data_opts.kr == 1: | |
| return False | |
| return True | |
| def select_partial_dataset(data_opts, data): | |
| num_selections = math.ceil(data_opts.kr * len(data)) | |
| return random.sample(data, num_selections) | |
| class GenerationDataLoader(DataLoader): | |
| def __init__(self, opt, categories): | |
| super(GenerationDataLoader, self).__init__(opt) | |
| self.categories = categories | |
| self.opt = opt | |
| for split in self.data: | |
| self.data[split] = {"total": []} | |
| self.offsets[split] = {"total": 0} | |
| self.vocab_encoder = None | |
| self.vocab_decoder = None | |
| self.special_chars = None | |
| self.max_event = None | |
| self.max_effect = None | |
| def load_data(self, path): | |
| if ".pickle" in path: | |
| print("Loading data from: {}".format(path)) | |
| data_utils.load_existing_data_loader(self, path) | |
| return True | |
| for split in self.data: | |
| file_name = "v4_atomic_{}.csv".format(map_name(split)) | |
| df = pandas.read_csv("{}/{}".format(path, file_name), index_col=0) | |
| df.iloc[:, :9] = df.iloc[:, :9].apply( | |
| lambda col: col.apply(json.loads)) | |
| for cat in self.categories: | |
| attr = df[cat] | |
| self.data[split]["total"] += utils.zipped_flatten(zip( | |
| attr.index, ["<{}>".format(cat)] * len(attr), attr.values)) | |
| if do_take_partial_dataset(self.opt.data): | |
| self.data["train"]["total"] = select_partial_dataset( | |
| self.opt.data, self.data["train"]["total"]) | |
| return False | |
| def make_tensors(self, text_encoder, special, | |
| splits=["train", "dev", "test"], test=False): | |
| self.vocab_encoder = text_encoder.encoder | |
| self.vocab_decoder = text_encoder.decoder | |
| self.special_chars = special | |
| sequences = {} | |
| for split in splits: | |
| sequences[split] = get_generation_sequences( | |
| self.opt, self.data, split, text_encoder, test) | |
| self.masks[split]["total"] = [(len(i[0]), len(i[1])) for | |
| i in sequences[split]] | |
| self.max_event = max([max([l[0] for l in self.masks[split]["total"]]) | |
| for split in self.masks]) | |
| self.max_effect = max([max([l[1] for l in self.masks[split]["total"]]) | |
| for split in self.masks]) | |
| print(self.max_event) | |
| print(self.max_effect) | |
| for split in splits: | |
| num_elements = len(sequences[split]) | |
| self.sequences[split]["total"] = torch.LongTensor( | |
| num_elements, self.max_event + self.max_effect).fill_(0) | |
| for i, seq in enumerate(sequences[split]): | |
| # print(self.sequences[split]["total"][i, :len(seq[0])].size()) | |
| # print(torch.FloatTensor(seq[0]).size()) | |
| self.sequences[split]["total"][i, :len(seq[0])] = \ | |
| torch.LongTensor(seq[0]) | |
| self.sequences[split]["total"][i, self.max_event:self.max_event + len(seq[1])] = \ | |
| torch.LongTensor(seq[1]) | |
| def sample_batch(self, split, bs, idxs=None): | |
| offset = self.offsets[split]["total"] | |
| batch = {} | |
| # Decided not to reduce computation on here because it's all parallel | |
| # anyway and we don't want to run out of memory in cases where we | |
| # don't see the longest version quickly enough | |
| if idxs: | |
| seqs = self.sequences[split]["total"].index_select( | |
| 0, torch.LongTensor(idxs).to( | |
| self.sequences[split]["total"].device)) | |
| else: | |
| seqs = self.sequences[split]["total"][offset:offset + bs] | |
| batch["sequences"] = seqs.to(cfg.device) | |
| batch["attention_mask"] = make_attention_mask(seqs) | |
| batch["loss_mask"] = make_loss_mask( | |
| seqs, self.max_event, 1) | |
| batch["key"] = ("total", offset, offset + bs) | |
| offset += seqs.size(0) | |
| self.offsets[split]["total"] = offset | |
| if split == "train" and offset + bs > len(self.sequences[split]["total"]): | |
| return batch, True | |
| elif offset >= len(self.sequences[split]["total"]): | |
| return batch, True | |
| else: | |
| return batch, False | |
| def reset_offsets(self, splits=["train", "test", "dev"], | |
| shuffle=True, keys=None): | |
| if isinstance(splits, str): | |
| splits = [splits] | |
| for split in splits: | |
| if keys is None: | |
| keys = ["total"] | |
| for key in keys: | |
| self.offsets[split][key] = 0 | |
| if shuffle: | |
| self.shuffle_sequences(split, keys) | |
| def shuffle_sequences(self, split="train", keys=None): | |
| if keys is None: | |
| # print(type(self.data)) | |
| # print(type(self.data.keys())) | |
| keys = self.data[split].keys() | |
| for key in keys: | |
| idxs = list(range(len(self.data[split][key]))) | |
| random.shuffle(idxs) | |
| self.sequences[split][key] = \ | |
| self.sequences[split][key].index_select( | |
| 0, torch.LongTensor(idxs)) | |
| temp = [self.data[split][key][i] for i in idxs] | |
| self.data[split][key] = temp | |
| temp = [self.masks[split][key][i] for i in idxs] | |
| self.masks[split][key] = temp | |
| def prune_data_for_evaluation(data_loader, categories, split): | |
| indices = [] | |
| for i, example in enumerate(data_loader.data[split]["total"]): | |
| if example[1] in categories: | |
| indices.append(i) | |
| data_loader.masks[split]["total"] = [data_loader.masks[split]["total"][i] | |
| for i in indices] | |
| data_loader.sequences[split]["total"] = \ | |
| data_loader.sequences[split]["total"].index_select( | |
| 0, torch.LongTensor(indices)) | |
| data_loader.data[split]["total"] = [data_loader.data[split]["total"][i] | |
| for i in indices] | |
| def make_attention_mask(sequences): | |
| return (sequences != 0).float().to(cfg.device) | |
| def make_loss_mask(sequences, max_event, num_delim_tokens): | |
| # print(num_delim_tokens) | |
| # print(sequences.size()) | |
| mask = (sequences != 0).float() | |
| mask[:, :max_event + num_delim_tokens] = 0 | |
| return mask[:, 1:].to(cfg.device) | |
| def find_underscore_length(seq): | |
| start = "_" | |
| while start in seq: | |
| start += "_" | |
| return start[:-1] | |
| def handle_underscores(suffix, text_encoder, prefix=False): | |
| encoder = text_encoder.encoder | |
| if prefix: | |
| tok = "___" | |
| else: | |
| tok = find_underscore_length(suffix) | |
| suffix_parts = [i.strip() for i in suffix.split("{}".format(tok))] | |
| to_flatten = [] | |
| for i, part in enumerate(suffix_parts): | |
| if part: | |
| to_flatten.append(text_encoder.encode([part], verbose=False)[0]) | |
| if i != len(suffix_parts) - 1 and suffix_parts[i+1]: | |
| to_flatten.append([encoder["<blank>"]]) | |
| else: | |
| to_flatten.append([encoder["<blank>"]]) | |
| final_suffix = utils.flatten(to_flatten) | |
| return final_suffix | |
| def get_generation_sequences(opt, data, split, text_encoder, test): | |
| sequences = [] | |
| count = 0 | |
| final_prefix = None | |
| final_suffix = None | |
| for prefix, category, suffix in tqdm(data[split]["total"]): | |
| final_prefix, final_suffix = do_example( | |
| text_encoder, prefix, suffix, True, True) | |
| # if do_prefix: | |
| # if "___" in prefix: | |
| # final_prefix = handle_underscores(prefix, text_encoder, True) | |
| # else: | |
| # final_prefix = text_encoder.encode([prefix], verbose=False)[0] | |
| # if do_suffix: | |
| # if "_" in suffix: | |
| # final_suffix = handle_underscores(suffix, text_encoder) | |
| # else: | |
| # final_suffix = text_encoder.encode([suffix], verbose=False)[0] | |
| final = compile_final_sequence( | |
| opt, final_prefix, final_suffix, category, text_encoder) | |
| sequences.append(final) | |
| count += 1 | |
| if count > 10 and test: | |
| break | |
| return sequences | |
| def do_example(text_encoder, prefix, suffix, do_prefix, do_suffix): | |
| final_prefix = None | |
| final_suffix = None | |
| if do_prefix: | |
| if "___" in prefix: | |
| final_prefix = handle_underscores(prefix, text_encoder, True) | |
| else: | |
| final_prefix = text_encoder.encode([prefix], verbose=False)[0] | |
| if do_suffix: | |
| if "_" in suffix: | |
| final_suffix = handle_underscores(suffix, text_encoder) | |
| else: | |
| final_suffix = text_encoder.encode([suffix], verbose=False)[0] | |
| return final_prefix, final_suffix | |
| def compile_final_sequence(opt, final_prefix, final_suffix, category, text_encoder): | |
| final = [] | |
| final.append(final_prefix) | |
| final.append( | |
| [text_encoder.encoder[category]] | |
| + final_suffix) | |
| final[-1].append(text_encoder.encoder["<END>"]) | |
| return final | |
| num_delimiter_tokens = { | |
| "category": 1, | |
| "hierarchy": 3, | |
| "hierarchy+label": 4, | |
| "category+hierarchy": 4, | |
| "category+hierarchy+label": 5 | |
| } | |