File size: 561 Bytes
2568013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from abc import ABC, abstractmethod
from typing import Generic, TypeVar

from jaxtyping import Float
from torch import Tensor

T_cfg = TypeVar("T_cfg")
T_encoder = TypeVar("T_encoder")


class EncoderVisualizer(ABC, Generic[T_cfg, T_encoder]):
    cfg: T_cfg
    encoder: T_encoder

    def __init__(self, cfg: T_cfg, encoder: T_encoder) -> None:
        self.cfg = cfg
        self.encoder = encoder

    @abstractmethod
    def visualize(
        self,
        context: dict,
        global_step: int,
    ) -> dict[str, Float[Tensor, "3 _ _"]]:
        pass