File size: 4,357 Bytes
372531f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# libraries
from __future__ import annotations

import json
import logging
from typing import Optional, Any, Dict

from colorama import Fore, Style
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate

from ..prompts import generate_subtopics_prompt
from .costs import estimate_llm_cost
from .validators import Subtopics


def get_llm(llm_provider, **kwargs):
    from gpt_researcher.llm_provider import GenericLLMProvider
    return GenericLLMProvider.from_provider(llm_provider, **kwargs)


async def create_chat_completion(

        messages: list,  # type: ignore

        model: Optional[str] = None,

        temperature: Optional[float] = 0.4,

        max_tokens: Optional[int] = 4000,

        llm_provider: Optional[str] = None,

        stream: Optional[bool] = False,

        websocket: Any | None = None,

        llm_kwargs: Dict[str, Any] | None = None,

        cost_callback: callable = None

) -> str:
    """Create a chat completion using the OpenAI API

    Args:

        messages (list[dict[str, str]]): The messages to send to the chat completion

        model (str, optional): The model to use. Defaults to None.

        temperature (float, optional): The temperature to use. Defaults to 0.4.

        max_tokens (int, optional): The max tokens to use. Defaults to 4000.

        stream (bool, optional): Whether to stream the response. Defaults to False.

        llm_provider (str, optional): The LLM Provider to use.

        webocket (WebSocket): The websocket used in the currect request,

        cost_callback: Callback function for updating cost

    Returns:

        str: The response from the chat completion

    """
    # validate input
    if model is None:
        raise ValueError("Model cannot be None")
    if max_tokens is not None and max_tokens > 16001:
        raise ValueError(
            f"Max tokens cannot be more than 16,000, but got {max_tokens}")

    # Get the provider from supported providers
    provider = get_llm(llm_provider, model=model, temperature=temperature,
                       max_tokens=max_tokens, **(llm_kwargs or {}))

    response = ""
    # create response
    for _ in range(10):  # maximum of 10 attempts
        response = await provider.get_chat_response(
            messages, stream, websocket
        )

        if cost_callback:
            llm_costs = estimate_llm_cost(str(messages), response)
            cost_callback(llm_costs)

        return response

    logging.error(f"Failed to get response from {llm_provider} API")
    raise RuntimeError(f"Failed to get response from {llm_provider} API")


async def construct_subtopics(task: str, data: str, config, subtopics: list = []) -> list:
    """

    Construct subtopics based on the given task and data.



    Args:

        task (str): The main task or topic.

        data (str): Additional data for context.

        config: Configuration settings.

        subtopics (list, optional): Existing subtopics. Defaults to [].



    Returns:

        list: A list of constructed subtopics.

    """
    try:
        parser = PydanticOutputParser(pydantic_object=Subtopics)

        prompt = PromptTemplate(
            template=generate_subtopics_prompt(),
            input_variables=["task", "data", "subtopics", "max_subtopics"],
            partial_variables={
                "format_instructions": parser.get_format_instructions()},
        )

        print(f"\n🤖 Calling {config.smart_llm_model}...\n")

        temperature = config.temperature
        # temperature = 0 # Note: temperature throughout the code base is currently set to Zero
        provider = get_llm(
            config.smart_llm_provider,
            model=config.smart_llm_model,
            temperature=temperature,
            max_tokens=config.smart_token_limit,
            **config.llm_kwargs,
        )
        model = provider.llm

        chain = prompt | model | parser

        output = chain.invoke({
            "task": task,
            "data": data,
            "subtopics": subtopics,
            "max_subtopics": config.max_subtopics
        })

        return output

    except Exception as e:
        print("Exception in parsing subtopics : ", e)
        return subtopics