Spaces:
Sleeping
Sleeping
Commit
·
77180e4
1
Parent(s):
ff15dff
feat: remove mpi4py
Browse files- requirements.txt +0 -0
- src/improved_diffusion/dist_util.py +20 -21
- src/improved_diffusion/text_datasets.py +429 -273
requirements.txt
CHANGED
|
Binary files a/requirements.txt and b/requirements.txt differ
|
|
|
src/improved_diffusion/dist_util.py
CHANGED
|
@@ -8,7 +8,6 @@ import socket
|
|
| 8 |
|
| 9 |
import blobfile as bf
|
| 10 |
|
| 11 |
-
from mpi4py import MPI
|
| 12 |
import torch as th
|
| 13 |
import torch.distributed as dist
|
| 14 |
|
|
@@ -46,26 +45,26 @@ def setup_dist(rank, world_size, port="12145"):
|
|
| 46 |
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
| 47 |
|
| 48 |
|
| 49 |
-
def dev():
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def load_state_dict(path, **kwargs):
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
|
| 70 |
|
| 71 |
def sync_params(params):
|
|
|
|
| 8 |
|
| 9 |
import blobfile as bf
|
| 10 |
|
|
|
|
| 11 |
import torch as th
|
| 12 |
import torch.distributed as dist
|
| 13 |
|
|
|
|
| 45 |
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
| 46 |
|
| 47 |
|
| 48 |
+
# def dev():
|
| 49 |
+
# """
|
| 50 |
+
# Get the device to use for torch.distributed.
|
| 51 |
+
# """
|
| 52 |
+
# if th.cuda.is_available():
|
| 53 |
+
# return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
|
| 54 |
+
# return th.device("cpu")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# def load_state_dict(path, **kwargs):
|
| 58 |
+
# """
|
| 59 |
+
# Load a PyTorch file without redundant fetches across MPI ranks.
|
| 60 |
+
# """
|
| 61 |
+
# if MPI.COMM_WORLD.Get_rank() == 0:
|
| 62 |
+
# with bf.BlobFile(path, "rb") as f:
|
| 63 |
+
# data = f.read()
|
| 64 |
+
# else:
|
| 65 |
+
# data = None
|
| 66 |
+
# data = MPI.COMM_WORLD.bcast(data)
|
| 67 |
+
# return th.load(io.BytesIO(data), **kwargs)
|
| 68 |
|
| 69 |
|
| 70 |
def sync_params(params):
|
src/improved_diffusion/text_datasets.py
CHANGED
|
@@ -1,13 +1,21 @@
|
|
| 1 |
# from PIL import Image
|
| 2 |
# import blobfile as bf
|
| 3 |
-
from mpi4py import MPI
|
| 4 |
import numpy as np
|
| 5 |
from torch.utils.data import DataLoader, Dataset
|
| 6 |
-
from transformers import
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
# from datasets import load_dataset
|
| 9 |
import sys, os
|
| 10 |
import torch
|
|
|
|
| 11 |
# sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
|
| 12 |
# from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise
|
| 13 |
from collections import Counter, defaultdict
|
|
@@ -16,8 +24,18 @@ from itertools import chain
|
|
| 16 |
|
| 17 |
|
| 18 |
def load_data_text(
|
| 19 |
-
*,
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
):
|
| 22 |
"""
|
| 23 |
For a dataset, create a generator over (images, kwargs) pairs.
|
|
@@ -35,29 +53,34 @@ def load_data_text(
|
|
| 35 |
exception will be raised.
|
| 36 |
:param deterministic: if True, yield results in a deterministic order.
|
| 37 |
"""
|
| 38 |
-
print(
|
| 39 |
|
| 40 |
-
if data_args.experiment.startswith(
|
| 41 |
model = None
|
| 42 |
# elif data_args.experiment.startswith('random') and model is not None:
|
| 43 |
# print('loading initialized random embeddings. ')
|
| 44 |
|
| 45 |
-
if task_mode ==
|
| 46 |
pass
|
| 47 |
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
|
| 48 |
# padding_mode=padding_mode, split=split,
|
| 49 |
-
|
| 50 |
-
elif task_mode ==
|
| 51 |
pass
|
| 52 |
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
elif task_mode ==
|
| 57 |
-
print(
|
| 58 |
-
training_data, model = get_corpus_rocstory(
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# elif task_mode == 'yelp':
|
| 62 |
# print('hello loading yelp ')
|
| 63 |
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
|
|
@@ -80,8 +103,12 @@ def load_data_text(
|
|
| 80 |
# training_data, model = get_corpus_book(data_args, tokenizer, model, image_size,
|
| 81 |
# padding_mode=padding_mode, split=split,)
|
| 82 |
|
| 83 |
-
if
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
# training_data,
|
| 86 |
# image_size,
|
| 87 |
# data_args,
|
|
@@ -98,7 +125,7 @@ def load_data_text(
|
|
| 98 |
|
| 99 |
if deterministic:
|
| 100 |
|
| 101 |
-
pass# data_loader = DataLoader(
|
| 102 |
# dataset,
|
| 103 |
# batch_size=batch_size, # 20,
|
| 104 |
# drop_last=True,
|
|
@@ -117,64 +144,83 @@ def load_data_text(
|
|
| 117 |
while True:
|
| 118 |
yield from data_loader
|
| 119 |
|
|
|
|
| 120 |
def helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, seqlen, data_args):
|
| 121 |
result_train_lst = []
|
| 122 |
group_lst = defaultdict(list)
|
| 123 |
with torch.no_grad():
|
| 124 |
-
for
|
| 125 |
-
tokenized_ = [vocab_dict.get(x, vocab_dict[
|
| 126 |
-
tokenized_src = [vocab_dict.get(x, vocab_dict[
|
| 127 |
input_ids = [0] + tokenized_ + [1]
|
| 128 |
-
group_lst[
|
| 129 |
-
group_lst[
|
| 130 |
|
| 131 |
-
print(group_lst[
|
| 132 |
-
print(
|
| 133 |
max_length = seqlen
|
| 134 |
-
group_lst[
|
| 135 |
-
|
|
|
|
|
|
|
| 136 |
print(max_src_length, seqlen)
|
| 137 |
max_src_length = min(seqlen, max_src_length)
|
| 138 |
-
group_lst[
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
return_mask=True)
|
| 142 |
-
|
| 143 |
|
| 144 |
-
for input_ids, src_ids, src_mask in zip(
|
| 145 |
-
|
| 146 |
-
|
|
|
|
| 147 |
hidden_state = model(torch.tensor(input_ids))
|
| 148 |
-
elif data_args.experiment ==
|
| 149 |
input_ids2 = torch.tensor(input_ids).to(model.device)
|
| 150 |
input_embs = model.transformer.wte(input_ids2) # input_embs
|
| 151 |
hidden_state = model.down_proj(input_embs)
|
| 152 |
hidden_state = hidden_state * data_args.emb_scale_factor
|
| 153 |
-
result_train_lst.append(
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
return result_train_lst
|
| 160 |
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
import psutil
|
|
|
|
| 163 |
# Process.memory_info is expressed in bytes, so convert to megabytes
|
| 164 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
| 165 |
from datasets import Dataset as Dataset2
|
| 166 |
-
|
|
|
|
| 167 |
print(raw_datasets)
|
| 168 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
| 169 |
|
| 170 |
-
|
| 171 |
def tokenize_function(examples):
|
| 172 |
if isinstance(vocab_dict, dict):
|
| 173 |
-
input_ids = [
|
|
|
|
|
|
|
|
|
|
| 174 |
elif isinstance(vocab_dict, PreTrainedTokenizerFast):
|
| 175 |
-
examples[
|
| 176 |
-
input_ids = vocab_dict(examples[
|
| 177 |
-
|
|
|
|
|
|
|
| 178 |
# clm input could be much much longer than block_size
|
| 179 |
return result_dict
|
| 180 |
|
|
@@ -182,28 +228,30 @@ def helper_tokenize_stream(sentence_lst, vocab_dict, model, seqlen, data_args, p
|
|
| 182 |
tokenize_function,
|
| 183 |
batched=True,
|
| 184 |
num_proc=4,
|
| 185 |
-
remove_columns=[
|
| 186 |
load_from_cache_file=True,
|
| 187 |
desc="Running tokenizer on dataset",
|
| 188 |
)
|
| 189 |
print(tokenized_datasets)
|
| 190 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
| 191 |
|
| 192 |
-
if padding_mode ==
|
| 193 |
block_size = seqlen
|
|
|
|
| 194 |
def group_texts(examples):
|
| 195 |
-
concatenated_examples = {
|
|
|
|
|
|
|
| 196 |
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
| 197 |
if total_length >= block_size:
|
| 198 |
total_length = (total_length // block_size) * block_size
|
| 199 |
result = {
|
| 200 |
-
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
| 201 |
for k, t in concatenated_examples.items()
|
| 202 |
}
|
| 203 |
result["labels"] = result["input_ids"].copy()
|
| 204 |
return result
|
| 205 |
|
| 206 |
-
|
| 207 |
lm_datasets = tokenized_datasets.map(
|
| 208 |
group_texts,
|
| 209 |
batched=True,
|
|
@@ -212,12 +260,17 @@ def helper_tokenize_stream(sentence_lst, vocab_dict, model, seqlen, data_args, p
|
|
| 212 |
desc=f"Grouping texts in chunks of {block_size}",
|
| 213 |
)
|
| 214 |
else:
|
|
|
|
| 215 |
def pad_function(group_lst):
|
| 216 |
max_length = seqlen
|
| 217 |
if isinstance(vocab_dict, dict):
|
| 218 |
-
group_lst[
|
|
|
|
|
|
|
| 219 |
else:
|
| 220 |
-
group_lst[
|
|
|
|
|
|
|
| 221 |
return group_lst
|
| 222 |
|
| 223 |
# Process.memory_info is expressed in bytes, so convert to megabytes
|
|
@@ -230,59 +283,72 @@ def helper_tokenize_stream(sentence_lst, vocab_dict, model, seqlen, data_args, p
|
|
| 230 |
desc=f"padding",
|
| 231 |
)
|
| 232 |
|
| 233 |
-
|
| 234 |
-
print(lm_datasets, 'padded dataset')
|
| 235 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
| 236 |
import datasets
|
|
|
|
| 237 |
raw_datasets = datasets.DatasetDict()
|
| 238 |
-
raw_datasets[
|
| 239 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
| 240 |
return raw_datasets
|
| 241 |
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
result_train_lst = []
|
| 244 |
group_lst = defaultdict(list)
|
| 245 |
with torch.no_grad():
|
| 246 |
for input_ids in sentence_lst:
|
| 247 |
-
tokenized_ = [vocab_dict.get(x, vocab_dict[
|
| 248 |
input_ids = [0] + tokenized_ + [1]
|
| 249 |
-
group_lst[
|
| 250 |
-
print(group_lst[
|
| 251 |
|
| 252 |
-
if padding_mode ==
|
| 253 |
-
print(
|
| 254 |
concatenated_examples = {k: sum(group_lst[k], []) for k in group_lst.keys()}
|
| 255 |
total_length = len(concatenated_examples[list(group_lst.keys())[0]])
|
| 256 |
block_size = seqlen
|
| 257 |
total_length = (total_length // block_size) * block_size
|
| 258 |
# Split by chunks of max_len.
|
| 259 |
group_lst = {
|
| 260 |
-
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
| 261 |
for k, t in concatenated_examples.items()
|
| 262 |
}
|
| 263 |
-
elif padding_mode ==
|
| 264 |
-
print(
|
| 265 |
max_length = seqlen
|
| 266 |
-
group_lst[
|
|
|
|
|
|
|
| 267 |
|
| 268 |
-
for input_ids in group_lst[
|
| 269 |
-
if data_args.experiment.startswith(
|
| 270 |
hidden_state = model(torch.tensor(input_ids))
|
| 271 |
-
elif data_args.experiment ==
|
| 272 |
input_ids2 = torch.tensor(input_ids).to(model.device)
|
| 273 |
input_embs = model.transformer.wte(input_ids2) # input_embs
|
| 274 |
hidden_state = model.down_proj(input_embs)
|
| 275 |
hidden_state = hidden_state * data_args.emb_scale_factor
|
| 276 |
-
elif data_args.experiment ==
|
| 277 |
hidden_state = model(torch.tensor(input_ids))
|
| 278 |
-
result_train_lst.append(
|
|
|
|
|
|
|
| 279 |
|
| 280 |
return result_train_lst
|
| 281 |
|
|
|
|
| 282 |
def load_glove_model(File):
|
| 283 |
print("Loading Glove Model")
|
| 284 |
glove_model = {}
|
| 285 |
-
with open(File,
|
| 286 |
for line in f:
|
| 287 |
split_line = line.split()
|
| 288 |
word = split_line[0]
|
|
@@ -292,9 +358,10 @@ def load_glove_model(File):
|
|
| 292 |
print(f"{len(glove_model)} words loaded!")
|
| 293 |
return glove_model
|
| 294 |
|
|
|
|
| 295 |
def load_glove(vocab):
|
| 296 |
model = torch.nn.Embedding(len(vocab), 50)
|
| 297 |
-
glove_model = load_glove_model(
|
| 298 |
array_lst = []
|
| 299 |
count_ = 0
|
| 300 |
for word, idx in vocab.items():
|
|
@@ -303,20 +370,21 @@ def load_glove(vocab):
|
|
| 303 |
else:
|
| 304 |
count_ += 1
|
| 305 |
array_lst.append(torch.randn(50))
|
| 306 |
-
print(f
|
| 307 |
array_lst = torch.stack(array_lst)
|
| 308 |
print(torch.norm(array_lst, dim=-1).mean())
|
| 309 |
model.weight.data = array_lst
|
| 310 |
return model
|
| 311 |
|
| 312 |
|
| 313 |
-
def get_corpus_rocstory(
|
| 314 |
-
|
|
|
|
| 315 |
import csv, torch, json
|
| 316 |
from spacy.lang.en import English
|
| 317 |
|
| 318 |
-
if data_args.experiment_mode ==
|
| 319 |
-
if data_args.modality ==
|
| 320 |
pass
|
| 321 |
# print('loading dataset from ROCStory')
|
| 322 |
# nlp = English()
|
|
@@ -347,7 +415,7 @@ def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
|
|
| 347 |
# # sentence_lst.append(word_lst)
|
| 348 |
# # sentence_lst = sentence_lst[1:]
|
| 349 |
# print(sentence_lst[:2])
|
| 350 |
-
if data_args.modality ==
|
| 351 |
pass
|
| 352 |
# print('loading dataset from ROCStory')
|
| 353 |
# nlp = English()
|
|
@@ -381,7 +449,7 @@ def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
|
|
| 381 |
# word_lst = [x.text for x in tokenizer(sentences)]
|
| 382 |
# sentence_lst.append(word_lst)
|
| 383 |
# print(sentence_lst[:2],sentence_lst[-2:], 'dataset size=',len(sentence_lst))
|
| 384 |
-
elif data_args.modality ==
|
| 385 |
pass
|
| 386 |
# print('loading dataset from simple wikipedia')
|
| 387 |
# sentence_lst = []
|
|
@@ -390,57 +458,62 @@ def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
|
|
| 390 |
# word_lst = row.lower().split()
|
| 391 |
# sentence_lst.append(word_lst)
|
| 392 |
# print(sentence_lst[:2])
|
| 393 |
-
elif data_args.modality ==
|
| 394 |
-
print(
|
| 395 |
sentence_lst = []
|
| 396 |
nlp = English()
|
| 397 |
tokenizer = nlp.tokenizer
|
| 398 |
-
if split ==
|
| 399 |
-
print(
|
| 400 |
-
path =
|
|
|
|
|
|
|
| 401 |
# path = f'../{data_args.e2e_train}/src1_train.txt'
|
| 402 |
-
elif split ==
|
| 403 |
-
print(
|
| 404 |
-
path = f
|
| 405 |
-
path =
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
|
|
|
|
|
|
| 412 |
path = data_args.debug_path
|
| 413 |
import json
|
| 414 |
-
|
|
|
|
| 415 |
for line in ff:
|
| 416 |
-
sentence_lst.append(json.loads(line)[0].split(
|
| 417 |
sentence_lst = sentence_lst + sentence_lst
|
| 418 |
-
if split in [
|
| 419 |
-
with open(path,
|
| 420 |
for row in ff:
|
| 421 |
-
word_lst = row.split(
|
| 422 |
word_lst = [x.text for x in tokenizer(word_lst)]
|
| 423 |
sentence_lst.append(word_lst)
|
| 424 |
print(sentence_lst[:2])
|
| 425 |
|
| 426 |
-
elif data_args.modality ==
|
| 427 |
-
print(
|
| 428 |
sentence_lst = []
|
| 429 |
nlp = English()
|
| 430 |
tokenizer = nlp.tokenizer
|
| 431 |
-
if split ==
|
| 432 |
-
print(
|
| 433 |
-
path = f
|
| 434 |
-
elif split ==
|
| 435 |
-
print(
|
| 436 |
-
path = f
|
| 437 |
-
elif split ==
|
| 438 |
-
print(
|
| 439 |
-
path = f
|
| 440 |
-
if split in [
|
| 441 |
-
|
| 442 |
-
with open(path,
|
| 443 |
-
yelp_reader = csv.reader(csvfile)
|
| 444 |
for row in yelp_reader:
|
| 445 |
sentences = row[1]
|
| 446 |
word_lst = [x.text for x in tokenizer(sentences)]
|
|
@@ -448,175 +521,188 @@ def get_corpus_rocstory(data_args, model, image_size, padding_mode='block',
|
|
| 448 |
sentence_lst = sentence_lst[1:]
|
| 449 |
print(sentence_lst[:2])
|
| 450 |
|
| 451 |
-
elif data_args.modality ==
|
| 452 |
-
print(
|
| 453 |
sentence_lst = []
|
| 454 |
nlp = English()
|
| 455 |
tokenizer = nlp.tokenizer
|
| 456 |
-
if split ==
|
| 457 |
-
print(
|
| 458 |
-
path = f
|
| 459 |
-
elif split ==
|
| 460 |
-
print(
|
| 461 |
-
path = f
|
| 462 |
-
elif split ==
|
| 463 |
-
print(
|
| 464 |
-
path = f
|
| 465 |
-
if split in [
|
| 466 |
-
with open(path,
|
| 467 |
for line in ff:
|
| 468 |
line = json.loads(line)
|
| 469 |
-
for sentences in line[
|
| 470 |
word_lst = [x.text for x in tokenizer(sentences)]
|
| 471 |
sentence_lst.append(word_lst)
|
| 472 |
print(sentence_lst[:2])
|
| 473 |
|
| 474 |
-
elif data_args.modality ==
|
| 475 |
-
print(
|
| 476 |
sentence_lst = []
|
| 477 |
nlp = English()
|
| 478 |
tokenizer = nlp.tokenizer
|
| 479 |
-
if split ==
|
| 480 |
-
print(
|
| 481 |
-
path = f
|
| 482 |
-
path_lst = [f
|
| 483 |
-
path_lst.append(
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
|
|
|
|
|
|
| 487 |
path_lst = []
|
| 488 |
-
elif split ==
|
| 489 |
-
print(
|
| 490 |
-
path = f
|
| 491 |
path_lst = []
|
| 492 |
|
| 493 |
-
if split in [
|
| 494 |
-
with open(path,
|
| 495 |
for line in ff:
|
| 496 |
line = json.loads(line)
|
| 497 |
-
for sentences in line[
|
| 498 |
word_lst = [x.text for x in tokenizer(sentences)]
|
| 499 |
sentence_lst.append(word_lst)
|
| 500 |
print(sentence_lst[:2])
|
| 501 |
import itertools
|
|
|
|
| 502 |
for path in path_lst:
|
| 503 |
-
if path.endswith(
|
| 504 |
-
with open(path,
|
| 505 |
for row in roc_reader:
|
| 506 |
sentences = row.strip()
|
| 507 |
word_lst = [x.text for x in tokenizer(sentences)]
|
| 508 |
spl = [[]]
|
| 509 |
-
for x, y in itertools.groupby(word_lst, lambda z: z ==
|
| 510 |
spl[-1].extend(y)
|
| 511 |
-
if x:
|
|
|
|
| 512 |
sentence_lst.extend(spl[:-1])
|
| 513 |
else:
|
| 514 |
-
with open(path,
|
| 515 |
for row in roc_reader:
|
| 516 |
sentences = json.loads(row)[0].strip()
|
| 517 |
word_lst = [x.text for x in tokenizer(sentences)]
|
| 518 |
spl = [[]]
|
| 519 |
-
for x, y in itertools.groupby(word_lst, lambda z: z ==
|
| 520 |
spl[-1].extend(y)
|
| 521 |
-
if x:
|
|
|
|
| 522 |
sentence_lst.extend(spl[:-1])
|
| 523 |
|
| 524 |
print(sentence_lst[-2:])
|
| 525 |
|
| 526 |
-
|
| 527 |
# get tokenizer.
|
| 528 |
if load_vocab is None:
|
| 529 |
counter = Counter()
|
| 530 |
for input_ids in sentence_lst:
|
| 531 |
counter.update(input_ids)
|
| 532 |
|
| 533 |
-
if data_args.experiment_mode ==
|
| 534 |
-
if data_args.modality ==
|
| 535 |
-
print(
|
| 536 |
sentence_lst = []
|
| 537 |
nlp = English()
|
| 538 |
tokenizer = nlp.tokenizer
|
| 539 |
-
if split ==
|
| 540 |
-
path = f
|
| 541 |
-
with open(path,
|
| 542 |
for row in ff:
|
| 543 |
-
src_lst, word_lst = row.split(
|
| 544 |
word_lst = [x.text for x in tokenizer(word_lst)]
|
| 545 |
src_lst = [x.text for x in tokenizer(src_lst)]
|
| 546 |
sentence_lst.append((src_lst, word_lst))
|
| 547 |
-
elif split ==
|
| 548 |
-
path = f
|
| 549 |
sentence_lst = read_e2e_files(path, data_args, tokenizer)
|
| 550 |
print(sentence_lst[:2])
|
| 551 |
# get tokenizer.
|
| 552 |
if load_vocab is None:
|
| 553 |
counter = Counter()
|
| 554 |
-
for
|
| 555 |
counter.update(input_ids)
|
| 556 |
counter.update(src_ids)
|
| 557 |
|
| 558 |
if load_vocab is None:
|
| 559 |
-
vocab_dict = {
|
| 560 |
for k, v in counter.items():
|
| 561 |
if v > 10:
|
| 562 |
vocab_dict[k] = len(vocab_dict)
|
| 563 |
print(len(counter), len(vocab_dict))
|
| 564 |
|
| 565 |
-
path_save_vocab =
|
| 566 |
-
print(f
|
| 567 |
-
with open(path_save_vocab,
|
| 568 |
json.dump(vocab_dict, f)
|
| 569 |
else:
|
| 570 |
vocab_dict = load_vocab
|
| 571 |
-
path_save_vocab =
|
| 572 |
if not os.path.exists(path_save_vocab):
|
| 573 |
-
print(f
|
| 574 |
if isinstance(vocab_dict, dict):
|
| 575 |
-
with open(path_save_vocab,
|
| 576 |
json.dump(vocab_dict, f)
|
| 577 |
-
assert vocab_dict[
|
| 578 |
elif isinstance(vocab_dict, PreTrainedTokenizerFast):
|
| 579 |
vocab_dict.save_pretrained(data_args.checkpoint_path)
|
| 580 |
else:
|
| 581 |
assert False, "invalid type of vocab_dict"
|
| 582 |
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
if model is None and data_args.experiment == 'random':
|
| 586 |
model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel)
|
| 587 |
-
print(
|
| 588 |
torch.nn.init.normal_(model.weight)
|
| 589 |
-
path_save =
|
| 590 |
-
print(
|
|
|
|
|
|
|
| 591 |
torch.save(model.state_dict(), path_save)
|
| 592 |
|
| 593 |
# path_save = f'{data_args.checkpoint_path}/random_emb.torch'
|
| 594 |
# if not os.path.exists(path_save) and data_args.experiment == 'random':
|
| 595 |
# torch.save(model.state_dict(), path_save)
|
| 596 |
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
return train_dataset, model
|
| 602 |
-
elif data_args.experiment_mode ==
|
| 603 |
-
result_train_lst = helper_tokenize_encode(
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
|
| 609 |
def write_e2e_corr(prompt_lst, file_dict, corr_path):
|
| 610 |
print(len(prompt_lst))
|
| 611 |
-
with open(corr_path,
|
| 612 |
for x in prompt_lst:
|
| 613 |
for line in file_dict[x]:
|
| 614 |
print(" ".join(line), file=f)
|
| 615 |
-
print(
|
| 616 |
|
| 617 |
|
| 618 |
def write_e2e_src(prompt_lst, corr_path):
|
| 619 |
-
with open(corr_path,
|
| 620 |
for x in prompt_lst:
|
| 621 |
print(" ".join(x), file=f)
|
| 622 |
return
|
|
@@ -624,48 +710,55 @@ def write_e2e_src(prompt_lst, corr_path):
|
|
| 624 |
|
| 625 |
def read_e2e_files(path, args, tokenizer):
|
| 626 |
file_dict = {}
|
| 627 |
-
with open(path,
|
| 628 |
for line in f:
|
| 629 |
-
src_lst, word_lst = line.strip().split(
|
| 630 |
tgt = tuple([x.text for x in tokenizer(word_lst)])
|
| 631 |
src = tuple([x.text for x in tokenizer(src_lst)])
|
| 632 |
if src not in file_dict:
|
| 633 |
file_dict[src] = []
|
| 634 |
file_dict[src].append(tgt)
|
| 635 |
-
temp =
|
| 636 |
prompt_text_dict = file_dict
|
| 637 |
prompt_text_lst = list(prompt_text_dict.keys())
|
| 638 |
-
gold_dir = os.path.join(args.out_dir,
|
| 639 |
print("gold dir", gold_dir)
|
| 640 |
write_e2e_corr(prompt_text_lst, prompt_text_dict, gold_dir)
|
| 641 |
-
src_dir = os.path.join(args.out_dir,
|
| 642 |
write_e2e_src(prompt_text_lst, src_dir)
|
| 643 |
final_lst = [(xx, prompt_text_dict[xx][0]) for xx in prompt_text_lst]
|
| 644 |
return final_lst
|
| 645 |
|
| 646 |
|
| 647 |
-
def get_corpus_book(
|
| 648 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
import os
|
| 650 |
-
|
| 651 |
-
|
|
|
|
| 652 |
if "validation" not in raw_datasets.keys():
|
| 653 |
raw_datasets["validation"] = load_dataset(
|
| 654 |
-
|
| 655 |
split=f"train[:1%]",
|
| 656 |
)
|
| 657 |
raw_datasets["train"] = load_dataset(
|
| 658 |
-
|
| 659 |
split=f"train[1%:]",
|
| 660 |
)
|
| 661 |
print(raw_datasets)
|
| 662 |
column_names = raw_datasets["train"].column_names
|
| 663 |
|
| 664 |
def tokenize_function(examples):
|
| 665 |
-
output = tokenizer(examples[
|
| 666 |
return output
|
| 667 |
|
| 668 |
-
|
| 669 |
tokenized_datasets = raw_datasets.map(
|
| 670 |
tokenize_function,
|
| 671 |
batched=True,
|
|
@@ -686,7 +779,7 @@ def get_corpus_book(data_args, tokenizer, model, image_size, padding_mode='block
|
|
| 686 |
if total_length >= block_size:
|
| 687 |
total_length = (total_length // block_size) * block_size
|
| 688 |
result = {
|
| 689 |
-
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
| 690 |
for k, t in concatenated_examples.items()
|
| 691 |
}
|
| 692 |
return result
|
|
@@ -702,32 +795,44 @@ def get_corpus_book(data_args, tokenizer, model, image_size, padding_mode='block
|
|
| 702 |
print(lm_datasets)
|
| 703 |
|
| 704 |
if model is None:
|
| 705 |
-
if data_args.training_mode.startswith(
|
| 706 |
-
print(
|
| 707 |
model = torch.nn.Embedding(len(tokenizer), 1)
|
| 708 |
else:
|
| 709 |
model = torch.nn.Embedding(len(tokenizer), data_args.in_channel)
|
| 710 |
-
print(
|
| 711 |
torch.nn.init.normal_(model.weight)
|
| 712 |
-
path_save = f
|
| 713 |
-
print(
|
|
|
|
|
|
|
| 714 |
torch.save(model.state_dict(), path_save)
|
| 715 |
|
| 716 |
-
if split ==
|
| 717 |
return lm_datasets, model
|
| 718 |
else:
|
| 719 |
-
lm_datasets[
|
| 720 |
return lm_datasets, model
|
| 721 |
|
| 722 |
|
| 723 |
class TextDataset(Dataset):
|
| 724 |
-
def __init__(
|
| 725 |
-
|
| 726 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 727 |
super().__init__()
|
| 728 |
self.resolution = resolution
|
| 729 |
self.text_datasets = text_datasets
|
| 730 |
-
self.length = len(self.text_datasets[
|
| 731 |
self.model_arch = model_arch
|
| 732 |
self.data_args = data_args
|
| 733 |
print(self.resolution)
|
|
@@ -745,8 +850,8 @@ class TextDataset(Dataset):
|
|
| 745 |
# We are not on a new enough PIL to support the `reducing_gap`
|
| 746 |
# argument, which uses BOX downsampling at powers of two first.
|
| 747 |
# Thus, we do it by hand to improve downsample quality.
|
| 748 |
-
if self.model_arch ==
|
| 749 |
-
pass# arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
|
| 750 |
# dtype=np.float32).reshape(self.resolution, self.resolution, -1)
|
| 751 |
# # print(self.eigen_transform.shape)
|
| 752 |
# if self.eigen_transform is not None:
|
|
@@ -757,15 +862,14 @@ class TextDataset(Dataset):
|
|
| 757 |
# if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
|
| 758 |
# arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
|
| 759 |
|
| 760 |
-
|
| 761 |
# out_dict = {}
|
| 762 |
# out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
|
| 763 |
# # if self.local_classes is not None:
|
| 764 |
# # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
| 765 |
# # print(out_dict.keys())
|
| 766 |
# return np.transpose(arr, [2, 0, 1]), out_dict
|
| 767 |
-
elif self.model_arch ==
|
| 768 |
-
pass# arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
|
| 769 |
# dtype=np.float32) # seqlen, dim
|
| 770 |
# if self.eigen_transform is not None:
|
| 771 |
# old_shape = arr.shape
|
|
@@ -783,27 +887,39 @@ class TextDataset(Dataset):
|
|
| 783 |
# # print(arr.shape)
|
| 784 |
# return arr, out_dict
|
| 785 |
else:
|
| 786 |
-
arr = np.array(
|
| 787 |
-
|
| 788 |
-
|
|
|
|
| 789 |
old_shape = arr.shape
|
| 790 |
# arr = arr.reshape(1, -1) @ self.eigen_transform
|
| 791 |
-
arr = arr.reshape(1, -1) - self.eigen_transform[
|
| 792 |
-
arr = arr @ self.eigen_transform[
|
| 793 |
arr = arr.reshape(old_shape)
|
| 794 |
-
|
| 795 |
-
if
|
|
|
|
|
|
|
|
|
|
| 796 |
# print(arr.dtype)
|
| 797 |
# print(self.data_args.noise_level, 'using the noise level.')
|
| 798 |
-
arr = arr + self.data_args.noise_level * np.random.randn(
|
|
|
|
|
|
|
| 799 |
# print(arr.dtype)
|
| 800 |
|
| 801 |
out_dict = {}
|
| 802 |
-
out_dict[
|
|
|
|
|
|
|
| 803 |
# out_dict['mapping_func'] = self.mapping_func
|
| 804 |
-
if self.data_args.experiment_mode ==
|
| 805 |
-
out_dict[
|
| 806 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
# if self.local_classes is not None:
|
| 808 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
| 809 |
return arr, out_dict
|
|
@@ -813,13 +929,23 @@ class TextDataset(Dataset):
|
|
| 813 |
|
| 814 |
|
| 815 |
class TextDataset_NoCache(Dataset):
|
| 816 |
-
def __init__(
|
| 817 |
-
|
| 818 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 819 |
super().__init__()
|
| 820 |
self.resolution = resolution
|
| 821 |
self.text_datasets = text_datasets
|
| 822 |
-
self.length = len(self.text_datasets[
|
| 823 |
self.model_arch = model_arch
|
| 824 |
self.data_args = data_args
|
| 825 |
print(self.resolution)
|
|
@@ -838,81 +964,110 @@ class TextDataset_NoCache(Dataset):
|
|
| 838 |
# argument, which uses BOX downsampling at powers of two first.
|
| 839 |
# Thus, we do it by hand to improve downsample quality.
|
| 840 |
with torch.no_grad():
|
| 841 |
-
input_ids = self.text_datasets[
|
| 842 |
model = self.model_emb
|
| 843 |
-
if self.data_args.experiment.startswith(
|
| 844 |
hidden_state = model(torch.tensor(input_ids))
|
| 845 |
-
elif self.data_args.experiment ==
|
| 846 |
input_ids2 = torch.tensor(input_ids).to(model.device)
|
| 847 |
input_embs = model.transformer.wte(input_ids2) # input_embs
|
| 848 |
hidden_state = model.down_proj(input_embs)
|
| 849 |
hidden_state = hidden_state * data_args.emb_scale_factor
|
| 850 |
|
| 851 |
-
if self.model_arch ==
|
| 852 |
-
arr = np.array(hidden_state,
|
| 853 |
-
|
|
|
|
| 854 |
# print(self.eigen_transform.shape)
|
| 855 |
if self.eigen_transform is not None:
|
| 856 |
old_shape = arr.shape
|
| 857 |
-
arr = arr.reshape(1, -1) - self.eigen_transform[
|
| 858 |
-
arr = arr @ self.eigen_transform[
|
| 859 |
arr = arr.reshape(old_shape)
|
| 860 |
-
if
|
| 861 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 862 |
|
| 863 |
out_dict = {}
|
| 864 |
-
out_dict[
|
|
|
|
|
|
|
| 865 |
# if self.local_classes is not None:
|
| 866 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
| 867 |
# print(out_dict.keys())
|
| 868 |
return np.transpose(arr, [2, 0, 1]), out_dict
|
| 869 |
-
elif self.model_arch ==
|
| 870 |
-
arr = np.array(hidden_state,
|
| 871 |
-
dtype=np.float32) # seqlen, dim
|
| 872 |
if self.eigen_transform is not None:
|
| 873 |
old_shape = arr.shape
|
| 874 |
-
arr = arr.reshape(1, -1) - self.eigen_transform[
|
| 875 |
-
arr = arr @ self.eigen_transform[
|
| 876 |
arr = arr.reshape(old_shape)
|
| 877 |
-
if
|
| 878 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 879 |
arr = np.transpose(arr, [1, 0])
|
| 880 |
out_dict = {}
|
| 881 |
-
out_dict[
|
|
|
|
|
|
|
| 882 |
# out_dict['mapping_func'] = self.mapping_func
|
| 883 |
# if self.local_classes is not None:
|
| 884 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
| 885 |
# print(arr.shape)
|
| 886 |
return arr, out_dict
|
| 887 |
else:
|
| 888 |
-
arr = np.array(hidden_state,
|
| 889 |
-
dtype=np.float32)
|
| 890 |
if self.eigen_transform is not None:
|
| 891 |
old_shape = arr.shape
|
| 892 |
# arr = arr.reshape(1, -1) @ self.eigen_transform
|
| 893 |
-
arr = arr.reshape(1, -1) - self.eigen_transform[
|
| 894 |
-
arr = arr @ self.eigen_transform[
|
| 895 |
arr = arr.reshape(old_shape)
|
| 896 |
|
| 897 |
-
if
|
|
|
|
|
|
|
|
|
|
| 898 |
# print(arr.dtype)
|
| 899 |
# print(self.data_args.noise_level, 'using the noise level.')
|
| 900 |
-
arr = arr + self.data_args.noise_level * np.random.randn(
|
|
|
|
|
|
|
| 901 |
# print(arr.dtype)
|
| 902 |
|
| 903 |
out_dict = {}
|
| 904 |
-
out_dict[
|
|
|
|
|
|
|
| 905 |
# out_dict['mapping_func'] = self.mapping_func
|
| 906 |
-
if self.data_args.experiment_mode ==
|
| 907 |
-
out_dict[
|
| 908 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 909 |
# if self.local_classes is not None:
|
| 910 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
| 911 |
return arr, out_dict
|
| 912 |
|
|
|
|
| 913 |
def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False):
|
| 914 |
-
result = torch.full(
|
| 915 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 916 |
for i, example in enumerate(examples):
|
| 917 |
curr_len = min(len(example), max_length)
|
| 918 |
result[i][:curr_len] = example[:curr_len]
|
|
@@ -921,6 +1076,7 @@ def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False)
|
|
| 921 |
return result, mask_
|
| 922 |
return result
|
| 923 |
|
|
|
|
| 924 |
def _torch_collate_batch(examples, pad_token_id, max_length):
|
| 925 |
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
| 926 |
import numpy as np
|
|
@@ -945,4 +1101,4 @@ def _torch_collate_batch(examples, pad_token_id, max_length):
|
|
| 945 |
result[i, : example.shape[0]] = example
|
| 946 |
else:
|
| 947 |
result[i, -example.shape[0] :] = example
|
| 948 |
-
return result
|
|
|
|
| 1 |
# from PIL import Image
|
| 2 |
# import blobfile as bf
|
| 3 |
+
# from mpi4py import MPI
|
| 4 |
import numpy as np
|
| 5 |
from torch.utils.data import DataLoader, Dataset
|
| 6 |
+
from transformers import (
|
| 7 |
+
AutoModelForCausalLM,
|
| 8 |
+
AutoConfig,
|
| 9 |
+
AutoTokenizer,
|
| 10 |
+
default_data_collator,
|
| 11 |
+
PreTrainedTokenizerFast,
|
| 12 |
+
PreTrainedTokenizer,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
# from datasets import load_dataset
|
| 16 |
import sys, os
|
| 17 |
import torch
|
| 18 |
+
|
| 19 |
# sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
|
| 20 |
# from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise
|
| 21 |
from collections import Counter, defaultdict
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def load_data_text(
|
| 27 |
+
*,
|
| 28 |
+
data_dir,
|
| 29 |
+
batch_size,
|
| 30 |
+
image_size,
|
| 31 |
+
class_cond=False,
|
| 32 |
+
deterministic=False,
|
| 33 |
+
data_args=None,
|
| 34 |
+
task_mode="roc",
|
| 35 |
+
model=None,
|
| 36 |
+
padding_mode="block",
|
| 37 |
+
split="train",
|
| 38 |
+
load_vocab=None,
|
| 39 |
):
|
| 40 |
"""
|
| 41 |
For a dataset, create a generator over (images, kwargs) pairs.
|
|
|
|
| 53 |
exception will be raised.
|
| 54 |
:param deterministic: if True, yield results in a deterministic order.
|
| 55 |
"""
|
| 56 |
+
print("hello loading text data. ")
|
| 57 |
|
| 58 |
+
if data_args.experiment.startswith("random") and model is None:
|
| 59 |
model = None
|
| 60 |
# elif data_args.experiment.startswith('random') and model is not None:
|
| 61 |
# print('loading initialized random embeddings. ')
|
| 62 |
|
| 63 |
+
if task_mode == "roc" or task_mode == "roc-aug":
|
| 64 |
pass
|
| 65 |
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
|
| 66 |
# padding_mode=padding_mode, split=split,
|
| 67 |
+
# load_vocab=load_vocab)
|
| 68 |
+
elif task_mode == "simple-wiki":
|
| 69 |
pass
|
| 70 |
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
|
| 71 |
+
# padding_mode=padding_mode, split=split,
|
| 72 |
+
# load_vocab=load_vocab)
|
| 73 |
+
|
| 74 |
+
elif task_mode == "e2e-tgt":
|
| 75 |
+
print("hello loading e2e-tgt. ")
|
| 76 |
+
training_data, model = get_corpus_rocstory(
|
| 77 |
+
data_args,
|
| 78 |
+
model,
|
| 79 |
+
image_size,
|
| 80 |
+
padding_mode=padding_mode,
|
| 81 |
+
split=split,
|
| 82 |
+
load_vocab=load_vocab,
|
| 83 |
+
)
|
| 84 |
# elif task_mode == 'yelp':
|
| 85 |
# print('hello loading yelp ')
|
| 86 |
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
|
|
|
|
| 103 |
# training_data, model = get_corpus_book(data_args, tokenizer, model, image_size,
|
| 104 |
# padding_mode=padding_mode, split=split,)
|
| 105 |
|
| 106 |
+
if (
|
| 107 |
+
data_args.modality
|
| 108 |
+
in ["roc-aug", "roc", "book", "yelp", "commonGen", "commonGen-aug"]
|
| 109 |
+
and data_args.cache_mode == "no"
|
| 110 |
+
):
|
| 111 |
+
pass # dataset = TextDataset_NoCache(
|
| 112 |
# training_data,
|
| 113 |
# image_size,
|
| 114 |
# data_args,
|
|
|
|
| 125 |
|
| 126 |
if deterministic:
|
| 127 |
|
| 128 |
+
pass # data_loader = DataLoader(
|
| 129 |
# dataset,
|
| 130 |
# batch_size=batch_size, # 20,
|
| 131 |
# drop_last=True,
|
|
|
|
| 144 |
while True:
|
| 145 |
yield from data_loader
|
| 146 |
|
| 147 |
+
|
| 148 |
def helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, seqlen, data_args):
|
| 149 |
result_train_lst = []
|
| 150 |
group_lst = defaultdict(list)
|
| 151 |
with torch.no_grad():
|
| 152 |
+
for src_ids, input_ids in sentence_lst:
|
| 153 |
+
tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids]
|
| 154 |
+
tokenized_src = [vocab_dict.get(x, vocab_dict["UNK"]) for x in src_ids]
|
| 155 |
input_ids = [0] + tokenized_ + [1]
|
| 156 |
+
group_lst["word_ids"].append(input_ids)
|
| 157 |
+
group_lst["src_ids"].append(tokenized_src)
|
| 158 |
|
| 159 |
+
print(group_lst["word_ids"][:2])
|
| 160 |
+
print("padding mode is pad")
|
| 161 |
max_length = seqlen
|
| 162 |
+
group_lst["word_ids"] = _collate_batch_helper(
|
| 163 |
+
group_lst["word_ids"], vocab_dict["PAD"], max_length
|
| 164 |
+
)
|
| 165 |
+
max_src_length = max([len(xx) for xx in group_lst["src_ids"]])
|
| 166 |
print(max_src_length, seqlen)
|
| 167 |
max_src_length = min(seqlen, max_src_length)
|
| 168 |
+
group_lst["src_ids"], group_lst["src_mask"] = _collate_batch_helper(
|
| 169 |
+
group_lst["src_ids"], vocab_dict["PAD"], max_src_length, return_mask=True
|
| 170 |
+
)
|
|
|
|
|
|
|
| 171 |
|
| 172 |
+
for input_ids, src_ids, src_mask in zip(
|
| 173 |
+
group_lst["word_ids"], group_lst["src_ids"], group_lst["src_mask"]
|
| 174 |
+
):
|
| 175 |
+
if data_args.experiment.startswith("random"):
|
| 176 |
hidden_state = model(torch.tensor(input_ids))
|
| 177 |
+
elif data_args.experiment == "gpt2_pre_compress":
|
| 178 |
input_ids2 = torch.tensor(input_ids).to(model.device)
|
| 179 |
input_embs = model.transformer.wte(input_ids2) # input_embs
|
| 180 |
hidden_state = model.down_proj(input_embs)
|
| 181 |
hidden_state = hidden_state * data_args.emb_scale_factor
|
| 182 |
+
result_train_lst.append(
|
| 183 |
+
{
|
| 184 |
+
"input_ids": input_ids,
|
| 185 |
+
"hidden_states": hidden_state.cpu().tolist(),
|
| 186 |
+
"src_ids": src_ids,
|
| 187 |
+
"src_mask": src_mask,
|
| 188 |
+
}
|
| 189 |
+
)
|
| 190 |
|
| 191 |
return result_train_lst
|
| 192 |
|
| 193 |
+
|
| 194 |
+
def helper_tokenize_stream(
|
| 195 |
+
sentence_lst,
|
| 196 |
+
vocab_dict,
|
| 197 |
+
model,
|
| 198 |
+
seqlen,
|
| 199 |
+
data_args,
|
| 200 |
+
padding_mode,
|
| 201 |
+
):
|
| 202 |
import psutil
|
| 203 |
+
|
| 204 |
# Process.memory_info is expressed in bytes, so convert to megabytes
|
| 205 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
| 206 |
from datasets import Dataset as Dataset2
|
| 207 |
+
|
| 208 |
+
raw_datasets = Dataset2.from_dict({"text": sentence_lst})
|
| 209 |
print(raw_datasets)
|
| 210 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
| 211 |
|
|
|
|
| 212 |
def tokenize_function(examples):
|
| 213 |
if isinstance(vocab_dict, dict):
|
| 214 |
+
input_ids = [
|
| 215 |
+
[0] + [vocab_dict.get(x, vocab_dict["UNK"]) for x in seq] + [1]
|
| 216 |
+
for seq in examples["text"]
|
| 217 |
+
]
|
| 218 |
elif isinstance(vocab_dict, PreTrainedTokenizerFast):
|
| 219 |
+
examples["text"] = [" ".join(seq) for seq in examples["text"]]
|
| 220 |
+
input_ids = vocab_dict(examples["text"], add_special_tokens=True)[
|
| 221 |
+
"input_ids"
|
| 222 |
+
]
|
| 223 |
+
result_dict = {"input_ids": input_ids}
|
| 224 |
# clm input could be much much longer than block_size
|
| 225 |
return result_dict
|
| 226 |
|
|
|
|
| 228 |
tokenize_function,
|
| 229 |
batched=True,
|
| 230 |
num_proc=4,
|
| 231 |
+
remove_columns=["text"],
|
| 232 |
load_from_cache_file=True,
|
| 233 |
desc="Running tokenizer on dataset",
|
| 234 |
)
|
| 235 |
print(tokenized_datasets)
|
| 236 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
| 237 |
|
| 238 |
+
if padding_mode == "block":
|
| 239 |
block_size = seqlen
|
| 240 |
+
|
| 241 |
def group_texts(examples):
|
| 242 |
+
concatenated_examples = {
|
| 243 |
+
k: list(chain(*examples[k])) for k in examples.keys()
|
| 244 |
+
}
|
| 245 |
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
| 246 |
if total_length >= block_size:
|
| 247 |
total_length = (total_length // block_size) * block_size
|
| 248 |
result = {
|
| 249 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
| 250 |
for k, t in concatenated_examples.items()
|
| 251 |
}
|
| 252 |
result["labels"] = result["input_ids"].copy()
|
| 253 |
return result
|
| 254 |
|
|
|
|
| 255 |
lm_datasets = tokenized_datasets.map(
|
| 256 |
group_texts,
|
| 257 |
batched=True,
|
|
|
|
| 260 |
desc=f"Grouping texts in chunks of {block_size}",
|
| 261 |
)
|
| 262 |
else:
|
| 263 |
+
|
| 264 |
def pad_function(group_lst):
|
| 265 |
max_length = seqlen
|
| 266 |
if isinstance(vocab_dict, dict):
|
| 267 |
+
group_lst["input_ids"] = _collate_batch_helper(
|
| 268 |
+
group_lst["input_ids"], vocab_dict["PAD"], max_length
|
| 269 |
+
)
|
| 270 |
else:
|
| 271 |
+
group_lst["input_ids"] = _collate_batch_helper(
|
| 272 |
+
group_lst["input_ids"], vocab_dict.pad_token_id, max_length
|
| 273 |
+
)
|
| 274 |
return group_lst
|
| 275 |
|
| 276 |
# Process.memory_info is expressed in bytes, so convert to megabytes
|
|
|
|
| 283 |
desc=f"padding",
|
| 284 |
)
|
| 285 |
|
| 286 |
+
print(lm_datasets, "padded dataset")
|
|
|
|
| 287 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
| 288 |
import datasets
|
| 289 |
+
|
| 290 |
raw_datasets = datasets.DatasetDict()
|
| 291 |
+
raw_datasets["train"] = lm_datasets
|
| 292 |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
|
| 293 |
return raw_datasets
|
| 294 |
|
| 295 |
+
|
| 296 |
+
def helper_tokenize_encode(
|
| 297 |
+
sentence_lst,
|
| 298 |
+
vocab_dict,
|
| 299 |
+
model,
|
| 300 |
+
seqlen,
|
| 301 |
+
data_args,
|
| 302 |
+
padding_mode,
|
| 303 |
+
):
|
| 304 |
result_train_lst = []
|
| 305 |
group_lst = defaultdict(list)
|
| 306 |
with torch.no_grad():
|
| 307 |
for input_ids in sentence_lst:
|
| 308 |
+
tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids]
|
| 309 |
input_ids = [0] + tokenized_ + [1]
|
| 310 |
+
group_lst["word_ids"].append(input_ids)
|
| 311 |
+
print(group_lst["word_ids"][:2])
|
| 312 |
|
| 313 |
+
if padding_mode == "block":
|
| 314 |
+
print("padding mode is block")
|
| 315 |
concatenated_examples = {k: sum(group_lst[k], []) for k in group_lst.keys()}
|
| 316 |
total_length = len(concatenated_examples[list(group_lst.keys())[0]])
|
| 317 |
block_size = seqlen
|
| 318 |
total_length = (total_length // block_size) * block_size
|
| 319 |
# Split by chunks of max_len.
|
| 320 |
group_lst = {
|
| 321 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
| 322 |
for k, t in concatenated_examples.items()
|
| 323 |
}
|
| 324 |
+
elif padding_mode == "pad":
|
| 325 |
+
print("padding mode is pad")
|
| 326 |
max_length = seqlen
|
| 327 |
+
group_lst["word_ids"] = _collate_batch_helper(
|
| 328 |
+
group_lst["word_ids"], vocab_dict["PAD"], max_length
|
| 329 |
+
)
|
| 330 |
|
| 331 |
+
for input_ids in group_lst["word_ids"]:
|
| 332 |
+
if data_args.experiment.startswith("random"):
|
| 333 |
hidden_state = model(torch.tensor(input_ids))
|
| 334 |
+
elif data_args.experiment == "gpt2_pre_compress":
|
| 335 |
input_ids2 = torch.tensor(input_ids).to(model.device)
|
| 336 |
input_embs = model.transformer.wte(input_ids2) # input_embs
|
| 337 |
hidden_state = model.down_proj(input_embs)
|
| 338 |
hidden_state = hidden_state * data_args.emb_scale_factor
|
| 339 |
+
elif data_args.experiment == "glove":
|
| 340 |
hidden_state = model(torch.tensor(input_ids))
|
| 341 |
+
result_train_lst.append(
|
| 342 |
+
{"input_ids": input_ids, "hidden_states": hidden_state.cpu().tolist()}
|
| 343 |
+
)
|
| 344 |
|
| 345 |
return result_train_lst
|
| 346 |
|
| 347 |
+
|
| 348 |
def load_glove_model(File):
|
| 349 |
print("Loading Glove Model")
|
| 350 |
glove_model = {}
|
| 351 |
+
with open(File, "r") as f:
|
| 352 |
for line in f:
|
| 353 |
split_line = line.split()
|
| 354 |
word = split_line[0]
|
|
|
|
| 358 |
print(f"{len(glove_model)} words loaded!")
|
| 359 |
return glove_model
|
| 360 |
|
| 361 |
+
|
| 362 |
def load_glove(vocab):
|
| 363 |
model = torch.nn.Embedding(len(vocab), 50)
|
| 364 |
+
glove_model = load_glove_model("predictability/glove/glove.6B.50d.txt")
|
| 365 |
array_lst = []
|
| 366 |
count_ = 0
|
| 367 |
for word, idx in vocab.items():
|
|
|
|
| 370 |
else:
|
| 371 |
count_ += 1
|
| 372 |
array_lst.append(torch.randn(50))
|
| 373 |
+
print(f"{count_} out of {len(vocab)} is initialized. ")
|
| 374 |
array_lst = torch.stack(array_lst)
|
| 375 |
print(torch.norm(array_lst, dim=-1).mean())
|
| 376 |
model.weight.data = array_lst
|
| 377 |
return model
|
| 378 |
|
| 379 |
|
| 380 |
+
def get_corpus_rocstory(
|
| 381 |
+
data_args, model, image_size, padding_mode="block", split="train", load_vocab=None
|
| 382 |
+
):
|
| 383 |
import csv, torch, json
|
| 384 |
from spacy.lang.en import English
|
| 385 |
|
| 386 |
+
if data_args.experiment_mode == "lm":
|
| 387 |
+
if data_args.modality == "roc":
|
| 388 |
pass
|
| 389 |
# print('loading dataset from ROCStory')
|
| 390 |
# nlp = English()
|
|
|
|
| 415 |
# # sentence_lst.append(word_lst)
|
| 416 |
# # sentence_lst = sentence_lst[1:]
|
| 417 |
# print(sentence_lst[:2])
|
| 418 |
+
if data_args.modality == "roc-aug":
|
| 419 |
pass
|
| 420 |
# print('loading dataset from ROCStory')
|
| 421 |
# nlp = English()
|
|
|
|
| 449 |
# word_lst = [x.text for x in tokenizer(sentences)]
|
| 450 |
# sentence_lst.append(word_lst)
|
| 451 |
# print(sentence_lst[:2],sentence_lst[-2:], 'dataset size=',len(sentence_lst))
|
| 452 |
+
elif data_args.modality == "simple-wiki":
|
| 453 |
pass
|
| 454 |
# print('loading dataset from simple wikipedia')
|
| 455 |
# sentence_lst = []
|
|
|
|
| 458 |
# word_lst = row.lower().split()
|
| 459 |
# sentence_lst.append(word_lst)
|
| 460 |
# print(sentence_lst[:2])
|
| 461 |
+
elif data_args.modality == "e2e-tgt":
|
| 462 |
+
print("loading dataset from simple e2e dataset")
|
| 463 |
sentence_lst = []
|
| 464 |
nlp = English()
|
| 465 |
tokenizer = nlp.tokenizer
|
| 466 |
+
if split == "train":
|
| 467 |
+
print("loading form the TRAIN set")
|
| 468 |
+
path = (
|
| 469 |
+
"/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_train.txt"
|
| 470 |
+
)
|
| 471 |
# path = f'../{data_args.e2e_train}/src1_train.txt'
|
| 472 |
+
elif split == "valid":
|
| 473 |
+
print("loading form the VALID set")
|
| 474 |
+
path = f"../{data_args.e2e_train}/src1_valid.txt"
|
| 475 |
+
path = (
|
| 476 |
+
"/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_valid.txt"
|
| 477 |
+
)
|
| 478 |
+
elif split == "test":
|
| 479 |
+
print("loading form the TEST set")
|
| 480 |
+
path = f"../{data_args.e2e_train}/src1_test.txt"
|
| 481 |
+
path = "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_test.txt"
|
| 482 |
+
elif split == "debug":
|
| 483 |
+
print("loading form the DEBUG set")
|
| 484 |
path = data_args.debug_path
|
| 485 |
import json
|
| 486 |
+
|
| 487 |
+
with open(path, "r") as ff:
|
| 488 |
for line in ff:
|
| 489 |
+
sentence_lst.append(json.loads(line)[0].split(" "))
|
| 490 |
sentence_lst = sentence_lst + sentence_lst
|
| 491 |
+
if split in ["train", "valid", "test"]:
|
| 492 |
+
with open(path, "r") as ff:
|
| 493 |
for row in ff:
|
| 494 |
+
word_lst = row.split("||")[1]
|
| 495 |
word_lst = [x.text for x in tokenizer(word_lst)]
|
| 496 |
sentence_lst.append(word_lst)
|
| 497 |
print(sentence_lst[:2])
|
| 498 |
|
| 499 |
+
elif data_args.modality == "yelp":
|
| 500 |
+
print("loading dataset from simple YelpNLG dataset")
|
| 501 |
sentence_lst = []
|
| 502 |
nlp = English()
|
| 503 |
tokenizer = nlp.tokenizer
|
| 504 |
+
if split == "train":
|
| 505 |
+
print("loading form the TRAIN set")
|
| 506 |
+
path = f"{data_args.yelp_train}/yelpnlg-train.csv"
|
| 507 |
+
elif split == "valid":
|
| 508 |
+
print("loading form the VALID set")
|
| 509 |
+
path = f"{data_args.yelp_train}/yelpnlg-dev.csv"
|
| 510 |
+
elif split == "test":
|
| 511 |
+
print("loading form the TEST set")
|
| 512 |
+
path = f"{data_args.yelp_train}/yelpnlg-test.csv"
|
| 513 |
+
if split in ["train", "valid", "test"]:
|
| 514 |
+
|
| 515 |
+
with open(path, "r") as csvfile:
|
| 516 |
+
yelp_reader = csv.reader(csvfile) # delimiter=' ', quotechar='|')
|
| 517 |
for row in yelp_reader:
|
| 518 |
sentences = row[1]
|
| 519 |
word_lst = [x.text for x in tokenizer(sentences)]
|
|
|
|
| 521 |
sentence_lst = sentence_lst[1:]
|
| 522 |
print(sentence_lst[:2])
|
| 523 |
|
| 524 |
+
elif data_args.modality == "commonGen":
|
| 525 |
+
print("loading dataset from simple YelpNLG dataset")
|
| 526 |
sentence_lst = []
|
| 527 |
nlp = English()
|
| 528 |
tokenizer = nlp.tokenizer
|
| 529 |
+
if split == "train":
|
| 530 |
+
print("loading form the TRAIN set")
|
| 531 |
+
path = f"{data_args.commonGen_train}/commongen.train.jsonl"
|
| 532 |
+
elif split == "valid":
|
| 533 |
+
print("loading form the VALID set")
|
| 534 |
+
path = f"{data_args.commonGen_train}/commongen.dev.jsonl"
|
| 535 |
+
elif split == "test":
|
| 536 |
+
print("loading form the TEST set")
|
| 537 |
+
path = f"{data_args.commonGen_train}/commongen.test.jsonl"
|
| 538 |
+
if split in ["train", "valid", "test"]:
|
| 539 |
+
with open(path, "r") as ff:
|
| 540 |
for line in ff:
|
| 541 |
line = json.loads(line)
|
| 542 |
+
for sentences in line["scene"]:
|
| 543 |
word_lst = [x.text for x in tokenizer(sentences)]
|
| 544 |
sentence_lst.append(word_lst)
|
| 545 |
print(sentence_lst[:2])
|
| 546 |
|
| 547 |
+
elif data_args.modality == "commonGen-aug":
|
| 548 |
+
print("loading dataset from simple YelpNLG dataset")
|
| 549 |
sentence_lst = []
|
| 550 |
nlp = English()
|
| 551 |
tokenizer = nlp.tokenizer
|
| 552 |
+
if split == "train":
|
| 553 |
+
print("loading form the TRAIN set")
|
| 554 |
+
path = f"{data_args.commonGen_train}/commongen.train.jsonl"
|
| 555 |
+
path_lst = [f"{data_args.roc_train}/roc_train.json"]
|
| 556 |
+
path_lst.append(
|
| 557 |
+
"diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt"
|
| 558 |
+
)
|
| 559 |
+
elif split == "valid":
|
| 560 |
+
print("loading form the VALID set")
|
| 561 |
+
path = f"{data_args.commonGen_train}/commongen.dev.jsonl"
|
| 562 |
path_lst = []
|
| 563 |
+
elif split == "test":
|
| 564 |
+
print("loading form the TEST set")
|
| 565 |
+
path = f"{data_args.commonGen_train}/commongen.test.jsonl"
|
| 566 |
path_lst = []
|
| 567 |
|
| 568 |
+
if split in ["train", "valid", "test"]:
|
| 569 |
+
with open(path, "r") as ff:
|
| 570 |
for line in ff:
|
| 571 |
line = json.loads(line)
|
| 572 |
+
for sentences in line["scene"]:
|
| 573 |
word_lst = [x.text for x in tokenizer(sentences)]
|
| 574 |
sentence_lst.append(word_lst)
|
| 575 |
print(sentence_lst[:2])
|
| 576 |
import itertools
|
| 577 |
+
|
| 578 |
for path in path_lst:
|
| 579 |
+
if path.endswith("txt"):
|
| 580 |
+
with open(path, "r") as roc_reader:
|
| 581 |
for row in roc_reader:
|
| 582 |
sentences = row.strip()
|
| 583 |
word_lst = [x.text for x in tokenizer(sentences)]
|
| 584 |
spl = [[]]
|
| 585 |
+
for x, y in itertools.groupby(word_lst, lambda z: z == "."):
|
| 586 |
spl[-1].extend(y)
|
| 587 |
+
if x:
|
| 588 |
+
spl.append([])
|
| 589 |
sentence_lst.extend(spl[:-1])
|
| 590 |
else:
|
| 591 |
+
with open(path, "r") as roc_reader:
|
| 592 |
for row in roc_reader:
|
| 593 |
sentences = json.loads(row)[0].strip()
|
| 594 |
word_lst = [x.text for x in tokenizer(sentences)]
|
| 595 |
spl = [[]]
|
| 596 |
+
for x, y in itertools.groupby(word_lst, lambda z: z == "."):
|
| 597 |
spl[-1].extend(y)
|
| 598 |
+
if x:
|
| 599 |
+
spl.append([])
|
| 600 |
sentence_lst.extend(spl[:-1])
|
| 601 |
|
| 602 |
print(sentence_lst[-2:])
|
| 603 |
|
|
|
|
| 604 |
# get tokenizer.
|
| 605 |
if load_vocab is None:
|
| 606 |
counter = Counter()
|
| 607 |
for input_ids in sentence_lst:
|
| 608 |
counter.update(input_ids)
|
| 609 |
|
| 610 |
+
if data_args.experiment_mode == "conditional_gen":
|
| 611 |
+
if data_args.modality == "e2e":
|
| 612 |
+
print("loading dataset from simple e2e dataset")
|
| 613 |
sentence_lst = []
|
| 614 |
nlp = English()
|
| 615 |
tokenizer = nlp.tokenizer
|
| 616 |
+
if split == "train":
|
| 617 |
+
path = f"{data_args.e2e_train}/src1_train.txt"
|
| 618 |
+
with open(path, "r") as ff:
|
| 619 |
for row in ff:
|
| 620 |
+
src_lst, word_lst = row.split("||")
|
| 621 |
word_lst = [x.text for x in tokenizer(word_lst)]
|
| 622 |
src_lst = [x.text for x in tokenizer(src_lst)]
|
| 623 |
sentence_lst.append((src_lst, word_lst))
|
| 624 |
+
elif split == "valid":
|
| 625 |
+
path = f"{data_args.e2e_train}/src1_valid.txt"
|
| 626 |
sentence_lst = read_e2e_files(path, data_args, tokenizer)
|
| 627 |
print(sentence_lst[:2])
|
| 628 |
# get tokenizer.
|
| 629 |
if load_vocab is None:
|
| 630 |
counter = Counter()
|
| 631 |
+
for src_ids, input_ids in sentence_lst:
|
| 632 |
counter.update(input_ids)
|
| 633 |
counter.update(src_ids)
|
| 634 |
|
| 635 |
if load_vocab is None:
|
| 636 |
+
vocab_dict = {"START": 0, "END": 1, "UNK": 2, "PAD": 3}
|
| 637 |
for k, v in counter.items():
|
| 638 |
if v > 10:
|
| 639 |
vocab_dict[k] = len(vocab_dict)
|
| 640 |
print(len(counter), len(vocab_dict))
|
| 641 |
|
| 642 |
+
path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json"
|
| 643 |
+
print(f"save the vocab to {path_save_vocab}")
|
| 644 |
+
with open(path_save_vocab, "w") as f:
|
| 645 |
json.dump(vocab_dict, f)
|
| 646 |
else:
|
| 647 |
vocab_dict = load_vocab
|
| 648 |
+
path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json"
|
| 649 |
if not os.path.exists(path_save_vocab):
|
| 650 |
+
print(f"save the vocab to {path_save_vocab}")
|
| 651 |
if isinstance(vocab_dict, dict):
|
| 652 |
+
with open(path_save_vocab, "w") as f:
|
| 653 |
json.dump(vocab_dict, f)
|
| 654 |
+
assert vocab_dict["START"] == 0
|
| 655 |
elif isinstance(vocab_dict, PreTrainedTokenizerFast):
|
| 656 |
vocab_dict.save_pretrained(data_args.checkpoint_path)
|
| 657 |
else:
|
| 658 |
assert False, "invalid type of vocab_dict"
|
| 659 |
|
| 660 |
+
if model is None and data_args.experiment == "random":
|
|
|
|
|
|
|
| 661 |
model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel)
|
| 662 |
+
print("initializing the random embeddings", model)
|
| 663 |
torch.nn.init.normal_(model.weight)
|
| 664 |
+
path_save = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch"
|
| 665 |
+
print(
|
| 666 |
+
f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch"
|
| 667 |
+
)
|
| 668 |
torch.save(model.state_dict(), path_save)
|
| 669 |
|
| 670 |
# path_save = f'{data_args.checkpoint_path}/random_emb.torch'
|
| 671 |
# if not os.path.exists(path_save) and data_args.experiment == 'random':
|
| 672 |
# torch.save(model.state_dict(), path_save)
|
| 673 |
|
| 674 |
+
if (
|
| 675 |
+
data_args.experiment_mode == "lm"
|
| 676 |
+
and data_args.modality
|
| 677 |
+
in ["roc-aug", "roc", "yelp", "commonGen", "commonGen-aug"]
|
| 678 |
+
and data_args.cache_mode == "no"
|
| 679 |
+
):
|
| 680 |
+
train_dataset = helper_tokenize_stream(
|
| 681 |
+
sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode
|
| 682 |
+
)
|
| 683 |
return train_dataset, model
|
| 684 |
+
elif data_args.experiment_mode == "lm":
|
| 685 |
+
result_train_lst = helper_tokenize_encode(
|
| 686 |
+
sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode
|
| 687 |
+
)
|
| 688 |
+
elif data_args.experiment_mode == "conditional_gen":
|
| 689 |
+
result_train_lst = helper_tokenize_encode_cond(
|
| 690 |
+
sentence_lst, vocab_dict, model, image_size**2, data_args
|
| 691 |
+
)
|
| 692 |
+
return {"train": result_train_lst}, model
|
| 693 |
+
|
| 694 |
|
| 695 |
def write_e2e_corr(prompt_lst, file_dict, corr_path):
|
| 696 |
print(len(prompt_lst))
|
| 697 |
+
with open(corr_path, "w") as f:
|
| 698 |
for x in prompt_lst:
|
| 699 |
for line in file_dict[x]:
|
| 700 |
print(" ".join(line), file=f)
|
| 701 |
+
print("", file=f)
|
| 702 |
|
| 703 |
|
| 704 |
def write_e2e_src(prompt_lst, corr_path):
|
| 705 |
+
with open(corr_path, "w") as f:
|
| 706 |
for x in prompt_lst:
|
| 707 |
print(" ".join(x), file=f)
|
| 708 |
return
|
|
|
|
| 710 |
|
| 711 |
def read_e2e_files(path, args, tokenizer):
|
| 712 |
file_dict = {}
|
| 713 |
+
with open(path, "r") as f:
|
| 714 |
for line in f:
|
| 715 |
+
src_lst, word_lst = line.strip().split("||")
|
| 716 |
tgt = tuple([x.text for x in tokenizer(word_lst)])
|
| 717 |
src = tuple([x.text for x in tokenizer(src_lst)])
|
| 718 |
if src not in file_dict:
|
| 719 |
file_dict[src] = []
|
| 720 |
file_dict[src].append(tgt)
|
| 721 |
+
temp = "1"
|
| 722 |
prompt_text_dict = file_dict
|
| 723 |
prompt_text_lst = list(prompt_text_dict.keys())
|
| 724 |
+
gold_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "gold"))
|
| 725 |
print("gold dir", gold_dir)
|
| 726 |
write_e2e_corr(prompt_text_lst, prompt_text_dict, gold_dir)
|
| 727 |
+
src_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "src"))
|
| 728 |
write_e2e_src(prompt_text_lst, src_dir)
|
| 729 |
final_lst = [(xx, prompt_text_dict[xx][0]) for xx in prompt_text_lst]
|
| 730 |
return final_lst
|
| 731 |
|
| 732 |
|
| 733 |
+
def get_corpus_book(
|
| 734 |
+
data_args,
|
| 735 |
+
tokenizer,
|
| 736 |
+
model,
|
| 737 |
+
image_size,
|
| 738 |
+
padding_mode="block",
|
| 739 |
+
split="train",
|
| 740 |
+
):
|
| 741 |
+
max_length = image_size**2
|
| 742 |
import os
|
| 743 |
+
|
| 744 |
+
assert padding_mode == "block"
|
| 745 |
+
raw_datasets = load_dataset("bookcorpus")
|
| 746 |
if "validation" not in raw_datasets.keys():
|
| 747 |
raw_datasets["validation"] = load_dataset(
|
| 748 |
+
"bookcorpus",
|
| 749 |
split=f"train[:1%]",
|
| 750 |
)
|
| 751 |
raw_datasets["train"] = load_dataset(
|
| 752 |
+
"bookcorpus",
|
| 753 |
split=f"train[1%:]",
|
| 754 |
)
|
| 755 |
print(raw_datasets)
|
| 756 |
column_names = raw_datasets["train"].column_names
|
| 757 |
|
| 758 |
def tokenize_function(examples):
|
| 759 |
+
output = tokenizer(examples["text"], add_special_tokens=False)
|
| 760 |
return output
|
| 761 |
|
|
|
|
| 762 |
tokenized_datasets = raw_datasets.map(
|
| 763 |
tokenize_function,
|
| 764 |
batched=True,
|
|
|
|
| 779 |
if total_length >= block_size:
|
| 780 |
total_length = (total_length // block_size) * block_size
|
| 781 |
result = {
|
| 782 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
| 783 |
for k, t in concatenated_examples.items()
|
| 784 |
}
|
| 785 |
return result
|
|
|
|
| 795 |
print(lm_datasets)
|
| 796 |
|
| 797 |
if model is None:
|
| 798 |
+
if data_args.training_mode.startswith("e2e"):
|
| 799 |
+
print("since its e2e, initialize a dummy embedding")
|
| 800 |
model = torch.nn.Embedding(len(tokenizer), 1)
|
| 801 |
else:
|
| 802 |
model = torch.nn.Embedding(len(tokenizer), data_args.in_channel)
|
| 803 |
+
print("initializing the random embeddings", model)
|
| 804 |
torch.nn.init.normal_(model.weight)
|
| 805 |
+
path_save = f"{data_args.checkpoint_path}/random_emb.torch"
|
| 806 |
+
print(
|
| 807 |
+
f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch"
|
| 808 |
+
)
|
| 809 |
torch.save(model.state_dict(), path_save)
|
| 810 |
|
| 811 |
+
if split == "train":
|
| 812 |
return lm_datasets, model
|
| 813 |
else:
|
| 814 |
+
lm_datasets["train"] = lm_datasets["validation"]
|
| 815 |
return lm_datasets, model
|
| 816 |
|
| 817 |
|
| 818 |
class TextDataset(Dataset):
|
| 819 |
+
def __init__(
|
| 820 |
+
self,
|
| 821 |
+
text_datasets,
|
| 822 |
+
resolution,
|
| 823 |
+
data_args,
|
| 824 |
+
model_arch="conv-unet",
|
| 825 |
+
classes=None,
|
| 826 |
+
shard=0,
|
| 827 |
+
num_shards=1,
|
| 828 |
+
eigen_transform=None,
|
| 829 |
+
mapping_func=None,
|
| 830 |
+
model_emb=None,
|
| 831 |
+
):
|
| 832 |
super().__init__()
|
| 833 |
self.resolution = resolution
|
| 834 |
self.text_datasets = text_datasets
|
| 835 |
+
self.length = len(self.text_datasets["train"])
|
| 836 |
self.model_arch = model_arch
|
| 837 |
self.data_args = data_args
|
| 838 |
print(self.resolution)
|
|
|
|
| 850 |
# We are not on a new enough PIL to support the `reducing_gap`
|
| 851 |
# argument, which uses BOX downsampling at powers of two first.
|
| 852 |
# Thus, we do it by hand to improve downsample quality.
|
| 853 |
+
if self.model_arch == "conv-unet":
|
| 854 |
+
pass # arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
|
| 855 |
# dtype=np.float32).reshape(self.resolution, self.resolution, -1)
|
| 856 |
# # print(self.eigen_transform.shape)
|
| 857 |
# if self.eigen_transform is not None:
|
|
|
|
| 862 |
# if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
|
| 863 |
# arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
|
| 864 |
|
|
|
|
| 865 |
# out_dict = {}
|
| 866 |
# out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
|
| 867 |
# # if self.local_classes is not None:
|
| 868 |
# # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
| 869 |
# # print(out_dict.keys())
|
| 870 |
# return np.transpose(arr, [2, 0, 1]), out_dict
|
| 871 |
+
elif self.model_arch == "1d-unet":
|
| 872 |
+
pass # arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
|
| 873 |
# dtype=np.float32) # seqlen, dim
|
| 874 |
# if self.eigen_transform is not None:
|
| 875 |
# old_shape = arr.shape
|
|
|
|
| 887 |
# # print(arr.shape)
|
| 888 |
# return arr, out_dict
|
| 889 |
else:
|
| 890 |
+
arr = np.array(
|
| 891 |
+
self.text_datasets["train"][idx]["hidden_states"], dtype=np.float32
|
| 892 |
+
)
|
| 893 |
+
if self.eigen_transform is not None:
|
| 894 |
old_shape = arr.shape
|
| 895 |
# arr = arr.reshape(1, -1) @ self.eigen_transform
|
| 896 |
+
arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
|
| 897 |
+
arr = arr @ self.eigen_transform["map"]
|
| 898 |
arr = arr.reshape(old_shape)
|
| 899 |
+
|
| 900 |
+
if (
|
| 901 |
+
hasattr(self.data_args, "noise_level")
|
| 902 |
+
and self.data_args.noise_level > 0
|
| 903 |
+
):
|
| 904 |
# print(arr.dtype)
|
| 905 |
# print(self.data_args.noise_level, 'using the noise level.')
|
| 906 |
+
arr = arr + self.data_args.noise_level * np.random.randn(
|
| 907 |
+
*arr.shape
|
| 908 |
+
).astype(arr.dtype)
|
| 909 |
# print(arr.dtype)
|
| 910 |
|
| 911 |
out_dict = {}
|
| 912 |
+
out_dict["input_ids"] = np.array(
|
| 913 |
+
self.text_datasets["train"][idx]["input_ids"]
|
| 914 |
+
)
|
| 915 |
# out_dict['mapping_func'] = self.mapping_func
|
| 916 |
+
if self.data_args.experiment_mode == "conditional_gen":
|
| 917 |
+
out_dict["src_ids"] = np.array(
|
| 918 |
+
self.text_datasets["train"][idx]["src_ids"]
|
| 919 |
+
)
|
| 920 |
+
out_dict["src_mask"] = np.array(
|
| 921 |
+
self.text_datasets["train"][idx]["src_mask"]
|
| 922 |
+
)
|
| 923 |
# if self.local_classes is not None:
|
| 924 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
| 925 |
return arr, out_dict
|
|
|
|
| 929 |
|
| 930 |
|
| 931 |
class TextDataset_NoCache(Dataset):
|
| 932 |
+
def __init__(
|
| 933 |
+
self,
|
| 934 |
+
text_datasets,
|
| 935 |
+
resolution,
|
| 936 |
+
data_args,
|
| 937 |
+
model_arch="conv-unet",
|
| 938 |
+
classes=None,
|
| 939 |
+
shard=0,
|
| 940 |
+
num_shards=1,
|
| 941 |
+
eigen_transform=None,
|
| 942 |
+
mapping_func=None,
|
| 943 |
+
model_emb=None,
|
| 944 |
+
):
|
| 945 |
super().__init__()
|
| 946 |
self.resolution = resolution
|
| 947 |
self.text_datasets = text_datasets
|
| 948 |
+
self.length = len(self.text_datasets["train"])
|
| 949 |
self.model_arch = model_arch
|
| 950 |
self.data_args = data_args
|
| 951 |
print(self.resolution)
|
|
|
|
| 964 |
# argument, which uses BOX downsampling at powers of two first.
|
| 965 |
# Thus, we do it by hand to improve downsample quality.
|
| 966 |
with torch.no_grad():
|
| 967 |
+
input_ids = self.text_datasets["train"][idx]["input_ids"]
|
| 968 |
model = self.model_emb
|
| 969 |
+
if self.data_args.experiment.startswith("random"):
|
| 970 |
hidden_state = model(torch.tensor(input_ids))
|
| 971 |
+
elif self.data_args.experiment == "gpt2_pre_compress":
|
| 972 |
input_ids2 = torch.tensor(input_ids).to(model.device)
|
| 973 |
input_embs = model.transformer.wte(input_ids2) # input_embs
|
| 974 |
hidden_state = model.down_proj(input_embs)
|
| 975 |
hidden_state = hidden_state * data_args.emb_scale_factor
|
| 976 |
|
| 977 |
+
if self.model_arch == "conv-unet":
|
| 978 |
+
arr = np.array(hidden_state, dtype=np.float32).reshape(
|
| 979 |
+
self.resolution, self.resolution, -1
|
| 980 |
+
)
|
| 981 |
# print(self.eigen_transform.shape)
|
| 982 |
if self.eigen_transform is not None:
|
| 983 |
old_shape = arr.shape
|
| 984 |
+
arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
|
| 985 |
+
arr = arr @ self.eigen_transform["map"]
|
| 986 |
arr = arr.reshape(old_shape)
|
| 987 |
+
if (
|
| 988 |
+
hasattr(self.data_args, "noise_level")
|
| 989 |
+
and self.data_args.noise_level > 0
|
| 990 |
+
):
|
| 991 |
+
arr = arr + self.data_args.noise_level * np.random.randn(
|
| 992 |
+
*arr.shape
|
| 993 |
+
).astype(arr.dtype)
|
| 994 |
|
| 995 |
out_dict = {}
|
| 996 |
+
out_dict["input_ids"] = np.array(
|
| 997 |
+
self.text_datasets["train"][idx]["input_ids"]
|
| 998 |
+
)
|
| 999 |
# if self.local_classes is not None:
|
| 1000 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
| 1001 |
# print(out_dict.keys())
|
| 1002 |
return np.transpose(arr, [2, 0, 1]), out_dict
|
| 1003 |
+
elif self.model_arch == "1d-unet":
|
| 1004 |
+
arr = np.array(hidden_state, dtype=np.float32) # seqlen, dim
|
|
|
|
| 1005 |
if self.eigen_transform is not None:
|
| 1006 |
old_shape = arr.shape
|
| 1007 |
+
arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
|
| 1008 |
+
arr = arr @ self.eigen_transform["map"]
|
| 1009 |
arr = arr.reshape(old_shape)
|
| 1010 |
+
if (
|
| 1011 |
+
hasattr(self.data_args, "noise_level")
|
| 1012 |
+
and self.data_args.noise_level > 0
|
| 1013 |
+
):
|
| 1014 |
+
arr = arr + self.data_args.noise_level * np.random.randn(
|
| 1015 |
+
*arr.shape
|
| 1016 |
+
).astype(arr.dtype)
|
| 1017 |
arr = np.transpose(arr, [1, 0])
|
| 1018 |
out_dict = {}
|
| 1019 |
+
out_dict["input_ids"] = np.array(
|
| 1020 |
+
self.text_datasets["train"][idx]["input_ids"]
|
| 1021 |
+
)
|
| 1022 |
# out_dict['mapping_func'] = self.mapping_func
|
| 1023 |
# if self.local_classes is not None:
|
| 1024 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
| 1025 |
# print(arr.shape)
|
| 1026 |
return arr, out_dict
|
| 1027 |
else:
|
| 1028 |
+
arr = np.array(hidden_state, dtype=np.float32)
|
|
|
|
| 1029 |
if self.eigen_transform is not None:
|
| 1030 |
old_shape = arr.shape
|
| 1031 |
# arr = arr.reshape(1, -1) @ self.eigen_transform
|
| 1032 |
+
arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
|
| 1033 |
+
arr = arr @ self.eigen_transform["map"]
|
| 1034 |
arr = arr.reshape(old_shape)
|
| 1035 |
|
| 1036 |
+
if (
|
| 1037 |
+
hasattr(self.data_args, "noise_level")
|
| 1038 |
+
and self.data_args.noise_level > 0
|
| 1039 |
+
):
|
| 1040 |
# print(arr.dtype)
|
| 1041 |
# print(self.data_args.noise_level, 'using the noise level.')
|
| 1042 |
+
arr = arr + self.data_args.noise_level * np.random.randn(
|
| 1043 |
+
*arr.shape
|
| 1044 |
+
).astype(arr.dtype)
|
| 1045 |
# print(arr.dtype)
|
| 1046 |
|
| 1047 |
out_dict = {}
|
| 1048 |
+
out_dict["input_ids"] = np.array(
|
| 1049 |
+
self.text_datasets["train"][idx]["input_ids"]
|
| 1050 |
+
)
|
| 1051 |
# out_dict['mapping_func'] = self.mapping_func
|
| 1052 |
+
if self.data_args.experiment_mode == "conditional_gen":
|
| 1053 |
+
out_dict["src_ids"] = np.array(
|
| 1054 |
+
self.text_datasets["train"][idx]["src_ids"]
|
| 1055 |
+
)
|
| 1056 |
+
out_dict["src_mask"] = np.array(
|
| 1057 |
+
self.text_datasets["train"][idx]["src_mask"]
|
| 1058 |
+
)
|
| 1059 |
# if self.local_classes is not None:
|
| 1060 |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
| 1061 |
return arr, out_dict
|
| 1062 |
|
| 1063 |
+
|
| 1064 |
def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False):
|
| 1065 |
+
result = torch.full(
|
| 1066 |
+
[len(examples), max_length], pad_token_id, dtype=torch.int64
|
| 1067 |
+
).tolist()
|
| 1068 |
+
mask_ = torch.full(
|
| 1069 |
+
[len(examples), max_length], pad_token_id, dtype=torch.int64
|
| 1070 |
+
).tolist()
|
| 1071 |
for i, example in enumerate(examples):
|
| 1072 |
curr_len = min(len(example), max_length)
|
| 1073 |
result[i][:curr_len] = example[:curr_len]
|
|
|
|
| 1076 |
return result, mask_
|
| 1077 |
return result
|
| 1078 |
|
| 1079 |
+
|
| 1080 |
def _torch_collate_batch(examples, pad_token_id, max_length):
|
| 1081 |
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
| 1082 |
import numpy as np
|
|
|
|
| 1101 |
result[i, : example.shape[0]] = example
|
| 1102 |
else:
|
| 1103 |
result[i, -example.shape[0] :] = example
|
| 1104 |
+
return result
|