import gradio as gr import peft from peft import LoraConfig from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor import torch from peft import PeftModel import torch.nn as nn import whisperx import os clip_model_name = "openai/clip-vit-base-patch32" phi_model_name = "microsoft/phi-2" tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True) processor = AutoProcessor.from_pretrained(clip_model_name) tokenizer.pad_token = tokenizer.eos_token IMAGE_TOKEN_ID = 23893 # token for word comment device = "cuda" if torch.cuda.is_available() else "cpu" clip_embed = 768 phi_embed = 2560 compute_type = "float32" audio_batch_size = 16 class SimpleResBlock(nn.Module): def __init__(self, phi_embed): super().__init__() self.pre_norm = nn.LayerNorm(phi_embed) self.proj = nn.Sequential( nn.Linear(phi_embed, phi_embed), nn.GELU(), nn.Linear(phi_embed, phi_embed) ) def forward(self, x): x = self.pre_norm(x) return x + self.proj(x) # models clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device) projection = torch.nn.Linear(clip_embed, phi_embed).to(device) resblock = SimpleResBlock(phi_embed).to(device) phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device) # Assuming you have defined 'device' and 'compute_type' elsewhere audio_model = whisperx.load_model("tiny", device, compute_type=compute_type, asr_options={'max_new_tokens': 2048, 'clip_timestamps': True, 'hallucination_silence_threshold': 0.25}) # load weights model_to_merge = PeftModel.from_pretrained(phi_model,os.path.join(os.getcwd(), 'model_chkpt/lora_adaptor')) merged_model = model_to_merge.merge_and_unload() projection.load_state_dict(torch.load(os.path.join(os.getcwd(),'model_chkpt/finetunned_projection.pth'),map_location=torch.device(device))) resblock.load_state_dict(torch.load(os.path.join(os.getcwd(),'model_chkpt/finetuned_resblock.pth'),map_location=torch.device(device))) def model_generate_ans(img=None,img_audio=None,val_q=None): max_generate_length = 100 val_combined_embeds = [] with torch.no_grad(): # image if img is not None: image_processed = processor(images=img, return_tensors="pt").to(device) clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:] val_image_embeds = projection(clip_val_outputs) val_image_embeds = resblock(val_image_embeds).to(torch.float16) img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device) img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0) val_combined_embeds.append(val_image_embeds) val_combined_embeds.append(img_token_embeds) # audio if img_audio is not None: audio_result = audio_model.transcribe(img_audio) audio_text = '' for seg in audio_result['segments']: audio_text += seg['text'] audio_text = audio_text.strip() audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device) audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0) val_combined_embeds.append(audio_embeds) # text question if len(val_q) != 0: val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device) val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0) val_combined_embeds.append(val_q_embeds) val_combined_embeds = torch.cat(val_combined_embeds,dim=1) #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560 predicted_caption = torch.full((1,max_generate_length),50256).to(device) for g in range(max_generate_length): phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200 predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200 predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1 predicted_caption[:,g] = predicted_word_token.view(1,-1) next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560 val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1) predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0] return predicted_captions_decoded with gr.Blocks() as demo: # Add custom CSS stylesheet within Markdown gr.Markdown( """ # Engage with MultiModal GPT! A seamless AI experience combining CLIP and Phi-2 models. """ ) # app GUI with gr.Row(): with gr.Column(): img_input = gr.Image(label='Image',type="pil") img_audio = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath') img_question = gr.Text(label ='Text Query') with gr.Column(): img_answer = gr.Text(label ='Answer') section_btn = gr.Button("Submit") section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer]) demo.launch()