broadfield-dev commited on
Commit
e8b8149
·
verified ·
1 Parent(s): 3103a1e

Update model_logic.py

Browse files
Files changed (1) hide show
  1. model_logic.py +36 -115
model_logic.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import requests
3
  import json
4
  import logging
5
- import time # Import time for retries
6
 
7
  logging.basicConfig(
8
  level=logging.INFO,
@@ -11,7 +11,7 @@ logging.basicConfig(
11
  logger = logging.getLogger(__name__)
12
 
13
  API_KEYS = {
14
- "HUGGINGFACE": 'HF_TOKEN', # Note: HF_TOKEN is also for HF Hub, so maybe rename this in UI label?
15
  "GROQ": 'GROQ_API_KEY',
16
  "OPENROUTER": 'OPENROUTER_API_KEY',
17
  "TOGETHERAI": 'TOGETHERAI_API_KEY',
@@ -29,17 +29,15 @@ API_URLS = {
29
  "COHERE": 'https://api.cohere.ai/v1/chat',
30
  "XAI": 'https://api.x.ai/v1/chat/completions',
31
  "OPENAI": 'https://api.openai.com/v1/chat/completions',
32
- "GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/', # Base URL, model ID added later
33
  }
34
 
35
- # Load model configuration from JSON
36
  try:
37
  with open("models.json", "r") as f:
38
  MODELS_BY_PROVIDER = json.load(f)
39
  logger.info("models.json loaded successfully.")
40
- except FileNotFoundError:
41
- logger.error("models.json not found. Using hardcoded fallback models.")
42
- # Keep the hardcoded fallback as a safety measure
43
  MODELS_BY_PROVIDER = {
44
  "groq": {
45
  "default": "llama3-8b-8192",
@@ -70,47 +68,9 @@ except FileNotFoundError:
70
  "models": {
71
  "GPT-4o mini (OpenAI)": "gpt-4o-mini",
72
  "GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
73
- }
74
- },
75
- # Add other providers here if needed for fallback
76
- }
77
- except json.JSONDecodeError:
78
- logger.error("Error decoding models.json. Using hardcoded fallback models.")
79
- # Keep the hardcoded fallback as a safety measure
80
- MODELS_BY_PROVIDER = {
81
- "groq": {
82
- "default": "llama3-8b-8192",
83
- "models": {
84
- "Llama 3 8B (Groq)": "llama3-8b-8192",
85
- "Llama 3 70B (Groq)": "llama3-70b-8192",
86
- "Mixtral 8x7B (Groq)": "mixtral-8x7b-32768",
87
- "Gemma 7B (Groq)": "gemma-7b-it",
88
- }
89
- },
90
- "openrouter": {
91
- "default": "nousresearch/llama-3-8b-instruct",
92
- "models": {
93
- "Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct",
94
- "Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free",
95
- "Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free",
96
- }
97
- },
98
- "google": {
99
- "default": "gemini-1.5-flash-latest",
100
- "models": {
101
- "Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest",
102
- "Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest",
103
- }
104
- },
105
- "openai": {
106
- "default": "gpt-3.5-turbo",
107
- "models": {
108
- "GPT-4o mini (OpenAI)": "gpt-4o-mini",
109
- "GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
110
  }
111
  },
112
- # Add other providers here if needed for fallback
113
- }
114
 
115
 
116
  def _get_api_key(provider: str, ui_api_key_override: str = None) -> str:
@@ -125,7 +85,6 @@ def _get_api_key(provider: str, ui_api_key_override: str = None) -> str:
125
  logger.debug(f"Using env var {env_var_name} for {provider}")
126
  return env_key.strip()
127
 
128
- # Special case for Hugging Face, HF_TOKEN is common
129
  if provider.lower() == 'huggingface':
130
  hf_token = os.getenv("HF_TOKEN")
131
  if hf_token:
@@ -145,11 +104,9 @@ def get_default_model_for_provider(provider: str) -> str | None:
145
  models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
146
  default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default")
147
  if default_model_id:
148
- # Find the display name corresponding to the default model ID
149
  for display_name, model_id in models_dict.items():
150
  if model_id == default_model_id:
151
  return display_name
152
- # Fallback: If no default specified or found, return the first model in the sorted list
153
  if models_dict:
154
  return sorted(list(models_dict.keys()))[0]
155
  return None
@@ -179,7 +136,7 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
179
  headers = {}
180
  payload = {}
181
  request_url = base_url
182
- timeout_seconds = 180 # Increased timeout
183
 
184
  logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...")
185
 
@@ -190,25 +147,23 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
190
  "model": model_id,
191
  "messages": messages,
192
  "stream": True,
193
- "temperature": 0.7, # Add temperature
194
- "max_tokens": 4096 # Add max_tokens
195
  }
196
  if provider_lower == "openrouter":
197
- headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/your_username/ai-space-commander" # Use space name
198
- headers["X-Title"] = "Hugging Face Space Commander" # Use project title
199
 
200
  response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds)
201
  response.raise_for_status()
202
 
203
  byte_buffer = b""
204
  for chunk in response.iter_content(chunk_size=8192):
205
- # Check for potential HTTP errors during streaming
206
  if response.status_code != 200:
207
- # Attempt to read error body if available
208
  error_body = response.text
209
  logger.error(f"HTTP Error during stream: {response.status_code}, Body: {error_body}")
210
  yield f"API HTTP Error ({response.status_code}) during stream: {error_body}"
211
- return # Stop streaming on error
212
 
213
  byte_buffer += chunk
214
  while b'\n' in byte_buffer:
@@ -217,7 +172,7 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
217
  if decoded_line.startswith('data: '):
218
  data = decoded_line[6:]
219
  if data == '[DONE]':
220
- byte_buffer = b'' # Clear buffer after DONE
221
  break
222
  try:
223
  event_data = json.loads(data)
@@ -226,11 +181,9 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
226
  if delta and delta.get("content"):
227
  yield delta["content"]
228
  except json.JSONDecodeError:
229
- # Log warning but continue, partial data might be okay or next line fixes it
230
  logger.warning(f"Failed to decode JSON from stream line: {decoded_line.strip()}")
231
  except Exception as e:
232
  logger.error(f"Error processing stream data: {e}, Data: {decoded_line.strip()}")
233
- # Process any remaining data in the buffer after the loop
234
  if byte_buffer:
235
  remaining_line = byte_buffer.decode('utf-8', errors='ignore').strip()
236
  if remaining_line.startswith('data: '):
@@ -253,24 +206,20 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
253
  filtered_messages = []
254
  for msg in messages:
255
  if msg["role"] == "system":
256
- # Google's API takes system instruction separately or expects a specific history format
257
- # Let's extract the system instruction
258
- system_instruction = msg["content"]
259
  else:
260
- # Map roles: 'user' -> 'user', 'assistant' -> 'model'
261
  role = "model" if msg["role"] == "assistant" else msg["role"]
262
  filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]})
263
 
264
- # Ensure conversation history alternates roles correctly for Google
265
- # Simple check: if last two roles are same, it's invalid.
266
  for i in range(1, len(filtered_messages)):
267
  if filtered_messages[i]["role"] == filtered_messages[i-1]["role"]:
268
  yield f"Error: Google API requires alternating user/model roles in chat history. Please check prompt or history format."
269
- return # Stop if history format is invalid
270
 
271
  payload = {
272
  "contents": filtered_messages,
273
- "safetySettings": [ # Default safety settings to allow helpful but potentially sensitive code/instructions
274
  {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
275
  {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
276
  {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
@@ -278,66 +227,51 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
278
  ],
279
  "generationConfig": {
280
  "temperature": 0.7,
281
- "maxOutputTokens": 4096 # Google's max_tokens equivalent
282
  }
283
  }
284
- # System instruction is passed separately
285
  if system_instruction:
286
  payload["system_instruction"] = {"parts": [{"text": system_instruction}]}
287
 
288
-
289
  request_url = f"{base_url}{model_id}:streamGenerateContent"
290
- # API key is passed as a query parameter for Google
291
  request_url = f"{request_url}?key={api_key}"
292
- headers = {"Content-Type": "application/json"} # Content-Type is still application/json
293
 
294
  response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds)
295
  response.raise_for_status()
296
 
297
  byte_buffer = b""
 
298
  for chunk in response.iter_content(chunk_size=8192):
299
- # Check for potential HTTP errors during streaming
300
  if response.status_code != 200:
301
  error_body = response.text
302
  logger.error(f"HTTP Error during Google stream: {response.status_code}, Body: {error_body}")
303
  yield f"API HTTP Error ({response.status_code}) during Google stream: {error_body}"
304
- return # Stop streaming on error
305
 
306
  byte_buffer += chunk
307
- # Google's streaming can send multiple JSON objects in one chunk, sometimes split by newlines
308
- # Or just single JSON objects. They don't strictly follow the Server-Sent Events 'data:' format.
309
- # We need to find JSON objects in the buffer.
310
- json_decoder = json.JSONDecoder()
311
- while byte_buffer:
312
  try:
313
- # Attempt to decode a JSON object from the start of the buffer
314
- obj, idx = json_decoder.raw_decode(byte_buffer.decode('utf-8', errors='ignore').lstrip()) # lstrip to handle leading whitespace/newlines
315
- # If successful, process the object
316
- byte_buffer = byte_buffer[len(byte_buffer.decode('utf-8', errors='ignore').lstrip()[:idx]).encode('utf-8'):] # Remove the decoded part from the buffer
317
-
318
  if obj.get("candidates") and len(obj["candidates"]) > 0:
319
  candidate = obj["candidates"][0]
320
  if candidate.get("content") and candidate["content"].get("parts"):
321
  full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
322
  if full_text_chunk:
323
  yield full_text_chunk
324
- # Check for potential errors in the response object itself
325
  if obj.get("error"):
326
  error_details = obj["error"].get("message", str(obj["error"]))
327
  logger.error(f"Google API returned error in stream data: {error_details}")
328
  yield f"API Error (Google): {error_details}"
329
- return # Stop streaming
330
-
331
  except json.JSONDecodeError:
332
- # If raw_decode fails, it means the buffer doesn't contain a complete JSON object at the start.
333
- # Break the inner while loop and wait for more data.
334
  break
335
  except Exception as e:
336
  logger.error(f"Error processing Google stream data object: {e}, Object: {obj}")
337
- # Decide if this is a fatal error or just a bad chunk
338
- # For now, log and continue might be okay for processing subsequent chunks.
339
 
340
- # If loop finishes and buffer still has data, log it (incomplete data)
341
  if byte_buffer:
342
  logger.warning(f"Remaining data in Google stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}")
343
 
@@ -350,12 +284,9 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
350
  system_prompt_for_cohere = None
351
  current_message_for_cohere = ""
352
 
353
- # Cohere requires a specific history format and separates system/preamble
354
- # The last message is the "message", previous are "chat_history"
355
  temp_history = []
356
  for msg in messages:
357
  if msg["role"] == "system":
358
- # If multiple system prompts, concatenate them for preamble
359
  if system_prompt_for_cohere: system_prompt_for_cohere += "\n" + msg["content"]
360
  else: system_prompt_for_cohere = msg["content"]
361
  elif msg["role"] == "user" or msg["role"] == "assistant":
@@ -369,7 +300,6 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
369
  return
370
 
371
  current_message_for_cohere = temp_history[-1]["content"]
372
- # Map roles: 'user' -> 'user', 'assistant' -> 'chatbot'
373
  chat_history_for_cohere = [{"role": ("chatbot" if m["role"] == "assistant" else m["role"]), "message": m["content"]} for m in temp_history[:-1]]
374
 
375
  payload = {
@@ -377,7 +307,7 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
377
  "message": current_message_for_cohere,
378
  "stream": True,
379
  "temperature": 0.7,
380
- "max_tokens": 4096 # Add max_tokens
381
  }
382
  if chat_history_for_cohere:
383
  payload["chat_history"] = chat_history_for_cohere
@@ -389,54 +319,45 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
389
 
390
  byte_buffer = b""
391
  for chunk in response.iter_content(chunk_size=8192):
392
- # Check for potential HTTP errors during streaming
393
  if response.status_code != 200:
394
  error_body = response.text
395
  logger.error(f"HTTP Error during Cohere stream: {response.status_code}, Body: {error_body}")
396
  yield f"API HTTP Error ({response.status_code}) during Cohere stream: {error_body}"
397
- return # Stop streaming on error
398
 
399
  byte_buffer += chunk
400
- while b'\n\n' in byte_buffer: # Cohere uses \n\n as event separator
401
  event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1)
402
  lines = event_chunk.strip().split(b'\n')
403
  event_type = None
404
  event_data = None
405
 
406
  for l in lines:
407
- if l.strip() == b"": continue # Skip blank lines within an event
408
  if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
409
  elif l.startswith(b"data: "):
410
  try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
411
  except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}")
412
  else:
413
- # Log unexpected lines in event chunk
414
  logger.warning(f"Cohere: Unexpected line in event chunk: {l.decode('utf-8', errors='ignore').strip()}")
415
 
416
-
417
  if event_type == "text-generation" and event_data and "text" in event_data:
418
  yield event_data["text"]
419
  elif event_type == "stream-end":
420
  logger.debug("Cohere stream-end event received.")
421
- byte_buffer = b'' # Clear buffer after stream-end
422
- break # Exit the while loop
423
  elif event_type == "error":
424
  error_msg = event_data.get("message", str(event_data)) if event_data else "Unknown Cohere stream error"
425
  logger.error(f"Cohere stream error event: {error_msg}")
426
  yield f"API Error (Cohere stream): {error_msg}"
427
- return # Stop streaming on error
428
 
429
- # Process any remaining data in the buffer after the loop
430
  if byte_buffer:
431
  logger.warning(f"Remaining data in Cohere stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}")
432
 
433
-
434
  elif provider_lower == "huggingface":
435
- # Hugging Face Inference API often supports streaming for text-generation,
436
- # but chat completion streaming format varies greatly model by model, if supported.
437
- # Standard OpenAI-like streaming is not guaranteed.
438
- # Let's provide a more informative message.
439
- yield f"Error: Direct Hugging Face Inference API streaming for chat models is highly experimental and depends heavily on the specific model's implementation. Standard OpenAI-like streaming is NOT guaranteed. For better compatibility with HF models that support the OpenAI format, consider using the OpenRouter or TogetherAI providers and selecting the HF models listed there."
440
  return
441
 
442
  else:
@@ -456,4 +377,4 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
456
  yield f"API Request Error: Could not connect or receive response from {provider} ({e})"
457
  except Exception as e:
458
  logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:")
459
- yield f"An unexpected error occurred during streaming: {e}"
 
2
  import requests
3
  import json
4
  import logging
5
+ import time
6
 
7
  logging.basicConfig(
8
  level=logging.INFO,
 
11
  logger = logging.getLogger(__name__)
12
 
13
  API_KEYS = {
14
+ "HUGGINGFACE": 'HF_TOKEN',
15
  "GROQ": 'GROQ_API_KEY',
16
  "OPENROUTER": 'OPENROUTER_API_KEY',
17
  "TOGETHERAI": 'TOGETHERAI_API_KEY',
 
29
  "COHERE": 'https://api.cohere.ai/v1/chat',
30
  "XAI": 'https://api.x.ai/v1/chat/completions',
31
  "OPENAI": 'https://api.openai.com/v1/chat/completions',
32
+ "GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/',
33
  }
34
 
 
35
  try:
36
  with open("models.json", "r") as f:
37
  MODELS_BY_PROVIDER = json.load(f)
38
  logger.info("models.json loaded successfully.")
39
+ except (FileNotFoundError, json.JSONDecodeError) as e:
40
+ logger.error(f"Error loading models.json: {e}. Using hardcoded fallback models.")
 
41
  MODELS_BY_PROVIDER = {
42
  "groq": {
43
  "default": "llama3-8b-8192",
 
68
  "models": {
69
  "GPT-4o mini (OpenAI)": "gpt-4o-mini",
70
  "GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  }
72
  },
73
+ }
 
74
 
75
 
76
  def _get_api_key(provider: str, ui_api_key_override: str = None) -> str:
 
85
  logger.debug(f"Using env var {env_var_name} for {provider}")
86
  return env_key.strip()
87
 
 
88
  if provider.lower() == 'huggingface':
89
  hf_token = os.getenv("HF_TOKEN")
90
  if hf_token:
 
104
  models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
105
  default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default")
106
  if default_model_id:
 
107
  for display_name, model_id in models_dict.items():
108
  if model_id == default_model_id:
109
  return display_name
 
110
  if models_dict:
111
  return sorted(list(models_dict.keys()))[0]
112
  return None
 
136
  headers = {}
137
  payload = {}
138
  request_url = base_url
139
+ timeout_seconds = 180
140
 
141
  logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...")
142
 
 
147
  "model": model_id,
148
  "messages": messages,
149
  "stream": True,
150
+ "temperature": 0.7,
151
+ "max_tokens": 4096
152
  }
153
  if provider_lower == "openrouter":
154
+ headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/huggingface/ai-space-commander"
155
+ headers["X-Title"] = "Hugging Face Space Commander"
156
 
157
  response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds)
158
  response.raise_for_status()
159
 
160
  byte_buffer = b""
161
  for chunk in response.iter_content(chunk_size=8192):
 
162
  if response.status_code != 200:
 
163
  error_body = response.text
164
  logger.error(f"HTTP Error during stream: {response.status_code}, Body: {error_body}")
165
  yield f"API HTTP Error ({response.status_code}) during stream: {error_body}"
166
+ return
167
 
168
  byte_buffer += chunk
169
  while b'\n' in byte_buffer:
 
172
  if decoded_line.startswith('data: '):
173
  data = decoded_line[6:]
174
  if data == '[DONE]':
175
+ byte_buffer = b''
176
  break
177
  try:
178
  event_data = json.loads(data)
 
181
  if delta and delta.get("content"):
182
  yield delta["content"]
183
  except json.JSONDecodeError:
 
184
  logger.warning(f"Failed to decode JSON from stream line: {decoded_line.strip()}")
185
  except Exception as e:
186
  logger.error(f"Error processing stream data: {e}, Data: {decoded_line.strip()}")
 
187
  if byte_buffer:
188
  remaining_line = byte_buffer.decode('utf-8', errors='ignore').strip()
189
  if remaining_line.startswith('data: '):
 
206
  filtered_messages = []
207
  for msg in messages:
208
  if msg["role"] == "system":
209
+ if system_instruction: system_instruction += "\n" + msg["content"]
210
+ else: system_instruction = msg["content"]
 
211
  else:
 
212
  role = "model" if msg["role"] == "assistant" else msg["role"]
213
  filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]})
214
 
 
 
215
  for i in range(1, len(filtered_messages)):
216
  if filtered_messages[i]["role"] == filtered_messages[i-1]["role"]:
217
  yield f"Error: Google API requires alternating user/model roles in chat history. Please check prompt or history format."
218
+ return
219
 
220
  payload = {
221
  "contents": filtered_messages,
222
+ "safetySettings": [
223
  {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
224
  {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
225
  {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
 
227
  ],
228
  "generationConfig": {
229
  "temperature": 0.7,
230
+ "maxOutputTokens": 4096
231
  }
232
  }
 
233
  if system_instruction:
234
  payload["system_instruction"] = {"parts": [{"text": system_instruction}]}
235
 
 
236
  request_url = f"{base_url}{model_id}:streamGenerateContent"
 
237
  request_url = f"{request_url}?key={api_key}"
238
+ headers = {"Content-Type": "application/json"}
239
 
240
  response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds)
241
  response.raise_for_status()
242
 
243
  byte_buffer = b""
244
+ json_decoder = json.JSONDecoder()
245
  for chunk in response.iter_content(chunk_size=8192):
 
246
  if response.status_code != 200:
247
  error_body = response.text
248
  logger.error(f"HTTP Error during Google stream: {response.status_code}, Body: {error_body}")
249
  yield f"API HTTP Error ({response.status_code}) during Google stream: {error_body}"
250
+ return
251
 
252
  byte_buffer += chunk
253
+ decoded_buffer = byte_buffer.decode('utf-8', errors='ignore')
254
+ buffer_index = 0
255
+ while buffer_index < len(decoded_buffer):
 
 
256
  try:
257
+ obj, idx = json_decoder.raw_decode(decoded_buffer[buffer_index:].lstrip())
258
+ buffer_index += len(decoded_buffer[buffer_index:].lstrip()[:idx])
 
 
 
259
  if obj.get("candidates") and len(obj["candidates"]) > 0:
260
  candidate = obj["candidates"][0]
261
  if candidate.get("content") and candidate["content"].get("parts"):
262
  full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
263
  if full_text_chunk:
264
  yield full_text_chunk
 
265
  if obj.get("error"):
266
  error_details = obj["error"].get("message", str(obj["error"]))
267
  logger.error(f"Google API returned error in stream data: {error_details}")
268
  yield f"API Error (Google): {error_details}"
269
+ return
 
270
  except json.JSONDecodeError:
 
 
271
  break
272
  except Exception as e:
273
  logger.error(f"Error processing Google stream data object: {e}, Object: {obj}")
 
 
274
 
 
275
  if byte_buffer:
276
  logger.warning(f"Remaining data in Google stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}")
277
 
 
284
  system_prompt_for_cohere = None
285
  current_message_for_cohere = ""
286
 
 
 
287
  temp_history = []
288
  for msg in messages:
289
  if msg["role"] == "system":
 
290
  if system_prompt_for_cohere: system_prompt_for_cohere += "\n" + msg["content"]
291
  else: system_prompt_for_cohere = msg["content"]
292
  elif msg["role"] == "user" or msg["role"] == "assistant":
 
300
  return
301
 
302
  current_message_for_cohere = temp_history[-1]["content"]
 
303
  chat_history_for_cohere = [{"role": ("chatbot" if m["role"] == "assistant" else m["role"]), "message": m["content"]} for m in temp_history[:-1]]
304
 
305
  payload = {
 
307
  "message": current_message_for_cohere,
308
  "stream": True,
309
  "temperature": 0.7,
310
+ "max_tokens": 4096
311
  }
312
  if chat_history_for_cohere:
313
  payload["chat_history"] = chat_history_for_cohere
 
319
 
320
  byte_buffer = b""
321
  for chunk in response.iter_content(chunk_size=8192):
 
322
  if response.status_code != 200:
323
  error_body = response.text
324
  logger.error(f"HTTP Error during Cohere stream: {response.status_code}, Body: {error_body}")
325
  yield f"API HTTP Error ({response.status_code}) during Cohere stream: {error_body}"
326
+ return
327
 
328
  byte_buffer += chunk
329
+ while b'\n\n' in byte_buffer:
330
  event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1)
331
  lines = event_chunk.strip().split(b'\n')
332
  event_type = None
333
  event_data = None
334
 
335
  for l in lines:
336
+ if l.strip() == b"": continue
337
  if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
338
  elif l.startswith(b"data: "):
339
  try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
340
  except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}")
341
  else:
 
342
  logger.warning(f"Cohere: Unexpected line in event chunk: {l.decode('utf-8', errors='ignore').strip()}")
343
 
 
344
  if event_type == "text-generation" and event_data and "text" in event_data:
345
  yield event_data["text"]
346
  elif event_type == "stream-end":
347
  logger.debug("Cohere stream-end event received.")
348
+ byte_buffer = b''
349
+ break
350
  elif event_type == "error":
351
  error_msg = event_data.get("message", str(event_data)) if event_data else "Unknown Cohere stream error"
352
  logger.error(f"Cohere stream error event: {error_msg}")
353
  yield f"API Error (Cohere stream): {error_msg}"
354
+ return
355
 
 
356
  if byte_buffer:
357
  logger.warning(f"Remaining data in Cohere stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}")
358
 
 
359
  elif provider_lower == "huggingface":
360
+ yield f"Error: Direct Hugging Face Inference API streaming for chat models is highly experimental and depends heavily on the specific model's implementation. For better compatibility with HF models that support the OpenAI format, consider using the OpenRouter or TogetherAI providers and selecting the HF models listed there."
 
 
 
 
361
  return
362
 
363
  else:
 
377
  yield f"API Request Error: Could not connect or receive response from {provider} ({e})"
378
  except Exception as e:
379
  logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:")
380
+ yield f"An unexpected error occurred during streaming: {e}"