Shivdutta commited on
Commit
60b820b
Β·
verified Β·
1 Parent(s): c70b1b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -65
app.py CHANGED
@@ -6,30 +6,17 @@ import torch
6
  from peft import PeftModel
7
  import torch.nn as nn
8
  import whisperx
9
-
10
- # Determine the appropriate device
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
-
13
- # Set compute_type based on device capabilities
14
- if device == "cuda" and torch.cuda.is_bf16_supported():
15
- compute_type = "float16"
16
- elif device == "cuda":
17
- compute_type = "float32"
18
- else:
19
- compute_type = "int8"
20
-
21
-
22
  clip_model_name = "openai/clip-vit-base-patch32"
23
  phi_model_name = "microsoft/phi-2"
24
  tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
25
  processor = AutoProcessor.from_pretrained(clip_model_name)
26
  tokenizer.pad_token = tokenizer.eos_token
27
  IMAGE_TOKEN_ID = 23893 # token for word comment
28
- QA_TOKEN_ID = 50295 # token for qa
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  clip_embed = 768
31
  phi_embed = 2560
32
- compute_type = "float16"
33
  audio_batch_size = 16
34
 
35
  class SimpleResBlock(nn.Module):
@@ -44,50 +31,20 @@ class SimpleResBlock(nn.Module):
44
  def forward(self, x):
45
  x = self.pre_norm(x)
46
  return x + self.proj(x)
47
-
48
-
49
-
50
  # models
51
  clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
52
  projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
53
  resblock = SimpleResBlock(phi_embed).to(device)
54
  phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
55
- # Load the model with the appropriate compute_type
56
- # Load the audio model with appropriate compute_type
57
- audio_model_size = "tiny"
58
- compute_type = "float32" # Ensure using a compatible compute type
59
- try:
60
- audio_model = whisperx.load_model(
61
- audio_model_size,
62
- device,
63
- compute_type=compute_type
64
- # Removed unsupported parameters
65
- )
66
- print(f"Model loaded successfully with compute_type: {compute_type}")
67
- except ValueError as e:
68
- print(f"Error loading model: {e}")
69
- # Optionally, try loading with int8 if necessary
70
- try:
71
- audio_model = whisperx.load_model(
72
- audio_model_size,
73
- device,
74
- compute_type="int8"
75
- # Removed unsupported parameters
76
- )
77
- print("Fell back to int8 compute type successfully.")
78
- except Exception as e:
79
- print(f"Failed to load model with int8: {e}")
80
-
81
-
82
-
83
-
84
-
85
 
86
  # load weights
87
- model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/lora_adaptor')
88
  merged_model = model_to_merge.merge_and_unload()
89
- projection.load_state_dict(torch.load('./model_chkpt/finetunned_projection.pth',map_location=torch.device(device)))
90
- resblock.load_state_dict(torch.load('./model_chkpt/finetuned_resblock.pth',map_location=torch.device(device)))
91
 
92
  def model_generate_ans(img=None,img_audio=None,val_q=None):
93
 
@@ -126,20 +83,20 @@ def model_generate_ans(img=None,img_audio=None,val_q=None):
126
  val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
127
  val_combined_embeds.append(val_q_embeds)
128
 
129
-
130
- if img_audio is not None or len(val_q) != 0: # add QA Token
131
-
132
- QA_token_tensor = torch.tensor(QA_TOKEN_ID).to(device)
133
- QA_token_embeds = merged_model.model.embed_tokens(QA_token_tensor).unsqueeze(0).unsqueeze(0)
134
- val_combined_embeds.append(QA_token_embeds)
135
-
136
  val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
137
- predicted_caption = merged_model.generate(inputs_embeds=val_combined_embeds,
138
- max_new_tokens=max_generate_length,
139
- return_dict_in_generate = True)
140
 
141
- predicted_captions_decoded = tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
142
- predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>", "")
 
 
 
 
 
 
 
143
 
144
  return predicted_captions_decoded
145
 
@@ -165,5 +122,4 @@ with gr.Blocks() as demo:
165
  section_btn = gr.Button("Submit")
166
  section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
167
 
168
- if __name__ == "__main__":
169
- demo.launch()
 
6
  from peft import PeftModel
7
  import torch.nn as nn
8
  import whisperx
9
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
10
  clip_model_name = "openai/clip-vit-base-patch32"
11
  phi_model_name = "microsoft/phi-2"
12
  tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
13
  processor = AutoProcessor.from_pretrained(clip_model_name)
14
  tokenizer.pad_token = tokenizer.eos_token
15
  IMAGE_TOKEN_ID = 23893 # token for word comment
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  clip_embed = 768
18
  phi_embed = 2560
19
+ compute_type = "float32"
20
  audio_batch_size = 16
21
 
22
  class SimpleResBlock(nn.Module):
 
31
  def forward(self, x):
32
  x = self.pre_norm(x)
33
  return x + self.proj(x)
34
+
 
 
35
  # models
36
  clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
37
  projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
38
  resblock = SimpleResBlock(phi_embed).to(device)
39
  phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
40
+ # Assuming you have defined 'device' and 'compute_type' elsewhere
41
+ audio_model = whisperx.load_model("tiny", device, compute_type=compute_type, asr_options={'max_new_tokens': 2048, 'clip_timestamps': True, 'hallucination_silence_threshold': 0.25, 'hotwords': []})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # load weights
44
+ model_to_merge = PeftModel.from_pretrained(phi_model,os.path.join(os.getcwd(), 'model_chkpt/lora_adaptor'))
45
  merged_model = model_to_merge.merge_and_unload()
46
+ projection.load_state_dict(torch.load(os.path.join(os.getcwd(),'model_chkpt/finetunned_projection.pth'),map_location=torch.device(device)))
47
+ resblock.load_state_dict(torch.load(os.path.join(os.getcwd(),'model_chkpt/finetuned_resblock.pth'),map_location=torch.device(device)))
48
 
49
  def model_generate_ans(img=None,img_audio=None,val_q=None):
50
 
 
83
  val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
84
  val_combined_embeds.append(val_q_embeds)
85
 
 
 
 
 
 
 
 
86
  val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
87
+
88
+ #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
89
+ predicted_caption = torch.full((1,max_generate_length),50256).to(device)
90
 
91
+ for g in range(max_generate_length):
92
+ phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
93
+ predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
94
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
95
+ predicted_caption[:,g] = predicted_word_token.view(1,-1)
96
+ next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
97
+ val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
98
+
99
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
100
 
101
  return predicted_captions_decoded
102
 
 
122
  section_btn = gr.Button("Submit")
123
  section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
124
 
125
+ demo.launch()