vumichien commited on
Commit
e8e18ec
·
1 Parent(s): c955430

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +57 -0
main.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from fastapi import FastAPI, File
3
+ from faster_whisper import WhisperModel
4
+ from utils import ffmpeg_read, stt
5
+ from sentence_transformers import SentenceTransformer, util
6
+ import torch
7
+
8
+ app = FastAPI()
9
+
10
+ whisper_models = ["tiny", "base", "small", "medium", "large-v1", "large-v2"]
11
+ audio_model = WhisperModel("base", compute_type="int8", device="cpu")
12
+ text_model = SentenceTransformer('all-MiniLM-L6-v2')
13
+ corpus_embeddings = torch.load('corpus_embeddings.pt')
14
+
15
+
16
+ def speech_to_text(upload_audio, model_type="whisper"):
17
+ """
18
+ Transcribe audio using whisper model.
19
+ """
20
+ audio_path = ffmpeg_read(upload_audio, sampling_rate=16000)
21
+ # Transcribe audio
22
+ if model_type == "whisper":
23
+ transcribe_options = dict(task="transcribe", language="ja", beam_size=5, best_of=5, vad_filter=True)
24
+ segments_raw, info = audio_model.transcribe(audio_path, **transcribe_options)
25
+ segments = [segment.text for segment in segments_raw]
26
+ return ' '.join(segments)
27
+ else:
28
+ text = stt(audio_path)
29
+ return text
30
+
31
+
32
+ @app.get("/")
33
+ def read_root():
34
+ return {"Message": "Application startup complete"}
35
+
36
+
37
+ @app.post("/voice_detect/")
38
+ async def voice_detect_api(
39
+ voice_input: bytes = File(None),
40
+ threshold: float = 0.8,
41
+ model_type: str = "whisper"
42
+ ):
43
+ """
44
+ API to detect voice from audio file.
45
+ """
46
+ start = time.time()
47
+ text = speech_to_text(voice_input, model_type)
48
+ query_embedding = text_model.encode(text, convert_to_tensor=True)
49
+ hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=1)[0]
50
+ if hits[0]['score'] > threshold:
51
+ similar = 1
52
+ else:
53
+ similar = 0
54
+ end = time.time()
55
+ return {"text": text,
56
+ "similar": similar,
57
+ "time_taken": end - start}