ash-98 commited on
Commit
e6ef9f1
·
1 Parent(s): 14f59dd
Files changed (5) hide show
  1. .gitattributes copy +0 -35
  2. .gitignore +1 -1
  3. README.md +1 -1
  4. app.py +111 -49
  5. utils_on.py +429 -0
.gitattributes copy DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore CHANGED
@@ -1 +1 @@
1
- Dockerfile
 
1
+ __pycache__
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Llm Pricing Calculator
3
  emoji: 🐢
4
  colorFrom: indigo
5
  colorTo: gray
 
1
  ---
2
+ title: LLM Pricing Calculator
3
  emoji: 🐢
4
  colorFrom: indigo
5
  colorTo: gray
app.py CHANGED
@@ -2,12 +2,14 @@ import streamlit as st
2
  import asyncio
3
  import tokonomics
4
  from utils import create_model_hierarchy
 
5
 
6
- st.set_page_config(page_title="LLM Pricing App", layout="wide")
7
 
8
  # --------------------------
9
  # Async Data Loading Function
10
  # --------------------------
 
11
  async def load_data():
12
  """Simulate loading data asynchronously."""
13
  AVAILABLE_MODELS = await tokonomics.get_available_models()
@@ -43,7 +45,7 @@ def provider_change(provider, selected_type, all_types=["text", "vision", "video
43
  return new_models if new_models else all_models
44
 
45
  # --------------------------
46
- # Estimate Cost Function (Updated)
47
  # --------------------------
48
  def estimate_cost(num_alerts, input_size, output_size, model_id):
49
  pricing = st.session_state.get("pricing", {})
@@ -79,21 +81,68 @@ if "data_loaded" not in st.session_state:
79
  with st.sidebar:
80
  st.image("https://cdn.prod.website-files.com/630f558f2a15ca1e88a2f774/631f1436ad7a0605fecc5e15_Logo.svg",
81
  use_container_width=True)
82
- st.markdown(
83
- """ Visit: [https://www.priam.ai](https://www.priam.ai)
84
- """
85
- )
86
  st.divider()
87
  st.sidebar.title("LLM Pricing Calculator")
88
 
89
  # --------------------------
90
- # Main Content Layout (Model Selection Tab)
91
  # --------------------------
92
- tab1, tab2 = st.tabs(["Model Selection", "About"])
 
 
 
 
93
 
94
- with tab1:
95
- st.header("LLM Pricing App")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
 
 
 
 
 
 
97
  # --- Row 1: Provider/Type and Model Selection ---
98
  col_left, col_right = st.columns(2)
99
  with col_left:
@@ -103,50 +152,27 @@ with tab1:
103
  index=st.session_state["providers"].index("azure") if "azure" in st.session_state["providers"] else 0
104
  )
105
  selected_type = st.radio("Select type", options=["text", "image"], index=0)
106
-
107
  with col_right:
108
- # Filter models based on the selected provider and type
109
  filtered_models = provider_change(selected_provider, selected_type)
110
-
111
  if filtered_models:
112
- # Force "gpt-4-turbo" as default if available; otherwise, default to the first model.
113
  default_model = "o1" if "o1" in filtered_models else filtered_models[0]
114
- selected_model = st.selectbox(
115
- "Select a model",
116
- options=filtered_models,
117
- index=filtered_models.index(default_model)
118
- )
119
  else:
120
  selected_model = None
121
  st.write("No models available")
122
-
123
  # --- Row 2: Alert Stats ---
124
  col1, col2, col3 = st.columns(3)
125
  with col1:
126
- num_alerts = st.number_input(
127
- "Security Alerts Per Day",
128
- value=100,
129
- min_value=1,
130
- step=1,
131
- help="Number of security alerts to analyze daily"
132
- )
133
  with col2:
134
- input_size = st.number_input(
135
- "Alert Content Size (characters)",
136
- value=1000,
137
- min_value=1,
138
- step=1,
139
- help="Include logs, metadata, and context per alert"
140
- )
141
  with col3:
142
- output_size = st.number_input(
143
- "Analysis Output Size (characters)",
144
- value=500,
145
- min_value=1,
146
- step=1,
147
- help="Expected length of security analysis and recommendations"
148
- )
149
-
150
  # --- Row 3: Buttons ---
151
  btn_col1, btn_col2 = st.columns(2)
152
  with btn_col1:
@@ -163,21 +189,34 @@ with tab1:
163
  st.session_state["pricing"] = pricing
164
  st.session_state["providers"] = providers
165
  st.success("Pricing data refreshed!")
166
-
167
  st.divider()
168
- # --- Display Results ---
169
  st.markdown("### Results")
170
  if "result" in st.session_state:
171
  st.write(st.session_state["result"])
172
  else:
173
  st.write("Use the buttons above to estimate costs.")
174
-
175
- # --- Clear Button Below Results ---
176
  if st.button("Clear"):
177
  st.session_state.pop("result", None)
178
- st.rerun()
179
 
180
- with tab2:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  st.markdown(
182
  """
183
  ## About This App
@@ -186,8 +225,31 @@ with tab2:
186
 
187
  - The app downloads the latest pricing from the LiteLLM repository.
188
  - Using simple maths to estimate the total tokens.
189
- - Version 0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  Website: [https://www.priam.ai](https://www.priam.ai)
192
  """
193
  )
 
 
 
 
 
 
 
 
 
 
 
2
  import asyncio
3
  import tokonomics
4
  from utils import create_model_hierarchy
5
+ from utils_on import analyze_hf_model # New import for On Premise Estimator functionality
6
 
7
+ st.set_page_config(page_title="LLM Pricing Calculator", layout="wide")
8
 
9
  # --------------------------
10
  # Async Data Loading Function
11
  # --------------------------
12
+
13
  async def load_data():
14
  """Simulate loading data asynchronously."""
15
  AVAILABLE_MODELS = await tokonomics.get_available_models()
 
45
  return new_models if new_models else all_models
46
 
47
  # --------------------------
48
+ # Estimate Cost Function
49
  # --------------------------
50
  def estimate_cost(num_alerts, input_size, output_size, model_id):
51
  pricing = st.session_state.get("pricing", {})
 
81
  with st.sidebar:
82
  st.image("https://cdn.prod.website-files.com/630f558f2a15ca1e88a2f774/631f1436ad7a0605fecc5e15_Logo.svg",
83
  use_container_width=True)
84
+ st.markdown("Visit: [https://www.priam.ai](https://www.priam.ai)")
 
 
 
85
  st.divider()
86
  st.sidebar.title("LLM Pricing Calculator")
87
 
88
  # --------------------------
89
+ # Pills Navigation (Using st.pills)
90
  # --------------------------
91
+ # st.pills creates a pill-style selection widget.
92
+ page = st.pills("Head",
93
+ options=["Model Selection", "On Premise Estimator", "About"],selection_mode="single",default="Model Selection",label_visibility="hidden",
94
+ #index=0 # Change index if you want a different default
95
+ )
96
 
97
+ # --------------------------
98
+ # Helper: Format Analysis Report
99
+ # --------------------------
100
+ def format_analysis_report(analysis_result: dict) -> str:
101
+ """Convert the raw analysis_result dict into a human-readable report."""
102
+ if "error" in analysis_result:
103
+ return f"**Error:** {analysis_result['error']}"
104
+
105
+ lines = []
106
+ lines.append(f"### Model Analysis Report for `{analysis_result.get('model_id', 'Unknown Model')}`\n")
107
+ lines.append(f"**Parameter Size:** {analysis_result.get('parameter_size', 'N/A')} Billion parameters\n")
108
+ lines.append(f"**Precision:** {analysis_result.get('precision', 'N/A')}\n")
109
+
110
+ vram = analysis_result.get("vram_requirements", {})
111
+ lines.append("#### VRAM Requirements:")
112
+ lines.append(f"- Model Size: {vram.get('model_size_gb', 0):.2f} GB")
113
+ lines.append(f"- KV Cache: {vram.get('kv_cache_gb', 0):.2f} GB")
114
+ lines.append(f"- Activations: {vram.get('activations_gb', 0):.2f} GB")
115
+ lines.append(f"- Overhead: {vram.get('overhead_gb', 0):.2f} GB")
116
+ lines.append(f"- **Total VRAM:** {vram.get('total_vram_gb', 0):.2f} GB\n")
117
+
118
+ compatible_gpus = analysis_result.get("compatible_gpus", [])
119
+ lines.append("#### Compatible GPUs:")
120
+ if compatible_gpus:
121
+ for gpu in compatible_gpus:
122
+ lines.append(f"- {gpu}")
123
+ else:
124
+ lines.append("- None found")
125
+ lines.append(f"\n**Largest Compatible GPU:** {analysis_result.get('largest_compatible_gpu', 'N/A')}\n")
126
+
127
+ #gpu_perf = analysis_result.get("gpu_performance", {})
128
+ #if gpu_perf:
129
+ # lines.append("#### GPU Performance:")
130
+ # for gpu, perf in gpu_perf.items():
131
+ # lines.append(f"**{gpu}:**")
132
+ # lines.append(f" - Tokens per Second: {perf.get('tokens_per_second', 0):.2f}")
133
+ # lines.append(f" - FLOPs per Token: {perf.get('flops_per_token', 0):.2f}")
134
+ # lines.append(f" - Effective TFLOPS: {perf.get('effective_tflops', 0):.2f}\n")
135
+ #else:
136
+ # lines.append("#### GPU Performance: N/A\n")
137
+
138
+ return "\n".join(lines)
139
 
140
+ # --------------------------
141
+ # Render Content Based on Selected Pill
142
+ # --------------------------
143
+ if page == "Model Selection":
144
+ st.divider()
145
+ st.header("LLM Pricing App")
146
  # --- Row 1: Provider/Type and Model Selection ---
147
  col_left, col_right = st.columns(2)
148
  with col_left:
 
152
  index=st.session_state["providers"].index("azure") if "azure" in st.session_state["providers"] else 0
153
  )
154
  selected_type = st.radio("Select type", options=["text", "image"], index=0)
 
155
  with col_right:
 
156
  filtered_models = provider_change(selected_provider, selected_type)
 
157
  if filtered_models:
 
158
  default_model = "o1" if "o1" in filtered_models else filtered_models[0]
159
+ selected_model = st.selectbox("Select a model", options=filtered_models, index=filtered_models.index(default_model))
 
 
 
 
160
  else:
161
  selected_model = None
162
  st.write("No models available")
163
+
164
  # --- Row 2: Alert Stats ---
165
  col1, col2, col3 = st.columns(3)
166
  with col1:
167
+ num_alerts = st.number_input("Security Alerts Per Day", value=100, min_value=1, step=1,
168
+ help="Number of security alerts to analyze daily")
 
 
 
 
 
169
  with col2:
170
+ input_size = st.number_input("Alert Content Size (characters)", value=1000, min_value=1, step=1,
171
+ help="Include logs, metadata, and context per alert")
 
 
 
 
 
172
  with col3:
173
+ output_size = st.number_input("Analysis Output Size (characters)", value=500, min_value=1, step=1,
174
+ help="Expected length of security analysis and recommendations")
175
+
 
 
 
 
 
176
  # --- Row 3: Buttons ---
177
  btn_col1, btn_col2 = st.columns(2)
178
  with btn_col1:
 
189
  st.session_state["pricing"] = pricing
190
  st.session_state["providers"] = providers
191
  st.success("Pricing data refreshed!")
192
+
193
  st.divider()
 
194
  st.markdown("### Results")
195
  if "result" in st.session_state:
196
  st.write(st.session_state["result"])
197
  else:
198
  st.write("Use the buttons above to estimate costs.")
199
+
 
200
  if st.button("Clear"):
201
  st.session_state.pop("result", None)
 
202
 
203
+ elif page == "On Premise Estimator":
204
+ st.divider()
205
+ st.header("On Premise Estimator")
206
+ st.markdown("Enter a Hugging Face model ID to perform an on premise analysis using the provided estimator.")
207
+ hf_model_id = st.text_input("Hugging Face Model ID", value="meta-llama/Llama-4-Scout-17B-16E")
208
+
209
+ if st.button("Analyze Model"):
210
+ with st.spinner("Analyzing model..."):
211
+ analysis_result = analyze_hf_model(hf_model_id)
212
+ st.session_state["analysis_result"] = analysis_result
213
+
214
+ if "analysis_result" in st.session_state:
215
+ report = format_analysis_report(st.session_state["analysis_result"])
216
+ st.markdown(report)
217
+
218
+ elif page == "About":
219
+ st.divider()
220
  st.markdown(
221
  """
222
  ## About This App
 
225
 
226
  - The app downloads the latest pricing from the LiteLLM repository.
227
  - Using simple maths to estimate the total tokens.
228
+ - Helps you estimate hardware requirements for running open-source large language models (LLMs) on-premise using only the model ID from Hugging Face.
229
+ - Latest Version 0.1
230
+
231
+ ---
232
+
233
+ ### 📌 Version History
234
+
235
+ | Version | Release Date | Key Feature Updates |
236
+ |--------|--------------|---------------------|
237
+ | `v1.1` | 2025-04-06 | Added On Premise Estimator Feature |
238
+ | `v1.0` | 2025-03-26 | Initial release with basic total tokens estimation |
239
+
240
+
241
+ ---
242
 
243
  Website: [https://www.priam.ai](https://www.priam.ai)
244
  """
245
  )
246
+ st.markdown(
247
+ """
248
+ ### Found a Bug?
249
+
250
+ If you encounter any issues or have feedback, please email to **[email protected]**
251
+
252
+ Your input helps us improve the app!
253
+ """
254
+ )
255
+
utils_on.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Tuple, Optional, Union
2
+ import re
3
+ import math
4
+ import requests
5
+ import numpy as np
6
+ from huggingface_hub import HfApi, ModelInfo
7
+ from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError
8
+
9
+ def parse_model_entries(model_entries: List[str]) -> List[Dict[str, str]]:
10
+ """
11
+ Parse a list of model entries into structured dictionaries with provider, model name, version, region, and type.
12
+
13
+ Args:
14
+ model_entries: List of model entry strings as found in models.txt
15
+
16
+ Returns:
17
+ List of dictionaries with parsed model information containing keys:
18
+ - provider: Name of the provider (e.g., 'azure', 'openai', 'anthropic', etc.)
19
+ - model_name: Base name of the model
20
+ - version: Version of the model (if available)
21
+ - region: Deployment region (if available)
22
+ - model_type: Type of the model (text, image, audio based on pattern analysis)
23
+ """
24
+ parsed_models = []
25
+
26
+ # Common provider prefixes to identify
27
+ known_providers = [
28
+ 'azure', 'bedrock', 'anthropic', 'openai', 'cohere', 'google',
29
+ 'mistral', 'meta', 'amazon', 'ai21', 'anyscale', 'stability',
30
+ 'cloudflare', 'databricks', 'cerebras', 'assemblyai'
31
+ ]
32
+
33
+ # Image-related keywords to identify image models
34
+ image_indicators = ['dall-e', 'stable-diffusion', 'image', 'canvas', 'x-', 'steps']
35
+
36
+ # Audio-related keywords to identify audio models
37
+ audio_indicators = ['whisper', 'tts', 'audio', 'voice']
38
+
39
+ for entry in model_entries:
40
+ model_info = {
41
+ 'provider': '',
42
+ 'model_name': '',
43
+ 'version': '',
44
+ 'region': '',
45
+ 'model_type': 'text' # Default to text
46
+ }
47
+
48
+ # Check for image models
49
+ if any(indicator in entry.lower() for indicator in image_indicators):
50
+ model_info['model_type'] = 'image'
51
+
52
+ # Check for audio models
53
+ elif any(indicator in entry.lower() for indicator in audio_indicators):
54
+ model_info['model_type'] = 'audio'
55
+
56
+ # Parse the entry based on common patterns
57
+ parts = entry.split('/')
58
+
59
+ # Handle region and provider extraction
60
+ if len(parts) >= 2:
61
+ # Extract provider from the beginning (common pattern)
62
+ if parts[0].lower() in known_providers:
63
+ model_info['provider'] = parts[0].lower()
64
+
65
+ # For bedrock and azure, the region is often the next part
66
+ if parts[0].lower() in ['bedrock', 'azure'] and len(parts) >= 3:
67
+ # Skip commitment parts if present
68
+ if 'commitment' not in parts[1]:
69
+ model_info['region'] = parts[1]
70
+
71
+ # The last part typically contains the model name and possibly version
72
+ model_with_version = parts[-1]
73
+ else:
74
+ # For single-part entries
75
+ model_with_version = entry
76
+
77
+ # Extract provider from model name if not already set
78
+ if not model_info['provider']:
79
+ # Look for known providers within the model name
80
+ for provider in known_providers:
81
+ if provider in model_with_version.lower() or f'{provider}.' in model_with_version.lower():
82
+ model_info['provider'] = provider
83
+ # Remove provider prefix if it exists at the beginning
84
+ if model_with_version.lower().startswith(f'{provider}.'):
85
+ model_with_version = model_with_version[len(provider) + 1:]
86
+ break
87
+
88
+ # Extract version information
89
+ version_match = re.search(r'[:.-]v(\d+(?:\.\d+)*(?:-\d+)?|\d+)(?::\d+)?$', model_with_version)
90
+ if version_match:
91
+ model_info['version'] = version_match.group(1)
92
+ # Remove version from model name
93
+ model_name = model_with_version[:version_match.start()]
94
+ else:
95
+ # Look for date-based versions like 2024-08-06
96
+ date_match = re.search(r'-(\d{4}-\d{2}-\d{2})$', model_with_version)
97
+ if date_match:
98
+ model_info['version'] = date_match.group(1)
99
+ model_name = model_with_version[:date_match.start()]
100
+ else:
101
+ model_name = model_with_version
102
+
103
+ # Clean up model name by removing trailing/leading separators
104
+ model_info['model_name'] = model_name.strip('.-:')
105
+
106
+ parsed_models.append(model_info)
107
+
108
+ return parsed_models
109
+
110
+
111
+ def create_model_hierarchy(model_entries: List[str]) -> Dict[str, Dict[str, Dict[str, Dict[str, str]]]]:
112
+ """
113
+ Organize model entries into a nested dictionary structure by provider, model, version, and region.
114
+
115
+ Args:
116
+ model_entries: List of model entry strings as found in models.txt
117
+
118
+ Returns:
119
+ Nested dictionary with the structure:
120
+ Provider -> Model -> Version -> Region = full model string
121
+ If region or version is None, they are replaced with "NA".
122
+ """
123
+ # Parse the model entries to get structured information
124
+ parsed_models = parse_model_entries(model_entries)
125
+
126
+ # Create the nested dictionary structure
127
+ hierarchy = {}
128
+
129
+ for i, model_info in enumerate(parsed_models):
130
+ provider = model_info['provider'] if model_info['provider'] else 'unknown'
131
+ model_name = model_info['model_name']
132
+ version = model_info['version'] if model_info['version'] else 'NA'
133
+ # For Azure models, always use 'NA' as region since they are globally available
134
+ region = 'NA' if provider == 'azure' else (model_info['region'] if model_info['region'] else 'NA')
135
+
136
+ # Initialize nested dictionaries if they don't exist
137
+ if provider not in hierarchy:
138
+ hierarchy[provider] = {}
139
+
140
+ if model_name not in hierarchy[provider]:
141
+ hierarchy[provider][model_name] = {}
142
+
143
+ if version not in hierarchy[provider][model_name]:
144
+ hierarchy[provider][model_name][version] = {}
145
+
146
+ # Store the full model string at the leaf node
147
+ hierarchy[provider][model_name][version][region] = model_entries[i]
148
+
149
+ return hierarchy
150
+
151
+
152
+ # NVIDIA GPU specifications - Name: (VRAM in GB, FP16 TOPS)
153
+ NVIDIA_GPUS = {
154
+ "RTX 3050": (8, 18),
155
+ "RTX 3060": (12, 25),
156
+ "RTX 3070": (8, 40),
157
+ "RTX 3080": (10, 58),
158
+ "RTX 3090": (24, 71),
159
+ "RTX 4060": (8, 41),
160
+ "RTX 4070": (12, 56),
161
+ "RTX 4080": (16, 113),
162
+ "RTX 4090": (24, 165),
163
+ "RTX A2000": (6, 20),
164
+ "RTX A4000": (16, 40),
165
+ "RTX A5000": (24, 64),
166
+ "RTX A6000": (48, 75),
167
+ "A100 40GB": (40, 312),
168
+ "A100 80GB": (80, 312),
169
+ "H100 80GB": (80, 989),
170
+ }
171
+
172
+
173
+ def get_hf_model_info(model_id: str) -> Optional[ModelInfo]:
174
+ """
175
+ Retrieve model information from the Hugging Face Hub.
176
+
177
+ Args:
178
+ model_id: Hugging Face model ID (e.g., "facebook/opt-1.3b")
179
+
180
+ Returns:
181
+ ModelInfo object or None if model not found
182
+ """
183
+ try:
184
+ api = HfApi()
185
+ model_info = api.model_info(model_id)
186
+ return model_info
187
+ except (RepositoryNotFoundError, RevisionNotFoundError) as e:
188
+ print(f"Error fetching model info: {e}")
189
+ return None
190
+
191
+
192
+ def extract_model_size(model_info: ModelInfo) -> Optional[Tuple[float, str]]:
193
+ """
194
+ Extract the parameter size and precision from model information.
195
+
196
+ Args:
197
+ model_info: ModelInfo object from Hugging Face Hub
198
+
199
+ Returns:
200
+ Tuple of (parameter size in billions, precision) or None if not found
201
+ """
202
+ # Try to get parameter count from model card
203
+ if model_info.card_data is not None:
204
+ if "model-index" in model_info.card_data and isinstance(model_info.card_data["model-index"], list):
205
+ for item in model_info.card_data["model-index"]:
206
+ if "parameters" in item:
207
+ return float(item["parameters"]) / 1e9, "fp16" # Convert to billions and assume fp16
208
+
209
+ # Try to extract from model name
210
+ name = model_info.id.lower()
211
+ size_patterns = [
212
+ r"(\d+(\.\d+)?)b", # matches patterns like "1.3b" or "7b"
213
+ r"-(\d+(\.\d+)?)b", # matches patterns like "llama-7b"
214
+ r"(\d+(\.\d+)?)-b", # matches other formatting variations
215
+ ]
216
+
217
+ for pattern in size_patterns:
218
+ match = re.search(pattern, name)
219
+ if match:
220
+ size_str = match.group(1)
221
+ return float(size_str), "fp16" # Default to fp16
222
+
223
+ # Extract precision if available
224
+ precision = "fp16" # Default
225
+ precision_patterns = {"fp16": r"fp16", "int8": r"int8", "int4": r"int4", "fp32": r"fp32"}
226
+ for prec, pattern in precision_patterns.items():
227
+ if re.search(pattern, name):
228
+ precision = prec
229
+ break
230
+
231
+ # If couldn't determine size, check sibling models or readme
232
+ if model_info.siblings:
233
+ for sibling in model_info.siblings:
234
+ if sibling.rfilename == "README.md" and sibling.size < 100000: # reasonable size for readme
235
+ try:
236
+ content = requests.get(sibling.lfs.url).text
237
+ param_pattern = r"(\d+(\.\d+)?)\s*[Bb](illion)?\s*[Pp]arameters"
238
+ match = re.search(param_pattern, content)
239
+ if match:
240
+ return float(match.group(1)), precision
241
+ except:
242
+ pass
243
+
244
+ # As a last resort, try to analyze config.json if it exists
245
+ config_sibling = next((s for s in model_info.siblings if s.rfilename == "config.json"), None)
246
+ if config_sibling:
247
+ try:
248
+ config = requests.get(config_sibling.lfs.url).json()
249
+ if "n_params" in config:
250
+ return float(config["n_params"]) / 1e9, precision
251
+ # Calculate from architecture if available
252
+ if all(k in config for k in ["n_layer", "n_head", "n_embd"]):
253
+ n_layer = config["n_layer"]
254
+ n_embd = config["n_embd"]
255
+ n_head = config["n_head"]
256
+ # Transformer parameter estimation formula
257
+ params = 12 * n_layer * (n_embd**2) * (1 + 13 / (12 * n_embd))
258
+ return params / 1e9, precision
259
+ except:
260
+ pass
261
+
262
+ return None
263
+
264
+
265
+ def calculate_vram_requirements(param_size: float, precision: str = "fp16") -> Dict[str, float]:
266
+ """
267
+ Calculate VRAM requirements for inference using the EleutherAI transformer math formula.
268
+
269
+ Args:
270
+ param_size: Model size in billions of parameters
271
+ precision: Model precision ("fp32", "fp16", "int8", "int4")
272
+
273
+ Returns:
274
+ Dictionary with various memory requirements in GB
275
+ """
276
+ # Convert parameters to actual count
277
+ param_count = param_size * 1e9
278
+
279
+ # Size per parameter based on precision
280
+ bytes_per_param = {
281
+ "fp32": 4,
282
+ "fp16": 2,
283
+ "int8": 1,
284
+ "int4": 0.5, # 4 bits = 0.5 bytes
285
+ }[precision]
286
+
287
+ # Base model size (parameters * bytes per parameter)
288
+ model_size_gb = (param_count * bytes_per_param) / (1024**3)
289
+
290
+ # EleutherAI formula components for inference memory
291
+ # Layer activations - scales with sequence length
292
+ activation_factor = 1.2 # varies by architecture
293
+
294
+ # KV cache size (scales with batch size and sequence length)
295
+ # Estimate for single batch, 2048-token context
296
+ kv_cache_size_gb = (param_count * 0.0625 * bytes_per_param) / (1024**3) # ~6.25% of params for KV cache
297
+
298
+ # Total VRAM needed for inference
299
+ total_inference_gb = model_size_gb + (model_size_gb * activation_factor) + kv_cache_size_gb
300
+
301
+ # Add overhead for CUDA, buffers, and fragmentation
302
+ overhead_gb = 0.8 # 800 MB overhead
303
+
304
+ # Dynamic computation graph allocation
305
+ compute_overhead_factor = 0.1 # varies based on attention computation method
306
+
307
+ # Final VRAM estimate
308
+ total_vram_required_gb = total_inference_gb + overhead_gb + (total_inference_gb * compute_overhead_factor)
309
+
310
+ return {
311
+ "model_size_gb": model_size_gb,
312
+ "kv_cache_gb": kv_cache_size_gb,
313
+ "activations_gb": model_size_gb * activation_factor,
314
+ "overhead_gb": overhead_gb + (total_inference_gb * compute_overhead_factor),
315
+ "total_vram_gb": total_vram_required_gb
316
+ }
317
+
318
+
319
+ def find_compatible_gpus(vram_required: float) -> List[str]:
320
+ """
321
+ Find NVIDIA GPUs that can run a model requiring the specified VRAM.
322
+
323
+ Args:
324
+ vram_required: Required VRAM in GB
325
+
326
+ Returns:
327
+ List of compatible GPU names sorted by VRAM capacity (smallest first)
328
+ """
329
+ compatible_gpus = [(name, specs[0]) for name, specs in NVIDIA_GPUS.items() if specs[0] >= vram_required]
330
+ return [gpu[0] for gpu in sorted(compatible_gpus, key=lambda x: x[1])]
331
+
332
+
333
+ def estimate_performance(param_size: float, precision: str, gpu_name: str) -> Dict[str, float]:
334
+ """
335
+ Estimate token/second performance for a model on a specific GPU.
336
+
337
+ Args:
338
+ param_size: Model size in billions of parameters
339
+ precision: Model precision
340
+ gpu_name: Name of the NVIDIA GPU
341
+
342
+ Returns:
343
+ Dictionary with performance metrics
344
+ """
345
+ if gpu_name not in NVIDIA_GPUS:
346
+ return {"tokens_per_second": 0, "tflops_utilization": 0}
347
+
348
+ gpu_vram, gpu_tops = NVIDIA_GPUS[gpu_name]
349
+
350
+ # Calculate FLOPs per token (based on model size)
351
+ # Formula: ~6 * num_parameters FLOPs per token (inference)
352
+ flops_per_token = 6 * param_size * 1e9
353
+
354
+ # Convert TOPS to TFLOPS based on precision
355
+ precision_factor = 1.0 if precision == "fp32" else 2.0 if precision == "fp16" else 4.0 if precision in ["int8", "int4"] else 1.0
356
+ gpu_tflops = gpu_tops * precision_factor
357
+
358
+ # Practical utilization (GPUs rarely achieve 100% of theoretical performance)
359
+ practical_utilization = 0.6 # 60% utilization
360
+
361
+ # Calculate tokens per second
362
+ effective_tflops = gpu_tflops * practical_utilization
363
+ tokens_per_second = (effective_tflops * 1e12) / flops_per_token
364
+
365
+ return {
366
+ "tokens_per_second": tokens_per_second,
367
+ "flops_per_token": flops_per_token,
368
+ "tflops_utilization": practical_utilization,
369
+ "effective_tflops": effective_tflops
370
+ }
371
+
372
+
373
+ def analyze_hf_model(model_id: str) -> Dict[str, any]:
374
+ """
375
+ Comprehensive analysis of a Hugging Face model:
376
+ - Downloads model information
377
+ - Extracts parameter size and precision
378
+ - Estimates VRAM requirements
379
+ - Identifies compatible NVIDIA GPUs
380
+ - Estimates performance on these GPUs
381
+
382
+ Args:
383
+ model_id: Hugging Face model ID (e.g., "facebook/opt-1.3b")
384
+
385
+ Returns:
386
+ Dictionary with analysis results or error message
387
+ """
388
+ # Get model information
389
+ model_info = get_hf_model_info(model_id)
390
+ if not model_info:
391
+ return {"error": f"Model {model_id} not found on Hugging Face"}
392
+
393
+ # Extract model size and precision
394
+ size_info = extract_model_size(model_info)
395
+ if not size_info:
396
+ return {"error": f"Couldn't determine parameter count for {model_id}"}
397
+
398
+ param_size, precision = size_info
399
+
400
+ # Calculate VRAM requirements
401
+ vram_requirements = calculate_vram_requirements(param_size, precision)
402
+ total_vram_gb = vram_requirements["total_vram_gb"]
403
+
404
+ # Find compatible GPUs
405
+ compatible_gpus = find_compatible_gpus(total_vram_gb)
406
+
407
+ # Calculate performance for each compatible GPU
408
+ gpu_performance = {}
409
+ for gpu in compatible_gpus:
410
+ gpu_performance[gpu] = estimate_performance(param_size, precision, gpu)
411
+
412
+ # Determine the largest GPU that can run the model
413
+ largest_compatible_gpu = compatible_gpus[-1] if compatible_gpus else None
414
+
415
+ return {
416
+ "model_id": model_id,
417
+ "parameter_size": param_size, # in billions
418
+ "precision": precision,
419
+ "vram_requirements": vram_requirements,
420
+ "compatible_gpus": compatible_gpus,
421
+ "largest_compatible_gpu": largest_compatible_gpu,
422
+ "gpu_performance": gpu_performance,
423
+ #"model_info": {
424
+ #"description": model_info.description,
425
+ #"tags": model_info.tags,
426
+ #"downloads": model_info.downloads,
427
+ #"library": getattr(model_info, "library", None)
428
+ #}
429
+ }