k-mktr commited on
Commit
5e4594f
·
verified ·
1 Parent(s): 5d3a365

Upload model_stats.py

Browse files
Files changed (1) hide show
  1. model_stats.py +314 -0
model_stats.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import asyncio
3
+ from ollama import AsyncClient, ResponseError
4
+ from typing import Dict, List, Any
5
+ import time
6
+ from datetime import datetime
7
+ import logging
8
+ from tqdm import tqdm
9
+ import os
10
+
11
+ # Configure logging with more detailed format
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format='%(asctime)s - %(levelname)s - %(message)s',
15
+ handlers=[
16
+ logging.FileHandler('model_testing.log'),
17
+ logging.StreamHandler()
18
+ ]
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Ollama client configuration
23
+ OLLAMA_HOST = "http://20.7.14.25:11434"
24
+ client = AsyncClient(host=OLLAMA_HOST)
25
+
26
+ # Create results directory if it doesn't exist
27
+ RESULTS_DIR = "model_test_results"
28
+ os.makedirs(RESULTS_DIR, exist_ok=True)
29
+
30
+ async def check_model_exists(model_name: str) -> bool:
31
+ """Check if a model is already pulled."""
32
+ try:
33
+ # Try to get model info - if it exists, it will return successfully
34
+ await client.show(model_name)
35
+ return True
36
+ except ResponseError:
37
+ return False
38
+ except Exception as e:
39
+ logger.error(f"Error checking if model {model_name} exists: {e}")
40
+ return False
41
+
42
+ async def load_approved_models() -> List[str]:
43
+ """Load approved models from the JSON file."""
44
+ try:
45
+ with open('approved_models.json', 'r') as f:
46
+ data = json.load(f)
47
+ models = [model[0] for model in data['approved_models']]
48
+ logger.info(f"Successfully loaded {len(models)} models from approved_models.json")
49
+ return models
50
+ except Exception as e:
51
+ logger.error(f"Error loading approved models: {e}")
52
+ return []
53
+
54
+ async def pull_model(model_name: str) -> bool:
55
+ """Pull a model from Ollama."""
56
+ try:
57
+ logger.info(f"Starting to pull model: {model_name}")
58
+ start_time = time.time()
59
+ await client.pull(model_name)
60
+ end_time = time.time()
61
+ logger.info(f"Successfully pulled {model_name} in {end_time - start_time:.2f} seconds")
62
+ return True
63
+ except ResponseError as e:
64
+ logger.error(f"Error pulling model {model_name}: {e}")
65
+ return False
66
+ except Exception as e:
67
+ logger.error(f"Unexpected error while pulling {model_name}: {e}")
68
+ return False
69
+
70
+ async def check_loaded_models():
71
+ """Check if there are any models currently loaded in memory."""
72
+ try:
73
+ # Use ollama ps to check loaded models
74
+ ps_response = await client.ps()
75
+ if ps_response and hasattr(ps_response, 'models'):
76
+ logger.warning("Found loaded models in memory. Waiting for keep_alive to unload them...")
77
+ # Just log the loaded models, they will be unloaded by keep_alive: 0
78
+ for model in ps_response.models:
79
+ if model.name:
80
+ logger.info(f"Model currently loaded in memory: {model.name}")
81
+ logger.debug(f"Model details: size={model.size}, vram={model.size_vram}, params={model.details.parameter_size}")
82
+ except Exception as e:
83
+ logger.error(f"Error checking loaded models: {e}")
84
+
85
+ async def test_model(model_name: str) -> Dict[str, Any]:
86
+ """Test a model and collect performance stats."""
87
+ stats = {
88
+ "model_name": model_name,
89
+ "timestamp": datetime.now().isoformat(),
90
+ "success": False,
91
+ "error": None,
92
+ "performance": {},
93
+ "model_info": {}
94
+ }
95
+
96
+ try:
97
+ logger.info(f"Starting performance test for model: {model_name}")
98
+
99
+ # Test with a comprehensive prompt that should generate a longer response
100
+ prompt = """You are a creative writing assistant. Write a short story about a futuristic city where:
101
+ 1. The city is powered by a mysterious energy source
102
+ 2. The inhabitants have developed unique abilities
103
+ 3. There's a hidden conflict between different factions
104
+ 4. The protagonist discovers a shocking truth about the city's origins
105
+
106
+ Make the story engaging and include vivid descriptions of the city's architecture and technology."""
107
+
108
+ # First, generate a small response to ensure the model is loaded
109
+ await client.generate(
110
+ model=model_name,
111
+ prompt="test",
112
+ stream=False,
113
+ options={
114
+ "max_tokens": 1,
115
+ "keep_alive": 1 # Keep the model loaded
116
+ }
117
+ )
118
+
119
+ # Get model info while it's loaded
120
+ ps_response = await client.ps()
121
+ if ps_response and hasattr(ps_response, 'models'):
122
+ model_found = False
123
+ for model in ps_response.models:
124
+ if model.name == model_name:
125
+ model_found = True
126
+ stats["model_info"] = {
127
+ "size": model.size,
128
+ "size_vram": model.size_vram,
129
+ "parameter_size": model.details.parameter_size,
130
+ "quantization_level": model.details.quantization_level,
131
+ "format": model.details.format,
132
+ "family": model.details.family
133
+ }
134
+ logger.info(f"Found model info for {model_name}: {stats['model_info']}")
135
+ break
136
+ if not model_found:
137
+ logger.warning(f"Model {model_name} not found in ps response. Available models: {[m.name for m in ps_response.models]}")
138
+ else:
139
+ logger.warning(f"No models found in ps response")
140
+
141
+ start_time = time.time()
142
+
143
+ # Now generate the full response
144
+ response = await client.generate(
145
+ model=model_name,
146
+ prompt=prompt,
147
+ stream=False,
148
+ options={
149
+ "temperature": 0.7,
150
+ "top_p": 0.9,
151
+ "top_k": 40,
152
+ "max_tokens": 1000,
153
+ "repetition_penalty": 1.0,
154
+ "seed": 42,
155
+ "keep_alive": 0 # Ensure model is unloaded after generation
156
+ }
157
+ )
158
+
159
+ end_time = time.time()
160
+
161
+ # Calculate performance metrics
162
+ total_tokens = len(response.get("response", "").split())
163
+ total_time = end_time - start_time
164
+ tokens_per_second = total_tokens / total_time if total_time > 0 else 0
165
+ prompt_tokens = len(prompt.split())
166
+ generation_tokens = total_tokens - prompt_tokens
167
+
168
+ # Collect detailed performance metrics
169
+ stats["performance"] = {
170
+ "response_time": total_time,
171
+ "total_tokens": total_tokens,
172
+ "tokens_per_second": tokens_per_second,
173
+ "prompt_tokens": prompt_tokens,
174
+ "generation_tokens": generation_tokens,
175
+ "generation_tokens_per_second": generation_tokens / total_time if total_time > 0 else 0,
176
+ "response": response.get("response", ""),
177
+ "eval_count": response.get("eval_count", 0),
178
+ "eval_duration": response.get("eval_duration", 0),
179
+ "prompt_eval_duration": response.get("prompt_eval_duration", 0),
180
+ "total_duration": response.get("total_duration", 0),
181
+ }
182
+
183
+ stats["success"] = True
184
+ logger.info(f"Successfully tested {model_name}: {tokens_per_second:.2f} tokens/second")
185
+
186
+ except Exception as e:
187
+ stats["error"] = str(e)
188
+ logger.error(f"Error testing model {model_name}: {e}")
189
+
190
+ return stats
191
+
192
+ async def save_results(results: List[Dict[str, Any]], timestamp: str):
193
+ """Save results in multiple formats."""
194
+ # Save detailed results
195
+ detailed_path = os.path.join(RESULTS_DIR, f"model_stats_{timestamp}.json")
196
+ with open(detailed_path, 'w') as f:
197
+ json.dump(results, f, indent=2)
198
+ logger.info(f"Saved detailed results to {detailed_path}")
199
+
200
+ # Save summary results
201
+ summary = []
202
+ for result in results:
203
+ if result["success"]:
204
+ perf = result["performance"]
205
+ model_info = result["model_info"]
206
+ summary.append({
207
+ "model_name": result["model_name"],
208
+ "model_size": model_info.get("size", 0),
209
+ "vram_size": model_info.get("size_vram", 0),
210
+ "parameter_size": model_info.get("parameter_size", ""),
211
+ "quantization": model_info.get("quantization_level", ""),
212
+ "tokens_per_second": perf["tokens_per_second"],
213
+ "generation_tokens_per_second": perf["generation_tokens_per_second"],
214
+ "total_tokens": perf["total_tokens"],
215
+ "response_time": perf["response_time"],
216
+ "success": result["success"]
217
+ })
218
+
219
+ summary_path = os.path.join(RESULTS_DIR, f"model_stats_summary_{timestamp}.json")
220
+ with open(summary_path, 'w') as f:
221
+ json.dump(summary, f, indent=2)
222
+ logger.info(f"Saved summary results to {summary_path}")
223
+
224
+ # Log top performers
225
+ successful_results = [r for r in results if r["success"]]
226
+ if successful_results:
227
+ top_performers = sorted(
228
+ successful_results,
229
+ key=lambda x: x["performance"]["tokens_per_second"],
230
+ reverse=True
231
+ )[:5]
232
+ logger.info("\nTop 5 performers by tokens per second:")
233
+ for r in top_performers:
234
+ model_info = r["model_info"]
235
+ logger.info(f"{r['model_name']}:")
236
+ logger.info(f" Tokens/second: {r['performance']['tokens_per_second']:.2f}")
237
+ logger.info(f" VRAM Usage: {model_info.get('size_vram', 0)/1024/1024/1024:.2f} GB")
238
+ logger.info(f" Parameter Size: {model_info.get('parameter_size', 'N/A')}")
239
+ logger.info(f" Quantization: {model_info.get('quantization_level', 'N/A')}")
240
+ logger.info(" " + "-" * 30)
241
+
242
+ async def main():
243
+ """Main function to run the model testing process."""
244
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
245
+ logger.info("Starting model testing process")
246
+
247
+ # Load approved models
248
+ models = await load_approved_models()
249
+ if not models:
250
+ logger.error("No models loaded. Exiting.")
251
+ return
252
+
253
+ # Check and unload any models that might be already loaded
254
+ await check_loaded_models()
255
+
256
+ # Check which models need to be pulled
257
+ models_to_pull = []
258
+ for model in models:
259
+ if not await check_model_exists(model):
260
+ models_to_pull.append(model)
261
+
262
+ if models_to_pull:
263
+ logger.info(f"Found {len(models_to_pull)} models that need to be pulled:")
264
+ for model in models_to_pull:
265
+ logger.info(f"- {model}")
266
+
267
+ # Ask user if they want to pull missing models
268
+ while True:
269
+ response = input("\nDo you want to pull the missing models? (yes/no): ").lower()
270
+ if response in ['yes', 'no']:
271
+ break
272
+ print("Please answer 'yes' or 'no'")
273
+
274
+ if response == 'yes':
275
+ # Pull missing models with progress bar
276
+ logger.info("Starting model pulling phase")
277
+ for model in tqdm(models_to_pull, desc="Pulling models"):
278
+ await pull_model(model)
279
+ else:
280
+ logger.warning("Skipping model pulling. Some models may not be available for testing.")
281
+ # Filter out models that weren't pulled
282
+ models = [model for model in models if model not in models_to_pull]
283
+ if not models:
284
+ logger.error("No models available for testing. Exiting.")
285
+ return
286
+ else:
287
+ logger.info("All models are already pulled. Skipping pulling phase.")
288
+
289
+ # Test all models with progress bar
290
+ logger.info("Starting model testing phase")
291
+ results = []
292
+ for model in tqdm(models, desc="Testing models"):
293
+ # Check for any loaded models before testing
294
+ await check_loaded_models()
295
+
296
+ stats = await test_model(model)
297
+ results.append(stats)
298
+
299
+ # Save intermediate results after each model
300
+ await save_results(results, timestamp)
301
+
302
+ # Add sleep between model tests to ensure proper cleanup
303
+ logger.info("Waiting 3 seconds before next model test...")
304
+ await asyncio.sleep(3)
305
+
306
+ # Save final results
307
+ await save_results(results, timestamp)
308
+
309
+ # Log summary
310
+ successful_tests = sum(1 for r in results if r["success"])
311
+ logger.info(f"Model testing completed. {successful_tests}/{len(models)} models tested successfully")
312
+
313
+ if __name__ == "__main__":
314
+ asyncio.run(main())