File size: 1,801 Bytes
9b6ee83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""module for building the auto wrap policy for FSDP"""
import functools

from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
from torch.distributed.fsdp.wrap import (
    _or_policy,
    lambda_auto_wrap_policy,
    transformer_auto_wrap_policy,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer

SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
    "llama",
    "mistral",
    "mixtral",
]


def get_wrapping_policy_factory(model_type):
    if model_type == "llama":
        layer_to_wrap = LlamaDecoderLayer
    elif model_type == "mistral":
        layer_to_wrap = MistralDecoderLayer
    elif model_type == "mixtral":
        layer_to_wrap = MixtralDecoderLayer

    def get_wrapping_policy():
        """This checks for lora layers (has weight and requires_grad)"""

        def lambda_policy_fn(module):
            return (
                len(list(module.named_children())) == 0
                and getattr(module, "weight", None) is not None
                and module.weight.requires_grad
            )

        lambda_policy = functools.partial(
            lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
        )
        transformer_layer_name = layer_to_wrap
        transformer_wrap_policy = functools.partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls=(
                PrefixEncoder,
                PromptEncoder,
                PromptEmbedding,
                transformer_layer_name,
            ),
        )
        policies = [lambda_policy, transformer_wrap_policy]
        return functools.partial(_or_policy, policies=policies)

    return get_wrapping_policy