Spaces:
Running
Running
gpt-omni
commited on
Commit
·
58227c7
1
Parent(s):
369b919
update
Browse files- inference.py +4 -3
inference.py
CHANGED
|
@@ -138,6 +138,7 @@ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
|
|
| 138 |
return torch.stack([audio_feature, audio_feature]), stacked_inputids
|
| 139 |
|
| 140 |
|
|
|
|
| 141 |
def load_audio(path):
|
| 142 |
audio = whisper.load_audio(path)
|
| 143 |
duration_ms = (len(audio) / 16000) * 1000
|
|
@@ -357,7 +358,7 @@ def load_model(ckpt_dir, device):
|
|
| 357 |
config.post_adapter = False
|
| 358 |
|
| 359 |
with fabric.init_module(empty_init=False):
|
| 360 |
-
model = GPT(config)
|
| 361 |
|
| 362 |
# model = fabric.setup(model)
|
| 363 |
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
|
@@ -401,8 +402,8 @@ class OmniInference:
|
|
| 401 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
| 402 |
model = self.model
|
| 403 |
|
| 404 |
-
with self.fabric.init_tensor():
|
| 405 |
-
|
| 406 |
|
| 407 |
mel, leng = load_audio(audio_path)
|
| 408 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
|
|
|
| 138 |
return torch.stack([audio_feature, audio_feature]), stacked_inputids
|
| 139 |
|
| 140 |
|
| 141 |
+
@spaces.GPU
|
| 142 |
def load_audio(path):
|
| 143 |
audio = whisper.load_audio(path)
|
| 144 |
duration_ms = (len(audio) / 16000) * 1000
|
|
|
|
| 358 |
config.post_adapter = False
|
| 359 |
|
| 360 |
with fabric.init_module(empty_init=False):
|
| 361 |
+
model = GPT(config, device=device)
|
| 362 |
|
| 363 |
# model = fabric.setup(model)
|
| 364 |
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
|
|
|
| 402 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
| 403 |
model = self.model
|
| 404 |
|
| 405 |
+
# with self.fabric.init_tensor():
|
| 406 |
+
model.set_kv_cache(batch_size=2)
|
| 407 |
|
| 408 |
mel, leng = load_audio(audio_path)
|
| 409 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|