AkashDataScience commited on
Commit
b52bed7
·
1 Parent(s): 2d191f6

Added inference

Browse files
Files changed (2) hide show
  1. app.py +54 -1
  2. requirements.txt +2 -0
app.py CHANGED
@@ -7,6 +7,7 @@ import torch.nn as nn
7
  from model import Projections
8
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
9
  import gradio as gr
 
10
 
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
  projections = Projections(512, 3072)
@@ -47,7 +48,59 @@ whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
47
  whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
48
 
49
  def infer(message, history):
50
- return message.keys()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  examples=[{'text':"I am planning to buy a dog and a cat. Suggest some breeds that get along with each other"},
53
  {'text':"Explain biased coin flip"},
 
7
  from model import Projections
8
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
9
  import gradio as gr
10
+ import librosa
11
 
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
  projections = Projections(512, 3072)
 
48
  whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
49
 
50
  def infer(message, history):
51
+ max_generate_length = 100
52
+ combined_embeds = []
53
+
54
+ with torch.no_grad():
55
+ if message['file']:
56
+ projected_image_embeds = None
57
+ audio_text_embeds = None
58
+ for path in message['file']:
59
+
60
+ if path.endswith(('.jpg', '.png', '.jpeg')):
61
+ image = clip_preprocess(Image.open(path)).unsqueeze(0).to(device)
62
+ image_features = clip_model.encode_image(image)
63
+ projected_image_embeds = projections(image_features.to(torch.bfloat16)).unsqueeze(0)
64
+
65
+ elif path.endswith(('.mp3', '.wav')):
66
+ # Load and preprocess the audio
67
+ speech, rate = librosa.load(path, sr=16000)
68
+ input_features = whisper_processor(speech, return_tensors="pt", sampling_rate=16000).input_features
69
+ predicted_ids = whisper_model.generate(input_features)
70
+ transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
71
+ prompt = tokenizer.apply_chat_template([{"from": "human", "value": transcription}], tokenize=False, add_generation_prompt=True)
72
+ prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids']
73
+ audio_text_embeds = model.get_input_embeddings()(prompt_tokens)
74
+
75
+ if projected_image_embeds:
76
+ combined_embeds.append(projected_image_embeds)
77
+
78
+ if audio_text_embeds:
79
+ combined_embeds.append(audio_text_embeds)
80
+
81
+ if message['text']:
82
+ prompt = tokenizer.apply_chat_template([{"from": "human", "value": transcription}], tokenize=False, add_generation_prompt=True)
83
+ prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids']
84
+ text_embeds = model.get_input_embeddings()(prompt_tokens)
85
+ combined_embeds.append(text_embeds)
86
+
87
+ combined_embeds = torch.cat(combined_embeds,dim=1)
88
+
89
+ #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
90
+ predicted_caption = torch.full((1,max_generate_length),50256).to(device)
91
+
92
+ for g in range(max_generate_length):
93
+ phi_output_logits = model(inputs_embeds=combined_embeds)['logits'] # 4, 69, 51200
94
+ predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
95
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
96
+ predicted_caption[:,g] = predicted_word_token.view(1,-1)
97
+ next_token_embeds = model.get_input_embeddings()(prompt_tokens) # 4,1,2560
98
+ combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
99
+
100
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
101
+
102
+ return predicted_captions_decoded
103
+
104
 
105
  examples=[{'text':"I am planning to buy a dog and a cat. Suggest some breeds that get along with each other"},
106
  {'text':"Explain biased coin flip"},
requirements.txt CHANGED
@@ -3,6 +3,8 @@ clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8
3
  colorama==0.4.6
4
  datasets==3.0.0
5
  dill==0.3.8
 
 
6
  multiprocess==0.70.16
7
  numpy==1.26.4
8
  pandas==2.2.2
 
3
  colorama==0.4.6
4
  datasets==3.0.0
5
  dill==0.3.8
6
+ gradio==5.0.2
7
+ librosa==0.10.2
8
  multiprocess==0.70.16
9
  numpy==1.26.4
10
  pandas==2.2.2