Spaces:
Running
Running
Upload model_stats.py
Browse files- model_stats.py +314 -0
model_stats.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import asyncio
|
3 |
+
from ollama import AsyncClient, ResponseError
|
4 |
+
from typing import Dict, List, Any
|
5 |
+
import time
|
6 |
+
from datetime import datetime
|
7 |
+
import logging
|
8 |
+
from tqdm import tqdm
|
9 |
+
import os
|
10 |
+
|
11 |
+
# Configure logging with more detailed format
|
12 |
+
logging.basicConfig(
|
13 |
+
level=logging.INFO,
|
14 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
15 |
+
handlers=[
|
16 |
+
logging.FileHandler('model_testing.log'),
|
17 |
+
logging.StreamHandler()
|
18 |
+
]
|
19 |
+
)
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
# Ollama client configuration
|
23 |
+
OLLAMA_HOST = "http://20.7.14.25:11434"
|
24 |
+
client = AsyncClient(host=OLLAMA_HOST)
|
25 |
+
|
26 |
+
# Create results directory if it doesn't exist
|
27 |
+
RESULTS_DIR = "model_test_results"
|
28 |
+
os.makedirs(RESULTS_DIR, exist_ok=True)
|
29 |
+
|
30 |
+
async def check_model_exists(model_name: str) -> bool:
|
31 |
+
"""Check if a model is already pulled."""
|
32 |
+
try:
|
33 |
+
# Try to get model info - if it exists, it will return successfully
|
34 |
+
await client.show(model_name)
|
35 |
+
return True
|
36 |
+
except ResponseError:
|
37 |
+
return False
|
38 |
+
except Exception as e:
|
39 |
+
logger.error(f"Error checking if model {model_name} exists: {e}")
|
40 |
+
return False
|
41 |
+
|
42 |
+
async def load_approved_models() -> List[str]:
|
43 |
+
"""Load approved models from the JSON file."""
|
44 |
+
try:
|
45 |
+
with open('approved_models.json', 'r') as f:
|
46 |
+
data = json.load(f)
|
47 |
+
models = [model[0] for model in data['approved_models']]
|
48 |
+
logger.info(f"Successfully loaded {len(models)} models from approved_models.json")
|
49 |
+
return models
|
50 |
+
except Exception as e:
|
51 |
+
logger.error(f"Error loading approved models: {e}")
|
52 |
+
return []
|
53 |
+
|
54 |
+
async def pull_model(model_name: str) -> bool:
|
55 |
+
"""Pull a model from Ollama."""
|
56 |
+
try:
|
57 |
+
logger.info(f"Starting to pull model: {model_name}")
|
58 |
+
start_time = time.time()
|
59 |
+
await client.pull(model_name)
|
60 |
+
end_time = time.time()
|
61 |
+
logger.info(f"Successfully pulled {model_name} in {end_time - start_time:.2f} seconds")
|
62 |
+
return True
|
63 |
+
except ResponseError as e:
|
64 |
+
logger.error(f"Error pulling model {model_name}: {e}")
|
65 |
+
return False
|
66 |
+
except Exception as e:
|
67 |
+
logger.error(f"Unexpected error while pulling {model_name}: {e}")
|
68 |
+
return False
|
69 |
+
|
70 |
+
async def check_loaded_models():
|
71 |
+
"""Check if there are any models currently loaded in memory."""
|
72 |
+
try:
|
73 |
+
# Use ollama ps to check loaded models
|
74 |
+
ps_response = await client.ps()
|
75 |
+
if ps_response and hasattr(ps_response, 'models'):
|
76 |
+
logger.warning("Found loaded models in memory. Waiting for keep_alive to unload them...")
|
77 |
+
# Just log the loaded models, they will be unloaded by keep_alive: 0
|
78 |
+
for model in ps_response.models:
|
79 |
+
if model.name:
|
80 |
+
logger.info(f"Model currently loaded in memory: {model.name}")
|
81 |
+
logger.debug(f"Model details: size={model.size}, vram={model.size_vram}, params={model.details.parameter_size}")
|
82 |
+
except Exception as e:
|
83 |
+
logger.error(f"Error checking loaded models: {e}")
|
84 |
+
|
85 |
+
async def test_model(model_name: str) -> Dict[str, Any]:
|
86 |
+
"""Test a model and collect performance stats."""
|
87 |
+
stats = {
|
88 |
+
"model_name": model_name,
|
89 |
+
"timestamp": datetime.now().isoformat(),
|
90 |
+
"success": False,
|
91 |
+
"error": None,
|
92 |
+
"performance": {},
|
93 |
+
"model_info": {}
|
94 |
+
}
|
95 |
+
|
96 |
+
try:
|
97 |
+
logger.info(f"Starting performance test for model: {model_name}")
|
98 |
+
|
99 |
+
# Test with a comprehensive prompt that should generate a longer response
|
100 |
+
prompt = """You are a creative writing assistant. Write a short story about a futuristic city where:
|
101 |
+
1. The city is powered by a mysterious energy source
|
102 |
+
2. The inhabitants have developed unique abilities
|
103 |
+
3. There's a hidden conflict between different factions
|
104 |
+
4. The protagonist discovers a shocking truth about the city's origins
|
105 |
+
|
106 |
+
Make the story engaging and include vivid descriptions of the city's architecture and technology."""
|
107 |
+
|
108 |
+
# First, generate a small response to ensure the model is loaded
|
109 |
+
await client.generate(
|
110 |
+
model=model_name,
|
111 |
+
prompt="test",
|
112 |
+
stream=False,
|
113 |
+
options={
|
114 |
+
"max_tokens": 1,
|
115 |
+
"keep_alive": 1 # Keep the model loaded
|
116 |
+
}
|
117 |
+
)
|
118 |
+
|
119 |
+
# Get model info while it's loaded
|
120 |
+
ps_response = await client.ps()
|
121 |
+
if ps_response and hasattr(ps_response, 'models'):
|
122 |
+
model_found = False
|
123 |
+
for model in ps_response.models:
|
124 |
+
if model.name == model_name:
|
125 |
+
model_found = True
|
126 |
+
stats["model_info"] = {
|
127 |
+
"size": model.size,
|
128 |
+
"size_vram": model.size_vram,
|
129 |
+
"parameter_size": model.details.parameter_size,
|
130 |
+
"quantization_level": model.details.quantization_level,
|
131 |
+
"format": model.details.format,
|
132 |
+
"family": model.details.family
|
133 |
+
}
|
134 |
+
logger.info(f"Found model info for {model_name}: {stats['model_info']}")
|
135 |
+
break
|
136 |
+
if not model_found:
|
137 |
+
logger.warning(f"Model {model_name} not found in ps response. Available models: {[m.name for m in ps_response.models]}")
|
138 |
+
else:
|
139 |
+
logger.warning(f"No models found in ps response")
|
140 |
+
|
141 |
+
start_time = time.time()
|
142 |
+
|
143 |
+
# Now generate the full response
|
144 |
+
response = await client.generate(
|
145 |
+
model=model_name,
|
146 |
+
prompt=prompt,
|
147 |
+
stream=False,
|
148 |
+
options={
|
149 |
+
"temperature": 0.7,
|
150 |
+
"top_p": 0.9,
|
151 |
+
"top_k": 40,
|
152 |
+
"max_tokens": 1000,
|
153 |
+
"repetition_penalty": 1.0,
|
154 |
+
"seed": 42,
|
155 |
+
"keep_alive": 0 # Ensure model is unloaded after generation
|
156 |
+
}
|
157 |
+
)
|
158 |
+
|
159 |
+
end_time = time.time()
|
160 |
+
|
161 |
+
# Calculate performance metrics
|
162 |
+
total_tokens = len(response.get("response", "").split())
|
163 |
+
total_time = end_time - start_time
|
164 |
+
tokens_per_second = total_tokens / total_time if total_time > 0 else 0
|
165 |
+
prompt_tokens = len(prompt.split())
|
166 |
+
generation_tokens = total_tokens - prompt_tokens
|
167 |
+
|
168 |
+
# Collect detailed performance metrics
|
169 |
+
stats["performance"] = {
|
170 |
+
"response_time": total_time,
|
171 |
+
"total_tokens": total_tokens,
|
172 |
+
"tokens_per_second": tokens_per_second,
|
173 |
+
"prompt_tokens": prompt_tokens,
|
174 |
+
"generation_tokens": generation_tokens,
|
175 |
+
"generation_tokens_per_second": generation_tokens / total_time if total_time > 0 else 0,
|
176 |
+
"response": response.get("response", ""),
|
177 |
+
"eval_count": response.get("eval_count", 0),
|
178 |
+
"eval_duration": response.get("eval_duration", 0),
|
179 |
+
"prompt_eval_duration": response.get("prompt_eval_duration", 0),
|
180 |
+
"total_duration": response.get("total_duration", 0),
|
181 |
+
}
|
182 |
+
|
183 |
+
stats["success"] = True
|
184 |
+
logger.info(f"Successfully tested {model_name}: {tokens_per_second:.2f} tokens/second")
|
185 |
+
|
186 |
+
except Exception as e:
|
187 |
+
stats["error"] = str(e)
|
188 |
+
logger.error(f"Error testing model {model_name}: {e}")
|
189 |
+
|
190 |
+
return stats
|
191 |
+
|
192 |
+
async def save_results(results: List[Dict[str, Any]], timestamp: str):
|
193 |
+
"""Save results in multiple formats."""
|
194 |
+
# Save detailed results
|
195 |
+
detailed_path = os.path.join(RESULTS_DIR, f"model_stats_{timestamp}.json")
|
196 |
+
with open(detailed_path, 'w') as f:
|
197 |
+
json.dump(results, f, indent=2)
|
198 |
+
logger.info(f"Saved detailed results to {detailed_path}")
|
199 |
+
|
200 |
+
# Save summary results
|
201 |
+
summary = []
|
202 |
+
for result in results:
|
203 |
+
if result["success"]:
|
204 |
+
perf = result["performance"]
|
205 |
+
model_info = result["model_info"]
|
206 |
+
summary.append({
|
207 |
+
"model_name": result["model_name"],
|
208 |
+
"model_size": model_info.get("size", 0),
|
209 |
+
"vram_size": model_info.get("size_vram", 0),
|
210 |
+
"parameter_size": model_info.get("parameter_size", ""),
|
211 |
+
"quantization": model_info.get("quantization_level", ""),
|
212 |
+
"tokens_per_second": perf["tokens_per_second"],
|
213 |
+
"generation_tokens_per_second": perf["generation_tokens_per_second"],
|
214 |
+
"total_tokens": perf["total_tokens"],
|
215 |
+
"response_time": perf["response_time"],
|
216 |
+
"success": result["success"]
|
217 |
+
})
|
218 |
+
|
219 |
+
summary_path = os.path.join(RESULTS_DIR, f"model_stats_summary_{timestamp}.json")
|
220 |
+
with open(summary_path, 'w') as f:
|
221 |
+
json.dump(summary, f, indent=2)
|
222 |
+
logger.info(f"Saved summary results to {summary_path}")
|
223 |
+
|
224 |
+
# Log top performers
|
225 |
+
successful_results = [r for r in results if r["success"]]
|
226 |
+
if successful_results:
|
227 |
+
top_performers = sorted(
|
228 |
+
successful_results,
|
229 |
+
key=lambda x: x["performance"]["tokens_per_second"],
|
230 |
+
reverse=True
|
231 |
+
)[:5]
|
232 |
+
logger.info("\nTop 5 performers by tokens per second:")
|
233 |
+
for r in top_performers:
|
234 |
+
model_info = r["model_info"]
|
235 |
+
logger.info(f"{r['model_name']}:")
|
236 |
+
logger.info(f" Tokens/second: {r['performance']['tokens_per_second']:.2f}")
|
237 |
+
logger.info(f" VRAM Usage: {model_info.get('size_vram', 0)/1024/1024/1024:.2f} GB")
|
238 |
+
logger.info(f" Parameter Size: {model_info.get('parameter_size', 'N/A')}")
|
239 |
+
logger.info(f" Quantization: {model_info.get('quantization_level', 'N/A')}")
|
240 |
+
logger.info(" " + "-" * 30)
|
241 |
+
|
242 |
+
async def main():
|
243 |
+
"""Main function to run the model testing process."""
|
244 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
245 |
+
logger.info("Starting model testing process")
|
246 |
+
|
247 |
+
# Load approved models
|
248 |
+
models = await load_approved_models()
|
249 |
+
if not models:
|
250 |
+
logger.error("No models loaded. Exiting.")
|
251 |
+
return
|
252 |
+
|
253 |
+
# Check and unload any models that might be already loaded
|
254 |
+
await check_loaded_models()
|
255 |
+
|
256 |
+
# Check which models need to be pulled
|
257 |
+
models_to_pull = []
|
258 |
+
for model in models:
|
259 |
+
if not await check_model_exists(model):
|
260 |
+
models_to_pull.append(model)
|
261 |
+
|
262 |
+
if models_to_pull:
|
263 |
+
logger.info(f"Found {len(models_to_pull)} models that need to be pulled:")
|
264 |
+
for model in models_to_pull:
|
265 |
+
logger.info(f"- {model}")
|
266 |
+
|
267 |
+
# Ask user if they want to pull missing models
|
268 |
+
while True:
|
269 |
+
response = input("\nDo you want to pull the missing models? (yes/no): ").lower()
|
270 |
+
if response in ['yes', 'no']:
|
271 |
+
break
|
272 |
+
print("Please answer 'yes' or 'no'")
|
273 |
+
|
274 |
+
if response == 'yes':
|
275 |
+
# Pull missing models with progress bar
|
276 |
+
logger.info("Starting model pulling phase")
|
277 |
+
for model in tqdm(models_to_pull, desc="Pulling models"):
|
278 |
+
await pull_model(model)
|
279 |
+
else:
|
280 |
+
logger.warning("Skipping model pulling. Some models may not be available for testing.")
|
281 |
+
# Filter out models that weren't pulled
|
282 |
+
models = [model for model in models if model not in models_to_pull]
|
283 |
+
if not models:
|
284 |
+
logger.error("No models available for testing. Exiting.")
|
285 |
+
return
|
286 |
+
else:
|
287 |
+
logger.info("All models are already pulled. Skipping pulling phase.")
|
288 |
+
|
289 |
+
# Test all models with progress bar
|
290 |
+
logger.info("Starting model testing phase")
|
291 |
+
results = []
|
292 |
+
for model in tqdm(models, desc="Testing models"):
|
293 |
+
# Check for any loaded models before testing
|
294 |
+
await check_loaded_models()
|
295 |
+
|
296 |
+
stats = await test_model(model)
|
297 |
+
results.append(stats)
|
298 |
+
|
299 |
+
# Save intermediate results after each model
|
300 |
+
await save_results(results, timestamp)
|
301 |
+
|
302 |
+
# Add sleep between model tests to ensure proper cleanup
|
303 |
+
logger.info("Waiting 3 seconds before next model test...")
|
304 |
+
await asyncio.sleep(3)
|
305 |
+
|
306 |
+
# Save final results
|
307 |
+
await save_results(results, timestamp)
|
308 |
+
|
309 |
+
# Log summary
|
310 |
+
successful_tests = sum(1 for r in results if r["success"])
|
311 |
+
logger.info(f"Model testing completed. {successful_tests}/{len(models)} models tested successfully")
|
312 |
+
|
313 |
+
if __name__ == "__main__":
|
314 |
+
asyncio.run(main())
|