File size: 6,592 Bytes
313f83b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
"""
Token usage tracking for OpenAI API calls using tiktoken.
Provides accurate token counting and cost estimation.
"""
import tiktoken
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, field
from datetime import datetime
from ankigen_core.logging import logger
@dataclass
class TokenUsage:
"""Track token usage for a single request"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
estimated_cost: Optional[float]
model: str
timestamp: datetime = field(default_factory=datetime.now)
class TokenTracker:
"""Track token usage across multiple requests"""
def __init__(self):
self.usage_history: List[TokenUsage] = []
self.total_cost = 0.0
self.total_tokens = 0
def count_tokens_for_messages(
self, messages: List[Dict[str, str]], model: str
) -> int:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("o200k_base")
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3
return num_tokens
def count_tokens_for_text(self, text: str, model: str) -> int:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("o200k_base")
return len(encoding.encode(text))
def estimate_cost(
self, prompt_tokens: int, completion_tokens: int, model: str
) -> Optional[float]:
return None
def track_usage_from_response(
self, response_data, model: str
) -> Optional[TokenUsage]:
try:
if hasattr(response_data, "usage"):
usage = response_data.usage
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens
actual_cost = None
if hasattr(usage, "total_cost"):
actual_cost = usage.total_cost
elif hasattr(usage, "cost"):
actual_cost = usage.cost
return self.track_usage(
prompt_tokens, completion_tokens, model, actual_cost
)
return None
except Exception as e:
logger.error(f"Failed to track usage from response: {e}")
return None
def track_usage(
self,
prompt_tokens: int,
completion_tokens: int,
model: str,
actual_cost: Optional[float] = None,
) -> TokenUsage:
total_tokens = prompt_tokens + completion_tokens
if actual_cost is not None:
final_cost = actual_cost
else:
final_cost = self.estimate_cost(prompt_tokens, completion_tokens, model)
usage = TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
estimated_cost=final_cost,
model=model,
)
self.usage_history.append(usage)
if final_cost:
self.total_cost += final_cost
self.total_tokens += total_tokens
logger.info(
f"π° Token usage - Model: {model}, Prompt: {prompt_tokens}, Completion: {completion_tokens}, Cost: ${final_cost:.4f}"
if final_cost
else f"π° Token usage - Model: {model}, Prompt: {prompt_tokens}, Completion: {completion_tokens}"
)
return usage
def get_session_summary(self) -> Dict[str, Any]:
if not self.usage_history:
return {
"total_requests": 0,
"total_tokens": 0,
"total_cost": 0.0,
"by_model": {},
}
by_model = {}
for usage in self.usage_history:
if usage.model not in by_model:
by_model[usage.model] = {"requests": 0, "tokens": 0, "cost": 0.0}
by_model[usage.model]["requests"] += 1
by_model[usage.model]["tokens"] += usage.total_tokens
if usage.estimated_cost:
by_model[usage.model]["cost"] += usage.estimated_cost
return {
"total_requests": len(self.usage_history),
"total_tokens": self.total_tokens,
"total_cost": self.total_cost,
"by_model": by_model,
}
def get_session_usage(self) -> Dict[str, Any]:
return self.get_session_summary()
def reset_session(self):
self.usage_history.clear()
self.total_cost = 0.0
self.total_tokens = 0
logger.info("π Token usage tracking reset")
def track_usage_from_agents_sdk(
self, usage_dict: Dict[str, Any], model: str
) -> Optional[TokenUsage]:
"""Track usage from OpenAI Agents SDK usage format"""
try:
if not usage_dict or usage_dict.get("total_tokens", 0) == 0:
return None
prompt_tokens = usage_dict.get("input_tokens", 0)
completion_tokens = usage_dict.get("output_tokens", 0)
return self.track_usage(prompt_tokens, completion_tokens, model)
except Exception as e:
logger.error(f"Failed to track usage from agents SDK: {e}")
return None
# Global token tracker instance
_global_tracker = TokenTracker()
def get_token_tracker() -> TokenTracker:
return _global_tracker
def track_agent_usage(
prompt_text: str,
completion_text: str,
model: str,
actual_cost: Optional[float] = None,
) -> TokenUsage:
tracker = get_token_tracker()
prompt_tokens = tracker.count_tokens_for_text(prompt_text, model)
completion_tokens = tracker.count_tokens_for_text(completion_text, model)
return tracker.track_usage(prompt_tokens, completion_tokens, model, actual_cost)
def track_usage_from_openai_response(response_data, model: str) -> Optional[TokenUsage]:
tracker = get_token_tracker()
return tracker.track_usage_from_response(response_data, model)
def track_usage_from_agents_sdk(
usage_dict: Dict[str, Any], model: str
) -> Optional[TokenUsage]:
"""Track usage from OpenAI Agents SDK usage format"""
tracker = get_token_tracker()
return tracker.track_usage_from_agents_sdk(usage_dict, model)
|