jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
from dataclasses import dataclass, field
from typing import Dict, ForwardRef, List, Optional, Type, Union
ParamIdentifierType = ForwardRef("ParamIdentifier")
ContextParallelInputMetadataType = ForwardRef("ContextParallelInputMetadata")
ContextParallelOutputMetadataType = ForwardRef("ContextParallelOutputMetadata")
_ContextParallelInputType = Dict[
ParamIdentifierType, Union[ContextParallelInputMetadataType, List[ContextParallelInputMetadataType]]
]
_ContextParallelOutputType = List[ContextParallelOutputMetadataType]
ContextParallelModelPlan = Union[_ContextParallelInputType, _ContextParallelOutputType]
@dataclass(frozen=True)
class ParamId:
"""
A class to identify a parameter of a method.
Atleast one of `name` or `index` must be provided.
Attributes:
name (`str`, *optional*):
The name of the parameter.
index (`int`, *optional*):
The index of the parameter in the method signature. Indexing starts at 0 (ignore
the `self` parameter for instance methods).
"""
name: Optional[str] = None
index: Optional[int] = None
def __post_init__(self):
if self.name is None and self.index is None:
raise ValueError("At least one of `name` or `index` must be provided.")
@dataclass(frozen=True)
class CPInput:
split_dim: int
expected_dims: Optional[int] = None
split_output: bool = False
@dataclass(frozen=True)
class CPOutput:
gather_dim: int
expected_dims: Optional[int] = None
@dataclass
class TransformerMetadata:
# Mapping of FQN to mapping of input name to ContextParallelModelPlan
cp_plan: Dict[str, ContextParallelModelPlan] = field(default_factory=dict)
# tp_plan # TODO(aryan)
class TransformerRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: TransformerMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> TransformerMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]