|
import io |
|
import os |
|
import torch |
|
import logging |
|
import tempfile |
|
import numpy as np |
|
from typing import Optional, Dict, Any |
|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
from fastapi.responses import JSONResponse, HTMLResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import uvicorn |
|
|
|
|
|
import gradio as gr |
|
import librosa |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
try: |
|
from normalizers import get_normalizer, NORMALIZERS, normalize_hindi, normalize_bengali, normalize_tamil, get_language_info |
|
NORMALIZERS_AVAILABLE = True |
|
logger.info("β
Enhanced normalizers loaded successfully") |
|
except ImportError as e: |
|
logger.warning(f"Normalizers not available: {e}") |
|
NORMALIZERS_AVAILABLE = False |
|
NORMALIZERS = {} |
|
|
|
try: |
|
from language_detector import detect_language, IndicLanguageDetector, get_language_name |
|
LANGUAGE_DETECTOR_AVAILABLE = True |
|
logger.info("β
Enhanced language detector loaded successfully") |
|
except ImportError as e: |
|
logger.warning(f"Language detector not available: {e}") |
|
LANGUAGE_DETECTOR_AVAILABLE = False |
|
|
|
|
|
try: |
|
from indic_nlp import common |
|
INDIC_RESOURCES_PATH = "./indic_nlp_resources" |
|
if os.path.exists(INDIC_RESOURCES_PATH): |
|
common.set_resources_path(INDIC_RESOURCES_PATH) |
|
logger.info("β
IndicNLP resources configured") |
|
except ImportError: |
|
logger.warning("IndicNLP not available") |
|
|
|
|
|
conformer_model = None |
|
models_loaded = False |
|
language_detector_instance = None |
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
MAX_FILE_SIZE = 25 * 1024 * 1024 |
|
SUPPORTED_FORMATS = {'.wav', '.mp3', '.m4a', '.flac', '.ogg'} |
|
|
|
|
|
SUPPORTED_LANGUAGES = { |
|
'hi': 'Hindi', 'bn': 'Bengali', 'te': 'Telugu', 'ta': 'Tamil', |
|
'mr': 'Marathi', 'gu': 'Gujarati', 'kn': 'Kannada', 'ml': 'Malayalam', |
|
'pa': 'Punjabi', 'or': 'Odia', 'as': 'Assamese', 'ur': 'Urdu', |
|
'sa': 'Sanskrit', 'ne': 'Nepali', 'ks': 'Kashmiri', 'sd': 'Sindhi', |
|
'doi': 'Dogri', 'brx': 'Bodo', 'sat': 'Santali', 'mai': 'Maithili', |
|
'mni': 'Manipuri', 'gom': 'Konkani', 'en': 'English' |
|
} |
|
|
|
class MultiIndicASR: |
|
"""Enhanced Multi-language ASR system for all 22 Indian languages""" |
|
|
|
def __init__(self): |
|
self.model = None |
|
self.tokenizer = None |
|
self.device = self.get_device() |
|
self.language_detector = None |
|
|
|
if LANGUAGE_DETECTOR_AVAILABLE: |
|
try: |
|
self.language_detector = IndicLanguageDetector() |
|
logger.info("β
Enhanced language detector initialized") |
|
except Exception as e: |
|
logger.warning(f"Language detector initialization failed: {e}") |
|
|
|
def get_device(self): |
|
if torch.cuda.is_available(): |
|
return "cuda" |
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
|
return "mps" |
|
else: |
|
return "cpu" |
|
|
|
def load_models(self): |
|
try: |
|
logger.info("π Loading IndicConformer-600M-Multilingual...") |
|
model_name = "ai4bharat/indic-conformer-600m-multilingual" |
|
try: |
|
from transformers import AutoModel, AutoTokenizer |
|
self.model = AutoModel.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float32 if self.device == "cpu" else torch.float16, |
|
trust_remote_code=True, |
|
cache_dir="./models" |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
cache_dir="./models" |
|
) |
|
except Exception as e: |
|
logger.warning(f"Primary model failed: {e}, trying fallback...") |
|
model_name = "parthiv11/indic_whisper_nodcil" |
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor |
|
self.model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float32, |
|
trust_remote_code=True, |
|
cache_dir="./models" |
|
) |
|
self.tokenizer = AutoProcessor.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
cache_dir="./models" |
|
) |
|
self.model = self.model.to(self.device) |
|
self.model.eval() |
|
logger.info(f"β
Model loaded successfully on {self.device}") |
|
return True |
|
except Exception as e: |
|
logger.error(f"β Model loading failed completely: {e}") |
|
return False |
|
|
|
def detect_language_enhanced(self, text: str, audio_duration: float = 0) -> Dict[str, Any]: |
|
if self.language_detector: |
|
try: |
|
language = self.language_detector.detect_language(text) |
|
confidence = self.language_detector.get_language_confidence(text, language) |
|
if audio_duration > 0: |
|
duration_boost = min(audio_duration / 10.0, 0.1) |
|
confidence = min(confidence + duration_boost, 1.0) |
|
return { |
|
"language": language, |
|
"confidence": confidence, |
|
"language_name": get_language_name(language) if LANGUAGE_DETECTOR_AVAILABLE else SUPPORTED_LANGUAGES.get(language, 'Unknown'), |
|
"detection_method": "multi_strategy" |
|
} |
|
except Exception as e: |
|
logger.warning(f"Enhanced language detection failed: {e}") |
|
return { |
|
"language": "hi", |
|
"confidence": 0.5, |
|
"language_name": "Hindi", |
|
"detection_method": "fallback" |
|
} |
|
|
|
def preprocess_audio(self, audio_data: bytes) -> np.ndarray: |
|
try: |
|
audio_array, sr = librosa.load(io.BytesIO(audio_data), sr=SAMPLE_RATE, mono=True) |
|
if len(audio_array) > 0: |
|
max_val = np.max(np.abs(audio_array)) |
|
if max_val > 0: |
|
audio_array = audio_array / max_val |
|
min_samples = SAMPLE_RATE * 3 |
|
if len(audio_array) < min_samples: |
|
padding = min_samples - len(audio_array) |
|
audio_array = np.pad(audio_array, (0, padding)) |
|
return audio_array |
|
except Exception as e: |
|
logger.error(f"Audio preprocessing failed: {e}") |
|
raise e |
|
|
|
def transcribe_with_model(self, audio_array: np.ndarray, language: str) -> Dict[str, Any]: |
|
try: |
|
audio_tensor = torch.FloatTensor(audio_array).unsqueeze(0) |
|
if self.device != "cpu": |
|
audio_tensor = audio_tensor.to(self.device) |
|
with torch.no_grad(): |
|
if hasattr(self.model, '__call__') and hasattr(self.model, '__module__'): |
|
try: |
|
if self.model is not None: |
|
result = self.model(audio_tensor, language, "rnnt") |
|
return { |
|
'text': result, |
|
'confidence': 0.95, |
|
'model': 'IndicConformer-600M' |
|
} |
|
else: |
|
logger.error("Model is not loaded (None).") |
|
return {'text': "", 'confidence': 0.0, 'model': 'None', 'error': 'Model not loaded'} |
|
except: |
|
pass |
|
if self.tokenizer is not None and hasattr(self.tokenizer, '__call__'): |
|
if self.model is None: |
|
logger.error("Model is not loaded (None).") |
|
return {'text': "", 'confidence': 0.0, 'model': 'None', 'error': 'Model not loaded'} |
|
inputs = self.tokenizer( |
|
audio_array, |
|
sampling_rate=SAMPLE_RATE, |
|
return_tensors="pt" |
|
) |
|
input_features = inputs["input_features"].to(self.device) |
|
predicted_ids = self.model.generate( |
|
input_features, |
|
max_length=448, |
|
num_beams=1, |
|
temperature=0.0 |
|
) |
|
transcription = self.tokenizer.batch_decode( |
|
predicted_ids, |
|
skip_special_tokens=True |
|
)[0].strip() |
|
return {'text': transcription, 'confidence': 0.9, 'model': 'IndicWhisper'} |
|
elif self.tokenizer is None: |
|
logger.error("Tokenizer is not loaded (None).") |
|
return {'text': "", 'confidence': 0.0, 'model': 'None', 'error': 'Tokenizer not loaded'} |
|
return {'text': "", 'confidence': 0.0, 'model': 'Unknown', 'error': 'Model type not recognized'} |
|
except Exception as e: |
|
logger.error(f"Model transcription failed: {e}") |
|
return {'text': '', 'confidence': 0.0, 'model': 'Failed', 'error': str(e)} |
|
|
|
def normalize_text_enhanced(self, text: str, language: str) -> str: |
|
if not text.strip(): |
|
return "" |
|
if NORMALIZERS_AVAILABLE: |
|
try: |
|
normalizer = get_normalizer(language) |
|
normalized = normalizer.normalize(text) |
|
return normalized |
|
except Exception as e: |
|
logger.warning(f"Normalization failed for {language}: {e}") |
|
return text.strip() |
|
|
|
def transcribe(self, audio_data: bytes, target_language: Optional[str] = None) -> Dict[str, Any]: |
|
try: |
|
audio_array = self.preprocess_audio(audio_data) |
|
audio_duration = len(audio_array) / SAMPLE_RATE |
|
if not target_language: |
|
quick_result = self.transcribe_with_model(audio_array, 'hi') |
|
if quick_result['text']: |
|
lang_detection = self.detect_language_enhanced(quick_result['text'], audio_duration) |
|
target_language = lang_detection['language'] |
|
else: |
|
target_language = 'hi' |
|
if target_language not in SUPPORTED_LANGUAGES: |
|
target_language = 'hi' |
|
transcription_result = self.transcribe_with_model(audio_array, target_language) |
|
raw_text = transcription_result['text'] |
|
normalized_text = self.normalize_text_enhanced(raw_text, target_language) |
|
lang_detection = self.detect_language_enhanced(raw_text, audio_duration) |
|
return { |
|
"transcription": normalized_text, |
|
"raw_transcription": raw_text, |
|
"language": target_language, |
|
"language_info": get_language_info(target_language) if NORMALIZERS_AVAILABLE else {"name": SUPPORTED_LANGUAGES.get(target_language, "Unknown")}, |
|
"detected_language": lang_detection.get("language", target_language), |
|
"language_confidence": lang_detection.get("confidence", 0.5), |
|
"confidence": transcription_result['confidence'], |
|
"model": transcription_result['model'], |
|
"audio_duration_seconds": audio_duration, |
|
"normalization_applied": NORMALIZERS_AVAILABLE, |
|
"detection_method": lang_detection.get("detection_method", "fallback"), |
|
"status": "success" |
|
} |
|
except Exception as e: |
|
logger.error(f"Complete transcription failed: {e}") |
|
return { |
|
"error": f"Transcription failed: {str(e)}", |
|
"transcription": "", |
|
"language": "unknown", |
|
"status": "error" |
|
} |
|
|
|
|
|
asr_engine = MultiIndicASR() |
|
|
|
def load_models(): |
|
global models_loaded |
|
try: |
|
models_loaded = asr_engine.load_models() |
|
if models_loaded: |
|
logger.info("β
All models loaded successfully!") |
|
else: |
|
logger.error("β Model loading failed") |
|
except Exception as e: |
|
logger.error(f"β Model loading error: {e}") |
|
models_loaded = False |
|
|
|
|
|
logger.info("π Loading models for API...") |
|
load_models() |
|
|
|
|
|
app = FastAPI( |
|
title="Enhanced Multi-Indic ASR API", |
|
description="Enhanced ASR for 22+ Indian languages with normalization and detection", |
|
version="1.0.0", |
|
docs_url="/docs", |
|
redoc_url="/redoc" |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def root(): |
|
return """ |
|
<html> |
|
<head><title>Enhanced Multi-Indic ASR API</title></head> |
|
<body> |
|
<h1>π€ Enhanced Multi-Indic ASR API</h1> |
|
<p>Go to <a href="/docs">/docs</a> for Swagger UI.</p> |
|
</body> |
|
</html> |
|
""" |
|
|
|
@app.get("/health") |
|
async def health(): |
|
return { |
|
"status": "healthy" if models_loaded else "loading", |
|
"models_loaded": models_loaded, |
|
"device": asr_engine.device, |
|
"normalizers_available": NORMALIZERS_AVAILABLE, |
|
"language_detector_available": LANGUAGE_DETECTOR_AVAILABLE |
|
} |
|
|
|
@app.post("/transcribe") |
|
async def transcribe_api(file: UploadFile = File(...), language: Optional[str] = None): |
|
if not file.filename: |
|
raise HTTPException(status_code=400, detail="No file uploaded") |
|
audio_data = await file.read() |
|
if len(audio_data) > MAX_FILE_SIZE: |
|
raise HTTPException(status_code=413, detail="File too large") |
|
result = asr_engine.transcribe(audio_data, target_language=language) |
|
return JSONResponse(result) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
port = int(os.environ.get("PORT", 7860)) |
|
uvicorn.run( |
|
"app:app", |
|
host="0.0.0.0", |
|
port=port, |
|
reload=False |
|
) |
|
|
|
|