Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import torch | |
| def forward(model_name, model, input_ids, past, device='cpu'): | |
| if "gpt2" in model_name or "ctrl" in model_name: | |
| if past is not None: | |
| return model(input_ids[:, -1], past=past) | |
| return model(input_ids) | |
| elif "xlnet" in model_name: | |
| input_ids = torch.cat(( | |
| input_ids, | |
| torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=device) | |
| ), dim=1) | |
| perm_mask = torch.zeros( | |
| (input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]), | |
| dtype=torch.float, | |
| device=device | |
| ) | |
| perm_mask[:, :, -1] = 1.0 | |
| target_mapping = torch.zeros( | |
| (input_ids.shape[0], 1, input_ids.shape[1]), | |
| dtype=torch.float, | |
| device=device) | |
| target_mapping[:, 0, -1] = 1.0 | |
| return model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping) | |
| elif "transfo-xl" in model_name: | |
| return model(input_ids, mems=past) | |
| else: | |
| return model(input_ids) | |
| def create_context(model_name, tokenizer, initial_text="", padding_text=None, max_tokens=512): | |
| if not len(initial_text) and "gpt2" in model_name: | |
| initial_text = "<|endoftext|>" | |
| if 'xlnet' in model_name or "transfo-xl" in model_name: | |
| initial_text = padding_text + initial_text | |
| if 'transfo-xl' in model_name: | |
| max_tokens = int(max_tokens / 2) | |
| context_tokens = tokenizer.encode(initial_text)[-max_tokens:] | |
| if "gpt2" in model_name: | |
| eot_token = tokenizer.encoder["<|endoftext|>"] | |
| if len(context_tokens) == 0: | |
| context_tokens = [tokenizer.encoder["<|endoftext|>"]] | |
| elif "xlnet" in model_name: | |
| eot_token = tokenizer.convert_tokens_to_ids('<eop>') | |
| else: | |
| eot_token = None | |
| dot_token = tokenizer.encode(".")[-1] | |
| return context_tokens, eot_token, dot_token | |