File size: 4,256 Bytes
221b62f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9984001
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# a monkey patch to use llama-index completion
import os
from typing import Union, Callable
from functools import wraps
from src.translation_agent.utils import *


from llama_index.llms.groq import Groq
from llama_index.llms.cohere import Cohere
from llama_index.llms.openai import OpenAI
from llama_index.llms.together import TogetherLLM
from llama_index.llms.ollama import Ollama
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI

from llama_index.core import Settings
from llama_index.core.llms import ChatMessage


# Add your LLMs here

def model_load(

        endpoint: str,

        model: str,

        api_key: str = None,

        context_window: int = 4096,

        num_output: int = 512,

):
    if endpoint == "Groq":
        llm = Groq(
            model=model,
            api_key=api_key,
        )
    elif endpoint == "Cohere":
        llm = Cohere(
            model=model,
            api_key=api_key,
        )
    elif endpoint == "OpenAI":
        llm = OpenAI(
            model=model,
            api_key=api_key if api_key else os.getenv("OPENAI_API_KEY"),
        )
    elif endpoint == "TogetherAI":
        llm = TogetherLLM(
            model=model,
            api_key=api_key,
        )
    elif endpoint == "ollama":
        llm = Ollama(
            model=model,
            request_timeout=120.0)
    elif endpoint == "Huggingface":
        llm = HuggingFaceInferenceAPI(
            model_name=model,
            token=api_key,
            task="text-generation",
        )
    Settings.llm = llm
    # maximum input size to the LLM
    Settings.context_window = context_window

    # number of tokens reserved for text generation.
    Settings.num_output = num_output



def completion_wrapper(func: Callable) -> Callable:
    @wraps(func)
    def wrapper(

        prompt: str,

        system_message: str = "You are a helpful assistant.",

        temperature: float = 0.3,

        json_mode: bool = False,

    ) -> Union[str, dict]:
        """

            Generate a completion using the OpenAI API.



        Args:

            prompt (str): The user's prompt or query.

            system_message (str, optional): The system message to set the context for the assistant.

                Defaults to "You are a helpful assistant.".

            temperature (float, optional): The sampling temperature for controlling the randomness of the generated text.

                Defaults to 0.3.

            json_mode (bool, optional): Whether to return the response in JSON format.

                Defaults to False.



        Returns:

            Union[str, dict]: The generated completion.

                If json_mode is True, returns the complete API response as a dictionary.

                If json_mode is False, returns the generated text as a string.

        """
        llm = Settings.llm
        if llm.class_name() == "HuggingFaceInferenceAPI":
            llm.system_prompt = system_message
            messages = [
                ChatMessage(
                    role="user", content=prompt),
            ]
            response = llm.chat(
                messages=messages,
                temperature=temperature,
                top_p=1,
            )
            return response.message.content
        else:
            messages = [
                ChatMessage(
                    role="system", content=system_message),
                ChatMessage(
                    role="user", content=prompt),
            ]

            if json_mode:
                response = llm.chat(
                    temperature=temperature,
                    top_p=1,
                    response_format={"type": "json_object"},
                    messages=messages,
                )
                return response.message.content
            else:
                response = llm.chat(
                    temperature=temperature,
                    top_p=1,
                    messages=messages,
                )
                return response.message.content

    return wrapper

openai_completion = get_completion
get_completion = completion_wrapper(openai_completion)