from abc import ABC, abstractmethod from typing import Generic, TypeVar from jaxtyping import Float from torch import Tensor, nn from src.dataset.types import BatchedViews T = TypeVar("T") class Backbone(nn.Module, ABC, Generic[T]): cfg: T def __init__(self, cfg: T) -> None: super().__init__() self.cfg = cfg @abstractmethod def forward( self, context: BatchedViews, ) -> Float[Tensor, "batch view d_out height width"]: pass @property @abstractmethod def d_out(self) -> int: pass