from abc import ABC, abstractmethod | |
from typing import Generic, TypeVar | |
from jaxtyping import Float | |
from torch import Tensor, nn | |
from src.dataset.types import BatchedViews | |
T = TypeVar("T") | |
class Backbone(nn.Module, ABC, Generic[T]): | |
cfg: T | |
def __init__(self, cfg: T) -> None: | |
super().__init__() | |
self.cfg = cfg | |
def forward( | |
self, | |
context: BatchedViews, | |
) -> Float[Tensor, "batch view d_out height width"]: | |
pass | |
def d_out(self) -> int: | |
pass | |