Spaces:
Paused
Paused
alessandro trinca tornidor
[refactor] prepare packaging moving all the modules under 'lisa_on_cuda' (renamed from 'model')
60fa201
"""GPT Blocks used for the GPT Model.""" | |
from typing import Dict, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
from .attention import ATTN_CLASS_REGISTRY | |
from .norm import NORM_CLASS_REGISTRY | |
class MPTMLP(nn.Module): | |
def __init__( | |
self, d_model: int, expansion_ratio: int, device: Optional[str] = None | |
): | |
super().__init__() | |
self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) | |
self.act = nn.GELU(approximate="none") | |
self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) | |
self.down_proj._is_residual = True | |
def forward(self, x): | |
return self.down_proj(self.act(self.up_proj(x))) | |
class MPTBlock(nn.Module): | |
def __init__( | |
self, | |
d_model: int, | |
n_heads: int, | |
expansion_ratio: int, | |
attn_config: Dict = { | |
"attn_type": "multihead_attention", | |
"attn_pdrop": 0.0, | |
"attn_impl": "triton", | |
"qk_ln": False, | |
"clip_qkv": None, | |
"softmax_scale": None, | |
"prefix_lm": False, | |
"attn_uses_sequence_id": False, | |
"alibi": False, | |
"alibi_bias_max": 8, | |
}, | |
resid_pdrop: float = 0.0, | |
norm_type: str = "low_precision_layernorm", | |
verbose: int = 0, | |
device: Optional[str] = None, | |
**kwargs | |
): | |
del kwargs | |
super().__init__() | |
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] | |
attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]] | |
self.norm_1 = norm_class(d_model, device=device) | |
self.attn = attn_class( | |
attn_impl=attn_config["attn_impl"], | |
clip_qkv=attn_config["clip_qkv"], | |
qk_ln=attn_config["qk_ln"], | |
softmax_scale=attn_config["softmax_scale"], | |
attn_pdrop=attn_config["attn_pdrop"], | |
d_model=d_model, | |
n_heads=n_heads, | |
verbose=verbose, | |
device=device, | |
) | |
self.norm_2 = norm_class(d_model, device=device) | |
self.ffn = MPTMLP( | |
d_model=d_model, expansion_ratio=expansion_ratio, device=device | |
) | |
self.resid_attn_dropout = nn.Dropout(resid_pdrop) | |
self.resid_ffn_dropout = nn.Dropout(resid_pdrop) | |
def forward( | |
self, | |
x: torch.Tensor, | |
past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
attn_bias: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.ByteTensor] = None, | |
is_causal: bool = True, | |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: | |
a = self.norm_1(x) | |
(b, attn_weights, past_key_value) = self.attn( | |
a, | |
past_key_value=past_key_value, | |
attn_bias=attn_bias, | |
attention_mask=attention_mask, | |
is_causal=is_causal, | |
) | |
x = x + self.resid_attn_dropout(b) | |
m = self.norm_2(x) | |
n = self.ffn(m) | |
x = x + self.resid_ffn_dropout(n) | |
return (x, attn_weights, past_key_value) | |