File size: 1,114 Bytes
ac7cda5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from ..utils.load_model import load_model


class HubertStreaming:
    def __init__(self, model_path, device="cuda", **kwargs):
        kwargs["model_file"] = model_path
        kwargs["module_name"] = "HubertStreamingONNX"
        kwargs["package_name"] = "..aux_models.modules"

        self.model, self.model_type = load_model(model_path, device=device, **kwargs)
        self.device = device

    def forward_chunk(self, audio_chunk):
        if self.model_type == "onnx":
            output = self.model.run(None, {"input_values": audio_chunk.reshape(1, -1)})[0]
        elif self.model_type == "tensorrt":
            self.model.setup({"input_values": audio_chunk.reshape(1, -1)})
            self.model.infer()
            output = self.model.buffer["encoding_out"][0]
        else:
            raise ValueError(f"Unsupported model type: {self.model_type}")
        return output
    
    def __call__(self, audio_chunk):
        if self.model_type == "ori":
            output = self.model.forward_chunk(audio_chunk)
        else:
            output = self.forward_chunk(audio_chunk)
        return output