leaderboard / src /model_loader.py
akera's picture
Create model_loader.py
93b9d03 verified
raw
history blame
4.49 kB
# src/model_loader.py
import torch
import transformers
import unsloth
from typing import Tuple, Any
import warnings
warnings.filterwarnings("ignore")
def load_model(model_path: str, load_in_4bit: bool = True, use_unsloth: bool = True) -> Tuple[Any, Any]:
"""
Load model for evaluation. Supports multiple model types.
Returns (model, tokenizer) or ('google-translate', None) for Google Translate.
"""
print(f"Loading model from {model_path}...")
# Google Translate "model"
if model_path == 'google-translate':
return 'google-translate', None
try:
# NLLB models
if 'nllb' in model_path.lower():
tokenizer = transformers.NllbTokenizer.from_pretrained(model_path)
model = transformers.M2M100ForConditionalGeneration.from_pretrained(
model_path, torch_dtype=torch.bfloat16
).to('cuda' if torch.cuda.is_available() else 'cpu')
# Quantized models (4bit)
elif '4bit' in model_path.lower():
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_path,
model_max_length=4096,
padding_side='left'
)
tokenizer.pad_token = tokenizer.bos_token
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_path,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
# Standard models with unsloth optimization
else:
if use_unsloth:
try:
model, tokenizer = unsloth.FastModel.from_pretrained(
model_name=model_path,
max_seq_length=1024,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
)
except Exception as e:
print(f"Unsloth loading failed: {e}. Falling back to standard loading.")
use_unsloth = False
if not use_unsloth:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map='auto' if torch.cuda.is_available() else None,
)
print(f"Successfully loaded {model_path}")
return model, tokenizer
except Exception as e:
print(f"Error loading model {model_path}: {str(e)}")
raise Exception(f"Failed to load model: {str(e)}")
def get_model_info(model_path: str) -> dict:
"""Get basic information about a model without loading it."""
try:
if model_path == 'google-translate':
return {
'name': 'Google Translate',
'type': 'google-translate',
'size': 'Unknown',
'description': 'Google Cloud Translation API'
}
from huggingface_hub import model_info
info = model_info(model_path)
return {
'name': model_path,
'type': get_model_type(model_path),
'size': getattr(info, 'safetensors', {}).get('total', 'Unknown'),
'description': getattr(info, 'description', 'No description available')
}
except Exception as e:
return {
'name': model_path,
'type': 'unknown',
'size': 'Unknown',
'description': f'Error getting info: {str(e)}'
}
def get_model_type(model_path: str) -> str:
"""Determine model type from path."""
model_path_lower = model_path.lower()
if model_path == 'google-translate':
return 'google-translate'
elif 'gemma' in model_path_lower:
return 'gemma'
elif 'qwen' in model_path_lower:
return 'qwen'
elif 'llama' in model_path_lower:
return 'llama'
elif 'nllb' in model_path_lower:
return 'nllb'
else:
return 'other'