Spaces:
Running
Running
File size: 1,099 Bytes
506a2b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from typing import Union
import torch
from torch import nn, Tensor
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=.02):
super().__init__()
self.emb = nn.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init)
def forward(self, x):
"""
Returns positional embeddings for index 0 up to the length of x
"""
sl = x.shape[1]
return self.emb(torch.arange(0, sl, device=x.device))
def get_fixed_embedding(self, idx: 'Union[int, Tensor]'):
"""
Args:
idx: scalar int or an integer tensor of shape (T,) or (B, T)
Returns:
positional embeddings for given indices, shape (B, T, dim), ie (1, 1, dim) for int input
"""
device = self.emb.weight.device
idx = idx.to(device) if torch.is_tensor(idx) else torch.tensor(idx, device=device)
idx = torch.atleast_2d(idx)
assert idx.ndim == 2
return self.emb(idx) # (B, T, dim)
|