mcp_mod_test / app /utils.py
mtyrrell's picture
cleanup and harmonization
000787f
import os, asyncio, logging
import configparser
import logging
from dotenv import load_dotenv
# LangChain imports
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_cohere import ChatCohere
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_core.messages import SystemMessage, HumanMessage
# Local .env file
load_dotenv()
def getconfig(configfile_path: str):
"""
Read the config file
Params
----------------
configfile_path: file path of .cfg file
"""
config = configparser.ConfigParser()
try:
config.read_file(open(configfile_path))
return config
except:
logging.warning("config file not found")
# ---------------------------------------------------------------------
# Provider-agnostic authentication and configuration
# ---------------------------------------------------------------------
def get_auth_config(provider: str) -> dict:
"""Get authentication configuration for different providers"""
auth_configs = {
"openai": {"api_key": os.getenv("OPENAI_API_KEY")},
"huggingface": {"api_key": os.getenv("HF_TOKEN")},
"anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")},
"cohere": {"api_key": os.getenv("COHERE_API_KEY")},
}
if provider not in auth_configs:
raise ValueError(f"Unsupported provider: {provider}")
auth_config = auth_configs[provider]
api_key = auth_config.get("api_key")
if not api_key:
raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
return auth_config
# ---------------------------------------------------------------------
# Model / client initialization
# ---------------------------------------------------------------------
config = getconfig("params.cfg")
PROVIDER = config.get("generator", "PROVIDER")
MODEL = config.get("generator", "MODEL")
MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
# Set up authentication for the selected provider
auth_config = get_auth_config(PROVIDER)
def get_chat_model():
"""Initialize the appropriate LangChain chat model based on provider"""
common_params = {
"temperature": TEMPERATURE,
"max_tokens": MAX_TOKENS,
}
if PROVIDER == "openai":
return ChatOpenAI(
model=MODEL,
openai_api_key=auth_config["api_key"],
**common_params
)
elif PROVIDER == "anthropic":
return ChatAnthropic(
model=MODEL,
anthropic_api_key=auth_config["api_key"],
**common_params
)
elif PROVIDER == "cohere":
return ChatCohere(
model=MODEL,
cohere_api_key=auth_config["api_key"],
**common_params
)
elif PROVIDER == "huggingface":
# Initialize HuggingFaceEndpoint with explicit parameters
llm = HuggingFaceEndpoint(
repo_id=MODEL,
huggingfacehub_api_token=auth_config["api_key"],
task="text-generation",
temperature=TEMPERATURE,
max_new_tokens=MAX_TOKENS
)
return ChatHuggingFace(llm=llm)
else:
raise ValueError(f"Unsupported provider: {PROVIDER}")
# Initialize provider-agnostic chat model
chat_model = get_chat_model()
# ---------------------------------------------------------------------
# Core generation function for both Gradio UI and MCP
# ---------------------------------------------------------------------
async def _call_llm(messages: list) -> str:
"""
Provider-agnostic LLM call using LangChain.
Args:
messages: List of LangChain message objects
Returns:
Generated response content as string
"""
try:
# Use async invoke for better performance
response = await chat_model.ainvoke(messages)
return response.content.strip()
except Exception as e:
logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
raise
def build_messages(question: str, context: str) -> list:
"""
Build messages in LangChain format.
Args:
question: The user's question
context: The relevant context for answering
Returns:
List of LangChain message objects
"""
system_content = (
"You are an expert assistant. Answer the USER question using only the "
"CONTEXT provided. If the context is insufficient say 'I don't know.'"
)
user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
return [
SystemMessage(content=system_content),
HumanMessage(content=user_content)
]
async def rag_generate(query: str, context: str) -> str:
"""
Generate an answer to a query using provided context through RAG.
This function takes a user query and relevant context, then uses a language model
to generate a comprehensive answer based on the provided information.
Args:
query (str): The user's question or query
context (str): The relevant context/documents to use for answering
Returns:
str: The generated answer based on the query and context
"""
if not query.strip():
return "Error: Query cannot be empty"
if not context.strip():
return "Error: Context cannot be empty"
try:
messages = build_messages(query, context)
answer = await _call_llm(messages)
return answer
except Exception as e:
logging.exception("Generation failed")
return f"Error: {str(e)}"