|
|
|
|
|
import os
|
|
|
from typing import Union, Callable
|
|
|
from functools import wraps
|
|
|
from src.translation_agent.utils import *
|
|
|
|
|
|
|
|
|
from llama_index.llms.groq import Groq
|
|
|
from llama_index.llms.cohere import Cohere
|
|
|
from llama_index.llms.openai import OpenAI
|
|
|
from llama_index.llms.together import TogetherLLM
|
|
|
from llama_index.llms.ollama import Ollama
|
|
|
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
|
|
|
|
|
|
from llama_index.core import Settings
|
|
|
from llama_index.core.llms import ChatMessage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def model_load(
|
|
|
endpoint: str,
|
|
|
model: str,
|
|
|
api_key: str = None,
|
|
|
context_window: int = 4096,
|
|
|
num_output: int = 512,
|
|
|
):
|
|
|
if endpoint == "Groq":
|
|
|
llm = Groq(
|
|
|
model=model,
|
|
|
api_key=api_key,
|
|
|
)
|
|
|
elif endpoint == "Cohere":
|
|
|
llm = Cohere(
|
|
|
model=model,
|
|
|
api_key=api_key,
|
|
|
)
|
|
|
elif endpoint == "OpenAI":
|
|
|
llm = OpenAI(
|
|
|
model=model,
|
|
|
api_key=api_key if api_key else os.getenv("OPENAI_API_KEY"),
|
|
|
)
|
|
|
elif endpoint == "TogetherAI":
|
|
|
llm = TogetherLLM(
|
|
|
model=model,
|
|
|
api_key=api_key,
|
|
|
)
|
|
|
elif endpoint == "ollama":
|
|
|
llm = Ollama(
|
|
|
model=model,
|
|
|
request_timeout=120.0)
|
|
|
elif endpoint == "Huggingface":
|
|
|
llm = HuggingFaceInferenceAPI(
|
|
|
model_name=model,
|
|
|
token=api_key,
|
|
|
task="text-generation",
|
|
|
)
|
|
|
Settings.llm = llm
|
|
|
|
|
|
Settings.context_window = context_window
|
|
|
|
|
|
|
|
|
Settings.num_output = num_output
|
|
|
|
|
|
|
|
|
|
|
|
def completion_wrapper(func: Callable) -> Callable:
|
|
|
@wraps(func)
|
|
|
def wrapper(
|
|
|
prompt: str,
|
|
|
system_message: str = "You are a helpful assistant.",
|
|
|
temperature: float = 0.3,
|
|
|
json_mode: bool = False,
|
|
|
) -> Union[str, dict]:
|
|
|
"""
|
|
|
Generate a completion using the OpenAI API.
|
|
|
|
|
|
Args:
|
|
|
prompt (str): The user's prompt or query.
|
|
|
system_message (str, optional): The system message to set the context for the assistant.
|
|
|
Defaults to "You are a helpful assistant.".
|
|
|
temperature (float, optional): The sampling temperature for controlling the randomness of the generated text.
|
|
|
Defaults to 0.3.
|
|
|
json_mode (bool, optional): Whether to return the response in JSON format.
|
|
|
Defaults to False.
|
|
|
|
|
|
Returns:
|
|
|
Union[str, dict]: The generated completion.
|
|
|
If json_mode is True, returns the complete API response as a dictionary.
|
|
|
If json_mode is False, returns the generated text as a string.
|
|
|
"""
|
|
|
llm = Settings.llm
|
|
|
if llm.class_name() == "HuggingFaceInferenceAPI":
|
|
|
llm.system_prompt = system_message
|
|
|
messages = [
|
|
|
ChatMessage(
|
|
|
role="user", content=prompt),
|
|
|
]
|
|
|
response = llm.chat(
|
|
|
messages=messages,
|
|
|
temperature=temperature,
|
|
|
top_p=1,
|
|
|
)
|
|
|
return response.message.content
|
|
|
else:
|
|
|
messages = [
|
|
|
ChatMessage(
|
|
|
role="system", content=system_message),
|
|
|
ChatMessage(
|
|
|
role="user", content=prompt),
|
|
|
]
|
|
|
|
|
|
if json_mode:
|
|
|
response = llm.chat(
|
|
|
temperature=temperature,
|
|
|
top_p=1,
|
|
|
response_format={"type": "json_object"},
|
|
|
messages=messages,
|
|
|
)
|
|
|
return response.message.content
|
|
|
else:
|
|
|
response = llm.chat(
|
|
|
temperature=temperature,
|
|
|
top_p=1,
|
|
|
messages=messages,
|
|
|
)
|
|
|
return response.message.content
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
openai_completion = get_completion
|
|
|
get_completion = completion_wrapper(openai_completion) |