Spaces:
Sleeping
Sleeping
File size: 1,912 Bytes
4fb4269 f49023b 4fb4269 d26c7f3 4fb4269 cc6bd3b 4fb4269 d26c7f3 cc6bd3b 4fb4269 d26c7f3 4fb4269 d26c7f3 4fb4269 d26c7f3 4fb4269 d26c7f3 4fb4269 d26c7f3 4fb4269 d26c7f3 4fb4269 d26c7f3 4fb4269 d26c7f3 4fb4269 d26c7f3 4fb4269 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from args import LLMInterface, Args, AgentPreset
class LLMFactory():
@classmethod
def create(cls, agent_preset: AgentPreset):
interface = agent_preset.get_interface()
if interface == LLMInterface.OPENAI:
model = cls._create_openai_model(agent_preset)
elif interface == LLMInterface.HUGGINGFACE:
model = cls._create_huggingface_model(agent_preset)
else:
raise ValueError(f"Interface '{interface}' is not supported !")
return model
@staticmethod
def _create_openai_model(agent_preset: AgentPreset):
model_name = agent_preset.get_model_name()
temperature = agent_preset.get_temperature()
max_tokens = agent_preset.get_max_tokens()
repeat_penalty = agent_preset.get_repeat_penalty()
kwargs = {
"model": model_name,
"base_url": Args.api_base,
"api_key": Args.api_key,
"temperature": temperature,
"max_completion_tokens": max_tokens,
# "presence_penalty": repeat_penalty,
"frequency_penalty": repeat_penalty
}
model = ChatOpenAI(**kwargs)
return model
@staticmethod
def _create_huggingface_model(agent_preset: AgentPreset):
model_name = agent_preset.get_model_name()
temperature = agent_preset.get_temperature()
max_tokens = agent_preset.get_max_tokens()
repeat_penalty = agent_preset.get_repeat_penalty()
kwargs = {
"model": model_name,
"temperature": temperature,
"max_new_tokens": max_tokens,
"repetition_penalty": repeat_penalty
}
llm = HuggingFaceEndpoint(**kwargs)
model = ChatHuggingFace(llm=llm)
return model
|