Shivdutta commited on
Commit
5dfce3a
Β·
verified Β·
1 Parent(s): 27fa6df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -57
app.py CHANGED
@@ -1,118 +1,118 @@
1
  import gradio as gr
2
  import peft
3
  from peft import LoraConfig
4
- from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
5
  import torch
6
  from peft import PeftModel
7
  import torch.nn as nn
8
  import whisperx
9
  import os
10
 
11
- # Load models
12
  clip_model_name = "openai/clip-vit-base-patch32"
13
  phi_model_name = "microsoft/phi-2"
14
  tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
15
  processor = AutoProcessor.from_pretrained(clip_model_name)
16
  tokenizer.pad_token = tokenizer.eos_token
17
- IMAGE_TOKEN_ID = 23893 # token for word comment
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  clip_embed = 768
20
  phi_embed = 2560
21
- compute_type = "float32"
22
- audio_batch_size = 16
23
-
24
- class SimpleResBlock(nn.Module):
25
- def __init__(self, phi_embed):
26
- super().__init__()
27
- self.pre_norm = nn.LayerNorm(phi_embed)
28
- self.proj = nn.Sequential(
29
- nn.Linear(phi_embed, phi_embed),
30
  nn.GELU(),
31
  nn.Linear(phi_embed, phi_embed)
32
  )
33
-
34
  def forward(self, x):
35
  x = self.pre_norm(x)
36
  return x + self.proj(x)
37
-
38
- # Initialize models
39
  clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
40
  projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
41
  resblock = SimpleResBlock(phi_embed).to(device)
42
- phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name, trust_remote_code=True).to(device)
 
43
  audio_model = whisperx.load_model("tiny", device, compute_type=compute_type, asr_options={'max_new_tokens': 2048, 'clip_timestamps': True, 'hallucination_silence_threshold': 0.25})
44
 
45
- # Load weights
46
- model_to_merge = PeftModel.from_pretrained(phi_model, os.path.join(os.getcwd(), 'model_chkpt/lora_adaptor'))
47
- merged_model = model_to_merge.merge_and_unload()
48
- projection.load_state_dict(torch.load(os.path.join(os.getcwd(), 'model_chkpt/finetunned_projection.pth'), map_location=device))
49
- resblock.load_state_dict(torch.load(os.path.join(os.getcwd(), 'model_chkpt/finetuned_resblock.pth'), map_location=device))
 
 
50
 
51
- def model_generate_ans(img=None, img_audio=None, val_q=None, max_length=100):
52
  val_combined_embeds = []
53
 
54
  with torch.no_grad():
55
- # Image processing
 
56
  if img is not None:
57
- image_processed = processor(images=img, return_tensors="pt").to(device)
58
- clip_val_outputs = clip_model(**image_processed).last_hidden_state[:, 1:, :]
59
  val_image_embeds = projection(clip_val_outputs)
60
  val_image_embeds = resblock(val_image_embeds).to(torch.float16)
61
 
62
- img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
63
- img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
64
-
65
  val_combined_embeds.append(val_image_embeds)
66
  val_combined_embeds.append(img_token_embeds)
67
 
68
- # Audio processing
69
  if img_audio is not None:
70
  audio_result = audio_model.transcribe(img_audio)
71
- audio_text = ' '.join(seg['text'] for seg in audio_result['segments']).strip()
 
 
 
72
  audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
73
- audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
74
  val_combined_embeds.append(audio_embeds)
75
-
76
- # Text question processing
77
- if val_q:
78
- val_q_tokenized = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
79
- val_q_embeds = merged_model.model.embed_tokens(val_q_tokenized).unsqueeze(0)
80
  val_combined_embeds.append(val_q_embeds)
81
 
82
- val_combined_embeds = torch.cat(val_combined_embeds, dim=1)
83
- predicted_caption = torch.full((1, max_length), 50256).to(device)
 
 
84
 
85
- for g in range(max_length):
86
- phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits']
87
- predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1)
88
- predicted_word_token = torch.argmax(predicted_word_token_logits, dim=-1)
89
- predicted_caption[:, g] = predicted_word_token.view(1, -1)
90
- next_token_embeds = phi_model.model.embed_tokens(predicted_word_token)
91
- val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
92
 
93
- predicted_captions_decoded = tokenizer.batch_decode(predicted_caption, skip_special_tokens=True)[0]
94
 
95
  return predicted_captions_decoded
 
 
96
 
97
- # Gradio Interface
98
  with gr.Blocks() as demo:
 
99
  gr.Markdown(
100
- """
101
- # Chat with MultiModal GPT!
102
- Combining CLIP model and Phi-2 model for multimodal understanding.
103
- """
104
  )
105
 
 
106
  with gr.Row():
107
  with gr.Column():
108
- img_input = gr.Image(label='Image', type="pil", tool="editor")
109
- img_audio = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
110
- img_question = gr.Textbox(label='Text Query', placeholder='Type your question here...')
111
- max_length = gr.Slider(label='Max Length of Response', minimum=1, maximum=300, value=100, step=1)
112
  with gr.Column():
113
- img_answer = gr.Textbox(label='Answer', interactive=False)
114
 
115
  section_btn = gr.Button("Submit")
116
- section_btn.click(model_generate_ans, inputs=[img_input, img_audio, img_question, max_length], outputs=[img_answer])
117
 
118
  demo.launch()
 
1
  import gradio as gr
2
  import peft
3
  from peft import LoraConfig
4
+ from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
5
  import torch
6
  from peft import PeftModel
7
  import torch.nn as nn
8
  import whisperx
9
  import os
10
 
11
+
12
  clip_model_name = "openai/clip-vit-base-patch32"
13
  phi_model_name = "microsoft/phi-2"
14
  tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
15
  processor = AutoProcessor.from_pretrained(clip_model_name)
16
  tokenizer.pad_token = tokenizer.eos_token
17
+ IMAGE_TOKEN_ID = 23893 # token for word comment
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  clip_embed = 768
20
  phi_embed = 2560
 
 
 
 
 
 
 
 
 
21
  nn.GELU(),
22
  nn.Linear(phi_embed, phi_embed)
23
  )
24
+
25
  def forward(self, x):
26
  x = self.pre_norm(x)
27
  return x + self.proj(x)
28
+
29
+ # models
30
  clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
31
  projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
32
  resblock = SimpleResBlock(phi_embed).to(device)
33
+ phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
34
+ # Assuming you have defined 'device' and 'compute_type' elsewhere
35
  audio_model = whisperx.load_model("tiny", device, compute_type=compute_type, asr_options={'max_new_tokens': 2048, 'clip_timestamps': True, 'hallucination_silence_threshold': 0.25})
36
 
37
+ # load weights
38
+ model_to_merge = PeftModel.from_pretrained(phi_model,os.path.join(os.getcwd(), 'model_chkpt/lora_adaptor'))
39
+ merged_model = model_to_merge.merge_and_unload()
40
+ projection.load_state_dict(torch.load(os.path.join(os.getcwd(),'model_chkpt/finetunned_projection.pth'),map_location=torch.device(device)))
41
+ resblock.load_state_dict(torch.load(os.path.join(os.getcwd(),'model_chkpt/finetuned_resblock.pth'),map_location=torch.device(device)))
42
+
43
+ def model_generate_ans(img=None,img_audio=None,val_q=None):
44
 
45
+ max_generate_length = 100
46
  val_combined_embeds = []
47
 
48
  with torch.no_grad():
49
+
50
+ # image
51
  if img is not None:
52
+ image_processed = processor(images=img, return_tensors="pt").to(device)
53
+ clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
54
  val_image_embeds = projection(clip_val_outputs)
55
  val_image_embeds = resblock(val_image_embeds).to(torch.float16)
56
 
 
 
 
57
  val_combined_embeds.append(val_image_embeds)
58
  val_combined_embeds.append(img_token_embeds)
59
 
60
+ # audio
61
  if img_audio is not None:
62
  audio_result = audio_model.transcribe(img_audio)
63
+ audio_text = ''
64
+ for seg in audio_result['segments']:
65
+ audio_text += seg['text']
66
+ audio_text = audio_text.strip()
67
  audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
68
+ audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
69
  val_combined_embeds.append(audio_embeds)
70
+
71
+ # text question
72
+ if len(val_q) != 0:
73
+ val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
74
+ val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
75
  val_combined_embeds.append(val_q_embeds)
76
 
77
+ val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
78
+
79
+ #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
80
+ predicted_caption = torch.full((1,max_generate_length),50256).to(device)
81
 
82
+ for g in range(max_generate_length):
83
+ phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
84
+ predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
85
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
86
+ predicted_caption[:,g] = predicted_word_token.view(1,-1)
87
+ next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
88
+ val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
89
 
90
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
91
 
92
  return predicted_captions_decoded
93
+
94
+
95
 
 
96
  with gr.Blocks() as demo:
97
+
98
  gr.Markdown(
99
+ """
100
+ # Chat with MultiModal GPT !
101
+ Build using combining clip model and phi-2 model.
102
+ """
103
  )
104
 
105
+ # app GUI
106
  with gr.Row():
107
  with gr.Column():
108
+ img_input = gr.Image(label='Image',type="pil")
109
+ img_audio = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
110
+ img_question = gr.Text(label ='Text Query')
111
+
112
  with gr.Column():
113
+ img_answer = gr.Text(label ='Answer')
114
 
115
  section_btn = gr.Button("Submit")
116
+ section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
117
 
118
  demo.launch()