akera commited on
Commit
c1926c2
·
verified ·
1 Parent(s): b78ec70

Rename src/model_loader.py to src/test_set.py

Browse files
Files changed (2) hide show
  1. src/model_loader.py +0 -125
  2. src/test_set.py +195 -0
src/model_loader.py DELETED
@@ -1,125 +0,0 @@
1
- # src/model_loader.py
2
- import torch
3
- import transformers
4
- import unsloth
5
- from typing import Tuple, Any
6
- import warnings
7
- warnings.filterwarnings("ignore")
8
-
9
- def load_model(model_path: str, load_in_4bit: bool = True, use_unsloth: bool = True) -> Tuple[Any, Any]:
10
- """
11
- Load model for evaluation. Supports multiple model types.
12
- Returns (model, tokenizer) or ('google-translate', None) for Google Translate.
13
- """
14
- print(f"Loading model from {model_path}...")
15
-
16
- # Google Translate "model"
17
- if model_path == 'google-translate':
18
- return 'google-translate', None
19
-
20
- try:
21
- # NLLB models
22
- if 'nllb' in model_path.lower():
23
- tokenizer = transformers.NllbTokenizer.from_pretrained(model_path)
24
- model = transformers.M2M100ForConditionalGeneration.from_pretrained(
25
- model_path, torch_dtype=torch.bfloat16
26
- ).to('cuda' if torch.cuda.is_available() else 'cpu')
27
-
28
- # Quantized models (4bit)
29
- elif '4bit' in model_path.lower():
30
- tokenizer = transformers.AutoTokenizer.from_pretrained(
31
- model_path,
32
- model_max_length=4096,
33
- padding_side='left'
34
- )
35
- tokenizer.pad_token = tokenizer.bos_token
36
-
37
- bnb_config = transformers.BitsAndBytesConfig(
38
- load_in_4bit=True,
39
- bnb_4bit_quant_type="nf4",
40
- bnb_4bit_compute_dtype=torch.bfloat16,
41
- bnb_4bit_use_double_quant=True,
42
- )
43
-
44
- model = transformers.AutoModelForCausalLM.from_pretrained(
45
- model_path,
46
- quantization_config=bnb_config,
47
- device_map="auto",
48
- torch_dtype=torch.bfloat16,
49
- trust_remote_code=True,
50
- )
51
-
52
- # Standard models with unsloth optimization
53
- else:
54
- if use_unsloth:
55
- try:
56
- model, tokenizer = unsloth.FastModel.from_pretrained(
57
- model_name=model_path,
58
- max_seq_length=1024,
59
- load_in_4bit=False,
60
- load_in_8bit=False,
61
- full_finetuning=False,
62
- )
63
- except Exception as e:
64
- print(f"Unsloth loading failed: {e}. Falling back to standard loading.")
65
- use_unsloth = False
66
-
67
- if not use_unsloth:
68
- tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
69
- model = transformers.AutoModelForCausalLM.from_pretrained(
70
- model_path,
71
- torch_dtype=torch.bfloat16,
72
- device_map='auto' if torch.cuda.is_available() else None,
73
- )
74
-
75
- print(f"Successfully loaded {model_path}")
76
- return model, tokenizer
77
-
78
- except Exception as e:
79
- print(f"Error loading model {model_path}: {str(e)}")
80
- raise Exception(f"Failed to load model: {str(e)}")
81
-
82
- def get_model_info(model_path: str) -> dict:
83
- """Get basic information about a model without loading it."""
84
- try:
85
- if model_path == 'google-translate':
86
- return {
87
- 'name': 'Google Translate',
88
- 'type': 'google-translate',
89
- 'size': 'Unknown',
90
- 'description': 'Google Cloud Translation API'
91
- }
92
-
93
- from huggingface_hub import model_info
94
- info = model_info(model_path)
95
-
96
- return {
97
- 'name': model_path,
98
- 'type': get_model_type(model_path),
99
- 'size': getattr(info, 'safetensors', {}).get('total', 'Unknown'),
100
- 'description': getattr(info, 'description', 'No description available')
101
- }
102
- except Exception as e:
103
- return {
104
- 'name': model_path,
105
- 'type': 'unknown',
106
- 'size': 'Unknown',
107
- 'description': f'Error getting info: {str(e)}'
108
- }
109
-
110
- def get_model_type(model_path: str) -> str:
111
- """Determine model type from path."""
112
- model_path_lower = model_path.lower()
113
-
114
- if model_path == 'google-translate':
115
- return 'google-translate'
116
- elif 'gemma' in model_path_lower:
117
- return 'gemma'
118
- elif 'qwen' in model_path_lower:
119
- return 'qwen'
120
- elif 'llama' in model_path_lower:
121
- return 'llama'
122
- elif 'nllb' in model_path_lower:
123
- return 'nllb'
124
- else:
125
- return 'other'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/test_set.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/test_set.py
2
+ import pandas as pd
3
+ import yaml
4
+ from datasets import Dataset, load_dataset
5
+ from typing import Dict, Tuple
6
+ import salt.dataset
7
+ from config import *
8
+
9
+ def generate_test_set(max_samples_per_pair: int = MAX_TEST_SAMPLES) -> pd.DataFrame:
10
+ """Generate standardized test set from SALT dataset."""
11
+
12
+ print("Generating SALT test set...")
13
+
14
+ # Load full SALT dataset
15
+ dataset_config = f'''
16
+ huggingface_load:
17
+ path: {SALT_DATASET}
18
+ name: text-all
19
+ split: test
20
+ source:
21
+ type: text
22
+ language: {ALL_UG40_LANGUAGES}
23
+ target:
24
+ type: text
25
+ language: {ALL_UG40_LANGUAGES}
26
+ allow_same_src_and_tgt_language: False
27
+ '''
28
+
29
+ config = yaml.safe_load(dataset_config)
30
+ full_data = pd.DataFrame(salt.dataset.create(config))
31
+
32
+ # Sample data for each language pair
33
+ test_samples = []
34
+ sample_id_counter = 1
35
+
36
+ for src_lang in ALL_UG40_LANGUAGES:
37
+ for tgt_lang in ALL_UG40_LANGUAGES:
38
+ if src_lang != tgt_lang:
39
+ # Filter for this language pair
40
+ pair_data = full_data[
41
+ (full_data['source.language'] == src_lang) &
42
+ (full_data['target.language'] == tgt_lang)
43
+ ].copy()
44
+
45
+ if len(pair_data) > 0:
46
+ # Sample up to max_samples_per_pair
47
+ n_samples = min(len(pair_data), max_samples_per_pair)
48
+ sampled = pair_data.sample(n=n_samples, random_state=42)
49
+
50
+ # Add to test set with unique IDs
51
+ for _, row in sampled.iterrows():
52
+ test_samples.append({
53
+ 'sample_id': f"salt_{sample_id_counter:06d}",
54
+ 'source_text': row['source'],
55
+ 'target_text': row['target'], # Hidden from public test set
56
+ 'source_language': src_lang,
57
+ 'target_language': tgt_lang,
58
+ 'domain': row.get('domain', 'general'),
59
+ 'google_comparable': (src_lang in GOOGLE_SUPPORTED_LANGUAGES and
60
+ tgt_lang in GOOGLE_SUPPORTED_LANGUAGES)
61
+ })
62
+ sample_id_counter += 1
63
+
64
+ test_df = pd.DataFrame(test_samples)
65
+
66
+ print(f"Generated test set with {len(test_df)} samples across {len(get_all_language_pairs())} language pairs")
67
+
68
+ return test_df
69
+
70
+ def get_public_test_set() -> pd.DataFrame:
71
+ """Get public test set (sources only, no targets)."""
72
+
73
+ try:
74
+ # Try to load existing test set
75
+ dataset = load_dataset(TEST_SET_DATASET, split='train')
76
+ test_df = dataset.to_pandas()
77
+ print(f"Loaded existing test set with {len(test_df)} samples")
78
+
79
+ except Exception as e:
80
+ print(f"Could not load existing test set: {e}")
81
+ print("Generating new test set...")
82
+
83
+ # Generate new test set
84
+ test_df = generate_test_set()
85
+
86
+ # Save complete test set (with targets) privately
87
+ save_complete_test_set(test_df)
88
+
89
+ # Return public version (without targets)
90
+ public_columns = [
91
+ 'sample_id', 'source_text', 'source_language',
92
+ 'target_language', 'domain', 'google_comparable'
93
+ ]
94
+
95
+ return test_df[public_columns].copy()
96
+
97
+ def get_complete_test_set() -> pd.DataFrame:
98
+ """Get complete test set with targets (for evaluation)."""
99
+
100
+ try:
101
+ # Load from private storage or regenerate
102
+ dataset = load_dataset(TEST_SET_DATASET + "-private", split='train')
103
+ return dataset.to_pandas()
104
+
105
+ except Exception as e:
106
+ print(f"Regenerating complete test set: {e}")
107
+ return generate_test_set()
108
+
109
+ def save_complete_test_set(test_df: pd.DataFrame) -> bool:
110
+ """Save complete test set to HuggingFace dataset."""
111
+
112
+ try:
113
+ # Save public version (no targets)
114
+ public_df = test_df[[
115
+ 'sample_id', 'source_text', 'source_language',
116
+ 'target_language', 'domain', 'google_comparable'
117
+ ]].copy()
118
+
119
+ public_dataset = Dataset.from_pandas(public_df)
120
+ public_dataset.push_to_hub(
121
+ TEST_SET_DATASET,
122
+ token=HF_TOKEN,
123
+ commit_message="Update public test set"
124
+ )
125
+
126
+ # Save private version (with targets)
127
+ private_dataset = Dataset.from_pandas(test_df)
128
+ private_dataset.push_to_hub(
129
+ TEST_SET_DATASET + "-private",
130
+ token=HF_TOKEN,
131
+ private=True,
132
+ commit_message="Update private test set with targets"
133
+ )
134
+
135
+ print("Test sets saved successfully!")
136
+ return True
137
+
138
+ except Exception as e:
139
+ print(f"Error saving test sets: {e}")
140
+ return False
141
+
142
+ def create_test_set_download() -> Tuple[str, Dict]:
143
+ """Create downloadable test set file and statistics."""
144
+
145
+ public_test = get_public_test_set()
146
+
147
+ # Create download file
148
+ download_path = "salt_test_set.csv"
149
+ public_test.to_csv(download_path, index=False)
150
+
151
+ # Generate statistics
152
+ stats = {
153
+ 'total_samples': len(public_test),
154
+ 'language_pairs': len(public_test.groupby(['source_language', 'target_language'])),
155
+ 'google_comparable_samples': len(public_test[public_test['google_comparable'] == True]),
156
+ 'languages': list(set(public_test['source_language'].unique()) | set(public_test['target_language'].unique())),
157
+ 'domains': list(public_test['domain'].unique()) if 'domain' in public_test.columns else ['general']
158
+ }
159
+
160
+ return download_path, stats
161
+
162
+ def validate_test_set_integrity() -> Dict:
163
+ """Validate test set integrity and coverage."""
164
+
165
+ try:
166
+ public_test = get_public_test_set()
167
+ complete_test = get_complete_test_set()
168
+
169
+ # Check alignment
170
+ public_ids = set(public_test['sample_id'])
171
+ private_ids = set(complete_test['sample_id'])
172
+
173
+ coverage_by_pair = {}
174
+ for src in ALL_UG40_LANGUAGES:
175
+ for tgt in ALL_UG40_LANGUAGES:
176
+ if src != tgt:
177
+ pair_samples = public_test[
178
+ (public_test['source_language'] == src) &
179
+ (public_test['target_language'] == tgt)
180
+ ]
181
+
182
+ coverage_by_pair[f"{src}_{tgt}"] = {
183
+ 'count': len(pair_samples),
184
+ 'has_samples': len(pair_samples) >= MIN_SAMPLES_PER_PAIR
185
+ }
186
+
187
+ return {
188
+ 'alignment_check': len(public_ids - private_ids) == 0,
189
+ 'total_samples': len(public_test),
190
+ 'coverage_by_pair': coverage_by_pair,
191
+ 'missing_pairs': [k for k, v in coverage_by_pair.items() if not v['has_samples']]
192
+ }
193
+
194
+ except Exception as e:
195
+ return {'error': str(e)}