File size: 561 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from jaxtyping import Float
from torch import Tensor
T_cfg = TypeVar("T_cfg")
T_encoder = TypeVar("T_encoder")
class EncoderVisualizer(ABC, Generic[T_cfg, T_encoder]):
cfg: T_cfg
encoder: T_encoder
def __init__(self, cfg: T_cfg, encoder: T_encoder) -> None:
self.cfg = cfg
self.encoder = encoder
@abstractmethod
def visualize(
self,
context: dict,
global_step: int,
) -> dict[str, Float[Tensor, "3 _ _"]]:
pass
|