alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame contribute delete
569 Bytes
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from jaxtyping import Float
from torch import Tensor, nn
from src.dataset.types import BatchedViews
T = TypeVar("T")
class Backbone(nn.Module, ABC, Generic[T]):
cfg: T
def __init__(self, cfg: T) -> None:
super().__init__()
self.cfg = cfg
@abstractmethod
def forward(
self,
context: BatchedViews,
) -> Float[Tensor, "batch view d_out height width"]:
pass
@property
@abstractmethod
def d_out(self) -> int:
pass