Final_Assignment_Template / llm_factory.py
24Arys11's picture
bugfixing; fixed toolbox; isolated [Base|AI|Human]Message crap logic to the agent interface; implemented tests
e4f6727
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from args import LLMInterface, Args, AgentPreset
class LLMFactory():
@classmethod
def create(cls, agent_preset: AgentPreset):
interface = agent_preset.get_interface()
if interface == LLMInterface.OPENAI:
model = cls._create_openai_model(agent_preset)
elif interface == LLMInterface.HUGGINGFACE:
model = cls._create_huggingface_model(agent_preset)
else:
raise ValueError(f"Interface '{interface}' is not supported !")
return model
@staticmethod
def _create_openai_model(agent_preset: AgentPreset):
model_name = agent_preset.get_model_name()
temperature = agent_preset.get_temperature()
max_tokens = agent_preset.get_max_tokens()
repeat_penalty = agent_preset.get_repeat_penalty()
kwargs = {
"name": model_name,
"model": model_name,
"base_url": Args.api_base,
"api_key": Args.api_key,
"temperature": temperature,
"max_completion_tokens": max_tokens,
"frequency_penalty": repeat_penalty
}
model = ChatOpenAI(**kwargs)
return model
@staticmethod
def _create_huggingface_model(agent_preset: AgentPreset):
model_name = agent_preset.get_model_name()
temperature = agent_preset.get_temperature()
max_tokens = agent_preset.get_max_tokens()
repeat_penalty = agent_preset.get_repeat_penalty()
kwargs = {
"name": model_name,
"model": model_name,
"temperature": temperature,
"max_new_tokens": max_tokens,
"repetition_penalty": repeat_penalty
}
llm = HuggingFaceEndpoint(**kwargs)
model = ChatHuggingFace(llm=llm)
return model