ash-98 commited on
Commit
90c9a37
·
1 Parent(s): fcae37f

Initial test

Browse files
Files changed (4) hide show
  1. .streamlit/config.toml +5 -0
  2. requirements.txt +84 -0
  3. streamlit.py +176 -0
  4. utils.py +144 -0
.streamlit/config.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor="#01d2fc"
3
+ backgroundColor="#252040"
4
+ secondaryBackgroundColor="#262626"
5
+ textColor="#f4f4f4"
requirements.txt ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ gradio
3
+ requests
4
+ python-dotenv
5
+ tokonomics
6
+ aiofiles==23.2.1
7
+ altair==5.5.0
8
+ annotated-types==0.7.0
9
+ anyenv==0.4.11
10
+ anyio==4.9.0
11
+ appdirs==1.4.4
12
+ attrs==25.3.0
13
+ audioop-lts==0.2.1
14
+ blinker==1.9.0
15
+ cachetools==5.5.2
16
+ certifi==2025.1.31
17
+ charset-normalizer==3.4.1
18
+ click==8.1.8
19
+ fastapi==0.115.12
20
+ ffmpy==0.5.0
21
+ filelock==3.18.0
22
+ fsspec==2025.3.0
23
+ gitdb==4.0.12
24
+ GitPython==3.1.44
25
+ gradio==5.23.0
26
+ gradio_client==1.8.0
27
+ groovy==0.1.2
28
+ h11==0.14.0
29
+ hishel==0.1.1
30
+ httpcore==1.0.7
31
+ httpx==0.28.1
32
+ huggingface-hub==0.29.3
33
+ idna==3.10
34
+ Jinja2==3.1.6
35
+ jsonschema==4.23.0
36
+ jsonschema-specifications==2024.10.1
37
+ markdown-it-py==3.0.0
38
+ MarkupSafe==3.0.2
39
+ mdurl==0.1.2
40
+ narwhals==1.32.0
41
+ numpy==2.2.4
42
+ orjson==3.10.16
43
+ packaging==24.2
44
+ pandas==2.2.3
45
+ pillow==11.1.0
46
+ platformdirs==4.3.7
47
+ protobuf==5.29.4
48
+ pyarrow==19.0.1
49
+ pydantic==2.10.6
50
+ pydantic_core==2.27.2
51
+ pydeck==0.9.1
52
+ pydub==0.25.1
53
+ Pygments==2.19.1
54
+ python-dateutil==2.9.0.post0
55
+ python-dotenv==1.0.1
56
+ python-multipart==0.0.20
57
+ pytz==2025.2
58
+ PyYAML==6.0.2
59
+ referencing==0.36.2
60
+ requests==2.32.3
61
+ rich==13.9.4
62
+ rpds-py==0.23.1
63
+ ruff==0.11.2
64
+ safehttpx==0.1.6
65
+ semantic-version==2.10.0
66
+ shellingham==1.5.4
67
+ six==1.17.0
68
+ smmap==5.0.2
69
+ sniffio==1.3.1
70
+ starlette==0.46.1
71
+ streamlit==1.44.0
72
+ tenacity==9.0.0
73
+ tokonomics==0.3.9
74
+ toml==0.10.2
75
+ tomlkit==0.13.2
76
+ tornado==6.4.2
77
+ tqdm==4.67.1
78
+ typer==0.15.2
79
+ typing_extensions==4.12.2
80
+ tzdata==2025.2
81
+ urllib3==2.3.0
82
+ uvicorn==0.34.0
83
+ watchdog==6.0.0
84
+ websockets==15.0.1
streamlit.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import asyncio
3
+ import tokonomics
4
+ from utils import create_model_hierarchy
5
+
6
+ st.set_page_config(page_title="LLM Pricing App", layout="wide")
7
+
8
+ # --------------------------
9
+ # Async Data Loading Function
10
+ # --------------------------
11
+ async def load_data():
12
+ """Simulate loading data asynchronously."""
13
+ AVAILABLE_MODELS = await tokonomics.get_available_models()
14
+ hierarchy = create_model_hierarchy(AVAILABLE_MODELS)
15
+ FILTERED_MODELS = []
16
+ MODEL_PRICING = {}
17
+ PROVIDERS = list(hierarchy.keys())
18
+ for provider in PROVIDERS:
19
+ for model_family in hierarchy[provider]:
20
+ for model_version in hierarchy[provider][model_family].keys():
21
+ for region in hierarchy[provider][model_family][model_version]:
22
+ model_id = hierarchy[provider][model_family][model_version][region]
23
+ MODEL_PRICING[model_id] = await tokonomics.get_model_costs(model_id)
24
+ FILTERED_MODELS.append(model_id)
25
+ return FILTERED_MODELS, MODEL_PRICING, PROVIDERS
26
+
27
+ # --------------------------
28
+ # Provider Change Function
29
+ # --------------------------
30
+ def provider_change(provider, selected_type, all_types=["text", "vision", "video", "image"]):
31
+ """Filter models based on the selected provider and type."""
32
+ all_models = st.session_state.get("models", [])
33
+ new_models = []
34
+ others = [a_type for a_type in all_types if selected_type != a_type]
35
+ for model_name in all_models:
36
+ if provider in model_name:
37
+ if selected_type in model_name:
38
+ new_models.append(model_name)
39
+ elif any(other in model_name for other in others):
40
+ continue
41
+ else:
42
+ new_models.append(model_name)
43
+ return new_models if new_models else all_models
44
+
45
+ # --------------------------
46
+ # Estimate Cost Function (Updated)
47
+ # --------------------------
48
+ def estimate_cost(num_alerts, input_size, output_size, model_id):
49
+ pricing = st.session_state.get("pricing", {})
50
+ cost_token = pricing.get(model_id)
51
+ if not cost_token:
52
+ return "NA"
53
+ input_tokens = round(input_size * 1.3)
54
+ output_tokens = round(output_size * 1.3)
55
+ price_day = cost_token.get("input_cost_per_token", 0) * input_tokens + cost_token.get("output_cost_per_token", 0) * output_tokens
56
+ price_total = price_day * num_alerts
57
+ return f"""## Estimated Cost:
58
+
59
+ Day Price: {price_total:0.2f} USD
60
+ Month Price: {price_total * 31:0.2f} USD
61
+ Year Price: {price_total * 365:0.2f} USD
62
+ """
63
+
64
+ # --------------------------
65
+ # Load Data into Session State (only once)
66
+ # --------------------------
67
+ if "data_loaded" not in st.session_state:
68
+ with st.spinner("Loading pricing data..."):
69
+ models, pricing, providers = asyncio.run(load_data())
70
+ st.session_state["models"] = models
71
+ st.session_state["pricing"] = pricing
72
+ st.session_state["providers"] = providers
73
+ st.session_state["data_loaded"] = True
74
+
75
+ # --------------------------
76
+ # Sidebar
77
+ # --------------------------
78
+ with st.sidebar:
79
+ st.image("https://cdn.prod.website-files.com/630f558f2a15ca1e88a2f774/631f1436ad7a0605fecc5e15_Logo.svg", use_container_width=True)
80
+ st.divider()
81
+ st.sidebar.title("LLM Pricing Calculator")
82
+
83
+
84
+ # --------------------------
85
+ # Main Content Layout (Model Selection Tab)
86
+ # --------------------------
87
+ tab1, tab2 = st.tabs(["Model Selection", "About"])
88
+
89
+ with tab1:
90
+ st.header("LLM Pricing App")
91
+
92
+ # --- Row 1: Provider/Type and Model Selection ---
93
+ col_left, col_right = st.columns(2)
94
+ with col_left:
95
+ selected_provider = st.selectbox("Select a provider", st.session_state["providers"])
96
+ selected_type = st.radio("Select type", options=["text", "image"], index=0)
97
+ with col_right:
98
+ # Filter models based on the selected provider and type
99
+ filtered_models = provider_change(selected_provider, selected_type)
100
+ if filtered_models:
101
+ selected_model = st.selectbox("Select a model", options=filtered_models)
102
+ else:
103
+ selected_model = None
104
+ st.write("No models available")
105
+
106
+ # --- Row 2: Alert Stats ---
107
+ col1, col2, col3 = st.columns(3)
108
+ with col1:
109
+ num_alerts = st.number_input(
110
+ "Security Alerts Per Day",
111
+ value=100,
112
+ min_value=1,
113
+ step=1,
114
+ help="Number of security alerts to analyze daily"
115
+ )
116
+ with col2:
117
+ input_size = st.number_input(
118
+ "Alert Content Size (characters)",
119
+ value=1000,
120
+ min_value=1,
121
+ step=1,
122
+ help="Include logs, metadata, and context per alert"
123
+ )
124
+ with col3:
125
+ output_size = st.number_input(
126
+ "Analysis Output Size (characters)",
127
+ value=500,
128
+ min_value=1,
129
+ step=1,
130
+ help="Expected length of security analysis and recommendations"
131
+ )
132
+
133
+ # --- Row 3: Buttons ---
134
+ btn_col1, btn_col2 = st.columns(2)
135
+ with btn_col1:
136
+ if st.button("Estimate"):
137
+ if selected_model:
138
+ st.session_state["result"] = estimate_cost(num_alerts, input_size, output_size, selected_model)
139
+ else:
140
+ st.session_state["result"] = "No model selected."
141
+ with btn_col2:
142
+ if st.button("Refresh Pricing Data"):
143
+ with st.spinner("Refreshing pricing data..."):
144
+ models, pricing, providers = asyncio.run(load_data())
145
+ st.session_state["models"] = models
146
+ st.session_state["pricing"] = pricing
147
+ st.session_state["providers"] = providers
148
+ st.success("Pricing data refreshed!")
149
+
150
+ st.divider()
151
+ # --- Display Results ---
152
+ st.markdown("### Results")
153
+ if "result" in st.session_state:
154
+ st.write(st.session_state["result"])
155
+ else:
156
+ st.write("Use the buttons above to estimate costs.")
157
+
158
+ # --- Clear Button Below Results ---
159
+ if st.button("Clear"):
160
+ st.session_state.pop("result", None)
161
+ st.rerun()
162
+
163
+ with tab2:
164
+ st.markdown(
165
+ """
166
+ ## About This App
167
+
168
+ This is based on the tokonomics package.
169
+
170
+ - The app downloads the latest pricing from the LiteLLM repository.
171
+ - Using simple maths to estimate the total tokens.
172
+ - Version 0.1
173
+
174
+ Website: [https://www.priam.ai](https://www.priam.ai)
175
+ """
176
+ )
utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List,Dict
2
+ import re
3
+
4
+ def parse_model_entries(model_entries: List[str]) -> List[Dict[str, str]]:
5
+ """
6
+ Parse a list of model entries into structured dictionaries with provider, model name, version, region, and type.
7
+
8
+ Args:
9
+ model_entries: List of model entry strings as found in models.txt
10
+
11
+ Returns:
12
+ List of dictionaries with parsed model information containing keys:
13
+ - provider: Name of the provider (e.g., 'azure', 'openai', 'anthropic', etc.)
14
+ - model_name: Base name of the model
15
+ - version: Version of the model (if available)
16
+ - region: Deployment region (if available)
17
+ - model_type: Type of the model (text, image, audio based on pattern analysis)
18
+ """
19
+ parsed_models = []
20
+
21
+ # Common provider prefixes to identify
22
+ known_providers = [
23
+ 'azure', 'bedrock', 'anthropic', 'openai', 'cohere', 'google',
24
+ 'mistral', 'meta', 'amazon', 'ai21', 'anyscale', 'stability',
25
+ 'cloudflare', 'databricks', 'cerebras', 'assemblyai'
26
+ ]
27
+
28
+ # Image-related keywords to identify image models
29
+ image_indicators = ['dall-e', 'stable-diffusion', 'image', 'canvas', 'x-', 'steps']
30
+
31
+ # Audio-related keywords to identify audio models
32
+ audio_indicators = ['whisper', 'tts', 'audio', 'voice']
33
+
34
+ for entry in model_entries:
35
+ model_info = {
36
+ 'provider': '',
37
+ 'model_name': '',
38
+ 'version': '',
39
+ 'region': '',
40
+ 'model_type': 'text' # Default to text
41
+ }
42
+
43
+ # Check for image models
44
+ if any(indicator in entry.lower() for indicator in image_indicators):
45
+ model_info['model_type'] = 'image'
46
+
47
+ # Check for audio models
48
+ elif any(indicator in entry.lower() for indicator in audio_indicators):
49
+ model_info['model_type'] = 'audio'
50
+
51
+ # Parse the entry based on common patterns
52
+ parts = entry.split('/')
53
+
54
+ # Handle region and provider extraction
55
+ if len(parts) >= 2:
56
+ # Extract provider from the beginning (common pattern)
57
+ if parts[0].lower() in known_providers:
58
+ model_info['provider'] = parts[0].lower()
59
+
60
+ # For bedrock and azure, the region is often the next part
61
+ if parts[0].lower() in ['bedrock', 'azure'] and len(parts) >= 3:
62
+ # Skip commitment parts if present
63
+ if 'commitment' not in parts[1]:
64
+ model_info['region'] = parts[1]
65
+
66
+ # The last part typically contains the model name and possibly version
67
+ model_with_version = parts[-1]
68
+ else:
69
+ # For single-part entries
70
+ model_with_version = entry
71
+
72
+ # Extract provider from model name if not already set
73
+ if not model_info['provider']:
74
+ # Look for known providers within the model name
75
+ for provider in known_providers:
76
+ if provider in model_with_version.lower() or f'{provider}.' in model_with_version.lower():
77
+ model_info['provider'] = provider
78
+ # Remove provider prefix if it exists at the beginning
79
+ if model_with_version.lower().startswith(f'{provider}.'):
80
+ model_with_version = model_with_version[len(provider) + 1:]
81
+ break
82
+
83
+ # Extract version information
84
+ version_match = re.search(r'[:.-]v(\d+(?:\.\d+)*(?:-\d+)?|\d+)(?::\d+)?$', model_with_version)
85
+ if version_match:
86
+ model_info['version'] = version_match.group(1)
87
+ # Remove version from model name
88
+ model_name = model_with_version[:version_match.start()]
89
+ else:
90
+ # Look for date-based versions like 2024-08-06
91
+ date_match = re.search(r'-(\d{4}-\d{2}-\d{2})$', model_with_version)
92
+ if date_match:
93
+ model_info['version'] = date_match.group(1)
94
+ model_name = model_with_version[:date_match.start()]
95
+ else:
96
+ model_name = model_with_version
97
+
98
+ # Clean up model name by removing trailing/leading separators
99
+ model_info['model_name'] = model_name.strip('.-:')
100
+
101
+ parsed_models.append(model_info)
102
+
103
+ return parsed_models
104
+
105
+
106
+ def create_model_hierarchy(model_entries: List[str]) -> Dict[str, Dict[str, Dict[str, Dict[str, str]]]]:
107
+ """
108
+ Organize model entries into a nested dictionary structure by provider, model, version, and region.
109
+
110
+ Args:
111
+ model_entries: List of model entry strings as found in models.txt
112
+
113
+ Returns:
114
+ Nested dictionary with the structure:
115
+ Provider -> Model -> Version -> Region = full model string
116
+ If region or version is None, they are replaced with "NA".
117
+ """
118
+ # Parse the model entries to get structured information
119
+ parsed_models = parse_model_entries(model_entries)
120
+
121
+ # Create the nested dictionary structure
122
+ hierarchy = {}
123
+
124
+ for i, model_info in enumerate(parsed_models):
125
+ provider = model_info['provider'] if model_info['provider'] else 'unknown'
126
+ model_name = model_info['model_name']
127
+ version = model_info['version'] if model_info['version'] else 'NA'
128
+ # For Azure models, always use 'NA' as region since they are globally available
129
+ region = 'NA' if provider == 'azure' else (model_info['region'] if model_info['region'] else 'NA')
130
+
131
+ # Initialize nested dictionaries if they don't exist
132
+ if provider not in hierarchy:
133
+ hierarchy[provider] = {}
134
+
135
+ if model_name not in hierarchy[provider]:
136
+ hierarchy[provider][model_name] = {}
137
+
138
+ if version not in hierarchy[provider][model_name]:
139
+ hierarchy[provider][model_name][version] = {}
140
+
141
+ # Store the full model string at the leaf node
142
+ hierarchy[provider][model_name][version][region] = model_entries[i]
143
+
144
+ return hierarchy