File size: 1,926 Bytes
4fb4269
 
 
f49023b
4fb4269
 
d26c7f3
 
 
4fb4269
 
 
 
 
cc6bd3b
4fb4269
d26c7f3
cc6bd3b
4fb4269
 
d26c7f3
 
4fb4269
 
 
 
 
d26c7f3
 
e4f6727
4fb4269
 
d26c7f3
4fb4269
 
 
d26c7f3
 
4fb4269
d26c7f3
4fb4269
d26c7f3
 
4fb4269
 
 
 
 
 
d26c7f3
e4f6727
4fb4269
 
 
 
d26c7f3
 
4fb4269
 
d26c7f3
4fb4269
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI

from args import LLMInterface, Args, AgentPreset


class LLMFactory():

    @classmethod
    def create(cls, agent_preset: AgentPreset):
        interface = agent_preset.get_interface()

        if interface == LLMInterface.OPENAI:
            model = cls._create_openai_model(agent_preset)
        elif interface == LLMInterface.HUGGINGFACE:
            model = cls._create_huggingface_model(agent_preset)
        else:
            raise ValueError(f"Interface '{interface}' is not supported !")
        
        return model

    @staticmethod
    def _create_openai_model(agent_preset: AgentPreset):
        model_name = agent_preset.get_model_name()
        temperature = agent_preset.get_temperature()
        max_tokens = agent_preset.get_max_tokens()
        repeat_penalty = agent_preset.get_repeat_penalty()

        kwargs = {
            "name": model_name,
            "model": model_name,
            "base_url": Args.api_base,
            "api_key": Args.api_key,
            "temperature": temperature,
            "max_completion_tokens": max_tokens,
            "frequency_penalty": repeat_penalty
        }

        model = ChatOpenAI(**kwargs)

        return model

    @staticmethod
    def _create_huggingface_model(agent_preset: AgentPreset):
        model_name = agent_preset.get_model_name()
        temperature = agent_preset.get_temperature()
        max_tokens = agent_preset.get_max_tokens()
        repeat_penalty = agent_preset.get_repeat_penalty()

        kwargs = {
            "name": model_name,
            "model": model_name,
            "temperature": temperature,
            "max_new_tokens": max_tokens,
            "repetition_penalty": repeat_penalty
        }

        llm = HuggingFaceEndpoint(**kwargs)
        model = ChatHuggingFace(llm=llm)

        return model