Spaces:
Sleeping
Sleeping
File size: 6,833 Bytes
27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 27bfa71 2709e97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import asyncio
from typing import AsyncGenerator, List, Dict, Tuple
from config import logger
from api import ask_openai, ask_anthropic, ask_gemini
async def query_model(
query: str,
providers: List[str],
history: List[Dict[str, str]]
) -> AsyncGenerator[
Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]],
None
]:
logger.info(f"Processing query with providers: {providers}")
openai_response = ""
anthropic_response = ""
gemini_response = ""
openai_messages = []
anthropic_messages = []
gemini_messages = []
# Build message history for each provider
for msg in history:
if "user" in msg:
openai_messages.append({"role": "user", "content": msg["user"]})
anthropic_messages.append({"role": "user", "content": msg["user"]})
gemini_messages.append({"role": "user", "content": msg["user"]})
if msg.get("openai"):
openai_messages.append({"role": "assistant", "content": msg["openai"]})
if msg.get("anthropic"):
anthropic_messages.append({"role": "assistant", "content": msg["anthropic"]})
if msg.get("gemini"):
gemini_messages.append({"role": "assistant", "content": msg["gemini"]})
# Append the user query and prepare for assistant response
if "OpenAI" in providers:
openai_messages.append({"role": "user", "content": query})
openai_messages.append({"role": "assistant", "content": ""})
if "Anthropic" in providers:
anthropic_messages.append({"role": "user", "content": query})
anthropic_messages.append({"role": "assistant", "content": ""})
if "Gemini" in providers:
gemini_messages.append({"role": "user", "content": query})
gemini_messages.append({"role": "assistant", "content": ""})
# Yield initial state with user query
logger.info(f"Yielding initial state with user query: {query}")
yield "", openai_messages, anthropic_messages, gemini_messages, history
tasks = []
if "OpenAI" in providers:
tasks.append(("OpenAI", ask_openai(query, history), openai_response, openai_messages))
if "Anthropic" in providers:
tasks.append(("Anthropic", ask_anthropic(query, history), anthropic_response, anthropic_messages))
if "Gemini" in providers:
tasks.append(("Gemini", ask_gemini(query, history), gemini_response, gemini_messages))
async def collect_chunks(
provider: str,
generator: AsyncGenerator[str, None],
response: str,
messages: List[Dict[str, str]]
) -> AsyncGenerator[Tuple[str, str, List[Dict[str, str]]], None]:
async for chunk in generator:
response += chunk
messages[-1] = {"role": "assistant", "content": response}
yield provider, response, messages
generator_states = [(provider, collect_chunks(provider, gen, resp, msgs), None) for provider, gen, resp, msgs in tasks]
active_generators = generator_states[:]
while active_generators:
tasks_to_wait = []
new_generator_states = []
for provider, gen, active_task in active_generators:
if active_task is None or active_task.done():
try:
task = asyncio.create_task(gen.__anext__())
new_generator_states.append((provider, gen, task))
tasks_to_wait.append(task)
logger.debug(f"Created task for {provider}")
except StopAsyncIteration:
logger.info(f"Generator for {provider} completed")
continue
else:
new_generator_states.append((provider, gen, active_task))
tasks_to_wait.append(active_task)
if not tasks_to_wait:
break
done, _ = await asyncio.wait(tasks_to_wait, return_when=asyncio.FIRST_COMPLETED)
for provider, gen, task in new_generator_states:
if task in done:
try:
provider, response, messages = task.result()
if provider == "OpenAI":
openai_response = response
openai_messages = messages
elif provider == "Anthropic":
anthropic_response = response
anthropic_messages = messages
elif provider == "Gemini":
gemini_response = response
gemini_messages = messages
logger.info(f"Yielding update for {provider}: {response[:50]}...")
yield "", openai_messages, anthropic_messages, gemini_messages, history
new_generator_states[new_generator_states.index((provider, gen, task))] = (provider, gen, None)
except StopAsyncIteration:
logger.info(f"Generator for {provider} completed")
new_generator_states.remove((provider, gen, task))
else:
if (provider, gen, task) not in new_generator_states:
new_generator_states.append((provider, gen, task))
active_generators = new_generator_states
updated_history = history + [{
"user": query,
"openai": openai_response.strip() if openai_response else "",
"anthropic": anthropic_response.strip() if anthropic_response else "",
"gemini": gemini_response.strip() if gemini_response else ""
}]
logger.info(f"Updated history: {updated_history}")
yield "", openai_messages, anthropic_messages, gemini_messages, updated_history
async def submit_query(
query: str,
providers: List[str],
history: List[Dict[str, str]]
) -> AsyncGenerator[
Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]],
None
]:
if not query.strip():
msg = {"role": "assistant", "content": "Please enter a query."}
yield "", [msg], [msg], [msg], history
return
if not providers:
msg = {"role": "assistant", "content": "Please select at least one provider."}
yield "", [msg], [msg], [msg], history
return
async for _, openai_msgs, anthropic_msgs, gemini_msgs, updated_history in query_model(query, providers, history):
logger.info(f"Submitting update to UI: OpenAI: {openai_msgs[-1]['content'][:50] if openai_msgs else ''}, "
f"Anthropic: {anthropic_msgs[-1]['content'][:50] if anthropic_msgs else ''}, "
f"Gemini: {gemini_msgs[-1]['content'][:50] if gemini_msgs else ''}")
yield "", openai_msgs, anthropic_msgs, gemini_msgs, updated_history
def clear_history():
logger.info("Clearing history")
return [], [], [], [] |