|
""" |
|
π§ LLM Client for CourseCrafter AI |
|
|
|
Multi-provider LLM client with streaming support. |
|
""" |
|
|
|
import json |
|
from typing import Dict, List, Any, Optional, AsyncGenerator |
|
from dataclasses import dataclass |
|
from abc import ABC, abstractmethod |
|
import os |
|
|
|
import openai |
|
import anthropic |
|
import google.generativeai as genai |
|
|
|
from ..types import LLMProvider, StreamChunk |
|
from ..utils.config import config |
|
|
|
|
|
@dataclass |
|
class Message: |
|
"""Standard message format""" |
|
role: str |
|
content: str |
|
|
|
|
|
class BaseLLMClient(ABC): |
|
"""Abstract base class for LLM clients""" |
|
|
|
def __init__(self, provider: LLMProvider): |
|
self.provider = provider |
|
self.config = config.get_llm_config(provider) |
|
|
|
@abstractmethod |
|
async def generate_stream(self, messages: List[Message]) -> AsyncGenerator[StreamChunk, None]: |
|
"""Generate streaming response""" |
|
pass |
|
|
|
|
|
class OpenAIClient(BaseLLMClient): |
|
"""OpenAI client with streaming support (works with OpenAI and compatible endpoints)""" |
|
|
|
def __init__(self, provider: LLMProvider = "openai"): |
|
super().__init__(provider) |
|
|
|
|
|
client_kwargs = { |
|
"api_key": self.config.api_key or "dummy", |
|
"timeout": self.config.timeout |
|
} |
|
|
|
|
|
if hasattr(self.config, 'base_url') and self.config.base_url: |
|
client_kwargs["base_url"] = self.config.base_url |
|
|
|
self.client = openai.AsyncOpenAI(**client_kwargs) |
|
|
|
def _format_messages(self, messages: List[Message]) -> List[Dict[str, Any]]: |
|
"""Format messages for OpenAI""" |
|
return [{"role": msg.role, "content": msg.content} for msg in messages] |
|
|
|
async def generate_stream(self, messages: List[Message]) -> AsyncGenerator[StreamChunk, None]: |
|
"""Generate streaming response from OpenAI""" |
|
|
|
formatted_messages = self._format_messages(messages) |
|
|
|
kwargs = { |
|
"model": self.config.model, |
|
"messages": formatted_messages, |
|
"temperature": self.config.temperature, |
|
"stream": True |
|
} |
|
|
|
if self.config.max_tokens: |
|
kwargs["max_tokens"] = self.config.max_tokens |
|
|
|
try: |
|
stream = await self.client.chat.completions.create(**kwargs) |
|
|
|
async for chunk in stream: |
|
if chunk.choices and chunk.choices[0].delta: |
|
delta = chunk.choices[0].delta |
|
|
|
if delta.content: |
|
yield StreamChunk( |
|
type="text", |
|
content=delta.content |
|
) |
|
|
|
except Exception as e: |
|
yield StreamChunk( |
|
type="error", |
|
content=f"OpenAI API error: {str(e)}" |
|
) |
|
|
|
|
|
|
|
|
|
class AnthropicClient(BaseLLMClient): |
|
"""Anthropic client with streaming support""" |
|
|
|
def __init__(self): |
|
super().__init__("anthropic") |
|
self.client = anthropic.AsyncAnthropic( |
|
api_key=self.config.api_key, |
|
timeout=self.config.timeout |
|
) |
|
|
|
def _format_messages(self, messages: List[Message]) -> tuple[List[Dict[str, Any]], Optional[str]]: |
|
"""Format messages for Anthropic""" |
|
formatted = [] |
|
system_message = None |
|
|
|
for msg in messages: |
|
if msg.role == "system": |
|
system_message = msg.content |
|
elif msg.role in ["user", "assistant"]: |
|
formatted.append({ |
|
"role": msg.role, |
|
"content": msg.content |
|
}) |
|
|
|
return formatted, system_message |
|
|
|
async def generate_stream(self, messages: List[Message]) -> AsyncGenerator[StreamChunk, None]: |
|
"""Generate streaming response from Anthropic""" |
|
|
|
formatted_messages, system_message = self._format_messages(messages) |
|
|
|
kwargs = { |
|
"model": self.config.model, |
|
"messages": formatted_messages, |
|
"temperature": self.config.temperature, |
|
"stream": True |
|
} |
|
|
|
if system_message: |
|
kwargs["system"] = system_message |
|
|
|
if self.config.max_tokens: |
|
kwargs["max_tokens"] = self.config.max_tokens |
|
|
|
try: |
|
stream = await self.client.messages.create(**kwargs) |
|
|
|
async for chunk in stream: |
|
if chunk.type == "content_block_delta": |
|
if hasattr(chunk.delta, 'text'): |
|
yield StreamChunk( |
|
type="text", |
|
content=chunk.delta.text |
|
) |
|
|
|
except Exception as e: |
|
yield StreamChunk( |
|
type="error", |
|
content=f"Anthropic API error: {str(e)}" |
|
) |
|
|
|
|
|
class GoogleClient(BaseLLMClient): |
|
"""Google Gemini client with streaming support""" |
|
|
|
def __init__(self): |
|
super().__init__("google") |
|
genai.configure(api_key=self.config.api_key) |
|
self.model = genai.GenerativeModel(self.config.model) |
|
|
|
def _format_messages(self, messages: List[Message]) -> List[Dict[str, Any]]: |
|
"""Format messages for Google""" |
|
formatted = [] |
|
|
|
for msg in messages: |
|
if msg.role == "system": |
|
|
|
formatted.append({ |
|
"role": "user", |
|
"parts": [{"text": f"System: {msg.content}"}] |
|
}) |
|
elif msg.role == "user": |
|
formatted.append({ |
|
"role": "user", |
|
"parts": [{"text": msg.content}] |
|
}) |
|
elif msg.role == "assistant": |
|
formatted.append({ |
|
"role": "model", |
|
"parts": [{"text": msg.content}] |
|
}) |
|
|
|
return formatted |
|
|
|
async def generate_stream(self, messages: List[Message]) -> AsyncGenerator[StreamChunk, None]: |
|
"""Generate streaming response from Google""" |
|
|
|
formatted_messages = self._format_messages(messages) |
|
|
|
generation_config = { |
|
"temperature": self.config.temperature, |
|
} |
|
|
|
if self.config.max_tokens: |
|
generation_config["max_output_tokens"] = self.config.max_tokens |
|
|
|
try: |
|
response = await self.model.generate_content_async( |
|
formatted_messages, |
|
generation_config=generation_config, |
|
stream=True |
|
) |
|
|
|
async for chunk in response: |
|
if chunk.text: |
|
yield StreamChunk( |
|
type="text", |
|
content=chunk.text |
|
) |
|
|
|
except Exception as e: |
|
yield StreamChunk( |
|
type="error", |
|
content=f"Google API error: {str(e)}" |
|
) |
|
|
|
|
|
class LlmClient: |
|
""" |
|
Unified LLM client that manages multiple providers |
|
""" |
|
|
|
def __init__(self): |
|
self.clients = {} |
|
self._initialize_clients() |
|
|
|
def _initialize_clients(self): |
|
"""Initialize available LLM clients""" |
|
available_providers = config.get_available_llm_providers() |
|
|
|
for provider in available_providers: |
|
try: |
|
if provider in ["openai", "openai_compatible"]: |
|
self.clients[provider] = OpenAIClient(provider) |
|
elif provider == "anthropic": |
|
self.clients[provider] = AnthropicClient() |
|
elif provider == "google": |
|
self.clients[provider] = GoogleClient() |
|
|
|
print(f"β
Initialized {provider} client") |
|
except Exception as e: |
|
print(f"β Failed to initialize {provider} client: {e}") |
|
|
|
def update_provider_config(self, provider: str, api_key: str = None, **kwargs): |
|
"""Update configuration for a specific provider and reinitialize client""" |
|
|
|
|
|
if provider == "openai" and api_key: |
|
os.environ["OPENAI_API_KEY"] = api_key |
|
elif provider == "anthropic" and api_key: |
|
os.environ["ANTHROPIC_API_KEY"] = api_key |
|
elif provider == "google" and api_key: |
|
os.environ["GOOGLE_API_KEY"] = api_key |
|
elif provider == "openai_compatible": |
|
if api_key: |
|
os.environ["OPENAI_COMPATIBLE_API_KEY"] = api_key |
|
if kwargs.get("base_url"): |
|
os.environ["OPENAI_COMPATIBLE_BASE_URL"] = kwargs["base_url"] |
|
if kwargs.get("model"): |
|
os.environ["OPENAI_COMPATIBLE_MODEL"] = kwargs["model"] |
|
|
|
|
|
try: |
|
if provider in ["openai", "openai_compatible"]: |
|
self.clients[provider] = OpenAIClient(provider) |
|
elif provider == "anthropic": |
|
self.clients[provider] = AnthropicClient() |
|
elif provider == "google": |
|
self.clients[provider] = GoogleClient() |
|
|
|
print(f"β
Updated and reinitialized {provider} client") |
|
return True |
|
except Exception as e: |
|
print(f"β Failed to reinitialize {provider} client: {e}") |
|
return False |
|
|
|
def get_available_providers(self) -> List[LLMProvider]: |
|
"""Get list of available providers""" |
|
return list(self.clients.keys()) |
|
|
|
def get_client(self, provider: LLMProvider) -> BaseLLMClient: |
|
"""Get client for specific provider""" |
|
if provider not in self.clients: |
|
raise ValueError(f"Provider {provider} not available") |
|
return self.clients[provider] |
|
|
|
async def generate_stream( |
|
self, |
|
provider: LLMProvider, |
|
messages: List[Message] |
|
) -> AsyncGenerator[StreamChunk, None]: |
|
"""Generate streaming response using specified provider""" |
|
client = self.get_client(provider) |
|
async for chunk in client.generate_stream(messages): |
|
yield chunk |