File size: 5,545 Bytes
70d113f
 
 
 
 
 
 
 
60b820b
70d113f
 
 
 
 
 
 
 
 
60b820b
70d113f
 
 
 
 
 
 
 
 
 
 
 
 
 
60b820b
70d113f
 
 
 
 
60b820b
 
c70b1b3
70d113f
60b820b
70d113f
60b820b
 
70d113f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60b820b
 
 
70d113f
60b820b
 
 
 
 
 
 
 
 
70d113f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60b820b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import peft
from peft import LoraConfig
from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
import torch
from peft import PeftModel
import torch.nn as nn
import whisperx
import os
clip_model_name = "openai/clip-vit-base-patch32"
phi_model_name  = "microsoft/phi-2"
tokenizer  = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
processor  = AutoProcessor.from_pretrained(clip_model_name)
tokenizer.pad_token = tokenizer.eos_token
IMAGE_TOKEN_ID = 23893 # token for word comment
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_embed = 768
phi_embed  = 2560
compute_type = "float32"
audio_batch_size = 16

class SimpleResBlock(nn.Module):
    def __init__(self, phi_embed):
        super().__init__()
        self.pre_norm = nn.LayerNorm(phi_embed)
        self.proj = nn.Sequential(
            nn.Linear(phi_embed, phi_embed),
            nn.GELU(),
            nn.Linear(phi_embed, phi_embed)
        )
    def forward(self, x):
        x = self.pre_norm(x)
        return x + self.proj(x)
        
# models
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
resblock = SimpleResBlock(phi_embed).to(device)
phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
# Assuming you have defined 'device' and 'compute_type' elsewhere
audio_model = whisperx.load_model("tiny", device, compute_type=compute_type, asr_options={'max_new_tokens': 2048, 'clip_timestamps': True, 'hallucination_silence_threshold': 0.25, 'hotwords': []})

# load weights
model_to_merge = PeftModel.from_pretrained(phi_model,os.path.join(os.getcwd(), 'model_chkpt/lora_adaptor'))
merged_model   = model_to_merge.merge_and_unload()
projection.load_state_dict(torch.load(os.path.join(os.getcwd(),'model_chkpt/finetunned_projection.pth'),map_location=torch.device(device)))
resblock.load_state_dict(torch.load(os.path.join(os.getcwd(),'model_chkpt/finetuned_resblock.pth'),map_location=torch.device(device)))

def model_generate_ans(img=None,img_audio=None,val_q=None):

    max_generate_length = 100
    val_combined_embeds = []
    
    with torch.no_grad():
    
        # image
        if img is not None:
            image_processed  = processor(images=img, return_tensors="pt").to(device)
            clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
            val_image_embeds = projection(clip_val_outputs)
            val_image_embeds = resblock(val_image_embeds).to(torch.float16)
            
            img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
            img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)

            val_combined_embeds.append(val_image_embeds)
            val_combined_embeds.append(img_token_embeds)

        # audio
        if img_audio is not None:
            audio_result = audio_model.transcribe(img_audio)
            audio_text = ''
            for seg in audio_result['segments']:
                audio_text += seg['text']
            audio_text = audio_text.strip()
            audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
            audio_embeds    = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
            val_combined_embeds.append(audio_embeds)
            
        # text question
        if len(val_q) != 0:
            val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
            val_q_embeds    = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
            val_combined_embeds.append(val_q_embeds)

        val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
        
        #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
        predicted_caption = torch.full((1,max_generate_length),50256).to(device)
    
        for g in range(max_generate_length):
            phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
            predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
            predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
            predicted_caption[:,g] = predicted_word_token.view(1,-1)
            next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
            val_combined_embeds   = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
            
        predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
    
    return predicted_captions_decoded
    

with gr.Blocks() as demo:

    gr.Markdown(
    """
    # Chat with MultiModal GPT !
    Build using combining clip model and phi-2 model.
    """
    )

    # app GUI
    with gr.Row():
        with gr.Column():
            img_input    = gr.Image(label='Image',type="pil")
            img_audio    = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
            img_question = gr.Text(label ='Text Query')
        with gr.Column():
            img_answer   = gr.Text(label ='Answer')

    section_btn = gr.Button("Submit")
    section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
    
demo.launch()