#basic backage import os import copy import warnings from PIL import Image from typing import Optional, Tuple, Union, List, Callable #torch and transformer import torch from torch import nn from torch.nn import CrossEntropyLoss from torch.distributions.categorical import Categorical from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from transformers.modeling_utils import PreTrainedModel from transformers.generation.streamers import BaseStreamer from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.modeling_outputs import CausalLMOutputWithPast #mcmd from .configuration_mcmd import mcmdConfig from .Vision_Tower import clip_vit_large_patch14_336,DFN5B_CLIP_ViT_H_14_378 from .Vision_Project import mlp2x_gelu def build_lm_model_tokenizer(lm_model_name : str, lm_tokenizer_name : str): model = AutoModelForCausalLM.from_pretrained( lm_model_name, torch_dtype="auto" ) tokenizer = AutoTokenizer.from_pretrained(lm_tokenizer_name) return model,tokenizer def build_vision_projector(vision_config): if vision_config=='mlp2x_gelu': return mlp2x_gelu(vision_config) def build_vision_tower(vision_tower_name=''): if vision_tower_name.endswith('clip-vit-large-patch14-336'): return clip_vit_large_patch14_336(vision_tower_name,use_resize_pos=True) elif vision_tower_name.endswith('DFN5B-CLIP-ViT-H-14-378'): return DFN5B_CLIP_ViT_H_14_378(vision_tower_name) class mcmdPreTrainedModel(PreTrainedModel): # config_class = mcmdConfig def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class mcmdForCausalLM(mcmdPreTrainedModel): _auto_class = 'AutoModelForCausalLM' def __init__(self, config): super().__init__(config) #Initialize language model self.max_length = config.max_length self.vocab_size = config.lm_model['vocab_size'] self.lm_model,self.lm_tokenizer = build_lm_model_tokenizer(config.lm_path,config.lm_path) #Initialize vit and vision_proj self.vit = build_vision_tower(config.clip_path) self.vision_proj = build_vision_projector(config.vision_config) # Initialize vis_processor for Image Preprocessing. The mean and std is equal in dfn5b and clip-vit self.vis_processor = transforms.Compose([ transforms.Resize((config.input_img_size, config.input_img_size), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) self.eos_token_id = self.lm_tokenizer.eos_token_id # 151645 <|im_end|> def print_trainable_parameters(self): print('可训练参数:') trainable_params = 0 all_param = 0 for _, param in self.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}") print('可训练的模块:') for name, param in self.named_parameters(): if param.requires_grad: print(name, param.shape) def print_model_layers_and_parameters(self): print('模型参数:') for name, module in self.named_modules(): if hasattr(module, 'weight'): num_params = sum(p.numel() for p in module.parameters() if p.requires_grad) print(f"Layer: {name}, Type: {module.__class__.__name__}, Trainable Parameters: {num_params}") else: print(f"Layer: {name}, Type: {module.__class__.__name__}, No trainable parameters") def print_tokens_labels(self, tokens: List[int], target: List[int]): print("Sanity Check >>>>>>>>>>>>>") temp_tokens=copy.deepcopy(tokens[0].tolist()) temp_target=copy.deepcopy(target[0].tolist()) save_name='check_token_target.txt' if os.path.exists(save_name): os.remove(save_name) ff = open(save_name,'a+') for t, m in zip(temp_tokens, temp_target): if t<0: decoded='' else: decoded = self.lm_tokenizer.batch_decode([t], skip_special_tokens=False)[0] print("%20s: %6d -> %6d" % (repr(decoded), t, m)) ff.write("%20s: %6d -> %6d\n" % (repr(decoded), t, m)) ff.close() print("<<<<<<<<<<<<< Sanity Check") assert len(tokens) == len(target), f"length mismatch: {len(tokens)} vs {len(target)}" def img2emb(self, image): image=image.bfloat16() img_embeds = self.vision_proj(self.vit(image.to(self.device))) atts_img = torch.ones( img_embeds.size()[:-1], dtype=torch.long).to(img_embeds.device) img_target = torch.ones( img_embeds.size()[:2], dtype=torch.long).to( img_embeds.device) * -100 return img_embeds, atts_img, img_target def encode_img(self, image): if image is None: return None if isinstance(image, str): image = Image.open(image).convert('RGB') # Image Preprocessing # unsqueeze insert 1 dim in front of 0 # image is [1, 3, 490, 490] image = self.vis_processor(image).unsqueeze(0).to(self.device) else: assert isinstance(image, torch.Tensor) img_embeds, _, _ = self.img2emb(image) ''' img_embeds : [1, 1225, 4096] 1225? atts_img = torch.ones([1, 1225]) img_target = torch.ones([1, 1225]) * -100 ''' return img_embeds def get_tensor_image(self,fns): image_data=[] for one in fns: t_one=self.encode_img(one) image_data.append(t_one) image = torch.cat(image_data, dim=0) return image def interleav_wrap_chat(self, messages, image): #Deal prompt using qwen2 template, which is from transformers/tokenization_utils_base.py prompt = self.lm_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) ''' repr(prompt) add_generation_prompt=True : '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n比较一下下面这两张图片,第一张,\n第二张<|im_end|>\n<|im_start|>assistant\n' repr(prompt) add_generation_prompt=False: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n比较一下下面这两张图片,第一张,\n第二张<|im_end|>\n' ''' if image is None: im_len=0 image_nums=0 parts = prompt.split('') print(prompt.split('')) assert len(prompt.split(''))==1 else: im_len = image.shape[1] #1225 730 image_nums = len(image) parts = prompt.split('') wrap_embeds = [] temp_len = 0 if len(parts) != image_nums + 1: raise ValueError('Invalid prompt format.') for idx, part in enumerate(parts): if len(part) > 0: part_tokens = self.lm_tokenizer(part, return_tensors='pt').to(self.device) part_embeds = self.lm_model.model.embed_tokens( part_tokens.input_ids) wrap_embeds.append(part_embeds) temp_len += part_embeds.shape[1] if idx < image_nums: wrap_embeds.append(image[idx].unsqueeze(0)) temp_len += im_len if temp_len > self.max_length: break wrap_embeds = torch.cat(wrap_embeds, dim=1) #torch.Size([1, 2481, 3584]) wrap_embeds = wrap_embeds[:, :self.max_length].to(self.device) inputs = { 'inputs_embeds': wrap_embeds } return inputs def mask_user_targets(self, input_ids): target_batch = [] for bs in range(input_ids.shape[0]): ids = input_ids[bs] targets = copy.deepcopy(ids) im_round=0 id_im_start=0 # id_im_end=0 for i, temp_id in enumerate(ids): if temp_id == 151644: im_round+=1 if im_round==2: id_im_start=0 targets[id_im_start:i + 1] = -100 id_im_start=i elif im_round%2==0: id_im_start=i elif im_round%2==1: targets[id_im_start:i + 3] = -100 # if temp_id == 151645: # if im_round==1: # id_im_end=i target_batch.append(targets.unsqueeze(0)) target_batch = torch.cat(target_batch, dim=0) return target_batch def interleav_wrap(self, img_list, text_list): # Initialize lists to store the processed embeddings, attention masks, and targets. wrap_embeds_list, wrap_atts_list = [], [] wrap_target_list = [] # Iterate over pairs of images and texts. for image, text in zip(img_list, text_list): # Convert the image to embeddings using the method `img2emb`. img_embeds, atts_img, img_target = self.img2emb(image) # Get the first element of the text (assuming it's a list). text = text[0] # Split the text into parts where `` is found. parts = text.split('') # Initialize lists to store tokens, embeddings, and attention masks for the current item. wrap_tokens, wrap_embeds, wrap_atts = [], [], [] # Track the total length of the sequence being built. temp_len = 0 # Get the number of images and the length of each image embedding. image_nums, im_len = img_embeds.shape[:2] # Process each part of the split text. for idx, part in enumerate(parts): # If the part is not empty, process it as text. if len(part) > 0: # Tokenize the text part. part_tokens = self.lm_tokenizer( part, return_tensors='pt', padding='longest').to(self.device) # Append the token IDs, embeddings, and attention mask to their respective lists. wrap_tokens.append(part_tokens.input_ids) part_embeds = self.lm_model.model.embed_tokens(part_tokens.input_ids) wrap_embeds.append(part_embeds) wrap_atts.append(part_tokens.attention_mask) # Update the total length of the sequence. temp_len += part_embeds.shape[1] # If there are more images, append the image target, embeddings, and attention mask. if idx < image_nums: wrap_tokens.append(img_target[idx].unsqueeze(0)) wrap_embeds.append(img_embeds[idx].unsqueeze(0)) wrap_atts.append(atts_img[idx].unsqueeze(0)) # Update the total length of the sequence. temp_len += im_len # Break the loop if the total length exceeds the maximum length. if temp_len > self.max_length: break # Concatenate the tokens, embeddings, and attention masks. wrap_tokens = torch.cat(wrap_tokens, dim=1) wrap_embeds = torch.cat(wrap_embeds, dim=1) wrap_atts = torch.cat(wrap_atts, dim=1) # print('wrap_tokens',wrap_tokens.shape) # print('wrap_embeds',wrap_embeds.shape) # print('wrap_atts',wrap_atts.shape) # Mask the targets for the tokens. wrap_target = self.mask_user_targets(wrap_tokens).to(self.device) # Truncate the concatenated tensors to the max length. wrap_embeds = wrap_embeds[:, :self.max_length].to(self.device) wrap_atts = wrap_atts[:, :self.max_length].to(self.device) wrap_target = wrap_target[:, :self.max_length].to(self.device) # self.print_tokens_labels(wrap_tokens, wrap_target) # Add the processed data to the corresponding lists. wrap_embeds_list.append(wrap_embeds) wrap_atts_list.append(wrap_atts) wrap_target_list.append(wrap_target) # Concatenate all the processed data from different items. wrap_embeds = torch.cat(wrap_embeds_list) wrap_atts = torch.cat(wrap_atts_list) wrap_target = torch.cat(wrap_target_list) # Return the concatenated embeddings, attention masks, and targets. return wrap_embeds, wrap_atts, wrap_target def text2emb(self, text, add_special=False): to_regress_tokens = self.lm_tokenizer( text, return_tensors='pt', padding='longest').to(self.device) to_regress_tokens.input_ids targets = self.mask_user_targets(to_regress_tokens.input_ids) targets = targets.to(self.device) # self.print_tokens_labels(to_regress_tokens.input_ids, targets) return to_regress_tokens, targets def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: ```""" # prepared for train mode samples = kwargs.get('samples', None) if samples: if samples['data_type'][0] == 'text': has_img = False elif samples['data_type'][0] == 'multi': has_img = True else: raise NotImplementedError # encode text text = samples['text_input'] # encode image if has_img: image = samples['image'] to_regress_embeds, attention_mask, targets = self.interleav_wrap( image, text) else: to_regress_tokens, targets = self.text2emb(#------------------------------------------------------------------------------------------- text, add_special=True) to_regress_embeds = self.lm_model.model.embed_tokens(#------------------------------------------------------------------------------------------- to_regress_tokens.input_ids) attention_mask = to_regress_tokens.attention_mask inputs_embeds = to_regress_embeds[:, :self.max_length] attention_mask = attention_mask[:, :self.max_length] targets = targets[:, :self.max_length] labels = targets output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.lm_model.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_model.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @torch.no_grad() def chat( self, messages, images: List[str] = None, streamer: Optional[BaseStreamer] = None, max_new_tokens: int = 1024, do_sample: bool = True, num_beams: int = 1, temperature: float = 1.0, top_p: float = 0.8, repetition_penalty: float=1.005, **kwargs, ): if images!=[]: print('images ',images) image_pt=self.get_tensor_image(images) else: image_pt=None inputs=self.interleav_wrap_chat(messages,image_pt) inputs = { k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v) } # also add end-of-assistant token in eos token id to avoid unnecessary generation eos_token_id = [ self.eos_token_id ] outputs = self.lm_model.generate( **inputs, streamer=streamer, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=do_sample, temperature=temperature, top_p=top_p, eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, **kwargs, ) response = self.lm_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] messages+=[{"role": "assistant", "content": response}] return response, messages