vumichien commited on
Commit
f787cd1
·
1 Parent(s): 532a2ea

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +34 -0
main.py CHANGED
@@ -1,4 +1,38 @@
1
  from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
 
1
  from fastapi import FastAPI
2
+ import datetime
3
+ import torch
4
+ import os
5
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM, AutoConfig
6
+ from huggingface_hub import hf_hub_download
7
+ from fuzzywuzzy import fuzz
8
+ from utils import ffmpeg_read, query_dummy, query_raw, find_different
9
+
10
+ ## config
11
+ API_TOKEN = os.environ["API_TOKEN"]
12
+ MODEL_PATH = os.environ["MODEL_PATH"]
13
+ PITCH_PATH = os.environ["PITCH_PATH"]
14
+ QUANTIZED_MODEL_PATH = hf_hub_download(repo_id=MODEL_PATH, filename='quantized_model.pt', token=API_TOKEN)
15
+ QUANTIZED_PITCH_MODEL_PATH = hf_hub_download(repo_id=PITCH_PATH, filename='quantized_model.pt', token=API_TOKEN)
16
+
17
+
18
+ ## word preprocessor
19
+ processor_with_lm = Wav2Vec2ProcessorWithLM.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
20
+ processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
21
+
22
+ ### quantized model
23
+ config = AutoConfig.from_pretrained(MODEL_PATH, use_auth_token=API_TOKEN)
24
+ dummy_model = Wav2Vec2ForCTC(config)
25
+ quantized_model = torch.quantization.quantize_dynamic(dummy_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
26
+ quantized_model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH))
27
+
28
+ ## pitch preprocessor
29
+ processor_pitch = Wav2Vec2Processor.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN)
30
+
31
+ ### quantized pitch mode
32
+ config = AutoConfig.from_pretrained(PITCH_PATH, use_auth_token=API_TOKEN)
33
+ dummy_pitch_model = Wav2Vec2ForCTC(config)
34
+ quantized_pitch_model = torch.quantization.quantize_dynamic(dummy_pitch_model, {torch.nn.Linear}, dtype=torch.qint8, inplace=True)
35
+ quantized_pitch_model.load_state_dict(torch.load(QUANTIZED_PITCH_MODEL_PATH))
36
 
37
  app = FastAPI()
38