from abc import ABC, abstractmethod | |
from typing import Generic, TypeVar | |
from jaxtyping import Float | |
from torch import Tensor | |
T_cfg = TypeVar("T_cfg") | |
T_encoder = TypeVar("T_encoder") | |
class EncoderVisualizer(ABC, Generic[T_cfg, T_encoder]): | |
cfg: T_cfg | |
encoder: T_encoder | |
def __init__(self, cfg: T_cfg, encoder: T_encoder) -> None: | |
self.cfg = cfg | |
self.encoder = encoder | |
def visualize( | |
self, | |
context: dict, | |
global_step: int, | |
) -> dict[str, Float[Tensor, "3 _ _"]]: | |
pass | |