File size: 3,990 Bytes
ad33df7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from unittest.mock import patch

from openai.types.chat.chat_completion import ChatCompletion

from kotaemon.llms import AzureChatOpenAI
from kotaemon.llms.cot import ManualSequentialChainOfThought, Thought

_openai_chat_completion_response = [
    ChatCompletion.parse_obj(
        {
            "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
            "object": "chat.completion",
            "created": 1692338378,
            "model": "gpt-35-turbo",
            "system_fingerprint": None,
            "choices": [
                {
                    "index": 0,
                    "finish_reason": "stop",
                    "message": {
                        "role": "assistant",
                        "content": text,
                        "function_call": None,
                        "tool_calls": None,
                    },
                    "logprobs": None,
                }
            ],
            "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
        }
    )
    for text in ["Bonjour", "こんにけは (Konnichiwa)"]
]


@patch(
    "openai.resources.chat.completions.Completions.create",
    side_effect=_openai_chat_completion_response,
)
def test_cot_plus_operator(openai_completion):
    llm = AzureChatOpenAI(
        api_key="dummy",
        api_version="2024-05-01-preview",
        azure_deployment="gpt-4o",
        azure_endpoint="https://test.openai.azure.com/",
    )
    thought1 = Thought(
        prompt="Word {word} in {language} is ",
        llm=llm,
        post_process=lambda string: {"translated": string},
    )
    thought2 = Thought(
        prompt="Translate {translated} to Japanese",
        llm=llm,
        post_process=lambda string: {"output": string},
    )
    thought = thought1 + thought2
    output = thought(word="hello", language="French")
    assert output.content == {
        "word": "hello",
        "language": "French",
        "translated": "Bonjour",
        "output": "こんにけは (Konnichiwa)",
    }


@patch(
    "openai.resources.chat.completions.Completions.create",
    side_effect=_openai_chat_completion_response,
)
def test_cot_manual(openai_completion):
    llm = AzureChatOpenAI(
        api_key="dummy",
        api_version="2024-05-01-preview",
        azure_deployment="gpt-4o",
        azure_endpoint="https://test.openai.azure.com/",
    )
    thought1 = Thought(
        prompt="Word {word} in {language} is ",
        post_process=lambda string: {"translated": string},
    )
    thought2 = Thought(
        prompt="Translate {translated} to Japanese",
        post_process=lambda string: {"output": string},
    )
    thought = ManualSequentialChainOfThought(thoughts=[thought1, thought2], llm=llm)
    output = thought(word="hello", language="French")
    assert output.content == {
        "word": "hello",
        "language": "French",
        "translated": "Bonjour",
        "output": "こんにけは (Konnichiwa)",
    }


@patch(
    "openai.resources.chat.completions.Completions.create",
    side_effect=_openai_chat_completion_response,
)
def test_cot_with_termination_callback(openai_completion):
    llm = AzureChatOpenAI(
        api_key="dummy",
        api_version="2024-05-01-preview",
        azure_deployment="gpt-4o",
        azure_endpoint="https://test.openai.azure.com/",
    )
    thought1 = Thought(
        prompt="Word {word} in {language} is ",
        post_process=lambda string: {"translated": string},
    )
    thought2 = Thought(
        prompt="Translate {translated} to Japanese",
        post_process=lambda string: {"output": string},
    )
    thought = ManualSequentialChainOfThought(
        thoughts=[thought1, thought2],
        llm=llm,
        terminate=lambda d: True if d.get("translated", "") == "Bonjour" else False,
    )
    output = thought(word="hallo", language="French")
    assert output.content == {
        "word": "hallo",
        "language": "French",
        "translated": "Bonjour",
    }