greeting_api / main.py
vumichien's picture
Create main.py
e8e18ec
import time
from fastapi import FastAPI, File
from faster_whisper import WhisperModel
from utils import ffmpeg_read, stt
from sentence_transformers import SentenceTransformer, util
import torch
app = FastAPI()
whisper_models = ["tiny", "base", "small", "medium", "large-v1", "large-v2"]
audio_model = WhisperModel("base", compute_type="int8", device="cpu")
text_model = SentenceTransformer('all-MiniLM-L6-v2')
corpus_embeddings = torch.load('corpus_embeddings.pt')
def speech_to_text(upload_audio, model_type="whisper"):
"""
Transcribe audio using whisper model.
"""
audio_path = ffmpeg_read(upload_audio, sampling_rate=16000)
# Transcribe audio
if model_type == "whisper":
transcribe_options = dict(task="transcribe", language="ja", beam_size=5, best_of=5, vad_filter=True)
segments_raw, info = audio_model.transcribe(audio_path, **transcribe_options)
segments = [segment.text for segment in segments_raw]
return ' '.join(segments)
else:
text = stt(audio_path)
return text
@app.get("/")
def read_root():
return {"Message": "Application startup complete"}
@app.post("/voice_detect/")
async def voice_detect_api(
voice_input: bytes = File(None),
threshold: float = 0.8,
model_type: str = "whisper"
):
"""
API to detect voice from audio file.
"""
start = time.time()
text = speech_to_text(voice_input, model_type)
query_embedding = text_model.encode(text, convert_to_tensor=True)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=1)[0]
if hits[0]['score'] > threshold:
similar = 1
else:
similar = 0
end = time.time()
return {"text": text,
"similar": similar,
"time_taken": end - start}