File size: 2,165 Bytes
9fd1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from dataclasses import dataclass, field
from typing import Dict, ForwardRef, List, Optional, Type, Union


ParamIdentifierType = ForwardRef("ParamIdentifier")
ContextParallelInputMetadataType = ForwardRef("ContextParallelInputMetadata")
ContextParallelOutputMetadataType = ForwardRef("ContextParallelOutputMetadata")

_ContextParallelInputType = Dict[
    ParamIdentifierType, Union[ContextParallelInputMetadataType, List[ContextParallelInputMetadataType]]
]
_ContextParallelOutputType = List[ContextParallelOutputMetadataType]
ContextParallelModelPlan = Union[_ContextParallelInputType, _ContextParallelOutputType]


@dataclass(frozen=True)
class ParamId:
    """
    A class to identify a parameter of a method.

    Atleast one of `name` or `index` must be provided.

    Attributes:
        name (`str`, *optional*):
            The name of the parameter.
        index (`int`, *optional*):
            The index of the parameter in the method signature. Indexing starts at 0 (ignore
            the `self` parameter for instance methods).
    """

    name: Optional[str] = None
    index: Optional[int] = None

    def __post_init__(self):
        if self.name is None and self.index is None:
            raise ValueError("At least one of `name` or `index` must be provided.")


@dataclass(frozen=True)
class CPInput:
    split_dim: int
    expected_dims: Optional[int] = None
    split_output: bool = False


@dataclass(frozen=True)
class CPOutput:
    gather_dim: int
    expected_dims: Optional[int] = None


@dataclass
class TransformerMetadata:
    # Mapping of FQN to mapping of input name to ContextParallelModelPlan
    cp_plan: Dict[str, ContextParallelModelPlan] = field(default_factory=dict)

    # tp_plan  # TODO(aryan)


class TransformerRegistry:
    _registry = {}

    @classmethod
    def register(cls, model_class: Type, metadata: TransformerMetadata):
        cls._registry[model_class] = metadata

    @classmethod
    def get(cls, model_class: Type) -> TransformerMetadata:
        if model_class not in cls._registry:
            raise ValueError(f"Model class {model_class} not registered.")
        return cls._registry[model_class]