import threading from langchain_core.messages import BaseMessage from langchain_aws import ChatBedrock from app.llm.token.token_manager import TokenManager from app.core.config import settings from app.llm.llm_interface import LLMInterface class BedrockProvider(LLMInterface): _instance = None _lock = threading.Lock() token_manager = TokenManager(token_limit=50000, reset_interval=30) def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if not self._initialized: self.model_id = settings.BEDROCK_MODEL_ID self.aws_access_key = settings.AWS_ACCESS_KEY self.aws_secret_key = settings.AWS_SECRET_KEY self.aws_region = settings.AWS_REGION self.provider = settings.BEDROCK_PROVIDER # Initialize BedrockChat self.llm = ChatBedrock( model_id=self.model_id, region_name=self.aws_region, aws_access_key_id=self.aws_access_key, aws_secret_access_key=self.aws_secret_key, provider=self.provider, streaming=False, model_kwargs={ "temperature": 0.7, "max_tokens": 2000 } ) self._initialized = True def query(self, messages: list[BaseMessage]) -> BaseMessage: """Query AWS Bedrock with messages""" response = self.llm.invoke(messages) self._track_tokens(response) return response async def aquery(self, messages: list[BaseMessage]) -> BaseMessage: """Asynchronous query method""" response = await self.llm.ainvoke(messages) self._track_tokens(response) return response def _track_tokens(self, response: BaseMessage) -> None: """Helper to track token usage""" token_usage = response.response_metadata.get("token_usage", {}) if hasattr(response, "response_metadata") else {} total_tokens = token_usage.get("total_tokens", 0) self.token_manager.track_tokens(total_tokens)