mailpilot / app /llm /provider /bedrock_provider.py
Yadav122's picture
Initial deployment of MailPilot application
7a88b43
import threading
from langchain_core.messages import BaseMessage
from langchain_aws import ChatBedrock
from app.llm.token.token_manager import TokenManager
from app.core.config import settings
from app.llm.llm_interface import LLMInterface
class BedrockProvider(LLMInterface):
_instance = None
_lock = threading.Lock()
token_manager = TokenManager(token_limit=50000, reset_interval=30)
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self.model_id = settings.BEDROCK_MODEL_ID
self.aws_access_key = settings.AWS_ACCESS_KEY
self.aws_secret_key = settings.AWS_SECRET_KEY
self.aws_region = settings.AWS_REGION
self.provider = settings.BEDROCK_PROVIDER
# Initialize BedrockChat
self.llm = ChatBedrock(
model_id=self.model_id,
region_name=self.aws_region,
aws_access_key_id=self.aws_access_key,
aws_secret_access_key=self.aws_secret_key,
provider=self.provider,
streaming=False,
model_kwargs={
"temperature": 0.7,
"max_tokens": 2000
}
)
self._initialized = True
def query(self, messages: list[BaseMessage]) -> BaseMessage:
"""Query AWS Bedrock with messages"""
response = self.llm.invoke(messages)
self._track_tokens(response)
return response
async def aquery(self, messages: list[BaseMessage]) -> BaseMessage:
"""Asynchronous query method"""
response = await self.llm.ainvoke(messages)
self._track_tokens(response)
return response
def _track_tokens(self, response: BaseMessage) -> None:
"""Helper to track token usage"""
token_usage = response.response_metadata.get("token_usage", {}) if hasattr(response, "response_metadata") else {}
total_tokens = token_usage.get("total_tokens", 0)
self.token_manager.track_tokens(total_tokens)