Shivdutta commited on
Commit
70d113f
Β·
verified Β·
1 Parent(s): 841336b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +148 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import peft
3
+ from peft import LoraConfig
4
+ from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
5
+ import torch
6
+ from peft import PeftModel
7
+ import torch.nn as nn
8
+ import whisperx
9
+
10
+ # Determine the appropriate device
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # Set compute_type based on device capabilities
14
+ if device == "cuda" and torch.cuda.is_bf16_supported():
15
+ compute_type = "float16"
16
+ elif device == "cuda":
17
+ compute_type = "float32"
18
+ else:
19
+ compute_type = "int8"
20
+
21
+
22
+ clip_model_name = "openai/clip-vit-base-patch32"
23
+ phi_model_name = "microsoft/phi-2"
24
+ tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
25
+ processor = AutoProcessor.from_pretrained(clip_model_name)
26
+ tokenizer.pad_token = tokenizer.eos_token
27
+ IMAGE_TOKEN_ID = 23893 # token for word comment
28
+ QA_TOKEN_ID = 50295 # token for qa
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ clip_embed = 768
31
+ phi_embed = 2560
32
+ compute_type = "float16"
33
+ audio_batch_size = 16
34
+
35
+ class SimpleResBlock(nn.Module):
36
+ def __init__(self, phi_embed):
37
+ super().__init__()
38
+ self.pre_norm = nn.LayerNorm(phi_embed)
39
+ self.proj = nn.Sequential(
40
+ nn.Linear(phi_embed, phi_embed),
41
+ nn.GELU(),
42
+ nn.Linear(phi_embed, phi_embed)
43
+ )
44
+ def forward(self, x):
45
+ x = self.pre_norm(x)
46
+ return x + self.proj(x)
47
+
48
+
49
+
50
+ # models
51
+ clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
52
+ projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
53
+ resblock = SimpleResBlock(phi_embed).to(device)
54
+ phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
55
+ # Load the model with the appropriate compute_type
56
+ audio_model_size = "tiny"
57
+ try:
58
+ audio_model = whisperx.load_model(audio_model_size, device, compute_type=compute_type)
59
+ print(f"Model loaded successfully with compute_type: {compute_type}")
60
+ except ValueError as e:
61
+ print(f"Error loading model: {e}")
62
+ print("Falling back to int8 compute type")
63
+ audio_model = whisperx.load_model(audio_model_size, device, compute_type="int8")
64
+
65
+ # load weights
66
+ model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
67
+ merged_model = model_to_merge.merge_and_unload()
68
+ projection.load_state_dict(torch.load('./model_chkpt/step2_projection.pth',map_location=torch.device(device)))
69
+ resblock.load_state_dict(torch.load('./model_chkpt/step2_resblock.pth',map_location=torch.device(device)))
70
+
71
+ def model_generate_ans(img=None,img_audio=None,val_q=None):
72
+
73
+ max_generate_length = 100
74
+ val_combined_embeds = []
75
+
76
+ with torch.no_grad():
77
+
78
+ # image
79
+ if img is not None:
80
+ image_processed = processor(images=img, return_tensors="pt").to(device)
81
+ clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
82
+ val_image_embeds = projection(clip_val_outputs)
83
+ val_image_embeds = resblock(val_image_embeds).to(torch.float16)
84
+
85
+ img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
86
+ img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
87
+
88
+ val_combined_embeds.append(val_image_embeds)
89
+ val_combined_embeds.append(img_token_embeds)
90
+
91
+ # audio
92
+ if img_audio is not None:
93
+ audio_result = audio_model.transcribe(img_audio)
94
+ audio_text = ''
95
+ for seg in audio_result['segments']:
96
+ audio_text += seg['text']
97
+ audio_text = audio_text.strip()
98
+ audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
99
+ audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
100
+ val_combined_embeds.append(audio_embeds)
101
+
102
+ # text question
103
+ if len(val_q) != 0:
104
+ val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
105
+ val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
106
+ val_combined_embeds.append(val_q_embeds)
107
+
108
+
109
+ if img_audio is not None or len(val_q) != 0: # add QA Token
110
+
111
+ QA_token_tensor = torch.tensor(QA_TOKEN_ID).to(device)
112
+ QA_token_embeds = merged_model.model.embed_tokens(QA_token_tensor).unsqueeze(0).unsqueeze(0)
113
+ val_combined_embeds.append(QA_token_embeds)
114
+
115
+ val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
116
+ predicted_caption = merged_model.generate(inputs_embeds=val_combined_embeds,
117
+ max_new_tokens=max_generate_length,
118
+ return_dict_in_generate = True)
119
+
120
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
121
+ predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>", "")
122
+
123
+ return predicted_captions_decoded
124
+
125
+
126
+ with gr.Blocks() as demo:
127
+
128
+ gr.Markdown(
129
+ """
130
+ # Chat with MultiModal GPT !
131
+ Build using combining clip model and phi-2 model.
132
+ """
133
+ )
134
+
135
+ # app GUI
136
+ with gr.Row():
137
+ with gr.Column():
138
+ img_input = gr.Image(label='Image',type="pil")
139
+ img_audio = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
140
+ img_question = gr.Text(label ='Text Query')
141
+ with gr.Column():
142
+ img_answer = gr.Text(label ='Answer')
143
+
144
+ section_btn = gr.Button("Submit")
145
+ section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
146
+
147
+ if __name__ == "__main__":
148
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ peft
3
+ accelerate
4
+ transformers
5
+ einops
6
+ git+https://github.com/m-bain/whisperx.git