Shivdutta commited on
Commit
27fa6df
Β·
verified Β·
1 Parent(s): d486040

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -53
app.py CHANGED
@@ -1,18 +1,20 @@
1
  import gradio as gr
2
  import peft
3
  from peft import LoraConfig
4
- from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
5
  import torch
6
  from peft import PeftModel
7
  import torch.nn as nn
8
  import whisperx
9
  import os
 
 
10
  clip_model_name = "openai/clip-vit-base-patch32"
11
  phi_model_name = "microsoft/phi-2"
12
  tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
13
  processor = AutoProcessor.from_pretrained(clip_model_name)
14
  tokenizer.pad_token = tokenizer.eos_token
15
- IMAGE_TOKEN_ID = 23893 # token for word comment
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  clip_embed = 768
18
  phi_embed = 2560
@@ -28,35 +30,32 @@ class SimpleResBlock(nn.Module):
28
  nn.GELU(),
29
  nn.Linear(phi_embed, phi_embed)
30
  )
 
31
  def forward(self, x):
32
  x = self.pre_norm(x)
33
  return x + self.proj(x)
34
-
35
- # models
36
  clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
37
  projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
38
  resblock = SimpleResBlock(phi_embed).to(device)
39
- phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
40
- # Assuming you have defined 'device' and 'compute_type' elsewhere
41
  audio_model = whisperx.load_model("tiny", device, compute_type=compute_type, asr_options={'max_new_tokens': 2048, 'clip_timestamps': True, 'hallucination_silence_threshold': 0.25})
42
 
43
- # load weights
44
- model_to_merge = PeftModel.from_pretrained(phi_model,os.path.join(os.getcwd(), 'model_chkpt/lora_adaptor'))
45
- merged_model = model_to_merge.merge_and_unload()
46
- projection.load_state_dict(torch.load(os.path.join(os.getcwd(),'model_chkpt/finetunned_projection.pth'),map_location=torch.device(device)))
47
- resblock.load_state_dict(torch.load(os.path.join(os.getcwd(),'model_chkpt/finetuned_resblock.pth'),map_location=torch.device(device)))
48
 
49
- def model_generate_ans(img=None,img_audio=None,val_q=None):
50
-
51
- max_generate_length = 100
52
  val_combined_embeds = []
53
 
54
  with torch.no_grad():
55
-
56
- # image
57
  if img is not None:
58
- image_processed = processor(images=img, return_tensors="pt").to(device)
59
- clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
60
  val_image_embeds = projection(clip_val_outputs)
61
  val_image_embeds = resblock(val_image_embeds).to(torch.float16)
62
 
@@ -66,60 +65,54 @@ def model_generate_ans(img=None,img_audio=None,val_q=None):
66
  val_combined_embeds.append(val_image_embeds)
67
  val_combined_embeds.append(img_token_embeds)
68
 
69
- # audio
70
  if img_audio is not None:
71
  audio_result = audio_model.transcribe(img_audio)
72
- audio_text = ''
73
- for seg in audio_result['segments']:
74
- audio_text += seg['text']
75
- audio_text = audio_text.strip()
76
  audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
77
- audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
78
  val_combined_embeds.append(audio_embeds)
79
-
80
- # text question
81
- if len(val_q) != 0:
82
- val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
83
- val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
84
  val_combined_embeds.append(val_q_embeds)
85
 
86
- val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
87
-
88
- #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
89
- predicted_caption = torch.full((1,max_generate_length),50256).to(device)
90
 
91
- for g in range(max_generate_length):
92
- phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
93
- predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
94
- predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
95
- predicted_caption[:,g] = predicted_word_token.view(1,-1)
96
- next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
97
- val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
98
 
99
- predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
100
 
101
  return predicted_captions_decoded
102
-
103
 
 
104
  with gr.Blocks() as demo:
105
-
106
  gr.Markdown(
107
- """
108
- # Chat with MultiModal GPT !
109
- Build using combining clip model and phi-2 model.
110
- """
111
  )
112
 
113
- # app GUI
114
  with gr.Row():
115
  with gr.Column():
116
- img_input = gr.Image(label='Image',type="pil")
117
- img_audio = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
118
- img_question = gr.Text(label ='Text Query')
 
119
  with gr.Column():
120
- img_answer = gr.Text(label ='Answer')
121
 
122
  section_btn = gr.Button("Submit")
123
- section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
124
 
125
  demo.launch()
 
1
  import gradio as gr
2
  import peft
3
  from peft import LoraConfig
4
+ from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
5
  import torch
6
  from peft import PeftModel
7
  import torch.nn as nn
8
  import whisperx
9
  import os
10
+
11
+ # Load models
12
  clip_model_name = "openai/clip-vit-base-patch32"
13
  phi_model_name = "microsoft/phi-2"
14
  tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
15
  processor = AutoProcessor.from_pretrained(clip_model_name)
16
  tokenizer.pad_token = tokenizer.eos_token
17
+ IMAGE_TOKEN_ID = 23893 # token for word comment
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  clip_embed = 768
20
  phi_embed = 2560
 
30
  nn.GELU(),
31
  nn.Linear(phi_embed, phi_embed)
32
  )
33
+
34
  def forward(self, x):
35
  x = self.pre_norm(x)
36
  return x + self.proj(x)
37
+
38
+ # Initialize models
39
  clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
40
  projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
41
  resblock = SimpleResBlock(phi_embed).to(device)
42
+ phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name, trust_remote_code=True).to(device)
 
43
  audio_model = whisperx.load_model("tiny", device, compute_type=compute_type, asr_options={'max_new_tokens': 2048, 'clip_timestamps': True, 'hallucination_silence_threshold': 0.25})
44
 
45
+ # Load weights
46
+ model_to_merge = PeftModel.from_pretrained(phi_model, os.path.join(os.getcwd(), 'model_chkpt/lora_adaptor'))
47
+ merged_model = model_to_merge.merge_and_unload()
48
+ projection.load_state_dict(torch.load(os.path.join(os.getcwd(), 'model_chkpt/finetunned_projection.pth'), map_location=device))
49
+ resblock.load_state_dict(torch.load(os.path.join(os.getcwd(), 'model_chkpt/finetuned_resblock.pth'), map_location=device))
50
 
51
+ def model_generate_ans(img=None, img_audio=None, val_q=None, max_length=100):
 
 
52
  val_combined_embeds = []
53
 
54
  with torch.no_grad():
55
+ # Image processing
 
56
  if img is not None:
57
+ image_processed = processor(images=img, return_tensors="pt").to(device)
58
+ clip_val_outputs = clip_model(**image_processed).last_hidden_state[:, 1:, :]
59
  val_image_embeds = projection(clip_val_outputs)
60
  val_image_embeds = resblock(val_image_embeds).to(torch.float16)
61
 
 
65
  val_combined_embeds.append(val_image_embeds)
66
  val_combined_embeds.append(img_token_embeds)
67
 
68
+ # Audio processing
69
  if img_audio is not None:
70
  audio_result = audio_model.transcribe(img_audio)
71
+ audio_text = ' '.join(seg['text'] for seg in audio_result['segments']).strip()
 
 
 
72
  audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
73
+ audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
74
  val_combined_embeds.append(audio_embeds)
75
+
76
+ # Text question processing
77
+ if val_q:
78
+ val_q_tokenized = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
79
+ val_q_embeds = merged_model.model.embed_tokens(val_q_tokenized).unsqueeze(0)
80
  val_combined_embeds.append(val_q_embeds)
81
 
82
+ val_combined_embeds = torch.cat(val_combined_embeds, dim=1)
83
+ predicted_caption = torch.full((1, max_length), 50256).to(device)
 
 
84
 
85
+ for g in range(max_length):
86
+ phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits']
87
+ predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1)
88
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim=-1)
89
+ predicted_caption[:, g] = predicted_word_token.view(1, -1)
90
+ next_token_embeds = phi_model.model.embed_tokens(predicted_word_token)
91
+ val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
92
 
93
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_caption, skip_special_tokens=True)[0]
94
 
95
  return predicted_captions_decoded
 
96
 
97
+ # Gradio Interface
98
  with gr.Blocks() as demo:
 
99
  gr.Markdown(
100
+ """
101
+ # Chat with MultiModal GPT!
102
+ Combining CLIP model and Phi-2 model for multimodal understanding.
103
+ """
104
  )
105
 
 
106
  with gr.Row():
107
  with gr.Column():
108
+ img_input = gr.Image(label='Image', type="pil", tool="editor")
109
+ img_audio = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
110
+ img_question = gr.Textbox(label='Text Query', placeholder='Type your question here...')
111
+ max_length = gr.Slider(label='Max Length of Response', minimum=1, maximum=300, value=100, step=1)
112
  with gr.Column():
113
+ img_answer = gr.Textbox(label='Answer', interactive=False)
114
 
115
  section_btn = gr.Button("Submit")
116
+ section_btn.click(model_generate_ans, inputs=[img_input, img_audio, img_question, max_length], outputs=[img_answer])
117
 
118
  demo.launch()