Spaces:
Sleeping
Sleeping
import os, asyncio, logging | |
import configparser | |
import logging | |
from dotenv import load_dotenv | |
# LangChain imports | |
from langchain_openai import ChatOpenAI | |
from langchain_anthropic import ChatAnthropic | |
from langchain_cohere import ChatCohere | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
from langchain_core.messages import SystemMessage, HumanMessage | |
# Local .env file | |
load_dotenv() | |
def getconfig(configfile_path: str): | |
""" | |
Read the config file | |
Params | |
---------------- | |
configfile_path: file path of .cfg file | |
""" | |
config = configparser.ConfigParser() | |
try: | |
config.read_file(open(configfile_path)) | |
return config | |
except: | |
logging.warning("config file not found") | |
# --------------------------------------------------------------------- | |
# Provider-agnostic authentication and configuration | |
# --------------------------------------------------------------------- | |
def get_auth_config(provider: str) -> dict: | |
"""Get authentication configuration for different providers""" | |
auth_configs = { | |
"openai": {"api_key": os.getenv("OPENAI_API_KEY")}, | |
"huggingface": {"api_key": os.getenv("HF_TOKEN")}, | |
"anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")}, | |
"cohere": {"api_key": os.getenv("COHERE_API_KEY")}, | |
} | |
if provider not in auth_configs: | |
raise ValueError(f"Unsupported provider: {provider}") | |
auth_config = auth_configs[provider] | |
api_key = auth_config.get("api_key") | |
if not api_key: | |
raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.") | |
return auth_config | |
# --------------------------------------------------------------------- | |
# Model / client initialization | |
# --------------------------------------------------------------------- | |
config = getconfig("params.cfg") | |
PROVIDER = config.get("generator", "PROVIDER") | |
MODEL = config.get("generator", "MODEL") | |
MAX_TOKENS = int(config.get("generator", "MAX_TOKENS")) | |
TEMPERATURE = float(config.get("generator", "TEMPERATURE")) | |
# Set up authentication for the selected provider | |
auth_config = get_auth_config(PROVIDER) | |
def get_chat_model(): | |
"""Initialize the appropriate LangChain chat model based on provider""" | |
common_params = { | |
"temperature": TEMPERATURE, | |
"max_tokens": MAX_TOKENS, | |
} | |
if PROVIDER == "openai": | |
return ChatOpenAI( | |
model=MODEL, | |
openai_api_key=auth_config["api_key"], | |
**common_params | |
) | |
elif PROVIDER == "anthropic": | |
return ChatAnthropic( | |
model=MODEL, | |
anthropic_api_key=auth_config["api_key"], | |
**common_params | |
) | |
elif PROVIDER == "cohere": | |
return ChatCohere( | |
model=MODEL, | |
cohere_api_key=auth_config["api_key"], | |
**common_params | |
) | |
elif PROVIDER == "huggingface": | |
# Initialize HuggingFaceEndpoint with explicit parameters | |
llm = HuggingFaceEndpoint( | |
repo_id=MODEL, | |
huggingfacehub_api_token=auth_config["api_key"], | |
task="text-generation", | |
temperature=TEMPERATURE, | |
max_new_tokens=MAX_TOKENS | |
) | |
return ChatHuggingFace(llm=llm) | |
else: | |
raise ValueError(f"Unsupported provider: {PROVIDER}") | |
# Initialize provider-agnostic chat model | |
chat_model = get_chat_model() | |
# --------------------------------------------------------------------- | |
# Core generation function for both Gradio UI and MCP | |
# --------------------------------------------------------------------- | |
async def _call_llm(messages: list) -> str: | |
""" | |
Provider-agnostic LLM call using LangChain. | |
Args: | |
messages: List of LangChain message objects | |
Returns: | |
Generated response content as string | |
""" | |
try: | |
# Use async invoke for better performance | |
response = await chat_model.ainvoke(messages) | |
return response.content.strip() | |
except Exception as e: | |
logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}") | |
raise | |
def build_messages(question: str, context: str) -> list: | |
""" | |
Build messages in LangChain format. | |
Args: | |
question: The user's question | |
context: The relevant context for answering | |
Returns: | |
List of LangChain message objects | |
""" | |
system_content = ( | |
"You are an expert assistant. Answer the USER question using only the " | |
"CONTEXT provided. If the context is insufficient say 'I don't know.'" | |
) | |
user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}" | |
return [ | |
SystemMessage(content=system_content), | |
HumanMessage(content=user_content) | |
] | |
async def rag_generate(query: str, context: str) -> str: | |
""" | |
Generate an answer to a query using provided context through RAG. | |
This function takes a user query and relevant context, then uses a language model | |
to generate a comprehensive answer based on the provided information. | |
Args: | |
query (str): The user's question or query | |
context (str): The relevant context/documents to use for answering | |
Returns: | |
str: The generated answer based on the query and context | |
""" | |
if not query.strip(): | |
return "Error: Query cannot be empty" | |
if not context.strip(): | |
return "Error: Context cannot be empty" | |
try: | |
messages = build_messages(query, context) | |
answer = await _call_llm(messages) | |
return answer | |
except Exception as e: | |
logging.exception("Generation failed") | |
return f"Error: {str(e)}" |