# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch import torch.nn.functional as F from megablocks.layers.arguments import Arguments class FFN(torch.nn.Module): def __init__(self, args: Arguments): super().__init__() self.w1 = torch.nn.Parameter( torch.empty( args.hidden_size, args.ffn_hidden_size, device=args.device, dtype=torch.float16 if args.fp16 else torch.float32, ), ) self.w2 = torch.nn.Parameter( torch.empty( args.ffn_hidden_size, args.hidden_size, device=args.device, dtype=torch.float16 if args.fp16 else torch.float32, ), ) def forward(self, x): return torch.matmul( F.gelu(torch.matmul(x, self.w1), approximate='tanh'), self.w2, ) class GLU(FFN): def __init__(self, args: Arguments): super().__init__(args) self.v1 = torch.nn.Parameter( torch.empty( args.hidden_size, args.ffn_hidden_size, device=args.device, dtype=torch.float16 if args.fp16 else torch.float32, ), ) def forward(self, x): x1 = F.gelu(torch.matmul(x, self.w1), approximate='tanh') * torch.matmul(x, self.v1) return torch.matmul(x1, self.w2)