MATH-LLM-7B / modeling_mcmd.py
ALmonster's picture
Upload 17 files
9487267 verified
#basic backage
import os
import copy
import warnings
from PIL import Image
from typing import Optional, Tuple, Union, List, Callable
#torch and transformer
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.distributions.categorical import Categorical
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.streamers import BaseStreamer
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast
#mcmd
from .configuration_mcmd import mcmdConfig
from .Vision_Tower import clip_vit_large_patch14_336,DFN5B_CLIP_ViT_H_14_378
from .Vision_Project import mlp2x_gelu
def build_lm_model_tokenizer(lm_model_name : str, lm_tokenizer_name : str):
model = AutoModelForCausalLM.from_pretrained(
lm_model_name,
torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(lm_tokenizer_name)
return model,tokenizer
def build_vision_projector(vision_config):
if vision_config=='mlp2x_gelu':
return mlp2x_gelu(vision_config)
def build_vision_tower(vision_tower_name=''):
if vision_tower_name.endswith('clip-vit-large-patch14-336'):
return clip_vit_large_patch14_336(vision_tower_name,use_resize_pos=True)
elif vision_tower_name.endswith('DFN5B-CLIP-ViT-H-14-378'):
return DFN5B_CLIP_ViT_H_14_378(vision_tower_name)
class mcmdPreTrainedModel(PreTrainedModel):
# config_class = mcmdConfig
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class mcmdForCausalLM(mcmdPreTrainedModel):
_auto_class = 'AutoModelForCausalLM'
def __init__(self, config):
super().__init__(config)
#Initialize language model
self.max_length = config.max_length
self.vocab_size = config.lm_model['vocab_size']
self.lm_model,self.lm_tokenizer = build_lm_model_tokenizer(config.lm_path,config.lm_path)
#Initialize vit and vision_proj
self.vit = build_vision_tower(config.clip_path)
self.vision_proj = build_vision_projector(config.vision_config)
# Initialize vis_processor for Image Preprocessing. The mean and std is equal in dfn5b and clip-vit
self.vis_processor = transforms.Compose([
transforms.Resize((config.input_img_size, config.input_img_size),
interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
self.eos_token_id = self.lm_tokenizer.eos_token_id # 151645 <|im_end|>
def print_trainable_parameters(self):
print('可训练参数:')
trainable_params = 0
all_param = 0
for _, param in self.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}")
print('可训练的模块:')
for name, param in self.named_parameters():
if param.requires_grad:
print(name, param.shape)
def print_model_layers_and_parameters(self):
print('模型参数:')
for name, module in self.named_modules():
if hasattr(module, 'weight'):
num_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
print(f"Layer: {name}, Type: {module.__class__.__name__}, Trainable Parameters: {num_params}")
else:
print(f"Layer: {name}, Type: {module.__class__.__name__}, No trainable parameters")
def print_tokens_labels(self, tokens: List[int], target: List[int]):
print("Sanity Check >>>>>>>>>>>>>")
temp_tokens=copy.deepcopy(tokens[0].tolist())
temp_target=copy.deepcopy(target[0].tolist())
save_name='check_token_target.txt'
if os.path.exists(save_name):
os.remove(save_name)
ff = open(save_name,'a+')
for t, m in zip(temp_tokens, temp_target):
if t<0:
decoded='<Image Data>'
else:
decoded = self.lm_tokenizer.batch_decode([t], skip_special_tokens=False)[0]
print("%20s: %6d -> %6d" % (repr(decoded), t, m))
ff.write("%20s: %6d -> %6d\n" % (repr(decoded), t, m))
ff.close()
print("<<<<<<<<<<<<< Sanity Check")
assert len(tokens) == len(target), f"length mismatch: {len(tokens)} vs {len(target)}"
def img2emb(self, image):
image=image.bfloat16()
img_embeds = self.vision_proj(self.vit(image.to(self.device)))
atts_img = torch.ones(
img_embeds.size()[:-1], dtype=torch.long).to(img_embeds.device)
img_target = torch.ones(
img_embeds.size()[:2], dtype=torch.long).to(
img_embeds.device) * -100
return img_embeds, atts_img, img_target
def encode_img(self, image):
if image is None:
return None
if isinstance(image, str):
image = Image.open(image).convert('RGB')
# Image Preprocessing
# unsqueeze insert 1 dim in front of 0
# image is [1, 3, 490, 490]
image = self.vis_processor(image).unsqueeze(0).to(self.device)
else:
assert isinstance(image, torch.Tensor)
img_embeds, _, _ = self.img2emb(image)
'''
img_embeds : [1, 1225, 4096] 1225?
atts_img = torch.ones([1, 1225])
img_target = torch.ones([1, 1225]) * -100
'''
return img_embeds
def get_tensor_image(self,fns):
image_data=[]
for one in fns:
t_one=self.encode_img(one)
image_data.append(t_one)
image = torch.cat(image_data, dim=0)
return image
def interleav_wrap_chat(self, messages, image):
#Deal prompt using qwen2 template, which is from transformers/tokenization_utils_base.py
prompt = self.lm_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
'''
repr(prompt) add_generation_prompt=True : '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n比较一下下面这两张图片,第一张<ImageHere>,\n第二张<ImageHere><|im_end|>\n<|im_start|>assistant\n'
repr(prompt) add_generation_prompt=False: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n比较一下下面这两张图片,第一张<ImageHere>,\n第二张<ImageHere><|im_end|>\n'
'''
if image is None:
im_len=0
image_nums=0
parts = prompt.split('<ImageHere>')
print(prompt.split('<ImageHere>'))
assert len(prompt.split('<ImageHere>'))==1
else:
im_len = image.shape[1] #1225 730
image_nums = len(image)
parts = prompt.split('<ImageHere>')
wrap_embeds = []
temp_len = 0
if len(parts) != image_nums + 1:
raise ValueError('Invalid <ImageHere> prompt format.')
for idx, part in enumerate(parts):
if len(part) > 0:
part_tokens = self.lm_tokenizer(part, return_tensors='pt').to(self.device)
part_embeds = self.lm_model.model.embed_tokens(
part_tokens.input_ids)
wrap_embeds.append(part_embeds)
temp_len += part_embeds.shape[1]
if idx < image_nums:
wrap_embeds.append(image[idx].unsqueeze(0))
temp_len += im_len
if temp_len > self.max_length:
break
wrap_embeds = torch.cat(wrap_embeds, dim=1) #torch.Size([1, 2481, 3584])
wrap_embeds = wrap_embeds[:, :self.max_length].to(self.device)
inputs = {
'inputs_embeds': wrap_embeds
}
return inputs
def mask_user_targets(self, input_ids):
target_batch = []
for bs in range(input_ids.shape[0]):
ids = input_ids[bs]
targets = copy.deepcopy(ids)
im_round=0
id_im_start=0
# id_im_end=0
for i, temp_id in enumerate(ids):
if temp_id == 151644:
im_round+=1
if im_round==2:
id_im_start=0
targets[id_im_start:i + 1] = -100
id_im_start=i
elif im_round%2==0:
id_im_start=i
elif im_round%2==1:
targets[id_im_start:i + 3] = -100
# if temp_id == 151645:
# if im_round==1:
# id_im_end=i
target_batch.append(targets.unsqueeze(0))
target_batch = torch.cat(target_batch, dim=0)
return target_batch
def interleav_wrap(self, img_list, text_list):
# Initialize lists to store the processed embeddings, attention masks, and targets.
wrap_embeds_list, wrap_atts_list = [], []
wrap_target_list = []
# Iterate over pairs of images and texts.
for image, text in zip(img_list, text_list):
# Convert the image to embeddings using the method `img2emb`.
img_embeds, atts_img, img_target = self.img2emb(image)
# Get the first element of the text (assuming it's a list).
text = text[0]
# Split the text into parts where `<ImageHere>` is found.
parts = text.split('<ImageHere>')
# Initialize lists to store tokens, embeddings, and attention masks for the current item.
wrap_tokens, wrap_embeds, wrap_atts = [], [], []
# Track the total length of the sequence being built.
temp_len = 0
# Get the number of images and the length of each image embedding.
image_nums, im_len = img_embeds.shape[:2]
# Process each part of the split text.
for idx, part in enumerate(parts):
# If the part is not empty, process it as text.
if len(part) > 0:
# Tokenize the text part.
part_tokens = self.lm_tokenizer(
part,
return_tensors='pt',
padding='longest').to(self.device)
# Append the token IDs, embeddings, and attention mask to their respective lists.
wrap_tokens.append(part_tokens.input_ids)
part_embeds = self.lm_model.model.embed_tokens(part_tokens.input_ids)
wrap_embeds.append(part_embeds)
wrap_atts.append(part_tokens.attention_mask)
# Update the total length of the sequence.
temp_len += part_embeds.shape[1]
# If there are more images, append the image target, embeddings, and attention mask.
if idx < image_nums:
wrap_tokens.append(img_target[idx].unsqueeze(0))
wrap_embeds.append(img_embeds[idx].unsqueeze(0))
wrap_atts.append(atts_img[idx].unsqueeze(0))
# Update the total length of the sequence.
temp_len += im_len
# Break the loop if the total length exceeds the maximum length.
if temp_len > self.max_length:
break
# Concatenate the tokens, embeddings, and attention masks.
wrap_tokens = torch.cat(wrap_tokens, dim=1)
wrap_embeds = torch.cat(wrap_embeds, dim=1)
wrap_atts = torch.cat(wrap_atts, dim=1)
# print('wrap_tokens',wrap_tokens.shape)
# print('wrap_embeds',wrap_embeds.shape)
# print('wrap_atts',wrap_atts.shape)
# Mask the targets for the tokens.
wrap_target = self.mask_user_targets(wrap_tokens).to(self.device)
# Truncate the concatenated tensors to the max length.
wrap_embeds = wrap_embeds[:, :self.max_length].to(self.device)
wrap_atts = wrap_atts[:, :self.max_length].to(self.device)
wrap_target = wrap_target[:, :self.max_length].to(self.device)
# self.print_tokens_labels(wrap_tokens, wrap_target)
# Add the processed data to the corresponding lists.
wrap_embeds_list.append(wrap_embeds)
wrap_atts_list.append(wrap_atts)
wrap_target_list.append(wrap_target)
# Concatenate all the processed data from different items.
wrap_embeds = torch.cat(wrap_embeds_list)
wrap_atts = torch.cat(wrap_atts_list)
wrap_target = torch.cat(wrap_target_list)
# Return the concatenated embeddings, attention masks, and targets.
return wrap_embeds, wrap_atts, wrap_target
def text2emb(self, text, add_special=False):
to_regress_tokens = self.lm_tokenizer(
text,
return_tensors='pt',
padding='longest').to(self.device)
to_regress_tokens.input_ids
targets = self.mask_user_targets(to_regress_tokens.input_ids)
targets = targets.to(self.device)
# self.print_tokens_labels(to_regress_tokens.input_ids, targets)
return to_regress_tokens, targets
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
```"""
# prepared for train mode
samples = kwargs.get('samples', None)
if samples:
if samples['data_type'][0] == 'text':
has_img = False
elif samples['data_type'][0] == 'multi':
has_img = True
else:
raise NotImplementedError
# encode text
text = samples['text_input']
# encode image
if has_img:
image = samples['image']
to_regress_embeds, attention_mask, targets = self.interleav_wrap(
image, text)
else:
to_regress_tokens, targets = self.text2emb(#-------------------------------------------------------------------------------------------
text, add_special=True)
to_regress_embeds = self.lm_model.model.embed_tokens(#-------------------------------------------------------------------------------------------
to_regress_tokens.input_ids)
attention_mask = to_regress_tokens.attention_mask
inputs_embeds = to_regress_embeds[:, :self.max_length]
attention_mask = attention_mask[:, :self.max_length]
targets = targets[:, :self.max_length]
labels = targets
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.lm_model.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_model.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@torch.no_grad()
def chat(
self,
messages,
images: List[str] = None,
streamer: Optional[BaseStreamer] = None,
max_new_tokens: int = 1024,
do_sample: bool = True,
num_beams: int = 1,
temperature: float = 1.0,
top_p: float = 0.8,
repetition_penalty: float=1.005,
**kwargs,
):
if images!=[]:
print('images ',images)
image_pt=self.get_tensor_image(images)
else:
image_pt=None
inputs=self.interleav_wrap_chat(messages,image_pt)
inputs = {
k: v.to(self.device)
for k, v in inputs.items() if torch.is_tensor(v)
}
# also add end-of-assistant token in eos token id to avoid unnecessary generation
eos_token_id = [
self.eos_token_id
]
outputs = self.lm_model.generate(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
**kwargs,
)
response = self.lm_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
messages+=[{"role": "assistant", "content": response}]
return response, messages