Spaces:
Running
Running
from dataclasses import dataclass, field | |
from typing import Dict, ForwardRef, List, Optional, Type, Union | |
ParamIdentifierType = ForwardRef("ParamIdentifier") | |
ContextParallelInputMetadataType = ForwardRef("ContextParallelInputMetadata") | |
ContextParallelOutputMetadataType = ForwardRef("ContextParallelOutputMetadata") | |
_ContextParallelInputType = Dict[ | |
ParamIdentifierType, Union[ContextParallelInputMetadataType, List[ContextParallelInputMetadataType]] | |
] | |
_ContextParallelOutputType = List[ContextParallelOutputMetadataType] | |
ContextParallelModelPlan = Union[_ContextParallelInputType, _ContextParallelOutputType] | |
class ParamId: | |
""" | |
A class to identify a parameter of a method. | |
Atleast one of `name` or `index` must be provided. | |
Attributes: | |
name (`str`, *optional*): | |
The name of the parameter. | |
index (`int`, *optional*): | |
The index of the parameter in the method signature. Indexing starts at 0 (ignore | |
the `self` parameter for instance methods). | |
""" | |
name: Optional[str] = None | |
index: Optional[int] = None | |
def __post_init__(self): | |
if self.name is None and self.index is None: | |
raise ValueError("At least one of `name` or `index` must be provided.") | |
class CPInput: | |
split_dim: int | |
expected_dims: Optional[int] = None | |
split_output: bool = False | |
class CPOutput: | |
gather_dim: int | |
expected_dims: Optional[int] = None | |
class TransformerMetadata: | |
# Mapping of FQN to mapping of input name to ContextParallelModelPlan | |
cp_plan: Dict[str, ContextParallelModelPlan] = field(default_factory=dict) | |
# tp_plan # TODO(aryan) | |
class TransformerRegistry: | |
_registry = {} | |
def register(cls, model_class: Type, metadata: TransformerMetadata): | |
cls._registry[model_class] = metadata | |
def get(cls, model_class: Type) -> TransformerMetadata: | |
if model_class not in cls._registry: | |
raise ValueError(f"Model class {model_class} not registered.") | |
return cls._registry[model_class] | |