web-ui / src /utils /llm_provider.py
Unknown504's picture
Upload folder using huggingface_hub
19ea30a verified
from openai import OpenAI
import pdb
from langchain_openai import ChatOpenAI
from langchain_core.globals import get_llm_cache
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
)
import os
from langchain_core.load import dumpd, dumps
from langchain_core.messages import (
AIMessage,
SystemMessage,
AnyMessage,
BaseMessage,
BaseMessageChunk,
HumanMessage,
convert_to_messages,
message_chunk_to_message,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
LLMResult,
RunInfo,
)
from langchain_ollama import ChatOllama
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import BaseTool
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
Union,
cast, List,
)
from langchain_anthropic import ChatAnthropic
from langchain_mistralai import ChatMistralAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain_ibm import ChatWatsonx
from langchain_aws import ChatBedrock
from pydantic import SecretStr
from src.utils import config
class DeepSeekR1ChatOpenAI(ChatOpenAI):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.client = OpenAI(
base_url=kwargs.get("base_url"),
api_key=kwargs.get("api_key")
)
async def ainvoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AIMessage:
message_history = []
for input_ in input:
if isinstance(input_, SystemMessage):
message_history.append({"role": "system", "content": input_.content})
elif isinstance(input_, AIMessage):
message_history.append({"role": "assistant", "content": input_.content})
else:
message_history.append({"role": "user", "content": input_.content})
response = self.client.chat.completions.create(
model=self.model_name,
messages=message_history
)
reasoning_content = response.choices[0].message.reasoning_content
content = response.choices[0].message.content
return AIMessage(content=content, reasoning_content=reasoning_content)
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AIMessage:
message_history = []
for input_ in input:
if isinstance(input_, SystemMessage):
message_history.append({"role": "system", "content": input_.content})
elif isinstance(input_, AIMessage):
message_history.append({"role": "assistant", "content": input_.content})
else:
message_history.append({"role": "user", "content": input_.content})
response = self.client.chat.completions.create(
model=self.model_name,
messages=message_history
)
reasoning_content = response.choices[0].message.reasoning_content
content = response.choices[0].message.content
return AIMessage(content=content, reasoning_content=reasoning_content)
class DeepSeekR1ChatOllama(ChatOllama):
async def ainvoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AIMessage:
org_ai_message = await super().ainvoke(input=input)
org_content = org_ai_message.content
reasoning_content = org_content.split("</think>")[0].replace("<think>", "")
content = org_content.split("</think>")[1]
if "**JSON Response:**" in content:
content = content.split("**JSON Response:**")[-1]
return AIMessage(content=content, reasoning_content=reasoning_content)
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AIMessage:
org_ai_message = super().invoke(input=input)
org_content = org_ai_message.content
reasoning_content = org_content.split("</think>")[0].replace("<think>", "")
content = org_content.split("</think>")[1]
if "**JSON Response:**" in content:
content = content.split("**JSON Response:**")[-1]
return AIMessage(content=content, reasoning_content=reasoning_content)
def get_llm_model(provider: str, **kwargs):
"""
Get LLM model
:param provider: LLM provider
:param kwargs:
:return:
"""
if provider not in ["ollama", "bedrock"]:
env_var = f"{provider.upper()}_API_KEY"
api_key = kwargs.get("api_key", "") or os.getenv(env_var, "")
if not api_key:
provider_display = config.PROVIDER_DISPLAY_NAMES.get(provider, provider.upper())
error_msg = f"💥 {provider_display} API key not found! 🔑 Please set the `{env_var}` environment variable or provide it in the UI."
raise ValueError(error_msg)
kwargs["api_key"] = api_key
if provider == "anthropic":
if not kwargs.get("base_url", ""):
base_url = "https://api.anthropic.com"
else:
base_url = kwargs.get("base_url")
return ChatAnthropic(
model=kwargs.get("model_name", "claude-3-5-sonnet-20241022"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == 'mistral':
if not kwargs.get("base_url", ""):
base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1")
else:
base_url = kwargs.get("base_url")
if not kwargs.get("api_key", ""):
api_key = os.getenv("MISTRAL_API_KEY", "")
else:
api_key = kwargs.get("api_key")
return ChatMistralAI(
model=kwargs.get("model_name", "mistral-large-latest"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == "openai":
if not kwargs.get("base_url", ""):
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
else:
base_url = kwargs.get("base_url")
return ChatOpenAI(
model=kwargs.get("model_name", "gpt-4o"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == "grok":
if not kwargs.get("base_url", ""):
base_url = os.getenv("GROK_ENDPOINT", "https://api.x.ai/v1")
else:
base_url = kwargs.get("base_url")
return ChatOpenAI(
model=kwargs.get("model_name", "grok-3"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == "deepseek":
if not kwargs.get("base_url", ""):
base_url = os.getenv("DEEPSEEK_ENDPOINT", "")
else:
base_url = kwargs.get("base_url")
if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner":
return DeepSeekR1ChatOpenAI(
model=kwargs.get("model_name", "deepseek-reasoner"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
else:
return ChatOpenAI(
model=kwargs.get("model_name", "deepseek-chat"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == "google":
return ChatGoogleGenerativeAI(
model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
temperature=kwargs.get("temperature", 0.0),
api_key=api_key,
)
elif provider == "ollama":
if not kwargs.get("base_url", ""):
base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
else:
base_url = kwargs.get("base_url")
if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"):
return DeepSeekR1ChatOllama(
model=kwargs.get("model_name", "deepseek-r1:14b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=kwargs.get("num_ctx", 32000),
base_url=base_url,
)
else:
return ChatOllama(
model=kwargs.get("model_name", "qwen2.5:7b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=kwargs.get("num_ctx", 32000),
num_predict=kwargs.get("num_predict", 1024),
base_url=base_url,
)
elif provider == "azure_openai":
if not kwargs.get("base_url", ""):
base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
else:
base_url = kwargs.get("base_url")
api_version = kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview")
return AzureChatOpenAI(
model=kwargs.get("model_name", "gpt-4o"),
temperature=kwargs.get("temperature", 0.0),
api_version=api_version,
azure_endpoint=base_url,
api_key=api_key,
)
elif provider == "alibaba":
if not kwargs.get("base_url", ""):
base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1")
else:
base_url = kwargs.get("base_url")
return ChatOpenAI(
model=kwargs.get("model_name", "qwen-plus"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == "ibm":
parameters = {
"temperature": kwargs.get("temperature", 0.0),
"max_tokens": kwargs.get("num_ctx", 32000)
}
if not kwargs.get("base_url", ""):
base_url = os.getenv("IBM_ENDPOINT", "https://us-south.ml.cloud.ibm.com")
else:
base_url = kwargs.get("base_url")
return ChatWatsonx(
model_id=kwargs.get("model_name", "ibm/granite-vision-3.1-2b-preview"),
url=base_url,
project_id=os.getenv("IBM_PROJECT_ID"),
apikey=os.getenv("IBM_API_KEY"),
params=parameters
)
elif provider == "moonshot":
return ChatOpenAI(
model=kwargs.get("model_name", "moonshot-v1-32k-vision-preview"),
temperature=kwargs.get("temperature", 0.0),
base_url=os.getenv("MOONSHOT_ENDPOINT"),
api_key=os.getenv("MOONSHOT_API_KEY"),
)
elif provider == "unbound":
return ChatOpenAI(
model=kwargs.get("model_name", "gpt-4o-mini"),
temperature=kwargs.get("temperature", 0.0),
base_url=os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai"),
api_key=api_key,
)
elif provider == "siliconflow":
if not kwargs.get("api_key", ""):
api_key = os.getenv("SiliconFLOW_API_KEY", "")
else:
api_key = kwargs.get("api_key")
if not kwargs.get("base_url", ""):
base_url = os.getenv("SiliconFLOW_ENDPOINT", "")
else:
base_url = kwargs.get("base_url")
return ChatOpenAI(
api_key=api_key,
base_url=base_url,
model_name=kwargs.get("model_name", "Qwen/QwQ-32B"),
temperature=kwargs.get("temperature", 0.0),
)
elif provider == "modelscope":
if not kwargs.get("api_key", ""):
api_key = os.getenv("MODELSCOPE_API_KEY", "")
else:
api_key = kwargs.get("api_key")
if not kwargs.get("base_url", ""):
base_url = os.getenv("MODELSCOPE_ENDPOINT", "")
else:
base_url = kwargs.get("base_url")
return ChatOpenAI(
api_key=api_key,
base_url=base_url,
model_name=kwargs.get("model_name", "Qwen/QwQ-32B"),
temperature=kwargs.get("temperature", 0.0),
)
else:
raise ValueError(f"Unsupported provider: {provider}")