Spaces:
Build error
Build error
| from header import * | |
| import torch.nn.functional as F | |
| from .ImageBind import * | |
| from .ImageBind import data | |
| from .modeling_llama import LlamaForCausalLM | |
| from .AnomalyGPT_models import LinearLayer, PromptLearner | |
| from transformers import StoppingCriteria, StoppingCriteriaList | |
| from utils.loss import FocalLoss, BinaryDiceLoss | |
| import kornia as K | |
| import torch | |
| from torch.nn.utils import rnn | |
| from transformers import AutoConfig, AutoModelForCausalLM | |
| from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map | |
| CLASS_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', 'object', | |
| 'candle', 'cashew', 'chewinggum', 'fryum', 'macaroni', 'pcb', 'pipe fryum'] | |
| prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage'] | |
| prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage'] | |
| prompt_state = [prompt_normal, prompt_abnormal] | |
| prompt_templates = ['a photo of a {}.', 'a photo of the {}.'] | |
| # prompt_templates = [ | |
| # 'a cropped photo of the {}.', 'a cropped photo of a {}.', 'a close-up photo of a {}.', 'a close-up photo of the {}.', | |
| # 'a bright photo of the {}.', 'a bright photo of a {}.', 'a dark photo of a {}.', 'a dark photo of the {}.', | |
| # 'a dark photo of the {}.', 'a dark photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a jpeg corrupted photo of the {}.', | |
| # 'a blurry photo of the {}.', 'a blurry photo of a {}.', 'a photo of a {}.', 'a photo of the {}.', | |
| # 'a photo of the small {}.', 'a photo of a small {}.', 'a photo of the large {}.', 'a photo of a large {}.', | |
| # 'a photo of the {} for visual insprction.', 'a photo of a {} for visual insprction.', | |
| # 'a photo of the {} for anomaly detection.', 'a photo of a {} for anomaly detection.' | |
| # ] | |
| objs = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', 'object', | |
| 'candle', 'cashew', 'chewinggum', 'fryum', 'macaroni', 'pcb', 'pipe fryum', 'macaroni1', 'macaroni2','pcb1', 'pcb2', 'pcb3', 'pcb4', 'capsules'] | |
| prompt_sentences = {} | |
| for obj in objs: | |
| prompt_sentence_obj = [] | |
| for i in range(len(prompt_state)): | |
| prompted_state = [state.format(obj) for state in prompt_state[i]] | |
| prompted_sentence = [] | |
| for s in prompted_state: | |
| for template in prompt_templates: | |
| prompted_sentence.append(template.format(s)) | |
| prompted_sentence = data.load_and_transform_text(prompted_sentence, torch.cuda.current_device())#torch.cuda.current_device()) | |
| prompt_sentence_obj.append(prompted_sentence) | |
| prompt_sentences[obj] = prompt_sentence_obj | |
| def encode_text_with_prompt_ensemble(model, obj, device): | |
| global prompt_sentences | |
| normal_sentences = [] | |
| abnormal_sentences = [] | |
| for idx in range(len(obj)): | |
| sentence = prompt_sentences[obj[idx].replace('_', ' ')] | |
| normal_sentences.append(sentence[0]) | |
| abnormal_sentences.append(sentence[1]) | |
| normal_sentences = torch.cat(normal_sentences).to(device) | |
| abnormal_sentences = torch.cat(abnormal_sentences).to(device) | |
| class_embeddings_normal = model({ModalityType.TEXT: normal_sentences})[ModalityType.TEXT][0] | |
| class_embeddings_abnormal = model({ModalityType.TEXT: abnormal_sentences})[ModalityType.TEXT][0] | |
| # class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) | |
| class_embeddings_normal = class_embeddings_normal.reshape((len(obj), len(prompt_templates) * len(prompt_normal), 1024)) | |
| class_embeddings_normal = class_embeddings_normal.mean(dim=1, keepdim=True) | |
| class_embeddings_normal = class_embeddings_normal / class_embeddings_normal.norm(dim=-1, keepdim=True) | |
| class_embeddings_abnormal = class_embeddings_abnormal.reshape((len(obj), len(prompt_templates) * len(prompt_abnormal), 1024)) | |
| class_embeddings_abnormal = class_embeddings_abnormal.mean(dim=1, keepdim=True) | |
| class_embeddings_abnormal = class_embeddings_abnormal / class_embeddings_abnormal.norm(dim=-1, keepdim=True) | |
| text_features = torch.cat([class_embeddings_normal, class_embeddings_abnormal], dim=1) | |
| return text_features | |
| class StoppingCriteriaSub(StoppingCriteria): | |
| def __init__(self, stops = [], encounters=1): | |
| super().__init__() | |
| self.stops = stops | |
| self.ENCOUNTERS = encounters | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | |
| stop_count = 0 | |
| for stop in self.stops: | |
| stop_count = (stop == input_ids[0]).sum().item() | |
| if stop_count >= self.ENCOUNTERS: | |
| return True | |
| return False | |
| def build_one_instance(tokenizer, conversation): | |
| text_list = [] | |
| turn_num = len(conversation) | |
| input_ids, target_ids = [], [] | |
| for i in range(turn_num): | |
| turn = conversation[i] | |
| role = turn['from'] | |
| if i == 0: # the first human turn | |
| assert role == 'human' | |
| text = turn['value'] + '\n### Assistant:' | |
| one_input_id = tokenizer(text, add_special_tokens=False).input_ids | |
| input_ids += one_input_id | |
| target_ids += [-100]*len(one_input_id) # do not perform loss regression on human prompt | |
| else: | |
| if role == 'human': | |
| text = 'Human: ' + turn['value'] + '\n### Assistant:' | |
| one_input_id = tokenizer(text, add_special_tokens=False).input_ids | |
| input_ids += one_input_id | |
| target_ids += [-100]*len(one_input_id) | |
| elif role == 'gpt': | |
| text = turn['value'] + '\n###' | |
| one_input_id = tokenizer(text, add_special_tokens=False).input_ids | |
| input_ids += one_input_id | |
| target_ids += one_input_id | |
| else: | |
| raise Exception('Wrong Role!!!') | |
| text_list.append(text) | |
| assert len(input_ids) == len(target_ids) | |
| return text_list, input_ids, target_ids | |
| def process_batch_instance(tokenizer, batch_of_conversations, max_tgt_len): | |
| batch_input_ids, batch_target_ids = [], [] | |
| for conversation in batch_of_conversations: | |
| _, one_input_ids, one_target_ids = build_one_instance(tokenizer, conversation) | |
| batch_input_ids.append(torch.LongTensor(one_input_ids)) | |
| batch_target_ids.append(torch.LongTensor(one_target_ids)) | |
| input_ids = rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) | |
| target_ids = rnn.pad_sequence(batch_target_ids, batch_first=True, padding_value=-100) | |
| assert input_ids.size() == target_ids.size() | |
| input_ids = input_ids[:,:max_tgt_len] | |
| target_ids = target_ids[:,:max_tgt_len] | |
| attention_mask = input_ids.ne(tokenizer.pad_token_id) | |
| assert attention_mask.size() == input_ids.size() | |
| return input_ids, target_ids, attention_mask.long() | |
| def find_first_file_in_directory(directory_path): | |
| try: | |
| file_list = os.listdir(directory_path) | |
| for item in file_list: | |
| item_path = os.path.join(directory_path, item) | |
| if os.path.isfile(item_path): | |
| return item_path | |
| return None | |
| except OSError as e: | |
| print(f"Error while accessing directory: {e}") | |
| return None | |
| PROMPT_START = '### Human: <Img>' | |
| class OpenLLAMAPEFTModel(nn.Module): | |
| '''LoRA for LLaMa model''' | |
| def __init__(self, **args): | |
| super(OpenLLAMAPEFTModel, self).__init__() | |
| self.args = args | |
| imagebind_ckpt_path = args['imagebind_ckpt_path'] | |
| vicuna_ckpt_path = args['vicuna_ckpt_path'] | |
| max_tgt_len = args['max_tgt_len'] | |
| stage = args['stage'] | |
| self.device = torch.cuda.current_device() | |
| print (f'Initializing visual encoder from {imagebind_ckpt_path} ...') | |
| self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args) | |
| self.visual_encoder.to(torch.float16).to(self.device) | |
| imagebind_ckpt = torch.load(imagebind_ckpt_path, map_location=torch.device('cpu')) | |
| self.visual_encoder.load_state_dict(imagebind_ckpt, strict=True) | |
| self.iter = 0 | |
| self.image_decoder = LinearLayer(1280, 1024, 4).to(torch.float16).to(self.device) | |
| self.prompt_learner = PromptLearner(1, 4096).to(torch.float16).to(self.device) | |
| self.loss_focal = FocalLoss() | |
| self.loss_dice = BinaryDiceLoss() | |
| # free vision encoder | |
| for name, param in self.visual_encoder.named_parameters(): | |
| param.requires_grad = False | |
| self.visual_encoder.eval() | |
| print ('Visual encoder initialized.') | |
| print (f'Initializing language decoder from {vicuna_ckpt_path} ...') | |
| # add the lora module | |
| peft_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| inference_mode=False, | |
| r=self.args['lora_r'], | |
| lora_alpha=self.args['lora_alpha'], | |
| lora_dropout=self.args['lora_dropout'], | |
| target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'] | |
| ) | |
| # config = AutoConfig.from_pretrained(vicuna_ckpt_path) | |
| # with init_empty_weights(): | |
| # self.llama_model = AutoModelForCausalLM.from_config(config) | |
| # # device_map = infer_auto_device_map(self.llama_model, no_split_module_classes=["OPTDecoderLayer"], dtype="float16") | |
| # # print(device_map) | |
| device_map = {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10.self_attn': 0, 'model.layers.10.mlp.gate_proj': 0, 'model.layers.10.mlp.down_proj': 'cpu', 'model.layers.10.mlp.up_proj': 'cpu', 'model.layers.10.mlp.act_fn': 'cpu', 'model.layers.10.input_layernorm': 'cpu', 'model.layers.10.post_attention_layernorm': 'cpu', 'model.layers.11': 'cpu', 'model.layers.12': 'cpu', 'model.layers.13': 'cpu', 'model.layers.14': 'cpu', 'model.layers.15': 'cpu', 'model.layers.16': 'cpu', 'model.layers.17': 'cpu', 'model.layers.18': 'cpu', 'model.layers.19': 'cpu', 'model.layers.20': 'cpu', 'model.layers.21': 'cpu', 'model.layers.22': 'cpu', 'model.layers.23': 'cpu', 'model.layers.24': 'disk', 'model.layers.25': 'disk', 'model.layers.26': 'disk', 'model.layers.27': 'disk', 'model.layers.28': 'disk', 'model.layers.29': 'disk', 'model.layers.30': 'disk', 'model.layers.31.self_attn': 'disk', 'model.layers.31.mlp.gate_proj': 'disk', 'model.layers.31.mlp.down_proj': 'disk', 'model.layers.31.mlp.up_proj': 'disk', 'model.layers.31.mlp.act_fn': 'disk', 'model.layers.31.input_layernorm': 'disk', 'model.layers.31.post_attention_layernorm': 'disk', 'model.norm': 'disk', 'lm_head': 'disk'} | |
| # # self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map=device_map, offload_folder="offload", offload_state_dict = True) | |
| # # self.llama_model.to(torch.float16) | |
| # # try: | |
| self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', load_in_8bit=True) | |
| # # except: | |
| # pass | |
| # finally: | |
| # print(self.llama_model.hf_device_map) | |
| self.llama_model = get_peft_model(self.llama_model, peft_config) | |
| # delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu')) | |
| # self.llama_model.load_state_dict(delta_ckpt, strict=False) | |
| self.llama_model.print_trainable_parameters() | |
| self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.float16) | |
| self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token | |
| self.llama_tokenizer.padding_side = "right" | |
| print ('Language decoder initialized.') | |
| self.llama_proj = nn.Linear( | |
| self.visual_hidden_size, self.llama_model.config.hidden_size | |
| ).to(torch.float16).to(self.device) | |
| self.max_tgt_len = max_tgt_len | |
| def rot90_img(self,x,k): | |
| # k is 0,1,2,3 | |
| degreesarr = [0., 90., 180., 270., 360] | |
| degrees = torch.tensor(degreesarr[k]).to(self.llama_model.dtype).to(self.device) | |
| x = K.geometry.transform.rotate(x, angle = degrees, padding_mode='reflection') | |
| return x | |
| def encode_video(self, video_paths): | |
| inputs = {ModalityType.VISION: data.load_and_transform_video_data(video_paths, self.device)} | |
| # convert into visual dtype | |
| inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} | |
| with torch.no_grad(): | |
| embeddings = self.visual_encoder(inputs) | |
| video_embeds = embeddings[ModalityType.VISION][0] # bsz x 1024 | |
| inputs_llama = self.llama_proj(video_embeds).unsqueeze(1) # bsz x 1 x llama_size | |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 | |
| return inputs_llama, atts_llama | |
| def encode_audio(self, audio_paths): | |
| inputs = {ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, self.device)} | |
| # convert into visual dtype | |
| inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} | |
| with torch.no_grad(): | |
| embeddings = self.visual_encoder(inputs) | |
| audio_embeds = embeddings[ModalityType.AUDIO][0] # bsz x 1024 | |
| inputs_llama = self.llama_proj(audio_embeds).unsqueeze(1) # bsz x 1 x llama_size | |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 | |
| return inputs_llama, atts_llama | |
| def encode_thermal(self, thermal_paths): | |
| inputs = {ModalityType.THERMAL: data.load_and_transform_thermal_data(thermal_paths, self.device)} | |
| # convert into visual dtype | |
| inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} | |
| with torch.no_grad(): | |
| embeddings = self.visual_encoder(inputs) | |
| image_embeds = embeddings['thermal'][0] # bsz x 1024 | |
| inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size | |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 | |
| return inputs_llama, atts_llama | |
| def encode_image(self, image_paths): | |
| inputs = {ModalityType.VISION: data.load_and_transform_vision_data(image_paths, self.device)} | |
| # convert into visual dtype | |
| inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} | |
| with torch.no_grad(): | |
| embeddings = self.visual_encoder(inputs) | |
| image_embeds = embeddings['vision'][0] # bsz x 1024 | |
| patch_features = embeddings['vision'][1] # bsz x h*w x 1280 | |
| patch_tokens = self.image_decoder(patch_features) # bsz x h*w x 1024 | |
| inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size | |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 | |
| return inputs_llama, atts_llama, patch_tokens | |
| def encode_image_for_web_demo(self, image_paths): | |
| inputs = {ModalityType.VISION: data.load_and_transform_vision_data_for_web_demo(image_paths, self.device)} | |
| # convert into visual dtype | |
| inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} | |
| with torch.no_grad(): | |
| embeddings = self.visual_encoder(inputs) | |
| image_embeds = embeddings['vision'][0] # bsz x 1024 | |
| patch_features = embeddings['vision'][1] # bsz x h*w x 1280 | |
| patch_tokens = self.image_decoder(patch_features) # bsz x h*w x 1024 | |
| inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size | |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 | |
| return inputs_llama, atts_llama, patch_tokens | |
| def encode_image_for_one_shot(self, image_paths): | |
| inputs = {ModalityType.VISION: data.load_and_transform_vision_data(image_paths, self.device)} | |
| # convert into visual dtype | |
| inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} | |
| with torch.no_grad(): | |
| embeddings = self.visual_encoder(inputs) | |
| patch_features = embeddings['vision'][1] # bsz x h*w x 1280 | |
| for i in range(len(patch_features)): | |
| patch_features[i] = patch_features[i].transpose(0, 1)[:, 1:, :] | |
| return patch_features | |
| def encode_image_for_one_shot_from_tensor(self, image_tensors): | |
| if not isinstance(image_tensors, list): | |
| image_tensors = [image_tensors] | |
| inputs = {ModalityType.VISION: torch.stack(image_tensors, dim=0).to(self.device)} | |
| # convert into visual dtype | |
| inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} | |
| with torch.no_grad(): | |
| embeddings = self.visual_encoder(inputs) | |
| patch_features = embeddings['vision'][1] # bsz x h*w x 1280 | |
| for i in range(len(patch_features)): | |
| patch_features[i] = patch_features[i].transpose(0, 1)[:, 1:, :] | |
| return patch_features | |
| def encode_image_for_one_shot_with_aug(self, image_paths): | |
| image_tensors = data.load_and_transform_vision_data(image_paths, self.device).to(self.llama_model.dtype) | |
| B,C,H,W = image_tensors.shape | |
| # print(B,C,H,W) | |
| rotated_images = torch.zeros((4, B, C, H, W)).to(self.llama_model.dtype).to(self.device) | |
| for j, degree in enumerate([0, 1, 2, 3]): | |
| rotated_img = self.rot90_img(image_tensors, degree) | |
| # 存储旋转后的图像 | |
| rotated_images[j] = rotated_img | |
| image_tensors = rotated_images.transpose(0,1).reshape(B * 4, C, H, W) | |
| inputs = {ModalityType.VISION: image_tensors} | |
| # convert into visual dtype | |
| inputs = {key: inputs[key] for key in inputs} | |
| with torch.no_grad(): | |
| embeddings = self.visual_encoder(inputs) | |
| patch_features = embeddings['vision'][1] # bsz x h*w x 1280 | |
| for i in range(len(patch_features)): | |
| patch_features[i] = patch_features[i].transpose(0, 1)[:, 1:, :].reshape(B,4,256,1280).reshape(B, 4 * 256, 1280) | |
| return patch_features | |
| def encode_image_from_tensor(self, image_tensors): | |
| if not isinstance(image_tensors, list): | |
| image_tensors = [image_tensors] | |
| inputs = {ModalityType.VISION: torch.stack(image_tensors, dim=0).to(self.device)} | |
| # convert into visual dtype | |
| inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} | |
| with torch.no_grad(): | |
| embeddings = self.visual_encoder(inputs) | |
| image_embeds = embeddings['vision'][0] # bsz x 1024 | |
| patch_features = embeddings['vision'][1] # bsz x h*w x 1024 | |
| patch_tokens = self.image_decoder(patch_features) | |
| inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size | |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 | |
| return inputs_llama, atts_llama, patch_tokens | |
| def encode_image_from_tensor_no_patch(self, image_tensors): | |
| if not isinstance(image_tensors, list): | |
| image_tensors = [image_tensors] | |
| inputs = {ModalityType.VISION: torch.stack(image_tensors, dim=0).to(self.device)} | |
| # convert into visual dtype | |
| inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} | |
| with torch.no_grad(): | |
| embeddings = self.visual_encoder(inputs) | |
| image_embeds = embeddings['vision'][0] # bsz x 1024 | |
| inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size | |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 | |
| return inputs_llama, atts_llama | |
| def prompt_wrap(self, img_embeds, input_ids, target_ids, attention_mask, anomaly_embedding = None): | |
| ''' | |
| input_ids, target_ids, attention_mask: bsz x s2 | |
| ''' | |
| input_ids = input_ids.to(self.device) # bsz x s2 | |
| target_ids = target_ids.to(self.device) # bsz x s2 | |
| attention_mask = attention_mask.to(self.device) # bsz x s2 | |
| batch_size = img_embeds.shape[0] | |
| p_before = PROMPT_START | |
| p_before_tokens = self.llama_tokenizer(p_before, | |
| return_tensors="pt", add_special_tokens=False).to(self.device) | |
| # peft model need deeper call | |
| p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim | |
| p_middle = '</Img> ' | |
| p_middle_tokens = self.llama_tokenizer(p_middle, | |
| return_tensors="pt", add_special_tokens=False).to(self.device) | |
| # peft model need deeper call | |
| p_middle_embeds = self.llama_model.model.model.embed_tokens(p_middle_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim | |
| p_after_embeds = self.llama_model.model.model.embed_tokens(input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim | |
| bos = torch.ones([batch_size, 1], | |
| dtype=p_before_tokens.input_ids.dtype, | |
| device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1 | |
| bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim | |
| if anomaly_embedding != None: | |
| inputs_embeds = torch.cat([bos_embeds, p_before_embeds, img_embeds, p_middle_embeds, anomaly_embedding, p_after_embeds], dim=1) # bsz x (1+s1+1+s2) x embed_dim | |
| # create targets | |
| empty_targets = ( | |
| torch.ones([batch_size, 1+p_before_embeds.size()[1]+1+p_middle_embeds.size()[1] + anomaly_embedding.size()[1]], # 1 (bos) + s1 + 1 (image vector) | |
| dtype=torch.long).to(self.device).fill_(-100) | |
| ) # bsz x (1 + s1 + 1) | |
| targets = torch.cat([empty_targets, target_ids], dim=1) # bsz x (1 + s1 + 1 + s2) | |
| assert inputs_embeds.size()[1] == targets.size()[1] | |
| atts_prefix = torch.ones([batch_size, 1+p_before_embeds.size()[1]+1+p_middle_embeds.size()[1] + anomaly_embedding.size()[1]], dtype=torch.long).to(self.device) # bsz x (1 + s1 +1) | |
| attention_mask = torch.cat([atts_prefix, attention_mask], dim=1) | |
| assert attention_mask.size() == targets.size() # bsz x (1 + s1 + 1 + s2) | |
| return inputs_embeds, targets, attention_mask | |
| else: | |
| inputs_embeds = torch.cat([bos_embeds, p_before_embeds, img_embeds, p_middle_embeds, p_after_embeds], dim=1) # bsz x (1+s1+1+s2) x embed_dim | |
| # create targets | |
| empty_targets = ( | |
| torch.ones([batch_size, 1+p_before_embeds.size()[1]+1+p_middle_embeds.size()[1]], # 1 (bos) + s1 + 1 (image vector) | |
| dtype=torch.long).to(self.device).fill_(-100) | |
| ) # bsz x (1 + s1 + 1) | |
| targets = torch.cat([empty_targets, target_ids], dim=1) # bsz x (1 + s1 + 1 + s2) | |
| assert inputs_embeds.size()[1] == targets.size()[1] | |
| atts_prefix = torch.ones([batch_size, 1+p_before_embeds.size()[1]+1+p_middle_embeds.size()[1]], dtype=torch.long).to(self.device) # bsz x (1 + s1 +1) | |
| attention_mask = torch.cat([atts_prefix, attention_mask], dim=1) | |
| assert attention_mask.size() == targets.size() # bsz x (1 + s1 + 1 + s2) | |
| return inputs_embeds, targets, attention_mask | |
| def forward(self, inputs): | |
| if 'masks' in inputs: | |
| image_paths = inputs['images'] | |
| img_embeds, _, patch_tokens = self.encode_image_from_tensor(image_paths) | |
| class_name = inputs['class_names'] | |
| loss_pixel = 0 | |
| feats_text_tensor = encode_text_with_prompt_ensemble(self.visual_encoder, ['object' for _ in class_name], self.device) | |
| anomaly_maps = [] | |
| for layer in range(len(patch_tokens)): | |
| patch_tokens[layer] = patch_tokens[layer] / patch_tokens[layer].norm(dim=-1, keepdim=True) | |
| # print(patch_tokens[layer].shape) | |
| # anomaly_map = torch.bmm(patch_tokens[layer], feats_text_tensor.transpose(-2,-1)) | |
| anomaly_map = (100.0 * patch_tokens[layer] @ feats_text_tensor.transpose(-2,-1)) | |
| B, L, C = anomaly_map.shape | |
| H = int(np.sqrt(L)) | |
| anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H), | |
| size=224, mode='bilinear', align_corners=True) | |
| # anomaly_map_no_softmax = anomaly_map | |
| anomaly_map = torch.softmax(anomaly_map, dim=1) | |
| anomaly_maps.append(anomaly_map) | |
| # anomaly_maps_ns.append(anomaly_map_no_softmax) | |
| gt = inputs['masks'] | |
| gt = torch.stack(gt, dim=0).to(self.device) | |
| gt = gt.squeeze() | |
| # print(gt.max(), gt.min()) | |
| gt[gt > 0.3], gt[gt <= 0.3] = 1, 0 | |
| for num in range(len(anomaly_maps)): | |
| f_loss = self.loss_focal(anomaly_maps[num], gt) | |
| d_loss = self.loss_dice(anomaly_maps[num][:, 1, :, :], gt) | |
| loss_pixel = loss_pixel + f_loss + d_loss | |
| for num in range(len(anomaly_maps)): | |
| anomaly_maps[num] = anomaly_maps[num][:,1,:,:] | |
| anomaly_map_all = torch.mean(torch.stack(anomaly_maps, dim=0), dim=0).unsqueeze(1) | |
| if random.randint(0,1) == 0 and len(inputs['img_paths']) == len(image_paths): | |
| normal_paths = [] | |
| for path in inputs['img_paths']: | |
| normal_path = path.replace('test', 'train') | |
| normal_path = find_first_file_in_directory("/".join(normal_path.split('/')[:-2])+'/good') | |
| normal_paths.append(normal_path) | |
| print(normal_paths) | |
| query_patch_tokens = self.encode_image_for_one_shot_from_tensor(image_paths) | |
| normal_patch_tokens = self.encode_image_for_one_shot_with_aug(normal_paths) | |
| sims = [] | |
| B = len(image_paths) | |
| for i in range(len(query_patch_tokens)): | |
| query_patch_tokens_reshaped = query_patch_tokens[i].view(B,256,1,1280) | |
| normal_tokens_reshaped = normal_patch_tokens[i].reshape(B,1,-1,1280) | |
| cosine_similarity_matrix = F.cosine_similarity(query_patch_tokens_reshaped, normal_tokens_reshaped, dim=-1) | |
| sim_max, _ = torch.max(cosine_similarity_matrix, dim=-1) | |
| sims.append(sim_max) | |
| sim = torch.mean(torch.stack(sims,dim=0), dim=0).reshape(B,1,16,16) | |
| sim = F.interpolate(sim,size=224, mode='bilinear', align_corners=True) | |
| anomaly_map_all = 1 - sim # (anomaly_map_all + 1 - sim) / 2 | |
| anomaly_map_prompts = self.prompt_learner(anomaly_map_all) | |
| # img_embeds = img_embeds + anomaly_map_prompts | |
| output_texts = inputs['texts'] | |
| input_ids, target_ids, attention_mask = process_batch_instance(self.llama_tokenizer, output_texts, self.max_tgt_len) | |
| inputs_embeds, targets, attention_mask = self.prompt_wrap(img_embeds, input_ids, target_ids, attention_mask, anomaly_map_prompts) | |
| outputs = self.llama_model( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| return_dict=True, | |
| labels=targets, | |
| ) | |
| loss = outputs.loss | |
| # loss_l2 = torch.norm(anomaly_map_prompts / 2 , p=2) | |
| # loss_l2 = nn.MSELoss()(img_embeds_origin, img_embeds) | |
| # calculate the token accuarcy | |
| chosen_tokens = torch.max(outputs.logits, dim=-1)[1][:, 1:-1] # [B, S-1] | |
| # print(self.llama_tokenizer.decode(chosen_tokens[0], skip_special_tokens=True)) | |
| labels = targets[:, 2:] | |
| gen_acc = (chosen_tokens.reshape(-1) == labels.reshape(-1)).to(torch.long) # [B*S] | |
| valid_mask = (labels != -100).reshape(-1) | |
| # print(self.llama_tokenizer.decode(chosen_tokens.reshape(-1)[valid_mask], skip_special_tokens=True)) | |
| valid_tokens = gen_acc & valid_mask # [B*S] | |
| gen_acc = valid_tokens.sum().item() / valid_mask.sum().item() | |
| return loss + loss_pixel, gen_acc | |
| else: | |
| image_paths = inputs['image_paths'] | |
| img_embeds, _, patch_tokens = self.encode_image_from_tensor(image_paths) | |
| output_texts = inputs['output_texts'] | |
| c_name = 'object' | |
| for name in CLASS_NAMES: | |
| if name in output_texts: | |
| c_name = name | |
| break | |
| feats_text_tensor = encode_text_with_prompt_ensemble(self.visual_encoder, ['object'] * len(image_paths), self.device) | |
| anomaly_maps = [] | |
| for layer in range(len(patch_tokens)): | |
| patch_tokens[layer] = patch_tokens[layer] / patch_tokens[layer].norm(dim=-1, keepdim=True) | |
| # print(patch_tokens[layer].shape) | |
| # anomaly_map = torch.bmm(patch_tokens[layer], feats_text_tensor.transpose(-2,-1)) | |
| anomaly_map = (100.0 * patch_tokens[layer] @ feats_text_tensor.transpose(-2,-1)) | |
| B, L, C = anomaly_map.shape | |
| H = int(np.sqrt(L)) | |
| anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H), | |
| size=224, mode='bilinear', align_corners=True) | |
| # anomaly_map_no_softmax = anomaly_map | |
| anomaly_map = torch.softmax(anomaly_map, dim=1) | |
| anomaly_maps.append(anomaly_map) | |
| for num in range(len(anomaly_maps)): | |
| anomaly_maps[num] = anomaly_maps[num][:,1,:,:] | |
| anomaly_map_all = torch.mean(torch.stack(anomaly_maps, dim=0), dim=0).unsqueeze(1) | |
| anomaly_map_prompts = self.prompt_learner(anomaly_map_all) | |
| # img_embeds = img_embeds + anomaly_map_prompts | |
| input_ids, target_ids, attention_mask = process_batch_instance(self.llama_tokenizer, output_texts, self.max_tgt_len) | |
| inputs_embeds, targets, attention_mask = self.prompt_wrap(img_embeds, input_ids, target_ids, attention_mask, anomaly_map_prompts) | |
| outputs = self.llama_model( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| return_dict=True, | |
| labels=targets, | |
| ) | |
| loss = outputs.loss | |
| # calculate the token accuarcy | |
| chosen_tokens = torch.max(outputs.logits, dim=-1)[1][:, 1:-1] # [B, S-1] | |
| labels = targets[:, 2:] | |
| gen_acc = (chosen_tokens.reshape(-1) == labels.reshape(-1)).to(torch.long) # [B*S] | |
| valid_mask = (labels != -100).reshape(-1) | |
| valid_tokens = gen_acc & valid_mask # [B*S] | |
| gen_acc = valid_tokens.sum().item() / valid_mask.sum().item() | |
| return loss, gen_acc | |
| def extract_multimodal_feature(self, inputs, web_demo): | |
| features = [] | |
| if inputs['image_paths']: | |
| prompt = inputs['prompt'] | |
| c_name = 'object' | |
| for name in CLASS_NAMES: | |
| if name in prompt: | |
| c_name = name | |
| break | |
| if not web_demo: | |
| image_embeds, _, patch_tokens = self.encode_image(inputs['image_paths']) | |
| feats_text_tensor = encode_text_with_prompt_ensemble(self.visual_encoder, [c_name], self.device) | |
| else: | |
| image_embeds, _, patch_tokens = self.encode_image_for_web_demo(inputs['image_paths']) | |
| feats_text_tensor = encode_text_with_prompt_ensemble(self.visual_encoder, ['object'], self.device) | |
| anomaly_maps = [] | |
| for layer in range(len(patch_tokens)): | |
| patch_tokens[layer] = patch_tokens[layer] / patch_tokens[layer].norm(dim=-1, keepdim=True) | |
| # print(patch_tokens[layer].shape) | |
| # anomaly_map = torch.bmm(patch_tokens[layer], feats_text_tensor.transpose(-2,-1)) | |
| anomaly_map = (100.0 * patch_tokens[layer] @ feats_text_tensor.transpose(-2,-1)) | |
| B, L, C = anomaly_map.shape | |
| H = int(np.sqrt(L)) | |
| # anomaly_map = anomaly_map.to(torch.float16) | |
| anomaly_map = F.interpolate(anomaly_map.permute(0, 2, 1).view(B, 2, H, H), | |
| size=224, mode='bilinear', align_corners=True) | |
| # anomaly_map = anomaly_map.to(torch.bfloat16) | |
| anomaly_map = torch.softmax(anomaly_map, dim=1) | |
| anomaly_maps.append(anomaly_map[:,1,:,:]) | |
| anomaly_map_ret = torch.mean(torch.stack(anomaly_maps, dim=0), dim=0).unsqueeze(1) | |
| # anomaly_map_all = anomaly_map_ret.unsqueeze(1).repeat((1,3,1,1)) | |
| # anomaly_map_feature, _, _ = self.encode_image_from_tensor(anomaly_map_all) | |
| # image_embeds = anomaly_map_feature + image_embeds | |
| if inputs['normal_img_paths']: | |
| query_patch_tokens = self.encode_image_for_one_shot(inputs['image_paths']) | |
| if 'mvtec' in 'normal_img_paths': | |
| normal_patch_tokens = self.encode_image_for_one_shot_with_aug(inputs['normal_img_paths']) | |
| else: | |
| normal_patch_tokens = self.encode_image_for_one_shot(inputs['normal_img_paths']) | |
| sims = [] | |
| for i in range(len(query_patch_tokens)): | |
| query_patch_tokens_reshaped = query_patch_tokens[i].view(256,1,1280) | |
| normal_tokens_reshaped = normal_patch_tokens[i].reshape(1,-1,1280) | |
| cosine_similarity_matrix = F.cosine_similarity(query_patch_tokens_reshaped, normal_tokens_reshaped, dim=2) | |
| sim_max, _ = torch.max(cosine_similarity_matrix, dim=1) | |
| sims.append(sim_max) | |
| sim = torch.mean(torch.stack(sims,dim=0), dim=0).reshape(1,1,16,16) | |
| # anomaly_map = anomaly_map.to(torch.float16) | |
| sim = F.interpolate(sim,size=224, mode='bilinear', align_corners=True) | |
| # anomaly_map = anomaly_map.to(torch.bfloat16) | |
| anomaly_map_ret = 1 - sim # (anomaly_map_ret + 1 - sim) / 2 | |
| features.append(image_embeds) | |
| if inputs['audio_paths']: | |
| audio_embeds, _ = self.encode_audio(inputs['audio_paths']) | |
| features.append(audio_embeds) | |
| if inputs['video_paths']: | |
| video_embeds, _ = self.encode_video(inputs['video_paths']) | |
| features.append(video_embeds) | |
| if inputs['thermal_paths']: | |
| thermal_embeds, _ = self.encode_thermal(inputs['thermal_paths']) | |
| features.append(thermal_embeds) | |
| feature_embeds = torch.cat(features).sum(dim=0).unsqueeze(0) | |
| return feature_embeds, anomaly_map_ret | |
| def prepare_generation_embedding(self, inputs, web_demo): | |
| prompt = inputs['prompt'] | |
| # if len(inputs['modality_embeds']) == 1: | |
| # feature_embeds = inputs['modality_embeds'][0] | |
| # else: | |
| feature_embeds, anomaly_map = self.extract_multimodal_feature(inputs, web_demo) | |
| # print(anomaly_map.shape) | |
| inputs['modality_embeds'].append(feature_embeds) | |
| batch_size = feature_embeds.shape[0] | |
| p_before = PROMPT_START | |
| p_before_tokens = self.llama_tokenizer(p_before, | |
| return_tensors="pt", add_special_tokens=False).to(self.device) | |
| p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim | |
| p_middle = '</Img> ' | |
| p_middle_tokens = self.llama_tokenizer(p_middle, | |
| return_tensors="pt", add_special_tokens=False).to(self.device) | |
| # peft model need deeper call | |
| p_middle_embeds = self.llama_model.model.model.embed_tokens(p_middle_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim | |
| # self.prompt_learner.eval() | |
| anomaly_map_prompts = self.prompt_learner(anomaly_map) | |
| text = prompt + '\n### Assistant:' | |
| p_after_tokens = self.llama_tokenizer(text, add_special_tokens=False, return_tensors='pt').to(self.device) | |
| p_after_embeds = self.llama_model.model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim | |
| bos = torch.ones([batch_size, 1], | |
| dtype=p_before_tokens.input_ids.dtype, | |
| device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1 | |
| bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim | |
| inputs_embeds = torch.cat([bos_embeds, p_before_embeds, feature_embeds, p_middle_embeds, anomaly_map_prompts, p_after_embeds], dim=1) # bsz x (1+s1+1+s2) x embed_dim | |
| return inputs_embeds, anomaly_map | |
| def generate(self, inputs, web_demo=False): | |
| ''' | |
| inputs = { | |
| 'image_paths': optional, | |
| 'audio_paths': optional | |
| 'video_paths': optional | |
| 'thermal_paths': optional | |
| 'mode': generation mode, | |
| 'prompt': human input prompt, | |
| 'max_tgt_len': generation length, | |
| 'top_p': top_p, | |
| 'temperature': temperature | |
| 'modality_embeds': None or torch.tensor | |
| 'modality_cache': save the image cache | |
| } | |
| ''' | |
| # self.prompt_learner.eval() | |
| # self.llama_model.eval() | |
| # self.llama_proj.eval() | |
| # self.image_decoder.eval() | |
| # self.llama_tokenizer.eval() | |
| input_embeds, pixel_output = self.prepare_generation_embedding(inputs, web_demo) | |
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[2277], encounters=1)]) | |
| outputs = self.llama_model.generate( | |
| inputs_embeds=input_embeds, | |
| max_new_tokens=inputs['max_tgt_len'], | |
| top_p=inputs['top_p'], | |
| temperature=inputs['temperature'], | |
| do_sample=True, | |
| use_cache=True, | |
| stopping_criteria=stopping_criteria, | |
| ) | |
| output_text = self.llama_tokenizer.decode(outputs[0][:-2], skip_special_tokens=True) | |
| return output_text, pixel_output |