# src/model_loader.py import torch import transformers import unsloth from typing import Tuple, Any import warnings warnings.filterwarnings("ignore") def load_model(model_path: str, load_in_4bit: bool = True, use_unsloth: bool = True) -> Tuple[Any, Any]: """ Load model for evaluation. Supports multiple model types. Returns (model, tokenizer) or ('google-translate', None) for Google Translate. """ print(f"Loading model from {model_path}...") # Google Translate "model" if model_path == 'google-translate': return 'google-translate', None try: # NLLB models if 'nllb' in model_path.lower(): tokenizer = transformers.NllbTokenizer.from_pretrained(model_path) model = transformers.M2M100ForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch.bfloat16 ).to('cuda' if torch.cuda.is_available() else 'cpu') # Quantized models (4bit) elif '4bit' in model_path.lower(): tokenizer = transformers.AutoTokenizer.from_pretrained( model_path, model_max_length=4096, padding_side='left' ) tokenizer.pad_token = tokenizer.bos_token bnb_config = transformers.BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) model = transformers.AutoModelForCausalLM.from_pretrained( model_path, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, ) # Standard models with unsloth optimization else: if use_unsloth: try: model, tokenizer = unsloth.FastModel.from_pretrained( model_name=model_path, max_seq_length=1024, load_in_4bit=False, load_in_8bit=False, full_finetuning=False, ) except Exception as e: print(f"Unsloth loading failed: {e}. Falling back to standard loading.") use_unsloth = False if not use_unsloth: tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) model = transformers.AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, device_map='auto' if torch.cuda.is_available() else None, ) print(f"Successfully loaded {model_path}") return model, tokenizer except Exception as e: print(f"Error loading model {model_path}: {str(e)}") raise Exception(f"Failed to load model: {str(e)}") def get_model_info(model_path: str) -> dict: """Get basic information about a model without loading it.""" try: if model_path == 'google-translate': return { 'name': 'Google Translate', 'type': 'google-translate', 'size': 'Unknown', 'description': 'Google Cloud Translation API' } from huggingface_hub import model_info info = model_info(model_path) return { 'name': model_path, 'type': get_model_type(model_path), 'size': getattr(info, 'safetensors', {}).get('total', 'Unknown'), 'description': getattr(info, 'description', 'No description available') } except Exception as e: return { 'name': model_path, 'type': 'unknown', 'size': 'Unknown', 'description': f'Error getting info: {str(e)}' } def get_model_type(model_path: str) -> str: """Determine model type from path.""" model_path_lower = model_path.lower() if model_path == 'google-translate': return 'google-translate' elif 'gemma' in model_path_lower: return 'gemma' elif 'qwen' in model_path_lower: return 'qwen' elif 'llama' in model_path_lower: return 'llama' elif 'nllb' in model_path_lower: return 'nllb' else: return 'other'