akera commited on
Commit
93b9d03
·
verified ·
1 Parent(s): d0ca936

Create model_loader.py

Browse files
Files changed (1) hide show
  1. src/model_loader.py +125 -0
src/model_loader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/model_loader.py
2
+ import torch
3
+ import transformers
4
+ import unsloth
5
+ from typing import Tuple, Any
6
+ import warnings
7
+ warnings.filterwarnings("ignore")
8
+
9
+ def load_model(model_path: str, load_in_4bit: bool = True, use_unsloth: bool = True) -> Tuple[Any, Any]:
10
+ """
11
+ Load model for evaluation. Supports multiple model types.
12
+ Returns (model, tokenizer) or ('google-translate', None) for Google Translate.
13
+ """
14
+ print(f"Loading model from {model_path}...")
15
+
16
+ # Google Translate "model"
17
+ if model_path == 'google-translate':
18
+ return 'google-translate', None
19
+
20
+ try:
21
+ # NLLB models
22
+ if 'nllb' in model_path.lower():
23
+ tokenizer = transformers.NllbTokenizer.from_pretrained(model_path)
24
+ model = transformers.M2M100ForConditionalGeneration.from_pretrained(
25
+ model_path, torch_dtype=torch.bfloat16
26
+ ).to('cuda' if torch.cuda.is_available() else 'cpu')
27
+
28
+ # Quantized models (4bit)
29
+ elif '4bit' in model_path.lower():
30
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
31
+ model_path,
32
+ model_max_length=4096,
33
+ padding_side='left'
34
+ )
35
+ tokenizer.pad_token = tokenizer.bos_token
36
+
37
+ bnb_config = transformers.BitsAndBytesConfig(
38
+ load_in_4bit=True,
39
+ bnb_4bit_quant_type="nf4",
40
+ bnb_4bit_compute_dtype=torch.bfloat16,
41
+ bnb_4bit_use_double_quant=True,
42
+ )
43
+
44
+ model = transformers.AutoModelForCausalLM.from_pretrained(
45
+ model_path,
46
+ quantization_config=bnb_config,
47
+ device_map="auto",
48
+ torch_dtype=torch.bfloat16,
49
+ trust_remote_code=True,
50
+ )
51
+
52
+ # Standard models with unsloth optimization
53
+ else:
54
+ if use_unsloth:
55
+ try:
56
+ model, tokenizer = unsloth.FastModel.from_pretrained(
57
+ model_name=model_path,
58
+ max_seq_length=1024,
59
+ load_in_4bit=False,
60
+ load_in_8bit=False,
61
+ full_finetuning=False,
62
+ )
63
+ except Exception as e:
64
+ print(f"Unsloth loading failed: {e}. Falling back to standard loading.")
65
+ use_unsloth = False
66
+
67
+ if not use_unsloth:
68
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
69
+ model = transformers.AutoModelForCausalLM.from_pretrained(
70
+ model_path,
71
+ torch_dtype=torch.bfloat16,
72
+ device_map='auto' if torch.cuda.is_available() else None,
73
+ )
74
+
75
+ print(f"Successfully loaded {model_path}")
76
+ return model, tokenizer
77
+
78
+ except Exception as e:
79
+ print(f"Error loading model {model_path}: {str(e)}")
80
+ raise Exception(f"Failed to load model: {str(e)}")
81
+
82
+ def get_model_info(model_path: str) -> dict:
83
+ """Get basic information about a model without loading it."""
84
+ try:
85
+ if model_path == 'google-translate':
86
+ return {
87
+ 'name': 'Google Translate',
88
+ 'type': 'google-translate',
89
+ 'size': 'Unknown',
90
+ 'description': 'Google Cloud Translation API'
91
+ }
92
+
93
+ from huggingface_hub import model_info
94
+ info = model_info(model_path)
95
+
96
+ return {
97
+ 'name': model_path,
98
+ 'type': get_model_type(model_path),
99
+ 'size': getattr(info, 'safetensors', {}).get('total', 'Unknown'),
100
+ 'description': getattr(info, 'description', 'No description available')
101
+ }
102
+ except Exception as e:
103
+ return {
104
+ 'name': model_path,
105
+ 'type': 'unknown',
106
+ 'size': 'Unknown',
107
+ 'description': f'Error getting info: {str(e)}'
108
+ }
109
+
110
+ def get_model_type(model_path: str) -> str:
111
+ """Determine model type from path."""
112
+ model_path_lower = model_path.lower()
113
+
114
+ if model_path == 'google-translate':
115
+ return 'google-translate'
116
+ elif 'gemma' in model_path_lower:
117
+ return 'gemma'
118
+ elif 'qwen' in model_path_lower:
119
+ return 'qwen'
120
+ elif 'llama' in model_path_lower:
121
+ return 'llama'
122
+ elif 'nllb' in model_path_lower:
123
+ return 'nllb'
124
+ else:
125
+ return 'other'