Ganesh Chintalapati commited on
Commit
2709e97
·
1 Parent(s): 27bfa71

final version v1

Browse files
Files changed (3) hide show
  1. api.py +52 -110
  2. app.py +10 -9
  3. core.py +64 -54
api.py CHANGED
@@ -5,6 +5,7 @@ import traceback
5
  from typing import AsyncGenerator, List, Dict
6
  from config import logger
7
 
 
8
  async def ask_openai(query: str, history: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
9
  openai_api_key = os.getenv("OPENAI_API_KEY")
10
  if not openai_api_key:
@@ -12,12 +13,11 @@ async def ask_openai(query: str, history: List[Dict[str, str]]) -> AsyncGenerato
12
  yield "Error: OpenAI API key not provided."
13
  return
14
 
15
- # Build message history with user and assistant roles
16
  messages = []
17
  for msg in history:
18
  messages.append({"role": "user", "content": msg["user"]})
19
- if msg["bot"]:
20
- messages.append({"role": "assistant", "content": msg["bot"]})
21
  messages.append({"role": "user", "content": query})
22
 
23
  headers = {
@@ -43,32 +43,24 @@ async def ask_openai(query: str, history: List[Dict[str, str]]) -> AsyncGenerato
43
  line, buffer = buffer.split("\n", 1)
44
  if line.startswith("data: "):
45
  data = line[6:]
46
- if data == "[DONE]":
47
  break
48
  if not data.strip():
49
  continue
50
  try:
51
  json_data = json.loads(data)
52
- if "choices" in json_data and json_data["choices"]:
53
- delta = json_data["choices"][0].get("delta", {})
54
- if "content" in delta and delta["content"] is not None:
55
- logger.info(f"OpenAI yielding chunk: {delta['content']}")
56
- yield delta["content"]
57
- except json.JSONDecodeError as e:
58
- logger.error(f"Error parsing OpenAI stream chunk: {str(e)} - Data: {data}")
59
- yield f"Error parsing stream: {str(e)}"
60
  except Exception as e:
61
- logger.error(f"Unexpected error in OpenAI stream: {str(e)} - Data: {data}")
62
- yield f"Error in stream: {str(e)}"
63
-
64
- except httpx.HTTPStatusError as e:
65
- response_text = await e.response.aread()
66
- logger.error(f"OpenAI HTTP Status Error: {e.response.status_code}, {response_text}")
67
- yield f"Error: OpenAI HTTP Status Error: {e.response.status_code}, {response_text.decode('utf-8')}"
68
  except Exception as e:
69
- logger.error(f"OpenAI Error: {str(e)}")
70
- yield f"Error: OpenAI Error: {str(e)}"
71
 
 
 
72
  async def ask_anthropic(query: str, history: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
73
  anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
74
  if not anthropic_api_key:
@@ -79,8 +71,8 @@ async def ask_anthropic(query: str, history: List[Dict[str, str]]) -> AsyncGener
79
  messages = []
80
  for msg in history:
81
  messages.append({"role": "user", "content": msg["user"]})
82
- if msg["bot"]:
83
- messages.append({"role": "assistant", "content": msg["bot"]})
84
  messages.append({"role": "user", "content": query})
85
 
86
  headers = {
@@ -97,8 +89,7 @@ async def ask_anthropic(query: str, history: List[Dict[str, str]]) -> AsyncGener
97
  }
98
 
99
  try:
100
- async with httpx.AsyncClient(timeout=30.0) as client:
101
- logger.info(f"Sending Anthropic streaming request: {payload}")
102
  async with client.stream("POST", "https://api.anthropic.com/v1/messages", headers=headers, json=payload) as response:
103
  response.raise_for_status()
104
  buffer = ""
@@ -115,26 +106,17 @@ async def ask_anthropic(query: str, history: List[Dict[str, str]]) -> AsyncGener
115
  continue
116
  try:
117
  json_data = json.loads(data)
118
- if json_data.get("type") == "content_block_delta" and "delta" in json_data and "text" in json_data["delta"]:
119
- logger.info(f"Anthropic yielding chunk: {json_data['delta']['text']}")
120
- yield json_data["delta"]["text"]
121
- elif json_data.get("type") == "message_start" or json_data.get("type") == "message_delta":
122
- continue
123
- except json.JSONDecodeError as e:
124
- logger.error(f"Error parsing Anthropic stream chunk: {str(e)} - Data: {data}")
125
- yield f"Error parsing stream: {str(e)}"
126
  except Exception as e:
127
- logger.error(f"Unexpected error in Anthropic stream: {str(e)} - Data: {data}")
128
- yield f"Error in stream: {str(e)}"
129
-
130
- except httpx.HTTPStatusError as e:
131
- response_text = await e.response.aread()
132
- logger.error(f"Anthropic HTTP Status Error: {e.response.status_code}, {response_text.decode('utf-8')}")
133
- yield f"Error: Anthropic HTTP Status Error: {e.response.status_code}, {response_text.decode('utf-8')}"
134
  except Exception as e:
135
- logger.error(f"Anthropic Error: {str(e)}\nStack trace: {traceback.format_exc()}")
136
- yield f"Error: Anthropic Error: {str(e)}"
 
137
 
 
138
  async def ask_gemini(query: str, history: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
139
  gemini_api_key = os.getenv("GEMINI_API_KEY")
140
  if not gemini_api_key:
@@ -144,20 +126,18 @@ async def ask_gemini(query: str, history: List[Dict[str, str]]) -> AsyncGenerato
144
 
145
  history_text = ""
146
  for msg in history:
147
- history_text += f"User: {msg['user']}\nAssistant: {msg['bot']}\n" if msg['bot'] else f"User: {msg['user']}\n"
148
- full_query = history_text + f"User: {query}\n"
149
-
150
- headers = {
151
- "Content-Type": "application/json"
152
- }
153
 
 
154
  payload = {
155
- "contents": [{"parts": [{"text": full_query}]}]
156
  }
157
 
158
  try:
159
- async with httpx.AsyncClient(timeout=30.0) as client:
160
- logger.info(f"Sending Gemini streaming request: {payload}")
161
  async with client.stream(
162
  "POST",
163
  f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:streamGenerateContent?key={gemini_api_key}",
@@ -167,63 +147,25 @@ async def ask_gemini(query: str, history: List[Dict[str, str]]) -> AsyncGenerato
167
  response.raise_for_status()
168
  buffer = ""
169
  async for chunk in response.aiter_text():
170
- if chunk:
171
- buffer += chunk
172
- logger.info(f"Gemini stream chunk: {chunk}")
173
- while buffer.strip():
174
- try:
175
- json_data = json.loads(buffer)
176
- logger.info(f"Parsed Gemini JSON: {json_data}")
177
- buffer = ""
178
- objects = json_data if isinstance(json_data, list) else [json_data]
179
- for obj in objects:
180
- if isinstance(obj, dict) and "candidates" in obj and obj["candidates"]:
181
- content = obj["candidates"][0].get("content", {})
182
- if "parts" in content and content["parts"]:
183
- text = content["parts"][0].get("text", "")
184
- if text:
185
- logger.info(f"Gemini yielding chunk: {text}")
186
- yield text
187
- break
188
- except json.JSONDecodeError:
189
- brace_count = 0
190
- split_index = -1
191
- for i, char in enumerate(buffer):
192
- if char == '{':
193
- brace_count += 1
194
- elif char == '}':
195
- brace_count -= 1
196
- if brace_count == 0:
197
- split_index = i + 1
198
- if split_index > 0:
199
- try:
200
- json_str = buffer[:split_index]
201
- json_data = json.loads(json_str)
202
- logger.info(f"Parsed Gemini JSON: {json_data}")
203
- buffer = buffer[split_index:].lstrip(',')
204
- objects = json_data if isinstance(json_data, list) else [json_data]
205
- for obj in objects:
206
- if isinstance(obj, dict) and "candidates" in obj and obj["candidates"]:
207
- content = obj["candidates"][0].get("content", {})
208
- if "parts" in content and content["parts"]:
209
- text = content["parts"][0].get("text", "")
210
- if text:
211
- logger.info(f"Gemini yielding chunk: {text}")
212
- yield text
213
- continue
214
- except json.JSONDecodeError:
215
- pass
216
- break
217
- except Exception as e:
218
- logger.error(f"Unexpected error in Gemini stream: {str(e)} - Buffer: {buffer}")
219
- yield f"Error in stream: {str(e)}"
220
- buffer = ""
221
- break
222
-
223
- except httpx.HTTPStatusError as e:
224
- response_text = await e.response.aread()
225
- logger.error(f"Gemini HTTP Status Error: {e.response.status_code}, {response_text.decode('utf-8')}")
226
- yield f"Error: Gemini HTTP Status Error: {e.response.status_code}, {response_text.decode('utf-8')}"
227
  except Exception as e:
228
- logger.error(f"Gemini Error: {str(e)}\nStack trace: {traceback.format_exc()}")
229
- yield f"Error: Gemini Error: {str(e)}"
 
5
  from typing import AsyncGenerator, List, Dict
6
  from config import logger
7
 
8
+ # ===== OpenAI =====
9
  async def ask_openai(query: str, history: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
10
  openai_api_key = os.getenv("OPENAI_API_KEY")
11
  if not openai_api_key:
 
13
  yield "Error: OpenAI API key not provided."
14
  return
15
 
 
16
  messages = []
17
  for msg in history:
18
  messages.append({"role": "user", "content": msg["user"]})
19
+ if msg.get("openai"):
20
+ messages.append({"role": "assistant", "content": msg["openai"]})
21
  messages.append({"role": "user", "content": query})
22
 
23
  headers = {
 
43
  line, buffer = buffer.split("\n", 1)
44
  if line.startswith("data: "):
45
  data = line[6:]
46
+ if data.strip() == "[DONE]":
47
  break
48
  if not data.strip():
49
  continue
50
  try:
51
  json_data = json.loads(data)
52
+ delta = json_data["choices"][0].get("delta", {})
53
+ if "content" in delta:
54
+ yield delta["content"]
 
 
 
 
 
55
  except Exception as e:
56
+ logger.error(f"OpenAI parse error: {e}")
57
+ yield f"[OpenAI Error]: {e}"
 
 
 
 
 
58
  except Exception as e:
59
+ logger.error(f"OpenAI API error: {e}")
60
+ yield f"[OpenAI Error]: {e}"
61
 
62
+
63
+ # ===== Anthropic =====
64
  async def ask_anthropic(query: str, history: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
65
  anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
66
  if not anthropic_api_key:
 
71
  messages = []
72
  for msg in history:
73
  messages.append({"role": "user", "content": msg["user"]})
74
+ if msg.get("anthropic"):
75
+ messages.append({"role": "assistant", "content": msg["anthropic"]})
76
  messages.append({"role": "user", "content": query})
77
 
78
  headers = {
 
89
  }
90
 
91
  try:
92
+ async with httpx.AsyncClient() as client:
 
93
  async with client.stream("POST", "https://api.anthropic.com/v1/messages", headers=headers, json=payload) as response:
94
  response.raise_for_status()
95
  buffer = ""
 
106
  continue
107
  try:
108
  json_data = json.loads(data)
109
+ if json_data.get("type") == "content_block_delta" and "delta" in json_data:
110
+ yield json_data["delta"].get("text", "")
 
 
 
 
 
 
111
  except Exception as e:
112
+ logger.error(f"Anthropic parse error: {e}")
113
+ yield f"[Anthropic Error]: {e}"
 
 
 
 
 
114
  except Exception as e:
115
+ logger.error(f"Anthropic API error: {e}")
116
+ yield f"[Anthropic Error]: {e}"
117
+
118
 
119
+ # ===== Gemini =====
120
  async def ask_gemini(query: str, history: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
121
  gemini_api_key = os.getenv("GEMINI_API_KEY")
122
  if not gemini_api_key:
 
126
 
127
  history_text = ""
128
  for msg in history:
129
+ history_text += f"User: {msg['user']}\n"
130
+ if msg.get("gemini"):
131
+ history_text += f"Assistant: {msg['gemini']}\n"
132
+ full_prompt = f"{history_text}User: {query}\n"
 
 
133
 
134
+ headers = {"Content-Type": "application/json"}
135
  payload = {
136
+ "contents": [{"parts": [{"text": full_prompt}]}]
137
  }
138
 
139
  try:
140
+ async with httpx.AsyncClient() as client:
 
141
  async with client.stream(
142
  "POST",
143
  f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:streamGenerateContent?key={gemini_api_key}",
 
147
  response.raise_for_status()
148
  buffer = ""
149
  async for chunk in response.aiter_text():
150
+ if not chunk.strip():
151
+ continue
152
+ buffer += chunk
153
+ try:
154
+ json_data = json.loads(buffer.strip(", \n"))
155
+ buffer = ""
156
+
157
+ # handle both list and dict format
158
+ objects = json_data if isinstance(json_data, list) else [json_data]
159
+ for obj in objects:
160
+ candidates = obj.get("candidates", [])
161
+ if candidates:
162
+ parts = candidates[0].get("content", {}).get("parts", [])
163
+ for part in parts:
164
+ text = part.get("text", "")
165
+ if text:
166
+ yield text
167
+ except json.JSONDecodeError:
168
+ continue # wait for more data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  except Exception as e:
170
+ logger.error(f"Gemini API error: {e}")
171
+ yield f"[Gemini Error]: {e}"
app.py CHANGED
@@ -5,32 +5,33 @@ from core import submit_query, clear_history
5
  with gr.Blocks(theme=gr.themes.Soft(), css=".full-height { height: 100%; display: flex; align-items: stretch; min-height: 40px; } .full-height button { height: 100%; padding: 8px 16px; } .providers-row { height: 100%; display: flex; align-items: stretch; min-height: 40px; } .providers-row .checkbox-group { height: 100%; display: flex; flex-direction: row; align-items: center; gap: 10px; }") as demo:
6
  gr.Markdown("# Multi-Model Chat")
7
  gr.Markdown("Chat with OpenAI, Anthropic, or Gemini. Select providers and compare responses side by side!")
8
-
9
  with gr.Row(elem_classes="providers-row"):
10
  providers = gr.CheckboxGroup(choices=["OpenAI", "Anthropic", "Gemini"], label="Select Providers", value=["OpenAI"], elem_classes="checkbox-group")
11
-
12
  with gr.Row(elem_classes="full-height"):
13
  query = gr.Textbox(label="Enter your query", placeholder="e.g., What is the capital of the United States?", scale=4)
14
  submit_button = gr.Button("Submit", scale=1)
15
-
16
  with gr.Row():
17
  clear_button = gr.Button("Clear History")
18
-
19
  with gr.Row():
20
  openai_chatbot = gr.Chatbot(label="OpenAI", type="messages", scale=1)
21
  anthropic_chatbot = gr.Chatbot(label="Anthropic", type="messages", scale=1)
22
  gemini_chatbot = gr.Chatbot(label="Gemini", type="messages", scale=1)
23
 
 
 
24
  submit_button.click(
25
  fn=submit_query,
26
- inputs=[query, providers],
27
- outputs=[query, openai_chatbot, anthropic_chatbot, gemini_chatbot]
28
  )
29
  clear_button.click(
30
  fn=clear_history,
31
  inputs=[],
32
- outputs=[openai_chatbot, anthropic_chatbot, gemini_chatbot]
33
  )
34
 
35
- # Launch the Gradio app
36
- demo.launch()
 
5
  with gr.Blocks(theme=gr.themes.Soft(), css=".full-height { height: 100%; display: flex; align-items: stretch; min-height: 40px; } .full-height button { height: 100%; padding: 8px 16px; } .providers-row { height: 100%; display: flex; align-items: stretch; min-height: 40px; } .providers-row .checkbox-group { height: 100%; display: flex; flex-direction: row; align-items: center; gap: 10px; }") as demo:
6
  gr.Markdown("# Multi-Model Chat")
7
  gr.Markdown("Chat with OpenAI, Anthropic, or Gemini. Select providers and compare responses side by side!")
8
+
9
  with gr.Row(elem_classes="providers-row"):
10
  providers = gr.CheckboxGroup(choices=["OpenAI", "Anthropic", "Gemini"], label="Select Providers", value=["OpenAI"], elem_classes="checkbox-group")
11
+
12
  with gr.Row(elem_classes="full-height"):
13
  query = gr.Textbox(label="Enter your query", placeholder="e.g., What is the capital of the United States?", scale=4)
14
  submit_button = gr.Button("Submit", scale=1)
15
+
16
  with gr.Row():
17
  clear_button = gr.Button("Clear History")
18
+
19
  with gr.Row():
20
  openai_chatbot = gr.Chatbot(label="OpenAI", type="messages", scale=1)
21
  anthropic_chatbot = gr.Chatbot(label="Anthropic", type="messages", scale=1)
22
  gemini_chatbot = gr.Chatbot(label="Gemini", type="messages", scale=1)
23
 
24
+ chat_history = gr.State([])
25
+
26
  submit_button.click(
27
  fn=submit_query,
28
+ inputs=[query, providers, chat_history],
29
+ outputs=[query, openai_chatbot, anthropic_chatbot, gemini_chatbot, chat_history]
30
  )
31
  clear_button.click(
32
  fn=clear_history,
33
  inputs=[],
34
+ outputs=[openai_chatbot, anthropic_chatbot, gemini_chatbot, chat_history]
35
  )
36
 
37
+ demo.launch()
 
core.py CHANGED
@@ -3,30 +3,37 @@ from typing import AsyncGenerator, List, Dict, Tuple
3
  from config import logger
4
  from api import ask_openai, ask_anthropic, ask_gemini
5
 
6
- async def query_model(query: str, providers: List[str], history: List[Dict[str, str]]) -> AsyncGenerator[Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]], None]:
 
 
 
 
 
 
 
7
  logger.info(f"Processing query with providers: {providers}")
8
  openai_response = ""
9
  anthropic_response = ""
10
  gemini_response = ""
11
-
12
  openai_messages = []
13
  anthropic_messages = []
14
  gemini_messages = []
15
-
 
16
  for msg in history:
17
- openai_messages.append({"role": "user", "content": msg["user"]})
18
- anthropic_messages.append({"role": "user", "content": msg["user"]})
19
- gemini_messages.append({"role": "user", "content": msg["user"]})
20
- if msg["bot"]:
21
- lines = msg["bot"].split("\n\n")
22
- for line in lines:
23
- if line.startswith("[OpenAI]:"):
24
- openai_messages.append({"role": "assistant", "content": line[len("[OpenAI]:"):].strip()})
25
- elif line.startswith("[Anthropic]:"):
26
- anthropic_messages.append({"role": "assistant", "content": line[len("[Anthropic]:"):].strip()})
27
- elif line.startswith("[Gemini]:"):
28
- gemini_messages.append({"role": "assistant", "content": line[len("[Gemini]:"):].strip()})
29
 
 
30
  if "OpenAI" in providers:
31
  openai_messages.append({"role": "user", "content": query})
32
  openai_messages.append({"role": "assistant", "content": ""})
@@ -37,6 +44,10 @@ async def query_model(query: str, providers: List[str], history: List[Dict[str,
37
  gemini_messages.append({"role": "user", "content": query})
38
  gemini_messages.append({"role": "assistant", "content": ""})
39
 
 
 
 
 
40
  tasks = []
41
  if "OpenAI" in providers:
42
  tasks.append(("OpenAI", ask_openai(query, history), openai_response, openai_messages))
@@ -45,7 +56,12 @@ async def query_model(query: str, providers: List[str], history: List[Dict[str,
45
  if "Gemini" in providers:
46
  tasks.append(("Gemini", ask_gemini(query, history), gemini_response, gemini_messages))
47
 
48
- async def collect_chunks(provider: str, generator: AsyncGenerator, response: str, messages: List[Dict[str, str]]) -> AsyncGenerator[Tuple[str, str, List[Dict[str, str]]], None]:
 
 
 
 
 
49
  async for chunk in generator:
50
  response += chunk
51
  messages[-1] = {"role": "assistant", "content": response}
@@ -57,7 +73,7 @@ async def query_model(query: str, providers: List[str], history: List[Dict[str,
57
  while active_generators:
58
  tasks_to_wait = []
59
  new_generator_states = []
60
-
61
  for provider, gen, active_task in active_generators:
62
  if active_task is None or active_task.done():
63
  try:
@@ -71,12 +87,12 @@ async def query_model(query: str, providers: List[str], history: List[Dict[str,
71
  else:
72
  new_generator_states.append((provider, gen, active_task))
73
  tasks_to_wait.append(active_task)
74
-
75
  if not tasks_to_wait:
76
  break
77
 
78
  done, _ = await asyncio.wait(tasks_to_wait, return_when=asyncio.FIRST_COMPLETED)
79
-
80
  for provider, gen, task in new_generator_states:
81
  if task in done:
82
  try:
@@ -91,7 +107,7 @@ async def query_model(query: str, providers: List[str], history: List[Dict[str,
91
  gemini_response = response
92
  gemini_messages = messages
93
  logger.info(f"Yielding update for {provider}: {response[:50]}...")
94
- yield "", openai_messages, anthropic_messages, gemini_messages
95
  new_generator_states[new_generator_states.index((provider, gen, task))] = (provider, gen, None)
96
  except StopAsyncIteration:
97
  logger.info(f"Generator for {provider} completed")
@@ -102,46 +118,40 @@ async def query_model(query: str, providers: List[str], history: List[Dict[str,
102
 
103
  active_generators = new_generator_states
104
 
105
- responses = []
106
- if openai_response.strip() and not openai_response.startswith("Error:"):
107
- responses.append(f"[OpenAI]: {openai_response}")
108
- if anthropic_response.strip() and not anthropic_response.startswith("Error:"):
109
- responses.append(f"[Anthropic]: {anthropic_response}")
110
- if gemini_response.strip() and not gemini_response.startswith("Error:"):
111
- responses.append(f"[Gemini]: {gemini_response}")
112
-
113
- combined_response = "\n\n".join(responses) if responses else "No valid responses received."
114
- updated_history = history + [{"user": query, "bot": combined_response}]
115
- logger.info(f"Updated history: {updated_history}")
116
 
117
- yield "", openai_messages, anthropic_messages, gemini_messages
 
118
 
119
- async def submit_query(query: str, providers: List[str]) -> AsyncGenerator[Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]], None]:
 
 
 
 
 
 
 
120
  if not query.strip():
121
- openai_messages = [{"role": "assistant", "content": "Please enter a query."}]
122
- anthropic_messages = [{"role": "assistant", "content": "Please enter a query."}]
123
- gemini_messages = [{"role": "assistant", "content": "Please enter a query."}]
124
- logger.info("Yielding empty query response")
125
- yield "", openai_messages, anthropic_messages, gemini_messages
126
  return
127
-
128
  if not providers:
129
- openai_messages = [{"role": "assistant", "content": "Please select at least one provider."}]
130
- anthropic_messages = [{"role": "assistant", "content": "Please select at least one provider."}]
131
- gemini_messages = [{"role": "assistant", "content": "Please select at least one provider."}]
132
- logger.info("Yielding no providers response")
133
- yield "", openai_messages, anthropic_messages, gemini_messages
134
  return
135
 
136
- history = []
137
-
138
- async for response_chunk, openai_messages, anthropic_messages, gemini_messages in query_model(query, providers, history):
139
- logger.info(f"Submitting update to UI: OpenAI: {openai_messages[-1]['content'][:50] if openai_messages else ''}, "
140
- f"Anthropic: {anthropic_messages[-1]['content'][:50] if anthropic_messages else ''}, "
141
- f"Gemini: {gemini_messages[-1]['content'][:50] if gemini_messages else ''}")
142
- yield "", openai_messages, anthropic_messages, gemini_messages
143
- logger.info("Final UI update")
144
- yield "", openai_messages, anthropic_messages, gemini_messages
145
 
146
  def clear_history():
147
- return [], [], []
 
 
3
  from config import logger
4
  from api import ask_openai, ask_anthropic, ask_gemini
5
 
6
+ async def query_model(
7
+ query: str,
8
+ providers: List[str],
9
+ history: List[Dict[str, str]]
10
+ ) -> AsyncGenerator[
11
+ Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]],
12
+ None
13
+ ]:
14
  logger.info(f"Processing query with providers: {providers}")
15
  openai_response = ""
16
  anthropic_response = ""
17
  gemini_response = ""
18
+
19
  openai_messages = []
20
  anthropic_messages = []
21
  gemini_messages = []
22
+
23
+ # Build message history for each provider
24
  for msg in history:
25
+ if "user" in msg:
26
+ openai_messages.append({"role": "user", "content": msg["user"]})
27
+ anthropic_messages.append({"role": "user", "content": msg["user"]})
28
+ gemini_messages.append({"role": "user", "content": msg["user"]})
29
+ if msg.get("openai"):
30
+ openai_messages.append({"role": "assistant", "content": msg["openai"]})
31
+ if msg.get("anthropic"):
32
+ anthropic_messages.append({"role": "assistant", "content": msg["anthropic"]})
33
+ if msg.get("gemini"):
34
+ gemini_messages.append({"role": "assistant", "content": msg["gemini"]})
 
 
35
 
36
+ # Append the user query and prepare for assistant response
37
  if "OpenAI" in providers:
38
  openai_messages.append({"role": "user", "content": query})
39
  openai_messages.append({"role": "assistant", "content": ""})
 
44
  gemini_messages.append({"role": "user", "content": query})
45
  gemini_messages.append({"role": "assistant", "content": ""})
46
 
47
+ # Yield initial state with user query
48
+ logger.info(f"Yielding initial state with user query: {query}")
49
+ yield "", openai_messages, anthropic_messages, gemini_messages, history
50
+
51
  tasks = []
52
  if "OpenAI" in providers:
53
  tasks.append(("OpenAI", ask_openai(query, history), openai_response, openai_messages))
 
56
  if "Gemini" in providers:
57
  tasks.append(("Gemini", ask_gemini(query, history), gemini_response, gemini_messages))
58
 
59
+ async def collect_chunks(
60
+ provider: str,
61
+ generator: AsyncGenerator[str, None],
62
+ response: str,
63
+ messages: List[Dict[str, str]]
64
+ ) -> AsyncGenerator[Tuple[str, str, List[Dict[str, str]]], None]:
65
  async for chunk in generator:
66
  response += chunk
67
  messages[-1] = {"role": "assistant", "content": response}
 
73
  while active_generators:
74
  tasks_to_wait = []
75
  new_generator_states = []
76
+
77
  for provider, gen, active_task in active_generators:
78
  if active_task is None or active_task.done():
79
  try:
 
87
  else:
88
  new_generator_states.append((provider, gen, active_task))
89
  tasks_to_wait.append(active_task)
90
+
91
  if not tasks_to_wait:
92
  break
93
 
94
  done, _ = await asyncio.wait(tasks_to_wait, return_when=asyncio.FIRST_COMPLETED)
95
+
96
  for provider, gen, task in new_generator_states:
97
  if task in done:
98
  try:
 
107
  gemini_response = response
108
  gemini_messages = messages
109
  logger.info(f"Yielding update for {provider}: {response[:50]}...")
110
+ yield "", openai_messages, anthropic_messages, gemini_messages, history
111
  new_generator_states[new_generator_states.index((provider, gen, task))] = (provider, gen, None)
112
  except StopAsyncIteration:
113
  logger.info(f"Generator for {provider} completed")
 
118
 
119
  active_generators = new_generator_states
120
 
121
+ updated_history = history + [{
122
+ "user": query,
123
+ "openai": openai_response.strip() if openai_response else "",
124
+ "anthropic": anthropic_response.strip() if anthropic_response else "",
125
+ "gemini": gemini_response.strip() if gemini_response else ""
126
+ }]
 
 
 
 
 
127
 
128
+ logger.info(f"Updated history: {updated_history}")
129
+ yield "", openai_messages, anthropic_messages, gemini_messages, updated_history
130
 
131
+ async def submit_query(
132
+ query: str,
133
+ providers: List[str],
134
+ history: List[Dict[str, str]]
135
+ ) -> AsyncGenerator[
136
+ Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]],
137
+ None
138
+ ]:
139
  if not query.strip():
140
+ msg = {"role": "assistant", "content": "Please enter a query."}
141
+ yield "", [msg], [msg], [msg], history
 
 
 
142
  return
143
+
144
  if not providers:
145
+ msg = {"role": "assistant", "content": "Please select at least one provider."}
146
+ yield "", [msg], [msg], [msg], history
 
 
 
147
  return
148
 
149
+ async for _, openai_msgs, anthropic_msgs, gemini_msgs, updated_history in query_model(query, providers, history):
150
+ logger.info(f"Submitting update to UI: OpenAI: {openai_msgs[-1]['content'][:50] if openai_msgs else ''}, "
151
+ f"Anthropic: {anthropic_msgs[-1]['content'][:50] if anthropic_msgs else ''}, "
152
+ f"Gemini: {gemini_msgs[-1]['content'][:50] if gemini_msgs else ''}")
153
+ yield "", openai_msgs, anthropic_msgs, gemini_msgs, updated_history
 
 
 
 
154
 
155
  def clear_history():
156
+ logger.info("Clearing history")
157
+ return [], [], [], []