Spaces:
Sleeping
Sleeping
File size: 1,972 Bytes
8dcb583 c328e5c 8dcb583 6ff3417 c328e5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from flask import Flask, request, jsonify
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer, pipeline
from pydub import AudioSegment
import torch
import torchaudio
from datetime import datetime, time
import pytz
app = Flask(__name__)
# Load speech recognition model and tokenizer
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
# Load translation model explicitly
translation_model = "Helsinki-NLP/opus-mt-en-hi"
translation_pipeline = pipeline("translation", model=translation_model)
# Function to preprocess audio
def preprocess_audio(audio_file):
audio = AudioSegment.from_file(audio_file)
audio = audio.set_frame_rate(16000)
audio.export("processed.wav", format="wav")
waveform, sample_rate = torchaudio.load("processed.wav")
return waveform
# Function to check if the current time is after 6 PM IST
def is_after_6pm_ist():
ist = pytz.timezone('Asia/Kolkata')
current_time = datetime.now(ist).time()
return current_time >= time(18, 0)
@app.route('/translate', methods=['POST'])
def translate():
if not is_after_6pm_ist():
return jsonify({"error": "Service is available only after 6 PM IST"}), 403
if 'audio' not in request.files:
return jsonify({"error": "No audio file provided"}), 400
audio_file = request.files['audio']
waveform = preprocess_audio(audio_file)
input_values = tokenizer(waveform.squeeze().numpy(), return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids)[0]
translation = translation_pipeline(transcription)
translated_text = translation[0]['translation_text']
return jsonify({"transcription": transcription, "translation": translated_text})
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=8080)
|