ASR-NEW / app.py
kasimali67
Update
08c72af
import io
import os
import torch
import logging
import tempfile
import numpy as np
from typing import Optional, Dict, Any
# NEW: FastAPI imports
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
# Keep Gradio imports in case you still want to run locally with UI
import gradio as gr
import librosa
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Import your custom modules with proper error handling
try:
from normalizers import get_normalizer, NORMALIZERS, normalize_hindi, normalize_bengali, normalize_tamil, get_language_info
NORMALIZERS_AVAILABLE = True
logger.info("βœ… Enhanced normalizers loaded successfully")
except ImportError as e:
logger.warning(f"Normalizers not available: {e}")
NORMALIZERS_AVAILABLE = False
NORMALIZERS = {}
try:
from language_detector import detect_language, IndicLanguageDetector, get_language_name
LANGUAGE_DETECTOR_AVAILABLE = True
logger.info("βœ… Enhanced language detector loaded successfully")
except ImportError as e:
logger.warning(f"Language detector not available: {e}")
LANGUAGE_DETECTOR_AVAILABLE = False
# Try to setup IndicNLP resources
try:
from indic_nlp import common
INDIC_RESOURCES_PATH = "./indic_nlp_resources"
if os.path.exists(INDIC_RESOURCES_PATH):
common.set_resources_path(INDIC_RESOURCES_PATH)
logger.info("βœ… IndicNLP resources configured")
except ImportError:
logger.warning("IndicNLP not available")
# Global variables
conformer_model = None
models_loaded = False
language_detector_instance = None
# Constants
SAMPLE_RATE = 16000
MAX_FILE_SIZE = 25 * 1024 * 1024 # 25MB
SUPPORTED_FORMATS = {'.wav', '.mp3', '.m4a', '.flac', '.ogg'}
# All 22+ Indian languages
SUPPORTED_LANGUAGES = {
'hi': 'Hindi', 'bn': 'Bengali', 'te': 'Telugu', 'ta': 'Tamil',
'mr': 'Marathi', 'gu': 'Gujarati', 'kn': 'Kannada', 'ml': 'Malayalam',
'pa': 'Punjabi', 'or': 'Odia', 'as': 'Assamese', 'ur': 'Urdu',
'sa': 'Sanskrit', 'ne': 'Nepali', 'ks': 'Kashmiri', 'sd': 'Sindhi',
'doi': 'Dogri', 'brx': 'Bodo', 'sat': 'Santali', 'mai': 'Maithili',
'mni': 'Manipuri', 'gom': 'Konkani', 'en': 'English'
}
class MultiIndicASR:
"""Enhanced Multi-language ASR system for all 22 Indian languages"""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = self.get_device()
self.language_detector = None
if LANGUAGE_DETECTOR_AVAILABLE:
try:
self.language_detector = IndicLanguageDetector()
logger.info("βœ… Enhanced language detector initialized")
except Exception as e:
logger.warning(f"Language detector initialization failed: {e}")
def get_device(self):
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
def load_models(self):
try:
logger.info("πŸ”„ Loading IndicConformer-600M-Multilingual...")
model_name = "ai4bharat/indic-conformer-600m-multilingual"
try:
from transformers import AutoModel, AutoTokenizer
self.model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch.float32 if self.device == "cpu" else torch.float16,
trust_remote_code=True,
cache_dir="./models"
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
cache_dir="./models"
)
except Exception as e:
logger.warning(f"Primary model failed: {e}, trying fallback...")
model_name = "parthiv11/indic_whisper_nodcil"
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name,
torch_dtype=torch.float32,
trust_remote_code=True,
cache_dir="./models"
)
self.tokenizer = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=True,
cache_dir="./models"
)
self.model = self.model.to(self.device)
self.model.eval()
logger.info(f"βœ… Model loaded successfully on {self.device}")
return True
except Exception as e:
logger.error(f"❌ Model loading failed completely: {e}")
return False
def detect_language_enhanced(self, text: str, audio_duration: float = 0) -> Dict[str, Any]:
if self.language_detector:
try:
language = self.language_detector.detect_language(text)
confidence = self.language_detector.get_language_confidence(text, language)
if audio_duration > 0:
duration_boost = min(audio_duration / 10.0, 0.1)
confidence = min(confidence + duration_boost, 1.0)
return {
"language": language,
"confidence": confidence,
"language_name": get_language_name(language) if LANGUAGE_DETECTOR_AVAILABLE else SUPPORTED_LANGUAGES.get(language, 'Unknown'),
"detection_method": "multi_strategy"
}
except Exception as e:
logger.warning(f"Enhanced language detection failed: {e}")
return {
"language": "hi",
"confidence": 0.5,
"language_name": "Hindi",
"detection_method": "fallback"
}
def preprocess_audio(self, audio_data: bytes) -> np.ndarray:
try:
audio_array, sr = librosa.load(io.BytesIO(audio_data), sr=SAMPLE_RATE, mono=True)
if len(audio_array) > 0:
max_val = np.max(np.abs(audio_array))
if max_val > 0:
audio_array = audio_array / max_val
min_samples = SAMPLE_RATE * 3
if len(audio_array) < min_samples:
padding = min_samples - len(audio_array)
audio_array = np.pad(audio_array, (0, padding))
return audio_array
except Exception as e:
logger.error(f"Audio preprocessing failed: {e}")
raise e
def transcribe_with_model(self, audio_array: np.ndarray, language: str) -> Dict[str, Any]:
try:
audio_tensor = torch.FloatTensor(audio_array).unsqueeze(0)
if self.device != "cpu":
audio_tensor = audio_tensor.to(self.device)
with torch.no_grad():
if hasattr(self.model, '__call__') and hasattr(self.model, '__module__'):
try:
if self.model is not None:
result = self.model(audio_tensor, language, "rnnt")
return {
'text': result,
'confidence': 0.95,
'model': 'IndicConformer-600M'
}
else:
logger.error("Model is not loaded (None).")
return {'text': "", 'confidence': 0.0, 'model': 'None', 'error': 'Model not loaded'}
except:
pass
if self.tokenizer is not None and hasattr(self.tokenizer, '__call__'):
if self.model is None:
logger.error("Model is not loaded (None).")
return {'text': "", 'confidence': 0.0, 'model': 'None', 'error': 'Model not loaded'}
inputs = self.tokenizer(
audio_array,
sampling_rate=SAMPLE_RATE,
return_tensors="pt"
)
input_features = inputs["input_features"].to(self.device)
predicted_ids = self.model.generate(
input_features,
max_length=448,
num_beams=1,
temperature=0.0
)
transcription = self.tokenizer.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0].strip()
return {'text': transcription, 'confidence': 0.9, 'model': 'IndicWhisper'}
elif self.tokenizer is None:
logger.error("Tokenizer is not loaded (None).")
return {'text': "", 'confidence': 0.0, 'model': 'None', 'error': 'Tokenizer not loaded'}
return {'text': "", 'confidence': 0.0, 'model': 'Unknown', 'error': 'Model type not recognized'}
except Exception as e:
logger.error(f"Model transcription failed: {e}")
return {'text': '', 'confidence': 0.0, 'model': 'Failed', 'error': str(e)}
def normalize_text_enhanced(self, text: str, language: str) -> str:
if not text.strip():
return ""
if NORMALIZERS_AVAILABLE:
try:
normalizer = get_normalizer(language)
normalized = normalizer.normalize(text)
return normalized
except Exception as e:
logger.warning(f"Normalization failed for {language}: {e}")
return text.strip()
def transcribe(self, audio_data: bytes, target_language: Optional[str] = None) -> Dict[str, Any]:
try:
audio_array = self.preprocess_audio(audio_data)
audio_duration = len(audio_array) / SAMPLE_RATE
if not target_language:
quick_result = self.transcribe_with_model(audio_array, 'hi')
if quick_result['text']:
lang_detection = self.detect_language_enhanced(quick_result['text'], audio_duration)
target_language = lang_detection['language']
else:
target_language = 'hi'
if target_language not in SUPPORTED_LANGUAGES:
target_language = 'hi'
transcription_result = self.transcribe_with_model(audio_array, target_language)
raw_text = transcription_result['text']
normalized_text = self.normalize_text_enhanced(raw_text, target_language)
lang_detection = self.detect_language_enhanced(raw_text, audio_duration)
return {
"transcription": normalized_text,
"raw_transcription": raw_text,
"language": target_language,
"language_info": get_language_info(target_language) if NORMALIZERS_AVAILABLE else {"name": SUPPORTED_LANGUAGES.get(target_language, "Unknown")},
"detected_language": lang_detection.get("language", target_language),
"language_confidence": lang_detection.get("confidence", 0.5),
"confidence": transcription_result['confidence'],
"model": transcription_result['model'],
"audio_duration_seconds": audio_duration,
"normalization_applied": NORMALIZERS_AVAILABLE,
"detection_method": lang_detection.get("detection_method", "fallback"),
"status": "success"
}
except Exception as e:
logger.error(f"Complete transcription failed: {e}")
return {
"error": f"Transcription failed: {str(e)}",
"transcription": "",
"language": "unknown",
"status": "error"
}
# Initialize ASR engine globally
asr_engine = MultiIndicASR()
def load_models():
global models_loaded
try:
models_loaded = asr_engine.load_models()
if models_loaded:
logger.info("βœ… All models loaded successfully!")
else:
logger.error("❌ Model loading failed")
except Exception as e:
logger.error(f"❌ Model loading error: {e}")
models_loaded = False
# Load models at startup
logger.info("πŸš€ Loading models for API...")
load_models()
# ----------------- FASTAPI APP START -------------------
app = FastAPI(
title="Enhanced Multi-Indic ASR API",
description="Enhanced ASR for 22+ Indian languages with normalization and detection",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Allow all origins (CORS)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", response_class=HTMLResponse)
async def root():
return """
<html>
<head><title>Enhanced Multi-Indic ASR API</title></head>
<body>
<h1>🎀 Enhanced Multi-Indic ASR API</h1>
<p>Go to <a href="/docs">/docs</a> for Swagger UI.</p>
</body>
</html>
"""
@app.get("/health")
async def health():
return {
"status": "healthy" if models_loaded else "loading",
"models_loaded": models_loaded,
"device": asr_engine.device,
"normalizers_available": NORMALIZERS_AVAILABLE,
"language_detector_available": LANGUAGE_DETECTOR_AVAILABLE
}
@app.post("/transcribe")
async def transcribe_api(file: UploadFile = File(...), language: Optional[str] = None):
if not file.filename:
raise HTTPException(status_code=400, detail="No file uploaded")
audio_data = await file.read()
if len(audio_data) > MAX_FILE_SIZE:
raise HTTPException(status_code=413, detail="File too large")
result = asr_engine.transcribe(audio_data, target_language=language)
return JSONResponse(result)
# ----------------- FASTAPI APP END -------------------
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(
"app:app",
host="0.0.0.0",
port=port,
reload=False
)