danieldk's picture
danieldk HF Staff
Build
c6d6c2e
import torch
import torch.nn as nn
from torch.nn import functional as F
class LinearImplicitBackward(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)
class LinearBackward(nn.Module):
has_backward = True
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)
class LinearNoBackward(nn.Module):
has_backward = False
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)
__all__ = ["LinearImplicitBackward", "LinearBackward", "LinearNoBackward"]