Spaces:
Sleeping
Sleeping
# src/model_loader.py | |
import torch | |
import transformers | |
import unsloth | |
from typing import Tuple, Any | |
import warnings | |
warnings.filterwarnings("ignore") | |
def load_model(model_path: str, load_in_4bit: bool = True, use_unsloth: bool = True) -> Tuple[Any, Any]: | |
""" | |
Load model for evaluation. Supports multiple model types. | |
Returns (model, tokenizer) or ('google-translate', None) for Google Translate. | |
""" | |
print(f"Loading model from {model_path}...") | |
# Google Translate "model" | |
if model_path == 'google-translate': | |
return 'google-translate', None | |
try: | |
# NLLB models | |
if 'nllb' in model_path.lower(): | |
tokenizer = transformers.NllbTokenizer.from_pretrained(model_path) | |
model = transformers.M2M100ForConditionalGeneration.from_pretrained( | |
model_path, torch_dtype=torch.bfloat16 | |
).to('cuda' if torch.cuda.is_available() else 'cpu') | |
# Quantized models (4bit) | |
elif '4bit' in model_path.lower(): | |
tokenizer = transformers.AutoTokenizer.from_pretrained( | |
model_path, | |
model_max_length=4096, | |
padding_side='left' | |
) | |
tokenizer.pad_token = tokenizer.bos_token | |
bnb_config = transformers.BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_use_double_quant=True, | |
) | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_path, | |
quantization_config=bnb_config, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
) | |
# Standard models with unsloth optimization | |
else: | |
if use_unsloth: | |
try: | |
model, tokenizer = unsloth.FastModel.from_pretrained( | |
model_name=model_path, | |
max_seq_length=1024, | |
load_in_4bit=False, | |
load_in_8bit=False, | |
full_finetuning=False, | |
) | |
except Exception as e: | |
print(f"Unsloth loading failed: {e}. Falling back to standard loading.") | |
use_unsloth = False | |
if not use_unsloth: | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16, | |
device_map='auto' if torch.cuda.is_available() else None, | |
) | |
print(f"Successfully loaded {model_path}") | |
return model, tokenizer | |
except Exception as e: | |
print(f"Error loading model {model_path}: {str(e)}") | |
raise Exception(f"Failed to load model: {str(e)}") | |
def get_model_info(model_path: str) -> dict: | |
"""Get basic information about a model without loading it.""" | |
try: | |
if model_path == 'google-translate': | |
return { | |
'name': 'Google Translate', | |
'type': 'google-translate', | |
'size': 'Unknown', | |
'description': 'Google Cloud Translation API' | |
} | |
from huggingface_hub import model_info | |
info = model_info(model_path) | |
return { | |
'name': model_path, | |
'type': get_model_type(model_path), | |
'size': getattr(info, 'safetensors', {}).get('total', 'Unknown'), | |
'description': getattr(info, 'description', 'No description available') | |
} | |
except Exception as e: | |
return { | |
'name': model_path, | |
'type': 'unknown', | |
'size': 'Unknown', | |
'description': f'Error getting info: {str(e)}' | |
} | |
def get_model_type(model_path: str) -> str: | |
"""Determine model type from path.""" | |
model_path_lower = model_path.lower() | |
if model_path == 'google-translate': | |
return 'google-translate' | |
elif 'gemma' in model_path_lower: | |
return 'gemma' | |
elif 'qwen' in model_path_lower: | |
return 'qwen' | |
elif 'llama' in model_path_lower: | |
return 'llama' | |
elif 'nllb' in model_path_lower: | |
return 'nllb' | |
else: | |
return 'other' |