File size: 6,833 Bytes
27bfa71
 
 
 
 
2709e97
 
 
 
 
 
 
 
27bfa71
 
 
 
2709e97
27bfa71
 
 
2709e97
 
27bfa71
2709e97
 
 
 
 
 
 
 
 
 
27bfa71
2709e97
27bfa71
 
 
 
 
 
 
 
 
 
2709e97
 
 
 
27bfa71
 
 
 
 
 
 
 
2709e97
 
 
 
 
 
27bfa71
 
 
 
 
 
 
 
 
 
 
2709e97
27bfa71
 
 
 
 
 
 
 
 
 
 
 
 
2709e97
27bfa71
 
 
 
2709e97
27bfa71
 
 
 
 
 
 
 
 
 
 
 
 
 
2709e97
27bfa71
 
 
 
 
 
 
 
 
 
2709e97
 
 
 
 
 
27bfa71
2709e97
 
27bfa71
2709e97
 
 
 
 
 
 
 
27bfa71
2709e97
 
27bfa71
2709e97
27bfa71
2709e97
 
27bfa71
 
2709e97
 
 
 
 
27bfa71
 
2709e97
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import asyncio
from typing import AsyncGenerator, List, Dict, Tuple
from config import logger
from api import ask_openai, ask_anthropic, ask_gemini

async def query_model(
    query: str,
    providers: List[str],
    history: List[Dict[str, str]]
) -> AsyncGenerator[
    Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]],
    None
]:
    logger.info(f"Processing query with providers: {providers}")
    openai_response = ""
    anthropic_response = ""
    gemini_response = ""

    openai_messages = []
    anthropic_messages = []
    gemini_messages = []

    # Build message history for each provider
    for msg in history:
        if "user" in msg:
            openai_messages.append({"role": "user", "content": msg["user"]})
            anthropic_messages.append({"role": "user", "content": msg["user"]})
            gemini_messages.append({"role": "user", "content": msg["user"]})
        if msg.get("openai"):
            openai_messages.append({"role": "assistant", "content": msg["openai"]})
        if msg.get("anthropic"):
            anthropic_messages.append({"role": "assistant", "content": msg["anthropic"]})
        if msg.get("gemini"):
            gemini_messages.append({"role": "assistant", "content": msg["gemini"]})

    # Append the user query and prepare for assistant response
    if "OpenAI" in providers:
        openai_messages.append({"role": "user", "content": query})
        openai_messages.append({"role": "assistant", "content": ""})
    if "Anthropic" in providers:
        anthropic_messages.append({"role": "user", "content": query})
        anthropic_messages.append({"role": "assistant", "content": ""})
    if "Gemini" in providers:
        gemini_messages.append({"role": "user", "content": query})
        gemini_messages.append({"role": "assistant", "content": ""})

    # Yield initial state with user query
    logger.info(f"Yielding initial state with user query: {query}")
    yield "", openai_messages, anthropic_messages, gemini_messages, history

    tasks = []
    if "OpenAI" in providers:
        tasks.append(("OpenAI", ask_openai(query, history), openai_response, openai_messages))
    if "Anthropic" in providers:
        tasks.append(("Anthropic", ask_anthropic(query, history), anthropic_response, anthropic_messages))
    if "Gemini" in providers:
        tasks.append(("Gemini", ask_gemini(query, history), gemini_response, gemini_messages))

    async def collect_chunks(
        provider: str,
        generator: AsyncGenerator[str, None],
        response: str,
        messages: List[Dict[str, str]]
    ) -> AsyncGenerator[Tuple[str, str, List[Dict[str, str]]], None]:
        async for chunk in generator:
            response += chunk
            messages[-1] = {"role": "assistant", "content": response}
            yield provider, response, messages

    generator_states = [(provider, collect_chunks(provider, gen, resp, msgs), None) for provider, gen, resp, msgs in tasks]
    active_generators = generator_states[:]

    while active_generators:
        tasks_to_wait = []
        new_generator_states = []

        for provider, gen, active_task in active_generators:
            if active_task is None or active_task.done():
                try:
                    task = asyncio.create_task(gen.__anext__())
                    new_generator_states.append((provider, gen, task))
                    tasks_to_wait.append(task)
                    logger.debug(f"Created task for {provider}")
                except StopAsyncIteration:
                    logger.info(f"Generator for {provider} completed")
                    continue
            else:
                new_generator_states.append((provider, gen, active_task))
                tasks_to_wait.append(active_task)

        if not tasks_to_wait:
            break

        done, _ = await asyncio.wait(tasks_to_wait, return_when=asyncio.FIRST_COMPLETED)

        for provider, gen, task in new_generator_states:
            if task in done:
                try:
                    provider, response, messages = task.result()
                    if provider == "OpenAI":
                        openai_response = response
                        openai_messages = messages
                    elif provider == "Anthropic":
                        anthropic_response = response
                        anthropic_messages = messages
                    elif provider == "Gemini":
                        gemini_response = response
                        gemini_messages = messages
                    logger.info(f"Yielding update for {provider}: {response[:50]}...")
                    yield "", openai_messages, anthropic_messages, gemini_messages, history
                    new_generator_states[new_generator_states.index((provider, gen, task))] = (provider, gen, None)
                except StopAsyncIteration:
                    logger.info(f"Generator for {provider} completed")
                    new_generator_states.remove((provider, gen, task))
            else:
                if (provider, gen, task) not in new_generator_states:
                    new_generator_states.append((provider, gen, task))

        active_generators = new_generator_states

    updated_history = history + [{
        "user": query,
        "openai": openai_response.strip() if openai_response else "",
        "anthropic": anthropic_response.strip() if anthropic_response else "",
        "gemini": gemini_response.strip() if gemini_response else ""
    }]

    logger.info(f"Updated history: {updated_history}")
    yield "", openai_messages, anthropic_messages, gemini_messages, updated_history

async def submit_query(
    query: str,
    providers: List[str],
    history: List[Dict[str, str]]
) -> AsyncGenerator[
    Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]],
    None
]:
    if not query.strip():
        msg = {"role": "assistant", "content": "Please enter a query."}
        yield "", [msg], [msg], [msg], history
        return

    if not providers:
        msg = {"role": "assistant", "content": "Please select at least one provider."}
        yield "", [msg], [msg], [msg], history
        return

    async for _, openai_msgs, anthropic_msgs, gemini_msgs, updated_history in query_model(query, providers, history):
        logger.info(f"Submitting update to UI: OpenAI: {openai_msgs[-1]['content'][:50] if openai_msgs else ''}, "
                    f"Anthropic: {anthropic_msgs[-1]['content'][:50] if anthropic_msgs else ''}, "
                    f"Gemini: {gemini_msgs[-1]['content'][:50] if gemini_msgs else ''}")
        yield "", openai_msgs, anthropic_msgs, gemini_msgs, updated_history

def clear_history():
    logger.info("Clearing history")
    return [], [], [], []