Spaces:
Runtime error
Runtime error
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import open_clip | |
| import torch | |
| from .flamingo import Flamingo | |
| from .flamingo_lm import FlamingoLMMixin | |
| from .utils import extend_instance | |
| import logging | |
| import random | |
| import time | |
| def create_model_and_transforms( | |
| clip_vision_encoder_path: str, | |
| clip_vision_encoder_pretrained: str, | |
| lang_encoder_path: str, | |
| tokenizer_path: str, | |
| use_local_files: bool = False, | |
| decoder_layers_attr_name: str = None, | |
| location_token_num: int = 1000, | |
| checkpoint_activations: bool = False, | |
| freeze_vision_encoder: bool = False, | |
| lora: bool = False, | |
| lora_r: int = 16, | |
| fix_ffn: bool = False, | |
| add_visual_token: bool = False, | |
| add_box: bool = False, | |
| add_pe: bool = False, | |
| add_relation: bool = False, | |
| use_format_v2: bool = False, | |
| use_sam: str = None, | |
| enhance_data: bool = False, | |
| roi_align: bool = False, | |
| roi_output_size: int = 4, | |
| apply_mask: bool = False, | |
| **flamingo_kwargs, | |
| ): | |
| """ | |
| Initialize a Flamingo model from a pretrained vision encoder and language encoder. | |
| Appends special tokens to the tokenizer and freezes backbones. | |
| Args: | |
| clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32") | |
| clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k") | |
| lang_encoder_path (str): path to pretrained language encoder | |
| tokenizer_path (str): path to pretrained tokenizer | |
| cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. | |
| use_local_files (bool, optional): whether to use local files. Defaults to False. | |
| decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. | |
| Returns: | |
| Flamingo: Flamingo model from pretrained vision and language encoders | |
| Image processor: Pipeline to preprocess input images | |
| Tokenizer: A tokenizer for the language model | |
| """ | |
| if use_sam is None: | |
| no_success = True | |
| while no_success: | |
| try: | |
| vision_encoder, _, image_processor = open_clip.create_model_and_transforms( | |
| clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained | |
| ) | |
| no_success = False | |
| except: | |
| logging.info("retry creating vision_encoder") | |
| time.sleep(random.random() * 5) | |
| # set the vision encoder to output the visual features | |
| vision_encoder.visual.output_tokens = True | |
| # delete text encoder part | |
| del vision_encoder.transformer | |
| del vision_encoder.text_projection | |
| del vision_encoder.token_embedding | |
| del vision_encoder.ln_final | |
| del vision_encoder.positional_embedding | |
| del vision_encoder.logit_scale | |
| vision_encoder.visual.proj = None | |
| vision_encoder.visual.ln_post = torch.nn.Identity() | |
| else: | |
| from segment_anything import SamPredictor, sam_model_registry | |
| assert use_sam == "vit_l" | |
| sam = sam_model_registry[use_sam](checkpoint="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195_256x256.pth") | |
| del sam.prompt_encoder | |
| del sam.mask_decoder | |
| sam.image_encoder.neck = torch.nn.Identity() | |
| vision_encoder = sam.image_encoder | |
| from open_clip.transform import image_transform | |
| image_processor = image_transform( | |
| 256, | |
| is_train=False, | |
| mean=(0.48145466, 0.4578275, 0.40821073), | |
| std=(0.26862954, 0.26130258, 0.27577711), | |
| ) | |
| text_tokenizer = AutoTokenizer.from_pretrained( | |
| tokenizer_path, local_files_only=use_local_files | |
| ) | |
| # add Flamingo special tokens to the tokenizer | |
| additional_special_tokens = ["<|#image#|>", "<|#endofimage#|>"] | |
| if add_visual_token: | |
| additional_special_tokens += ["<|#visual#|>", "<|#object#|>"] | |
| if add_box: | |
| additional_special_tokens += ["<|#box#|>", "<|#endofobject#|>", "<|#attr#|>", "<|#endofattr#|>"] | |
| if use_format_v2: | |
| additional_special_tokens += ["<|#previsual#|>", "<|#prebox#|>"] | |
| if enhance_data: | |
| additional_special_tokens += ["<|#NOTHING#|>"] | |
| text_tokenizer.add_special_tokens( | |
| {"additional_special_tokens": additional_special_tokens} | |
| ) | |
| if text_tokenizer.pad_token is None: | |
| # Issue: GPT models don't have a pad token, which we use to | |
| # modify labels for the loss. | |
| text_tokenizer.add_special_tokens({"pad_token": "<PAD>"}) | |
| lang_encoder = AutoModelForCausalLM.from_pretrained( | |
| lang_encoder_path, local_files_only=use_local_files | |
| ) | |
| extend_instance(lang_encoder, FlamingoLMMixin) | |
| if decoder_layers_attr_name is None: | |
| decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) | |
| lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) | |
| lang_encoder.resize_token_embeddings(len(text_tokenizer)) | |
| lang_encoder_name = lang_encoder.__class__.__name__.lower() | |
| if checkpoint_activations: | |
| from fairscale.nn.checkpoint import checkpoint_wrapper | |
| if use_sam is None: | |
| for i in range(len(vision_encoder.visual.transformer.resblocks)): | |
| vision_encoder.visual.transformer.resblocks[i] = checkpoint_wrapper( | |
| vision_encoder.visual.transformer.resblocks[i], | |
| offload_to_cpu=False, | |
| ) | |
| else: | |
| for i in range(len(vision_encoder.blocks)): | |
| vision_encoder.blocks[i] = checkpoint_wrapper( | |
| vision_encoder.blocks[i], | |
| offload_to_cpu=False, | |
| ) | |
| if "opt" in lang_encoder_name: | |
| for i in range(len(lang_encoder.model.decoder.layers)): | |
| lang_encoder.model.decoder.layers[i] = checkpoint_wrapper( | |
| lang_encoder.model.decoder.layers[i], | |
| offload_to_cpu=False, | |
| ) | |
| elif "codegen" in lang_encoder_name: | |
| for i in range(len(lang_encoder.transformer.h)): | |
| lang_encoder.transformer.h[i] = checkpoint_wrapper( | |
| lang_encoder.transformer.h[i], | |
| offload_to_cpu=False, | |
| ) | |
| elif "llama" in lang_encoder_name: | |
| for i in range(len(lang_encoder.model.layers)): | |
| lang_encoder.model.layers[i] = checkpoint_wrapper( | |
| lang_encoder.model.layers[i], | |
| offload_to_cpu=False, | |
| ) | |
| elif "gptneo" in lang_encoder_name: | |
| for i in range(len(lang_encoder.gpt_neox.layers)): | |
| lang_encoder.gpt_neox.layers[i] = checkpoint_wrapper( | |
| lang_encoder.gpt_neox.layers[i], | |
| offload_to_cpu=False, | |
| ) | |
| else: | |
| raise ValueError(f"unknown model {lang_encoder_name}") | |
| if use_sam is None: | |
| vis_dim = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"] | |
| image_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["image_size"] | |
| patch_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["patch_size"] | |
| else: | |
| # SAM config | |
| vis_dim = 1024 | |
| image_size = 256 | |
| patch_size = 16 | |
| assert image_size % patch_size == 0 | |
| vis_embed_size = (image_size // patch_size) ** 2 | |
| if lora: | |
| from peft import LoraConfig, TaskType | |
| from peft import get_peft_model | |
| if "codegen" in lang_encoder_name: | |
| lang_target_modules = ["qkv_proj", "out_proj", "fc_in", "fc_out"] | |
| elif "opt" in lang_encoder_name: | |
| lang_target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"] | |
| elif "llama" in lang_encoder_name: | |
| lang_target_modules = ["k_proj", "v_proj", "q_proj", "o_proj", "gate_proj", "down_proj", "up_proj"] | |
| else: | |
| raise NotImplementedError | |
| lang_peft_config = LoraConfig( | |
| task_type="CAUSAL_LM", | |
| r=16, lora_alpha=16, | |
| target_modules=lang_target_modules, | |
| lora_dropout=0.05, bias="none", | |
| ) | |
| lang_encoder = get_peft_model(lang_encoder, lang_peft_config) | |
| lang_encoder.print_trainable_parameters() | |
| if fix_ffn: | |
| if "opt" in lang_encoder_name: | |
| for i in range(len(lang_encoder.model.decoder.layers)): | |
| lang_encoder.model.decoder.layers[i].requires_grad_(False) | |
| lang_encoder.model.decoder.layers[i].self_attn.requires_grad_(True) | |
| else: | |
| raise NotImplementedError | |
| lang_dim = int(lang_encoder.config.hidden_size) if not lora else int(lang_encoder.base_model.model.config.hidden_size) | |
| if hasattr(lang_encoder.config, "word_embed_proj_dim"): | |
| hidden_state_dim = lang_encoder.config.word_embed_proj_dim | |
| else: | |
| hidden_state_dim = lang_encoder.config.hidden_size | |
| model = Flamingo( | |
| vision_encoder=vision_encoder, | |
| lang_encoder=lang_encoder, | |
| eoc_token_id=text_tokenizer.encode(text_tokenizer.eos_token)[-1], | |
| media_token_id=text_tokenizer.encode("<|#image#|>")[-1], | |
| image_end_token_id=text_tokenizer.encode("<|#endofimage#|>")[-1], | |
| visual_token_id=text_tokenizer.encode("<|#visual#|>")[-1] if add_visual_token else None, | |
| previsual_token_id=text_tokenizer.encode("<|#previsual#|>")[-1] if add_visual_token else None, | |
| box_token_id=text_tokenizer.encode("<|#box#|>")[-1] if add_box else None, | |
| prebox_token_id=text_tokenizer.encode("<|#prebox#|>")[-1] if add_box else None, | |
| nothing_token_id=text_tokenizer.encode("<|#NOTHING#|>")[-1] if enhance_data else None, | |
| endofobject_token_id=text_tokenizer.encode("<|#endofobject#|>")[-1], | |
| vis_dim=vis_dim, | |
| vis_embed_size=vis_embed_size, | |
| lang_dim=lang_dim, | |
| image_size=image_size, | |
| patch_size=patch_size, | |
| hidden_state_dim=hidden_state_dim, | |
| add_visual_token=add_visual_token, | |
| add_pe=add_pe, | |
| add_relation=add_relation, | |
| use_format_v2=use_format_v2, | |
| roi_align=roi_align, | |
| roi_output_size=roi_output_size, | |
| apply_mask=apply_mask, | |
| **flamingo_kwargs, | |
| ) | |
| if freeze_vision_encoder: | |
| print("freeze vision encoder") | |
| model.vision_encoder.requires_grad_(False) | |
| print( | |
| f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" | |
| ) | |
| return model, image_processor, text_tokenizer, vis_embed_size | |
| def _infer_decoder_layers_attr_name(model): | |
| for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: | |
| if k.lower() in model.__class__.__name__.lower(): | |
| return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] | |
| raise ValueError( | |
| f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." | |
| ) | |
| __KNOWN_DECODER_LAYERS_ATTR_NAMES = { | |
| "opt": "model.decoder.layers", | |
| # "gptneo": "transformer.h", | |
| "gptj": "transformer.h", | |
| "gpt-j": "transformer.h", | |
| "pythia": "gpt_neox.layers", | |
| "gptneox": "gpt_neox.layers", | |
| "llama": "model.layers", | |
| "llamaforcausallm": "model.layers", | |
| "gpt2": "transformer.h", | |
| "codegen": "transformer.h", | |
| } | |