|
import threading |
|
from langchain_core.messages import BaseMessage |
|
from langchain_aws import ChatBedrock |
|
|
|
from app.llm.token.token_manager import TokenManager |
|
from app.core.config import settings |
|
from app.llm.llm_interface import LLMInterface |
|
|
|
|
|
class BedrockProvider(LLMInterface): |
|
_instance = None |
|
_lock = threading.Lock() |
|
token_manager = TokenManager(token_limit=50000, reset_interval=30) |
|
|
|
def __new__(cls): |
|
if cls._instance is None: |
|
with cls._lock: |
|
if cls._instance is None: |
|
cls._instance = super().__new__(cls) |
|
cls._instance._initialized = False |
|
return cls._instance |
|
|
|
def __init__(self): |
|
if not self._initialized: |
|
self.model_id = settings.BEDROCK_MODEL_ID |
|
self.aws_access_key = settings.AWS_ACCESS_KEY |
|
self.aws_secret_key = settings.AWS_SECRET_KEY |
|
self.aws_region = settings.AWS_REGION |
|
self.provider = settings.BEDROCK_PROVIDER |
|
|
|
|
|
self.llm = ChatBedrock( |
|
model_id=self.model_id, |
|
region_name=self.aws_region, |
|
aws_access_key_id=self.aws_access_key, |
|
aws_secret_access_key=self.aws_secret_key, |
|
provider=self.provider, |
|
streaming=False, |
|
model_kwargs={ |
|
"temperature": 0.7, |
|
"max_tokens": 2000 |
|
} |
|
) |
|
|
|
self._initialized = True |
|
|
|
def query(self, messages: list[BaseMessage]) -> BaseMessage: |
|
"""Query AWS Bedrock with messages""" |
|
response = self.llm.invoke(messages) |
|
self._track_tokens(response) |
|
return response |
|
|
|
async def aquery(self, messages: list[BaseMessage]) -> BaseMessage: |
|
"""Asynchronous query method""" |
|
response = await self.llm.ainvoke(messages) |
|
self._track_tokens(response) |
|
return response |
|
|
|
def _track_tokens(self, response: BaseMessage) -> None: |
|
"""Helper to track token usage""" |
|
token_usage = response.response_metadata.get("token_usage", {}) if hasattr(response, "response_metadata") else {} |
|
total_tokens = token_usage.get("total_tokens", 0) |
|
self.token_manager.track_tokens(total_tokens) |