deepugaur's picture
Update app.py
c328e5c verified
raw
history blame
1.97 kB
from flask import Flask, request, jsonify
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer, pipeline
from pydub import AudioSegment
import torch
import torchaudio
from datetime import datetime, time
import pytz
app = Flask(__name__)
# Load speech recognition model and tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
# Load translation model explicitly
translation_model = "Helsinki-NLP/opus-mt-en-hi"
translation_pipeline = pipeline("translation", model=translation_model)
# Function to preprocess audio
def preprocess_audio(audio_file):
audio = AudioSegment.from_file(audio_file)
audio = audio.set_frame_rate(16000)
audio.export("processed.wav", format="wav")
waveform, sample_rate = torchaudio.load("processed.wav")
return waveform
# Function to check if the current time is after 6 PM IST
def is_after_6pm_ist():
ist = pytz.timezone('Asia/Kolkata')
current_time = datetime.now(ist).time()
return current_time >= time(18, 0)
@app.route('/translate', methods=['POST'])
def translate():
if not is_after_6pm_ist():
return jsonify({"error": "Service is available only after 6 PM IST"}), 403
if 'audio' not in request.files:
return jsonify({"error": "No audio file provided"}), 400
audio_file = request.files['audio']
waveform = preprocess_audio(audio_file)
input_values = tokenizer(waveform.squeeze().numpy(), return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids)[0]
translation = translation_pipeline(transcription)
translated_text = translation[0]['translation_text']
return jsonify({"transcription": transcription, "translation": translated_text})
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=8080)