akera commited on
Commit
97a3aa2
·
verified ·
1 Parent(s): 78997a5

Delete src/src/model_loader.py

Browse files
Files changed (1) hide show
  1. src/src/model_loader.py +0 -125
src/src/model_loader.py DELETED
@@ -1,125 +0,0 @@
1
- # src/model_loader.py
2
- import torch
3
- import transformers
4
- import unsloth
5
- from typing import Tuple, Any
6
- import warnings
7
- warnings.filterwarnings("ignore")
8
-
9
- def load_model(model_path: str, load_in_4bit: bool = True, use_unsloth: bool = True) -> Tuple[Any, Any]:
10
- """
11
- Load model for evaluation. Supports multiple model types.
12
- Returns (model, tokenizer) or ('google-translate', None) for Google Translate.
13
- """
14
- print(f"Loading model from {model_path}...")
15
-
16
- # Google Translate "model"
17
- if model_path == 'google-translate':
18
- return 'google-translate', None
19
-
20
- try:
21
- # NLLB models
22
- if 'nllb' in model_path.lower():
23
- tokenizer = transformers.NllbTokenizer.from_pretrained(model_path)
24
- model = transformers.M2M100ForConditionalGeneration.from_pretrained(
25
- model_path, torch_dtype=torch.bfloat16
26
- ).to('cuda' if torch.cuda.is_available() else 'cpu')
27
-
28
- # Quantized models (4bit)
29
- elif '4bit' in model_path.lower():
30
- tokenizer = transformers.AutoTokenizer.from_pretrained(
31
- model_path,
32
- model_max_length=4096,
33
- padding_side='left'
34
- )
35
- tokenizer.pad_token = tokenizer.bos_token
36
-
37
- bnb_config = transformers.BitsAndBytesConfig(
38
- load_in_4bit=True,
39
- bnb_4bit_quant_type="nf4",
40
- bnb_4bit_compute_dtype=torch.bfloat16,
41
- bnb_4bit_use_double_quant=True,
42
- )
43
-
44
- model = transformers.AutoModelForCausalLM.from_pretrained(
45
- model_path,
46
- quantization_config=bnb_config,
47
- device_map="auto",
48
- torch_dtype=torch.bfloat16,
49
- trust_remote_code=True,
50
- )
51
-
52
- # Standard models with unsloth optimization
53
- else:
54
- if use_unsloth:
55
- try:
56
- model, tokenizer = unsloth.FastModel.from_pretrained(
57
- model_name=model_path,
58
- max_seq_length=1024,
59
- load_in_4bit=False,
60
- load_in_8bit=False,
61
- full_finetuning=False,
62
- )
63
- except Exception as e:
64
- print(f"Unsloth loading failed: {e}. Falling back to standard loading.")
65
- use_unsloth = False
66
-
67
- if not use_unsloth:
68
- tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
69
- model = transformers.AutoModelForCausalLM.from_pretrained(
70
- model_path,
71
- torch_dtype=torch.bfloat16,
72
- device_map='auto' if torch.cuda.is_available() else None,
73
- )
74
-
75
- print(f"Successfully loaded {model_path}")
76
- return model, tokenizer
77
-
78
- except Exception as e:
79
- print(f"Error loading model {model_path}: {str(e)}")
80
- raise Exception(f"Failed to load model: {str(e)}")
81
-
82
- def get_model_info(model_path: str) -> dict:
83
- """Get basic information about a model without loading it."""
84
- try:
85
- if model_path == 'google-translate':
86
- return {
87
- 'name': 'Google Translate',
88
- 'type': 'google-translate',
89
- 'size': 'Unknown',
90
- 'description': 'Google Cloud Translation API'
91
- }
92
-
93
- from huggingface_hub import model_info
94
- info = model_info(model_path)
95
-
96
- return {
97
- 'name': model_path,
98
- 'type': get_model_type(model_path),
99
- 'size': getattr(info, 'safetensors', {}).get('total', 'Unknown'),
100
- 'description': getattr(info, 'description', 'No description available')
101
- }
102
- except Exception as e:
103
- return {
104
- 'name': model_path,
105
- 'type': 'unknown',
106
- 'size': 'Unknown',
107
- 'description': f'Error getting info: {str(e)}'
108
- }
109
-
110
- def get_model_type(model_path: str) -> str:
111
- """Determine model type from path."""
112
- model_path_lower = model_path.lower()
113
-
114
- if model_path == 'google-translate':
115
- return 'google-translate'
116
- elif 'gemma' in model_path_lower:
117
- return 'gemma'
118
- elif 'qwen' in model_path_lower:
119
- return 'qwen'
120
- elif 'llama' in model_path_lower:
121
- return 'llama'
122
- elif 'nllb' in model_path_lower:
123
- return 'nllb'
124
- else:
125
- return 'other'