Spaces:
Runtime error
Runtime error
File size: 5,583 Bytes
03561be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
"""Modified from https://github.com/mlfoundations/open_flamingo"""
import open_clip
import torch
import torch.nn as nn
from bigmodelvis import Visualization
from peft import LoraConfig, get_peft_model
from transformers import LlamaForCausalLM, LlamaTokenizer
from .flamingo import Flamingo
from .flamingo_lm import FlamingoLMMixin
from .utils import extend_instance
def create_model_and_transforms(
clip_vision_encoder_path: str,
clip_vision_encoder_pretrained: str,
lang_encoder_path: str,
tokenizer_path: str,
decoder_layers_attr_name: str = None,
pretrained_model_path: str = None,
tuning_config=None,
**flamingo_kwargs,
):
"""
Initialize a Flamingo model from a pretrained vision encoder and language encoder.
Appends special tokens to the tokenizer and freezes backbones.
Args:
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
lang_encoder_path (str): path to pretrained language encoder
tokenizer_path (str): path to pretrained tokenizer
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
Returns:
Flamingo: Flamingo model from pretrained vision and language encoders
Image processor: Pipeline to preprocess input images
Tokenizer: A tokenizer for the language model
"""
print("init clip vision encoder")
vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained
)
# set the vision encoder to output the visual features
vision_encoder.visual.output_tokens = True
print("init tokenizer")
text_tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
# add Flamingo special tokens to the tokenizer
text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
if text_tokenizer.pad_token is None:
# Issue: GPT models don't have a pad token, which we use to
# modify labels for the loss.
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
text_tokenizer.bos_token_id = 1
text_tokenizer.eos_token_id = 2
print("init llama")
lang_encoder = LlamaForCausalLM.from_pretrained(lang_encoder_path)
extend_instance(lang_encoder, FlamingoLMMixin)
if decoder_layers_attr_name is None:
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
lang_encoder.resize_token_embeddings(len(text_tokenizer))
model = Flamingo(
vision_encoder,
lang_encoder,
text_tokenizer.encode("<|endofchunk|>")[-1],
text_tokenizer.encode("<image>")[-1],
vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"],
cross_attn_every_n_layers=4,
**flamingo_kwargs,
)
if pretrained_model_path is not None:
print(f"loading pretrained model from {pretrained_model_path}")
model.load_state_dict(torch.load(pretrained_model_path), strict=False)
# Freeze all parameters
model.requires_grad_(False)
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
if tuning_config is not None:
model = prepare_model_for_tuning(model, tuning_config)
else:
raise ValueError("tuning_config must be provided")
print(
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
)
return model, image_processor, text_tokenizer
def _infer_decoder_layers_attr_name(model):
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
if k.lower() in model.__class__.__name__.lower():
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
raise ValueError(
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
)
__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
"opt": "model.decoder.layers",
"gptneo": "transformer.h",
"gptj": "transformer.h",
"gpt-j": "transformer.h",
"pythia": "gpt_neox.layers",
"llama": "model.layers",
}
def prepare_model_for_tuning(model: nn.Module, config):
if config.lora:
lora_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
target_modules=config.lora_target_modules,
lora_dropout=config.lora_dropout,
bias="none", # won't use bias currently
modules_to_save=[], # TODO: might be helpful if save partial model
task_type="CAUSAL_LM",
)
model.lang_encoder = get_peft_model(model.lang_encoder, peft_config=lora_config)
# manually unfreeze modules, we use a `substring` fashion mathcing
for name, param in model.named_parameters():
if any(substr in name for substr in config.unfrozen):
param.requires_grad = True
if config.vis and is_rank0():
Visualization(model).structure_graph()
return model
# temporary workaround, should use a common utils in the future
def is_rank0():
if not torch.distributed.is_initialized():
return True
return torch.distributed.get_rank() == 0
|