Ganesh Chintalapati
final version v1
2709e97
import asyncio
from typing import AsyncGenerator, List, Dict, Tuple
from config import logger
from api import ask_openai, ask_anthropic, ask_gemini
async def query_model(
query: str,
providers: List[str],
history: List[Dict[str, str]]
) -> AsyncGenerator[
Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]],
None
]:
logger.info(f"Processing query with providers: {providers}")
openai_response = ""
anthropic_response = ""
gemini_response = ""
openai_messages = []
anthropic_messages = []
gemini_messages = []
# Build message history for each provider
for msg in history:
if "user" in msg:
openai_messages.append({"role": "user", "content": msg["user"]})
anthropic_messages.append({"role": "user", "content": msg["user"]})
gemini_messages.append({"role": "user", "content": msg["user"]})
if msg.get("openai"):
openai_messages.append({"role": "assistant", "content": msg["openai"]})
if msg.get("anthropic"):
anthropic_messages.append({"role": "assistant", "content": msg["anthropic"]})
if msg.get("gemini"):
gemini_messages.append({"role": "assistant", "content": msg["gemini"]})
# Append the user query and prepare for assistant response
if "OpenAI" in providers:
openai_messages.append({"role": "user", "content": query})
openai_messages.append({"role": "assistant", "content": ""})
if "Anthropic" in providers:
anthropic_messages.append({"role": "user", "content": query})
anthropic_messages.append({"role": "assistant", "content": ""})
if "Gemini" in providers:
gemini_messages.append({"role": "user", "content": query})
gemini_messages.append({"role": "assistant", "content": ""})
# Yield initial state with user query
logger.info(f"Yielding initial state with user query: {query}")
yield "", openai_messages, anthropic_messages, gemini_messages, history
tasks = []
if "OpenAI" in providers:
tasks.append(("OpenAI", ask_openai(query, history), openai_response, openai_messages))
if "Anthropic" in providers:
tasks.append(("Anthropic", ask_anthropic(query, history), anthropic_response, anthropic_messages))
if "Gemini" in providers:
tasks.append(("Gemini", ask_gemini(query, history), gemini_response, gemini_messages))
async def collect_chunks(
provider: str,
generator: AsyncGenerator[str, None],
response: str,
messages: List[Dict[str, str]]
) -> AsyncGenerator[Tuple[str, str, List[Dict[str, str]]], None]:
async for chunk in generator:
response += chunk
messages[-1] = {"role": "assistant", "content": response}
yield provider, response, messages
generator_states = [(provider, collect_chunks(provider, gen, resp, msgs), None) for provider, gen, resp, msgs in tasks]
active_generators = generator_states[:]
while active_generators:
tasks_to_wait = []
new_generator_states = []
for provider, gen, active_task in active_generators:
if active_task is None or active_task.done():
try:
task = asyncio.create_task(gen.__anext__())
new_generator_states.append((provider, gen, task))
tasks_to_wait.append(task)
logger.debug(f"Created task for {provider}")
except StopAsyncIteration:
logger.info(f"Generator for {provider} completed")
continue
else:
new_generator_states.append((provider, gen, active_task))
tasks_to_wait.append(active_task)
if not tasks_to_wait:
break
done, _ = await asyncio.wait(tasks_to_wait, return_when=asyncio.FIRST_COMPLETED)
for provider, gen, task in new_generator_states:
if task in done:
try:
provider, response, messages = task.result()
if provider == "OpenAI":
openai_response = response
openai_messages = messages
elif provider == "Anthropic":
anthropic_response = response
anthropic_messages = messages
elif provider == "Gemini":
gemini_response = response
gemini_messages = messages
logger.info(f"Yielding update for {provider}: {response[:50]}...")
yield "", openai_messages, anthropic_messages, gemini_messages, history
new_generator_states[new_generator_states.index((provider, gen, task))] = (provider, gen, None)
except StopAsyncIteration:
logger.info(f"Generator for {provider} completed")
new_generator_states.remove((provider, gen, task))
else:
if (provider, gen, task) not in new_generator_states:
new_generator_states.append((provider, gen, task))
active_generators = new_generator_states
updated_history = history + [{
"user": query,
"openai": openai_response.strip() if openai_response else "",
"anthropic": anthropic_response.strip() if anthropic_response else "",
"gemini": gemini_response.strip() if gemini_response else ""
}]
logger.info(f"Updated history: {updated_history}")
yield "", openai_messages, anthropic_messages, gemini_messages, updated_history
async def submit_query(
query: str,
providers: List[str],
history: List[Dict[str, str]]
) -> AsyncGenerator[
Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]],
None
]:
if not query.strip():
msg = {"role": "assistant", "content": "Please enter a query."}
yield "", [msg], [msg], [msg], history
return
if not providers:
msg = {"role": "assistant", "content": "Please select at least one provider."}
yield "", [msg], [msg], [msg], history
return
async for _, openai_msgs, anthropic_msgs, gemini_msgs, updated_history in query_model(query, providers, history):
logger.info(f"Submitting update to UI: OpenAI: {openai_msgs[-1]['content'][:50] if openai_msgs else ''}, "
f"Anthropic: {anthropic_msgs[-1]['content'][:50] if anthropic_msgs else ''}, "
f"Gemini: {gemini_msgs[-1]['content'][:50] if gemini_msgs else ''}")
yield "", openai_msgs, anthropic_msgs, gemini_msgs, updated_history
def clear_history():
logger.info("Clearing history")
return [], [], [], []