Spaces:
Sleeping
Sleeping
import asyncio | |
from typing import AsyncGenerator, List, Dict, Tuple | |
from config import logger | |
from api import ask_openai, ask_anthropic, ask_gemini | |
async def query_model( | |
query: str, | |
providers: List[str], | |
history: List[Dict[str, str]] | |
) -> AsyncGenerator[ | |
Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]], | |
None | |
]: | |
logger.info(f"Processing query with providers: {providers}") | |
openai_response = "" | |
anthropic_response = "" | |
gemini_response = "" | |
openai_messages = [] | |
anthropic_messages = [] | |
gemini_messages = [] | |
# Build message history for each provider | |
for msg in history: | |
if "user" in msg: | |
openai_messages.append({"role": "user", "content": msg["user"]}) | |
anthropic_messages.append({"role": "user", "content": msg["user"]}) | |
gemini_messages.append({"role": "user", "content": msg["user"]}) | |
if msg.get("openai"): | |
openai_messages.append({"role": "assistant", "content": msg["openai"]}) | |
if msg.get("anthropic"): | |
anthropic_messages.append({"role": "assistant", "content": msg["anthropic"]}) | |
if msg.get("gemini"): | |
gemini_messages.append({"role": "assistant", "content": msg["gemini"]}) | |
# Append the user query and prepare for assistant response | |
if "OpenAI" in providers: | |
openai_messages.append({"role": "user", "content": query}) | |
openai_messages.append({"role": "assistant", "content": ""}) | |
if "Anthropic" in providers: | |
anthropic_messages.append({"role": "user", "content": query}) | |
anthropic_messages.append({"role": "assistant", "content": ""}) | |
if "Gemini" in providers: | |
gemini_messages.append({"role": "user", "content": query}) | |
gemini_messages.append({"role": "assistant", "content": ""}) | |
# Yield initial state with user query | |
logger.info(f"Yielding initial state with user query: {query}") | |
yield "", openai_messages, anthropic_messages, gemini_messages, history | |
tasks = [] | |
if "OpenAI" in providers: | |
tasks.append(("OpenAI", ask_openai(query, history), openai_response, openai_messages)) | |
if "Anthropic" in providers: | |
tasks.append(("Anthropic", ask_anthropic(query, history), anthropic_response, anthropic_messages)) | |
if "Gemini" in providers: | |
tasks.append(("Gemini", ask_gemini(query, history), gemini_response, gemini_messages)) | |
async def collect_chunks( | |
provider: str, | |
generator: AsyncGenerator[str, None], | |
response: str, | |
messages: List[Dict[str, str]] | |
) -> AsyncGenerator[Tuple[str, str, List[Dict[str, str]]], None]: | |
async for chunk in generator: | |
response += chunk | |
messages[-1] = {"role": "assistant", "content": response} | |
yield provider, response, messages | |
generator_states = [(provider, collect_chunks(provider, gen, resp, msgs), None) for provider, gen, resp, msgs in tasks] | |
active_generators = generator_states[:] | |
while active_generators: | |
tasks_to_wait = [] | |
new_generator_states = [] | |
for provider, gen, active_task in active_generators: | |
if active_task is None or active_task.done(): | |
try: | |
task = asyncio.create_task(gen.__anext__()) | |
new_generator_states.append((provider, gen, task)) | |
tasks_to_wait.append(task) | |
logger.debug(f"Created task for {provider}") | |
except StopAsyncIteration: | |
logger.info(f"Generator for {provider} completed") | |
continue | |
else: | |
new_generator_states.append((provider, gen, active_task)) | |
tasks_to_wait.append(active_task) | |
if not tasks_to_wait: | |
break | |
done, _ = await asyncio.wait(tasks_to_wait, return_when=asyncio.FIRST_COMPLETED) | |
for provider, gen, task in new_generator_states: | |
if task in done: | |
try: | |
provider, response, messages = task.result() | |
if provider == "OpenAI": | |
openai_response = response | |
openai_messages = messages | |
elif provider == "Anthropic": | |
anthropic_response = response | |
anthropic_messages = messages | |
elif provider == "Gemini": | |
gemini_response = response | |
gemini_messages = messages | |
logger.info(f"Yielding update for {provider}: {response[:50]}...") | |
yield "", openai_messages, anthropic_messages, gemini_messages, history | |
new_generator_states[new_generator_states.index((provider, gen, task))] = (provider, gen, None) | |
except StopAsyncIteration: | |
logger.info(f"Generator for {provider} completed") | |
new_generator_states.remove((provider, gen, task)) | |
else: | |
if (provider, gen, task) not in new_generator_states: | |
new_generator_states.append((provider, gen, task)) | |
active_generators = new_generator_states | |
updated_history = history + [{ | |
"user": query, | |
"openai": openai_response.strip() if openai_response else "", | |
"anthropic": anthropic_response.strip() if anthropic_response else "", | |
"gemini": gemini_response.strip() if gemini_response else "" | |
}] | |
logger.info(f"Updated history: {updated_history}") | |
yield "", openai_messages, anthropic_messages, gemini_messages, updated_history | |
async def submit_query( | |
query: str, | |
providers: List[str], | |
history: List[Dict[str, str]] | |
) -> AsyncGenerator[ | |
Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]], | |
None | |
]: | |
if not query.strip(): | |
msg = {"role": "assistant", "content": "Please enter a query."} | |
yield "", [msg], [msg], [msg], history | |
return | |
if not providers: | |
msg = {"role": "assistant", "content": "Please select at least one provider."} | |
yield "", [msg], [msg], [msg], history | |
return | |
async for _, openai_msgs, anthropic_msgs, gemini_msgs, updated_history in query_model(query, providers, history): | |
logger.info(f"Submitting update to UI: OpenAI: {openai_msgs[-1]['content'][:50] if openai_msgs else ''}, " | |
f"Anthropic: {anthropic_msgs[-1]['content'][:50] if anthropic_msgs else ''}, " | |
f"Gemini: {gemini_msgs[-1]['content'][:50] if gemini_msgs else ''}") | |
yield "", openai_msgs, anthropic_msgs, gemini_msgs, updated_history | |
def clear_history(): | |
logger.info("Clearing history") | |
return [], [], [], [] |