guangyil commited on
Commit
750f337
·
verified ·
1 Parent(s): 4a8840b

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +13 -10
infer.py CHANGED
@@ -50,9 +50,10 @@ def load_model(model_name, audio_tokenizer_path):
50
  use_flash_attention_2=True,
51
  use_cache=True,
52
  )
53
- model = model.cuda()
54
  tokenizer = AutoTokenizer.from_pretrained(model_name)
55
- tokenizer_voila = VoilaTokenizer(model_path=audio_tokenizer_path, device="cuda")
 
56
  return model, tokenizer, tokenizer_voila, model_type
57
 
58
  def is_audio_output_task(task_type):
@@ -90,11 +91,11 @@ def eval_model(model, tokenizer, tokenizer_voila, model_type, task_type, history
90
  yield all_tokens[:,i]
91
 
92
  if model_type == "autonomous":
93
- input_generator = get_input_generator(torch.as_tensor(streaming_user_input_audio_tokens).cuda())
94
- input_ids = [torch.as_tensor([input]).transpose(1,2).cuda() for input in input_ids] # transpose to [bs, seq, num_codebooks]
95
  input_ids = torch.cat(input_ids, dim=2) # concat to [bs, seq, num_codebooks*2]
96
  else:
97
- input_ids = torch.as_tensor([input_ids]).transpose(1,2).cuda() # transpose to [bs, seq, num_codebooks]
98
  gen_params = {
99
  "input_ids": input_ids,
100
  "ref_embs": ref_embs,
@@ -110,8 +111,8 @@ def eval_model(model, tokenizer, tokenizer_voila, model_type, task_type, history
110
  "audio_top_k": 50,
111
  }
112
  if model_type == "audio":
113
- audio_datas = torch.tensor([audio_datas], dtype=torch.bfloat16).cuda()
114
- audio_data_masks = torch.tensor([audio_data_masks]).cuda()
115
  gen_params["audio_datas"] = audio_datas
116
  gen_params["audio_data_masks"] = audio_data_masks
117
  elif model_type == "autonomous":
@@ -141,7 +142,7 @@ def eval_model(model, tokenizer, tokenizer_voila, model_type, task_type, history
141
  'text': tokenizer.decode(text_outputs),
142
  }
143
  if is_audio_output_task(task_type):
144
- audio_values = tokenizer_voila.decode(torch.tensor(audio_outputs).cuda())
145
  out['audio'] = (audio_values.detach().cpu().numpy(), 16000)
146
  return out
147
 
@@ -185,10 +186,12 @@ if __name__ == "__main__":
185
  # step2: encode ref
186
  ref_embs, ref_embs_mask = None, None
187
  if is_audio_output_task(args.task_type):
188
- spkr_model = SpeakerEmbedding(device="cuda")
 
 
189
  wav, sr = torchaudio.load(args.ref_audio)
190
  ref_embs = spkr_model(wav, sr)
191
- ref_embs_mask = torch.tensor([1]).cuda()
192
 
193
  out = eval_model(model, tokenizer, tokenizer_voila, model_type, args.task_type, history, ref_embs, ref_embs_mask)
194
  print(f"Output str: {out['text']}")
 
50
  use_flash_attention_2=True,
51
  use_cache=True,
52
  )
53
+ model = model.to("cuda")
54
  tokenizer = AutoTokenizer.from_pretrained(model_name)
55
+ tokenizer_voila = VoilaTokenizer(model_path=audio_tokenizer_path, device="cpu")
56
+ tokenizer_voila.to("cuda")
57
  return model, tokenizer, tokenizer_voila, model_type
58
 
59
  def is_audio_output_task(task_type):
 
91
  yield all_tokens[:,i]
92
 
93
  if model_type == "autonomous":
94
+ input_generator = get_input_generator(torch.as_tensor(streaming_user_input_audio_tokens).to('cuda'))
95
+ input_ids = [torch.as_tensor([input]).transpose(1,2).to('cuda') for input in input_ids] # transpose to [bs, seq, num_codebooks]
96
  input_ids = torch.cat(input_ids, dim=2) # concat to [bs, seq, num_codebooks*2]
97
  else:
98
+ input_ids = torch.as_tensor([input_ids]).transpose(1,2).to('cuda') # transpose to [bs, seq, num_codebooks]
99
  gen_params = {
100
  "input_ids": input_ids,
101
  "ref_embs": ref_embs,
 
111
  "audio_top_k": 50,
112
  }
113
  if model_type == "audio":
114
+ audio_datas = torch.tensor([audio_datas], dtype=torch.bfloat16).to('cuda')
115
+ audio_data_masks = torch.tensor([audio_data_masks]).to('cuda')
116
  gen_params["audio_datas"] = audio_datas
117
  gen_params["audio_data_masks"] = audio_data_masks
118
  elif model_type == "autonomous":
 
142
  'text': tokenizer.decode(text_outputs),
143
  }
144
  if is_audio_output_task(task_type):
145
+ audio_values = tokenizer_voila.decode(torch.tensor(audio_outputs).to('cuda'))
146
  out['audio'] = (audio_values.detach().cpu().numpy(), 16000)
147
  return out
148
 
 
186
  # step2: encode ref
187
  ref_embs, ref_embs_mask = None, None
188
  if is_audio_output_task(args.task_type):
189
+ spkr_model = SpeakerEmbedding(device="cpu")
190
+ spkr_model.model.to("cuda")
191
+ spkr_model.device = "cuda"
192
  wav, sr = torchaudio.load(args.ref_audio)
193
  ref_embs = spkr_model(wav, sr)
194
+ ref_embs_mask = torch.tensor([1]).to('cuda')
195
 
196
  out = eval_model(model, tokenizer, tokenizer_voila, model_type, args.task_type, history, ref_embs, ref_embs_mask)
197
  print(f"Output str: {out['text']}")