AnySplat / src /model /encoder /visualization /encoder_visualizer.py
alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame contribute delete
561 Bytes
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from jaxtyping import Float
from torch import Tensor
T_cfg = TypeVar("T_cfg")
T_encoder = TypeVar("T_encoder")
class EncoderVisualizer(ABC, Generic[T_cfg, T_encoder]):
cfg: T_cfg
encoder: T_encoder
def __init__(self, cfg: T_cfg, encoder: T_encoder) -> None:
self.cfg = cfg
self.encoder = encoder
@abstractmethod
def visualize(
self,
context: dict,
global_step: int,
) -> dict[str, Float[Tensor, "3 _ _"]]:
pass