import io import os import torch import logging import tempfile import numpy as np from typing import Optional, Dict, Any # NEW: FastAPI imports from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import JSONResponse, HTMLResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn # Keep Gradio imports in case you still want to run locally with UI import gradio as gr import librosa # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Import your custom modules with proper error handling try: from normalizers import get_normalizer, NORMALIZERS, normalize_hindi, normalize_bengali, normalize_tamil, get_language_info NORMALIZERS_AVAILABLE = True logger.info("✅ Enhanced normalizers loaded successfully") except ImportError as e: logger.warning(f"Normalizers not available: {e}") NORMALIZERS_AVAILABLE = False NORMALIZERS = {} try: from language_detector import detect_language, IndicLanguageDetector, get_language_name LANGUAGE_DETECTOR_AVAILABLE = True logger.info("✅ Enhanced language detector loaded successfully") except ImportError as e: logger.warning(f"Language detector not available: {e}") LANGUAGE_DETECTOR_AVAILABLE = False # Try to setup IndicNLP resources try: from indic_nlp import common INDIC_RESOURCES_PATH = "./indic_nlp_resources" if os.path.exists(INDIC_RESOURCES_PATH): common.set_resources_path(INDIC_RESOURCES_PATH) logger.info("✅ IndicNLP resources configured") except ImportError: logger.warning("IndicNLP not available") # Global variables conformer_model = None models_loaded = False language_detector_instance = None # Constants SAMPLE_RATE = 16000 MAX_FILE_SIZE = 25 * 1024 * 1024 # 25MB SUPPORTED_FORMATS = {'.wav', '.mp3', '.m4a', '.flac', '.ogg'} # All 22+ Indian languages SUPPORTED_LANGUAGES = { 'hi': 'Hindi', 'bn': 'Bengali', 'te': 'Telugu', 'ta': 'Tamil', 'mr': 'Marathi', 'gu': 'Gujarati', 'kn': 'Kannada', 'ml': 'Malayalam', 'pa': 'Punjabi', 'or': 'Odia', 'as': 'Assamese', 'ur': 'Urdu', 'sa': 'Sanskrit', 'ne': 'Nepali', 'ks': 'Kashmiri', 'sd': 'Sindhi', 'doi': 'Dogri', 'brx': 'Bodo', 'sat': 'Santali', 'mai': 'Maithili', 'mni': 'Manipuri', 'gom': 'Konkani', 'en': 'English' } class MultiIndicASR: """Enhanced Multi-language ASR system for all 22 Indian languages""" def __init__(self): self.model = None self.tokenizer = None self.device = self.get_device() self.language_detector = None if LANGUAGE_DETECTOR_AVAILABLE: try: self.language_detector = IndicLanguageDetector() logger.info("✅ Enhanced language detector initialized") except Exception as e: logger.warning(f"Language detector initialization failed: {e}") def get_device(self): if torch.cuda.is_available(): return "cuda" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return "mps" else: return "cpu" def load_models(self): try: logger.info("🔄 Loading IndicConformer-600M-Multilingual...") model_name = "ai4bharat/indic-conformer-600m-multilingual" try: from transformers import AutoModel, AutoTokenizer self.model = AutoModel.from_pretrained( model_name, torch_dtype=torch.float32 if self.device == "cpu" else torch.float16, trust_remote_code=True, cache_dir="./models" ) self.tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, cache_dir="./models" ) except Exception as e: logger.warning(f"Primary model failed: {e}, trying fallback...") model_name = "parthiv11/indic_whisper_nodcil" from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor self.model = AutoModelForSpeechSeq2Seq.from_pretrained( model_name, torch_dtype=torch.float32, trust_remote_code=True, cache_dir="./models" ) self.tokenizer = AutoProcessor.from_pretrained( model_name, trust_remote_code=True, cache_dir="./models" ) self.model = self.model.to(self.device) self.model.eval() logger.info(f"✅ Model loaded successfully on {self.device}") return True except Exception as e: logger.error(f"❌ Model loading failed completely: {e}") return False def detect_language_enhanced(self, text: str, audio_duration: float = 0) -> Dict[str, Any]: if self.language_detector: try: language = self.language_detector.detect_language(text) confidence = self.language_detector.get_language_confidence(text, language) if audio_duration > 0: duration_boost = min(audio_duration / 10.0, 0.1) confidence = min(confidence + duration_boost, 1.0) return { "language": language, "confidence": confidence, "language_name": get_language_name(language) if LANGUAGE_DETECTOR_AVAILABLE else SUPPORTED_LANGUAGES.get(language, 'Unknown'), "detection_method": "multi_strategy" } except Exception as e: logger.warning(f"Enhanced language detection failed: {e}") return { "language": "hi", "confidence": 0.5, "language_name": "Hindi", "detection_method": "fallback" } def preprocess_audio(self, audio_data: bytes) -> np.ndarray: try: audio_array, sr = librosa.load(io.BytesIO(audio_data), sr=SAMPLE_RATE, mono=True) if len(audio_array) > 0: max_val = np.max(np.abs(audio_array)) if max_val > 0: audio_array = audio_array / max_val min_samples = SAMPLE_RATE * 3 if len(audio_array) < min_samples: padding = min_samples - len(audio_array) audio_array = np.pad(audio_array, (0, padding)) return audio_array except Exception as e: logger.error(f"Audio preprocessing failed: {e}") raise e def transcribe_with_model(self, audio_array: np.ndarray, language: str) -> Dict[str, Any]: try: audio_tensor = torch.FloatTensor(audio_array).unsqueeze(0) if self.device != "cpu": audio_tensor = audio_tensor.to(self.device) with torch.no_grad(): if hasattr(self.model, '__call__') and hasattr(self.model, '__module__'): try: if self.model is not None: result = self.model(audio_tensor, language, "rnnt") return { 'text': result, 'confidence': 0.95, 'model': 'IndicConformer-600M' } else: logger.error("Model is not loaded (None).") return {'text': "", 'confidence': 0.0, 'model': 'None', 'error': 'Model not loaded'} except: pass if self.tokenizer is not None and hasattr(self.tokenizer, '__call__'): if self.model is None: logger.error("Model is not loaded (None).") return {'text': "", 'confidence': 0.0, 'model': 'None', 'error': 'Model not loaded'} inputs = self.tokenizer( audio_array, sampling_rate=SAMPLE_RATE, return_tensors="pt" ) input_features = inputs["input_features"].to(self.device) predicted_ids = self.model.generate( input_features, max_length=448, num_beams=1, temperature=0.0 ) transcription = self.tokenizer.batch_decode( predicted_ids, skip_special_tokens=True )[0].strip() return {'text': transcription, 'confidence': 0.9, 'model': 'IndicWhisper'} elif self.tokenizer is None: logger.error("Tokenizer is not loaded (None).") return {'text': "", 'confidence': 0.0, 'model': 'None', 'error': 'Tokenizer not loaded'} return {'text': "", 'confidence': 0.0, 'model': 'Unknown', 'error': 'Model type not recognized'} except Exception as e: logger.error(f"Model transcription failed: {e}") return {'text': '', 'confidence': 0.0, 'model': 'Failed', 'error': str(e)} def normalize_text_enhanced(self, text: str, language: str) -> str: if not text.strip(): return "" if NORMALIZERS_AVAILABLE: try: normalizer = get_normalizer(language) normalized = normalizer.normalize(text) return normalized except Exception as e: logger.warning(f"Normalization failed for {language}: {e}") return text.strip() def transcribe(self, audio_data: bytes, target_language: Optional[str] = None) -> Dict[str, Any]: try: audio_array = self.preprocess_audio(audio_data) audio_duration = len(audio_array) / SAMPLE_RATE if not target_language: quick_result = self.transcribe_with_model(audio_array, 'hi') if quick_result['text']: lang_detection = self.detect_language_enhanced(quick_result['text'], audio_duration) target_language = lang_detection['language'] else: target_language = 'hi' if target_language not in SUPPORTED_LANGUAGES: target_language = 'hi' transcription_result = self.transcribe_with_model(audio_array, target_language) raw_text = transcription_result['text'] normalized_text = self.normalize_text_enhanced(raw_text, target_language) lang_detection = self.detect_language_enhanced(raw_text, audio_duration) return { "transcription": normalized_text, "raw_transcription": raw_text, "language": target_language, "language_info": get_language_info(target_language) if NORMALIZERS_AVAILABLE else {"name": SUPPORTED_LANGUAGES.get(target_language, "Unknown")}, "detected_language": lang_detection.get("language", target_language), "language_confidence": lang_detection.get("confidence", 0.5), "confidence": transcription_result['confidence'], "model": transcription_result['model'], "audio_duration_seconds": audio_duration, "normalization_applied": NORMALIZERS_AVAILABLE, "detection_method": lang_detection.get("detection_method", "fallback"), "status": "success" } except Exception as e: logger.error(f"Complete transcription failed: {e}") return { "error": f"Transcription failed: {str(e)}", "transcription": "", "language": "unknown", "status": "error" } # Initialize ASR engine globally asr_engine = MultiIndicASR() def load_models(): global models_loaded try: models_loaded = asr_engine.load_models() if models_loaded: logger.info("✅ All models loaded successfully!") else: logger.error("❌ Model loading failed") except Exception as e: logger.error(f"❌ Model loading error: {e}") models_loaded = False # Load models at startup logger.info("🚀 Loading models for API...") load_models() # ----------------- FASTAPI APP START ------------------- app = FastAPI( title="Enhanced Multi-Indic ASR API", description="Enhanced ASR for 22+ Indian languages with normalization and detection", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) # Allow all origins (CORS) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/", response_class=HTMLResponse) async def root(): return """ Enhanced Multi-Indic ASR API

🎤 Enhanced Multi-Indic ASR API

Go to /docs for Swagger UI.

""" @app.get("/health") async def health(): return { "status": "healthy" if models_loaded else "loading", "models_loaded": models_loaded, "device": asr_engine.device, "normalizers_available": NORMALIZERS_AVAILABLE, "language_detector_available": LANGUAGE_DETECTOR_AVAILABLE } @app.post("/transcribe") async def transcribe_api(file: UploadFile = File(...), language: Optional[str] = None): if not file.filename: raise HTTPException(status_code=400, detail="No file uploaded") audio_data = await file.read() if len(audio_data) > MAX_FILE_SIZE: raise HTTPException(status_code=413, detail="File too large") result = asr_engine.transcribe(audio_data, target_language=language) return JSONResponse(result) # ----------------- FASTAPI APP END ------------------- if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) uvicorn.run( "app:app", host="0.0.0.0", port=port, reload=False )