File size: 4,488 Bytes
93b9d03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# src/model_loader.py
import torch
import transformers
import unsloth
from typing import Tuple, Any
import warnings
warnings.filterwarnings("ignore")

def load_model(model_path: str, load_in_4bit: bool = True, use_unsloth: bool = True) -> Tuple[Any, Any]:
    """
    Load model for evaluation. Supports multiple model types.
    Returns (model, tokenizer) or ('google-translate', None) for Google Translate.
    """
    print(f"Loading model from {model_path}...")

    # Google Translate "model"
    if model_path == 'google-translate':
        return 'google-translate', None
    
    try:
        # NLLB models
        if 'nllb' in model_path.lower():
            tokenizer = transformers.NllbTokenizer.from_pretrained(model_path)
            model = transformers.M2M100ForConditionalGeneration.from_pretrained(
                model_path, torch_dtype=torch.bfloat16
            ).to('cuda' if torch.cuda.is_available() else 'cpu')
            
        # Quantized models (4bit)
        elif '4bit' in model_path.lower():
            tokenizer = transformers.AutoTokenizer.from_pretrained(
                model_path,
                model_max_length=4096,
                padding_side='left'
            )
            tokenizer.pad_token = tokenizer.bos_token

            bnb_config = transformers.BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
            )

            model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                quantization_config=bnb_config,
                device_map="auto",
                torch_dtype=torch.bfloat16,
                trust_remote_code=True,
            )
            
        # Standard models with unsloth optimization
        else:
            if use_unsloth:
                try:
                    model, tokenizer = unsloth.FastModel.from_pretrained(
                        model_name=model_path,
                        max_seq_length=1024,
                        load_in_4bit=False,
                        load_in_8bit=False,
                        full_finetuning=False,
                    )
                except Exception as e:
                    print(f"Unsloth loading failed: {e}. Falling back to standard loading.")
                    use_unsloth = False
            
            if not use_unsloth:
                tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
                model = transformers.AutoModelForCausalLM.from_pretrained(
                    model_path,
                    torch_dtype=torch.bfloat16,
                    device_map='auto' if torch.cuda.is_available() else None,
                )

        print(f"Successfully loaded {model_path}")
        return model, tokenizer
        
    except Exception as e:
        print(f"Error loading model {model_path}: {str(e)}")
        raise Exception(f"Failed to load model: {str(e)}")

def get_model_info(model_path: str) -> dict:
    """Get basic information about a model without loading it."""
    try:
        if model_path == 'google-translate':
            return {
                'name': 'Google Translate',
                'type': 'google-translate',
                'size': 'Unknown',
                'description': 'Google Cloud Translation API'
            }
        
        from huggingface_hub import model_info
        info = model_info(model_path)
        
        return {
            'name': model_path,
            'type': get_model_type(model_path),
            'size': getattr(info, 'safetensors', {}).get('total', 'Unknown'),
            'description': getattr(info, 'description', 'No description available')
        }
    except Exception as e:
        return {
            'name': model_path,
            'type': 'unknown',
            'size': 'Unknown',
            'description': f'Error getting info: {str(e)}'
        }

def get_model_type(model_path: str) -> str:
    """Determine model type from path."""
    model_path_lower = model_path.lower()
    
    if model_path == 'google-translate':
        return 'google-translate'
    elif 'gemma' in model_path_lower:
        return 'gemma'
    elif 'qwen' in model_path_lower:
        return 'qwen'
    elif 'llama' in model_path_lower:
        return 'llama'
    elif 'nllb' in model_path_lower:
        return 'nllb'
    else:
        return 'other'