|
|
|
""" |
|
Model Preloader for Multilingual Audio Intelligence System |
|
|
|
This module handles downloading and initializing all AI models before the application starts. |
|
It provides progress tracking, caching, and error handling for model loading. |
|
|
|
Models loaded: |
|
- pyannote.audio for speaker diarization |
|
- faster-whisper for speech recognition |
|
- mBART50 for neural machine translation |
|
""" |
|
|
|
import os |
|
import sys |
|
import logging |
|
import time |
|
from pathlib import Path |
|
from typing import Dict, Any, Optional |
|
import json |
|
from datetime import datetime |
|
|
|
|
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from faster_whisper import WhisperModel |
|
from pyannote.audio import Pipeline |
|
from rich.console import Console |
|
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeRemainingColumn |
|
from rich.panel import Panel |
|
from rich.text import Text |
|
import psutil |
|
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
console = Console() |
|
|
|
class ModelPreloader: |
|
"""Comprehensive model preloader with progress tracking and caching.""" |
|
|
|
def __init__(self, cache_dir: str = "./model_cache", device: str = "auto"): |
|
self.cache_dir = Path(cache_dir) |
|
self.cache_dir.mkdir(exist_ok=True) |
|
|
|
|
|
if device == "auto": |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
else: |
|
self.device = device |
|
|
|
self.models = {} |
|
self.model_info = {} |
|
|
|
|
|
self.model_configs = { |
|
"speaker_diarization": { |
|
"name": "pyannote/speaker-diarization-3.1", |
|
"type": "pyannote", |
|
"description": "Speaker Diarization Pipeline", |
|
"size_mb": 32 |
|
}, |
|
"whisper_small": { |
|
"name": "small", |
|
"type": "whisper", |
|
"description": "Whisper Speech Recognition (Small)", |
|
"size_mb": 484 |
|
}, |
|
"mbart_translation": { |
|
"name": "facebook/mbart-large-50-many-to-many-mmt", |
|
"type": "mbart", |
|
"description": "mBART Neural Machine Translation", |
|
"size_mb": 2440 |
|
}, |
|
"opus_mt_ja_en": { |
|
"name": "Helsinki-NLP/opus-mt-ja-en", |
|
"type": "opus_mt", |
|
"description": "Japanese to English Translation", |
|
"size_mb": 303 |
|
}, |
|
"opus_mt_es_en": { |
|
"name": "Helsinki-NLP/opus-mt-es-en", |
|
"type": "opus_mt", |
|
"description": "Spanish to English Translation", |
|
"size_mb": 303 |
|
}, |
|
"opus_mt_fr_en": { |
|
"name": "Helsinki-NLP/opus-mt-fr-en", |
|
"type": "opus_mt", |
|
"description": "French to English Translation", |
|
"size_mb": 303 |
|
} |
|
} |
|
|
|
def get_system_info(self) -> Dict[str, Any]: |
|
"""Get system information for optimal model loading.""" |
|
return { |
|
"cpu_count": psutil.cpu_count(), |
|
"memory_gb": round(psutil.virtual_memory().total / (1024**3), 2), |
|
"available_memory_gb": round(psutil.virtual_memory().available / (1024**3), 2), |
|
"device": self.device, |
|
"torch_version": torch.__version__, |
|
"cuda_available": torch.cuda.is_available(), |
|
"gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None |
|
} |
|
|
|
def check_model_cache(self, model_key: str) -> bool: |
|
"""Check if model is already cached and working.""" |
|
cache_file = self.cache_dir / f"{model_key}_info.json" |
|
if not cache_file.exists(): |
|
return False |
|
|
|
try: |
|
with open(cache_file, 'r') as f: |
|
cache_info = json.load(f) |
|
|
|
|
|
cache_time = datetime.fromisoformat(cache_info['timestamp']) |
|
days_old = (datetime.now() - cache_time).days |
|
|
|
if days_old > 7: |
|
logger.info(f"Cache for {model_key} is {days_old} days old, will refresh") |
|
return False |
|
|
|
return cache_info.get('status') == 'success' |
|
except Exception as e: |
|
logger.warning(f"Error reading cache for {model_key}: {e}") |
|
return False |
|
|
|
def save_model_cache(self, model_key: str, status: str, info: Dict[str, Any]): |
|
"""Save model loading information to cache.""" |
|
cache_file = self.cache_dir / f"{model_key}_info.json" |
|
cache_data = { |
|
"timestamp": datetime.now().isoformat(), |
|
"status": status, |
|
"device": self.device, |
|
"info": info |
|
} |
|
|
|
try: |
|
with open(cache_file, 'w') as f: |
|
json.dump(cache_data, f, indent=2) |
|
except Exception as e: |
|
logger.warning(f"Error saving cache for {model_key}: {e}") |
|
|
|
def load_pyannote_pipeline(self, task_id: str) -> Optional[Pipeline]: |
|
"""Load pyannote speaker diarization pipeline.""" |
|
try: |
|
console.print(f"[yellow]Loading pyannote.audio pipeline...[/yellow]") |
|
|
|
|
|
hf_token = os.getenv('HUGGINGFACE_TOKEN') |
|
if not hf_token: |
|
console.print("[red]Warning: HUGGINGFACE_TOKEN not found. Some models may not be accessible.[/red]") |
|
|
|
pipeline = Pipeline.from_pretrained( |
|
"pyannote/speaker-diarization-3.1", |
|
use_auth_token=hf_token |
|
) |
|
|
|
|
|
console.print(f"[green]β pyannote.audio pipeline loaded successfully on {self.device}[/green]") |
|
|
|
return pipeline |
|
|
|
except Exception as e: |
|
console.print(f"[red]β Failed to load pyannote.audio pipeline: {e}[/red]") |
|
logger.error(f"Pyannote loading failed: {e}") |
|
return None |
|
|
|
def load_whisper_model(self, task_id: str) -> Optional[WhisperModel]: |
|
"""Load Whisper speech recognition model.""" |
|
try: |
|
console.print(f"[yellow]Loading Whisper model (small)...[/yellow]") |
|
|
|
|
|
compute_type = "int8" if self.device == "cpu" else "float16" |
|
|
|
model = WhisperModel( |
|
"small", |
|
device=self.device, |
|
compute_type=compute_type, |
|
download_root=str(self.cache_dir / "whisper") |
|
) |
|
|
|
|
|
import numpy as np |
|
dummy_audio = np.zeros(16000, dtype=np.float32) |
|
segments, info = model.transcribe(dummy_audio, language="en") |
|
list(segments) |
|
|
|
console.print(f"[green]β Whisper model loaded successfully on {self.device} with {compute_type}[/green]") |
|
|
|
return model |
|
|
|
except Exception as e: |
|
console.print(f"[red]β Failed to load Whisper model: {e}[/red]") |
|
logger.error(f"Whisper loading failed: {e}") |
|
return None |
|
|
|
def load_mbart_model(self, task_id: str) -> Optional[Dict[str, Any]]: |
|
"""Load mBART translation model.""" |
|
try: |
|
console.print(f"[yellow]Loading mBART translation model...[/yellow]") |
|
|
|
model_name = "facebook/mbart-large-50-many-to-many-mmt" |
|
cache_path = self.cache_dir / "mbart" |
|
cache_path.mkdir(exist_ok=True) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
cache_dir=str(cache_path) |
|
) |
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
model_name, |
|
cache_dir=str(cache_path), |
|
torch_dtype=torch.float32 if self.device == "cpu" else torch.float16 |
|
) |
|
|
|
if self.device != "cpu": |
|
model = model.to(self.device) |
|
|
|
|
|
test_input = tokenizer("Hello world", return_tensors="pt") |
|
if self.device != "cpu": |
|
test_input = {k: v.to(self.device) for k, v in test_input.items()} |
|
|
|
with torch.no_grad(): |
|
output = model.generate(**test_input, max_length=10) |
|
|
|
console.print(f"[green]β mBART model loaded successfully on {self.device}[/green]") |
|
|
|
return { |
|
"model": model, |
|
"tokenizer": tokenizer |
|
} |
|
|
|
except Exception as e: |
|
console.print(f"[red]β Failed to load mBART model: {e}[/red]") |
|
logger.error(f"mBART loading failed: {e}") |
|
return None |
|
|
|
def load_opus_mt_model(self, task_id: str, model_name: str) -> Optional[Dict[str, Any]]: |
|
"""Load Opus-MT translation model.""" |
|
try: |
|
console.print(f"[yellow]Loading Opus-MT model: {model_name}...[/yellow]") |
|
|
|
cache_path = self.cache_dir / "opus_mt" / model_name.replace("/", "--") |
|
cache_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
cache_dir=str(cache_path) |
|
) |
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
model_name, |
|
cache_dir=str(cache_path), |
|
torch_dtype=torch.float32 if self.device == "cpu" else torch.float16 |
|
) |
|
|
|
if self.device != "cpu": |
|
model = model.to(self.device) |
|
|
|
|
|
test_input = tokenizer("Hello world", return_tensors="pt") |
|
if self.device != "cpu": |
|
test_input = {k: v.to(self.device) for k, v in test_input.items()} |
|
|
|
with torch.no_grad(): |
|
output = model.generate(**test_input, max_length=10) |
|
|
|
console.print(f"[green]β {model_name} loaded successfully on {self.device}[/green]") |
|
|
|
return { |
|
"model": model, |
|
"tokenizer": tokenizer |
|
} |
|
|
|
except Exception as e: |
|
console.print(f"[red]β Failed to load {model_name}: {e}[/red]") |
|
logger.error(f"Opus-MT loading failed: {e}") |
|
return None |
|
|
|
def preload_all_models(self) -> Dict[str, Any]: |
|
"""Preload all models with progress tracking.""" |
|
|
|
|
|
sys_info = self.get_system_info() |
|
|
|
info_panel = Panel.fit( |
|
f"""π₯οΈ System Information |
|
|
|
β’ CPU Cores: {sys_info['cpu_count']} |
|
β’ Total Memory: {sys_info['memory_gb']} GB |
|
β’ Available Memory: {sys_info['available_memory_gb']} GB |
|
β’ Device: {sys_info['device'].upper()} |
|
β’ PyTorch: {sys_info['torch_version']} |
|
β’ CUDA Available: {sys_info['cuda_available']} |
|
{f"β’ GPU: {sys_info['gpu_name']}" if sys_info['gpu_name'] else ""}""", |
|
title="[bold blue]Audio Intelligence System[/bold blue]", |
|
border_style="blue" |
|
) |
|
console.print(info_panel) |
|
console.print() |
|
|
|
results = { |
|
"system_info": sys_info, |
|
"models": {}, |
|
"total_time": 0, |
|
"success_count": 0, |
|
"total_count": len(self.model_configs) |
|
} |
|
|
|
start_time = time.time() |
|
|
|
with Progress( |
|
SpinnerColumn(), |
|
TextColumn("[progress.description]{task.description}"), |
|
BarColumn(), |
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), |
|
TimeRemainingColumn(), |
|
console=console |
|
) as progress: |
|
|
|
|
|
main_task = progress.add_task("[cyan]Loading AI Models...", total=len(self.model_configs)) |
|
|
|
|
|
for model_key, config in self.model_configs.items(): |
|
task_id = progress.add_task(f"[yellow]{config['description']}", total=100) |
|
|
|
|
|
if self.check_model_cache(model_key): |
|
console.print(f"[green]β {config['description']} found in cache[/green]") |
|
progress.update(task_id, completed=100) |
|
progress.update(main_task, advance=1) |
|
results["models"][model_key] = {"status": "cached", "time": 0} |
|
results["success_count"] += 1 |
|
continue |
|
|
|
model_start_time = time.time() |
|
progress.update(task_id, completed=10) |
|
|
|
|
|
if config["type"] == "pyannote": |
|
model = self.load_pyannote_pipeline(task_id) |
|
elif config["type"] == "whisper": |
|
model = self.load_whisper_model(task_id) |
|
elif config["type"] == "mbart": |
|
model = self.load_mbart_model(task_id) |
|
elif config["type"] == "opus_mt": |
|
model = self.load_opus_mt_model(task_id, config["name"]) |
|
else: |
|
model = None |
|
|
|
model_time = time.time() - model_start_time |
|
|
|
if model is not None: |
|
self.models[model_key] = model |
|
progress.update(task_id, completed=100) |
|
results["models"][model_key] = {"status": "success", "time": model_time} |
|
results["success_count"] += 1 |
|
|
|
|
|
self.save_model_cache(model_key, "success", { |
|
"load_time": model_time, |
|
"device": self.device, |
|
"model_name": config["name"] |
|
}) |
|
else: |
|
progress.update(task_id, completed=100) |
|
results["models"][model_key] = {"status": "failed", "time": model_time} |
|
|
|
|
|
self.save_model_cache(model_key, "failed", { |
|
"load_time": model_time, |
|
"device": self.device, |
|
"error": "Model loading failed" |
|
}) |
|
|
|
progress.update(main_task, advance=1) |
|
|
|
results["total_time"] = time.time() - start_time |
|
|
|
|
|
console.print() |
|
if results["success_count"] == results["total_count"]: |
|
status_text = "[bold green]β All models loaded successfully![/bold green]" |
|
status_color = "green" |
|
elif results["success_count"] > 0: |
|
status_text = f"[bold yellow]β {results['success_count']}/{results['total_count']} models loaded[/bold yellow]" |
|
status_color = "yellow" |
|
else: |
|
status_text = "[bold red]β No models loaded successfully[/bold red]" |
|
status_color = "red" |
|
|
|
summary_panel = Panel.fit( |
|
f"""{status_text} |
|
|
|
β’ Loading Time: {results['total_time']:.1f} seconds |
|
β’ Device: {self.device.upper()} |
|
β’ Memory Usage: {psutil.virtual_memory().percent:.1f}% |
|
β’ Models Ready: {results['success_count']}/{results['total_count']}""", |
|
title="[bold]Model Loading Summary[/bold]", |
|
border_style=status_color |
|
) |
|
console.print(summary_panel) |
|
|
|
return results |
|
|
|
def get_models(self) -> Dict[str, Any]: |
|
"""Get loaded models.""" |
|
return self.models |
|
|
|
def cleanup(self): |
|
"""Cleanup resources.""" |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def main(): |
|
"""Main function to run model preloading.""" |
|
console.print(Panel.fit( |
|
"[bold blue]π΅ Multilingual Audio Intelligence System[/bold blue]\n[yellow]Model Preloader[/yellow]", |
|
border_style="blue" |
|
)) |
|
console.print() |
|
|
|
|
|
preloader = ModelPreloader() |
|
|
|
|
|
try: |
|
results = preloader.preload_all_models() |
|
|
|
if results["success_count"] > 0: |
|
console.print("\n[bold green]β Model preloading completed![/bold green]") |
|
console.print(f"[dim]Models cached in: {preloader.cache_dir}[/dim]") |
|
return True |
|
else: |
|
console.print("\n[bold red]β Model preloading failed![/bold red]") |
|
return False |
|
|
|
except KeyboardInterrupt: |
|
console.print("\n[yellow]Model preloading interrupted by user[/yellow]") |
|
return False |
|
except Exception as e: |
|
console.print(f"\n[bold red]β Model preloading failed: {e}[/bold red]") |
|
logger.error(f"Preloading failed: {e}") |
|
return False |
|
finally: |
|
preloader.cleanup() |
|
|
|
|
|
if __name__ == "__main__": |
|
success = main() |
|
sys.exit(0 if success else 1) |