sizzlebop's picture
Upload 34 files
4d85aba verified
"""
🧠 LLM Client for CourseCrafter AI
Multi-provider LLM client with streaming support.
"""
import json
from typing import Dict, List, Any, Optional, AsyncGenerator
from dataclasses import dataclass
from abc import ABC, abstractmethod
import os
import openai
import anthropic
import google.generativeai as genai
from ..types import LLMProvider, StreamChunk
from ..utils.config import config
@dataclass
class Message:
"""Standard message format"""
role: str # "system", "user", "assistant"
content: str
class BaseLLMClient(ABC):
"""Abstract base class for LLM clients"""
def __init__(self, provider: LLMProvider):
self.provider = provider
self.config = config.get_llm_config(provider)
@abstractmethod
async def generate_stream(self, messages: List[Message]) -> AsyncGenerator[StreamChunk, None]:
"""Generate streaming response"""
pass
class OpenAIClient(BaseLLMClient):
"""OpenAI client with streaming support (works with OpenAI and compatible endpoints)"""
def __init__(self, provider: LLMProvider = "openai"):
super().__init__(provider)
# Build client kwargs
client_kwargs = {
"api_key": self.config.api_key or "dummy",
"timeout": self.config.timeout
}
# Add base_url for compatible endpoints
if hasattr(self.config, 'base_url') and self.config.base_url:
client_kwargs["base_url"] = self.config.base_url
self.client = openai.AsyncOpenAI(**client_kwargs)
def _format_messages(self, messages: List[Message]) -> List[Dict[str, Any]]:
"""Format messages for OpenAI"""
return [{"role": msg.role, "content": msg.content} for msg in messages]
async def generate_stream(self, messages: List[Message]) -> AsyncGenerator[StreamChunk, None]:
"""Generate streaming response from OpenAI"""
formatted_messages = self._format_messages(messages)
kwargs = {
"model": self.config.model,
"messages": formatted_messages,
"temperature": self.config.temperature,
"stream": True
}
if self.config.max_tokens:
kwargs["max_tokens"] = self.config.max_tokens
try:
stream = await self.client.chat.completions.create(**kwargs)
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta:
delta = chunk.choices[0].delta
if delta.content:
yield StreamChunk(
type="text",
content=delta.content
)
except Exception as e:
yield StreamChunk(
type="error",
content=f"OpenAI API error: {str(e)}"
)
class AnthropicClient(BaseLLMClient):
"""Anthropic client with streaming support"""
def __init__(self):
super().__init__("anthropic")
self.client = anthropic.AsyncAnthropic(
api_key=self.config.api_key,
timeout=self.config.timeout
)
def _format_messages(self, messages: List[Message]) -> tuple[List[Dict[str, Any]], Optional[str]]:
"""Format messages for Anthropic"""
formatted = []
system_message = None
for msg in messages:
if msg.role == "system":
system_message = msg.content
elif msg.role in ["user", "assistant"]:
formatted.append({
"role": msg.role,
"content": msg.content
})
return formatted, system_message
async def generate_stream(self, messages: List[Message]) -> AsyncGenerator[StreamChunk, None]:
"""Generate streaming response from Anthropic"""
formatted_messages, system_message = self._format_messages(messages)
kwargs = {
"model": self.config.model,
"messages": formatted_messages,
"temperature": self.config.temperature,
"stream": True
}
if system_message:
kwargs["system"] = system_message
if self.config.max_tokens:
kwargs["max_tokens"] = self.config.max_tokens
try:
stream = await self.client.messages.create(**kwargs)
async for chunk in stream:
if chunk.type == "content_block_delta":
if hasattr(chunk.delta, 'text'):
yield StreamChunk(
type="text",
content=chunk.delta.text
)
except Exception as e:
yield StreamChunk(
type="error",
content=f"Anthropic API error: {str(e)}"
)
class GoogleClient(BaseLLMClient):
"""Google Gemini client with streaming support"""
def __init__(self):
super().__init__("google")
genai.configure(api_key=self.config.api_key)
self.model = genai.GenerativeModel(self.config.model)
def _format_messages(self, messages: List[Message]) -> List[Dict[str, Any]]:
"""Format messages for Google"""
formatted = []
for msg in messages:
if msg.role == "system":
# Google handles system messages differently
formatted.append({
"role": "user",
"parts": [{"text": f"System: {msg.content}"}]
})
elif msg.role == "user":
formatted.append({
"role": "user",
"parts": [{"text": msg.content}]
})
elif msg.role == "assistant":
formatted.append({
"role": "model",
"parts": [{"text": msg.content}]
})
return formatted
async def generate_stream(self, messages: List[Message]) -> AsyncGenerator[StreamChunk, None]:
"""Generate streaming response from Google"""
formatted_messages = self._format_messages(messages)
generation_config = {
"temperature": self.config.temperature,
}
if self.config.max_tokens:
generation_config["max_output_tokens"] = self.config.max_tokens
try:
response = await self.model.generate_content_async(
formatted_messages,
generation_config=generation_config,
stream=True
)
async for chunk in response:
if chunk.text:
yield StreamChunk(
type="text",
content=chunk.text
)
except Exception as e:
yield StreamChunk(
type="error",
content=f"Google API error: {str(e)}"
)
class LlmClient:
"""
Unified LLM client that manages multiple providers
"""
def __init__(self):
self.clients = {}
self._initialize_clients()
def _initialize_clients(self):
"""Initialize available LLM clients"""
available_providers = config.get_available_llm_providers()
for provider in available_providers:
try:
if provider in ["openai", "openai_compatible"]:
self.clients[provider] = OpenAIClient(provider)
elif provider == "anthropic":
self.clients[provider] = AnthropicClient()
elif provider == "google":
self.clients[provider] = GoogleClient()
print(f"βœ… Initialized {provider} client")
except Exception as e:
print(f"❌ Failed to initialize {provider} client: {e}")
def update_provider_config(self, provider: str, api_key: str = None, **kwargs):
"""Update configuration for a specific provider and reinitialize client"""
# Update environment variables
if provider == "openai" and api_key:
os.environ["OPENAI_API_KEY"] = api_key
elif provider == "anthropic" and api_key:
os.environ["ANTHROPIC_API_KEY"] = api_key
elif provider == "google" and api_key:
os.environ["GOOGLE_API_KEY"] = api_key
elif provider == "openai_compatible":
if api_key:
os.environ["OPENAI_COMPATIBLE_API_KEY"] = api_key
if kwargs.get("base_url"):
os.environ["OPENAI_COMPATIBLE_BASE_URL"] = kwargs["base_url"]
if kwargs.get("model"):
os.environ["OPENAI_COMPATIBLE_MODEL"] = kwargs["model"]
# Reinitialize the specific client
try:
if provider in ["openai", "openai_compatible"]:
self.clients[provider] = OpenAIClient(provider)
elif provider == "anthropic":
self.clients[provider] = AnthropicClient()
elif provider == "google":
self.clients[provider] = GoogleClient()
print(f"βœ… Updated and reinitialized {provider} client")
return True
except Exception as e:
print(f"❌ Failed to reinitialize {provider} client: {e}")
return False
def get_available_providers(self) -> List[LLMProvider]:
"""Get list of available providers"""
return list(self.clients.keys())
def get_client(self, provider: LLMProvider) -> BaseLLMClient:
"""Get client for specific provider"""
if provider not in self.clients:
raise ValueError(f"Provider {provider} not available")
return self.clients[provider]
async def generate_stream(
self,
provider: LLMProvider,
messages: List[Message]
) -> AsyncGenerator[StreamChunk, None]:
"""Generate streaming response using specified provider"""
client = self.get_client(provider)
async for chunk in client.generate_stream(messages):
yield chunk