|
from abc import ABC, abstractmethod |
|
from typing import Generic, TypeVar |
|
|
|
from torch import nn |
|
from dataclasses import dataclass |
|
from src.dataset.types import BatchedViews, DataShim |
|
from ..types import Gaussians |
|
from jaxtyping import Float |
|
from torch import Tensor, nn |
|
|
|
T = TypeVar("T") |
|
|
|
@dataclass |
|
class EncoderOutput: |
|
gaussians: Gaussians |
|
pred_pose_enc_list: list[Float[Tensor, "batch view 6"]] | None |
|
pred_context_pose: dict | None |
|
depth_dict: dict | None |
|
infos: dict | None |
|
distill_infos: dict | None |
|
|
|
class Encoder(nn.Module, ABC, Generic[T]): |
|
cfg: T |
|
|
|
def __init__(self, cfg: T) -> None: |
|
super().__init__() |
|
self.cfg = cfg |
|
|
|
@abstractmethod |
|
def forward( |
|
self, |
|
context: BatchedViews, |
|
) -> Gaussians: |
|
pass |
|
|
|
def get_data_shim(self) -> DataShim: |
|
"""The default shim doesn't modify the batch.""" |
|
return lambda x: x |
|
|