broadfield-dev commited on
Commit
544ffe5
·
verified ·
1 Parent(s): 031a90d

Update model_logic.py

Browse files
Files changed (1) hide show
  1. model_logic.py +220 -163
model_logic.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import requests
3
  import json
4
  import logging
 
5
 
6
  logging.basicConfig(
7
  level=logging.INFO,
@@ -10,7 +11,7 @@ logging.basicConfig(
10
  logger = logging.getLogger(__name__)
11
 
12
  API_KEYS = {
13
- "HUGGINGFACE": 'HF_TOKEN',
14
  "GROQ": 'GROQ_API_KEY',
15
  "OPENROUTER": 'OPENROUTER_API_KEY',
16
  "TOGETHERAI": 'TOGETHERAI_API_KEY',
@@ -28,96 +29,108 @@ API_URLS = {
28
  "COHERE": 'https://api.cohere.ai/v1/chat',
29
  "XAI": 'https://api.x.ai/v1/chat/completions',
30
  "OPENAI": 'https://api.openai.com/v1/chat/completions',
31
- "GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/',
32
  }
33
 
34
- MODELS_BY_PROVIDER = {
35
- "groq": {
36
- "default": "llama3-8b-8192",
37
- "models": {
38
- "Llama 3 8B (Groq)": "llama3-8b-8192",
39
- "Llama 3 70B (Groq)": "llama3-70b-8192",
40
- "Mixtral 8x7B (Groq)": "mixtral-8x7b-32768",
41
- "Gemma 7B (Groq)": "gemma-7b-it",
42
- }
43
- },
44
- "openrouter": {
45
- "default": "nousresearch/llama-3-8b-instruct",
46
- "models": {
47
- "Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct",
48
- "Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free",
49
- "Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free",
50
- "Mixtral 8x7B Instruct v0.1 (OpenRouter)": "mistralai/mixtral-8x7b-instruct",
51
- "Llama 2 70B Chat (OpenRouter)": "meta-llama/llama-2-70b-chat",
52
- "Neural Chat 7B v3.1 (OpenRouter)": "intel/neural-chat-7b-v3-1",
53
- "Goliath 120B (OpenRouter)": "twob/goliath-v2-120b",
54
- }
55
- },
56
- "togetherai": {
57
- "default": "meta-llama/Llama-3-8b-chat-hf",
58
- "models": {
59
- "Llama 3 8B Chat (TogetherAI)": "meta-llama/Llama-3-8b-chat-hf",
60
- "Llama 3 70B Chat (TogetherAI)": "meta-llama/Llama-3-70b-chat-hf",
61
- "Mixtral 8x7B Instruct (TogetherAI)": "mistralai/Mixtral-8x7B-Instruct-v0.1",
62
- "Gemma 7B Instruct (TogetherAI)": "google/gemma-7b-it",
63
- "RedPajama INCITE Chat 3B (TogetherAI)": "togethercomputer/RedPajama-INCITE-Chat-3B-v1",
64
- }
65
- },
66
- "google": {
67
- "default": "gemini-1.5-flash-latest",
68
- "models": {
69
- "Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest",
70
- "Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest",
71
- }
72
- },
73
- "cohere": {
74
- "default": "command-light",
75
- "models": {
76
- "Command R (Cohere)": "command-r",
77
- "Command R+ (Cohere)": "command-r-plus",
78
- "Command Light (Cohere)": "command-light",
79
- "Command (Cohere)": "command",
80
- }
81
- },
82
- "huggingface": {
83
- "default": "HuggingFaceH4/zephyr-7b-beta",
84
- "models": {
85
- "Zephyr 7B Beta (H4/HF Inf.)": "HuggingFaceH4/zephyr-7b-beta",
86
- "Mistral 7B Instruct v0.2 (HF Inf.)": "mistralai/Mistral-7B-Instruct-v0.2",
87
- "Llama 2 13B Chat (Meta/HF Inf.)": "meta-llama/Llama-2-13b-chat-hf",
88
- "OpenAssistant/oasst-sft-4-pythia-12b (HF Inf.)": "OpenAssistant/oasst-sft-4-pythia-12b",
89
- }
90
- },
91
- "openai": {
92
- "default": "gpt-3.5-turbo",
93
- "models": {
94
- "GPT-4o (OpenAI)": "gpt-4o",
95
- "GPT-4o mini (OpenAI)": "gpt-4o-mini",
96
- "GPT-4 Turbo (OpenAI)": "gpt-4-turbo",
97
- "GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
98
- }
99
- },
100
- "xai": {
101
- "default": "grok-1",
102
- "models": {
103
- "Grok-1 (xAI)": "grok-1",
104
- }
 
 
 
 
 
 
 
105
  }
106
- }
107
 
108
  def _get_api_key(provider: str, ui_api_key_override: str = None) -> str:
109
  if ui_api_key_override:
 
110
  return ui_api_key_override.strip()
111
 
112
  env_var_name = API_KEYS.get(provider.upper())
113
  if env_var_name:
114
  env_key = os.getenv(env_var_name)
115
  if env_key:
 
116
  return env_key.strip()
117
 
 
118
  if provider.lower() == 'huggingface':
119
  hf_token = os.getenv("HF_TOKEN")
120
- if hf_token: return hf_token.strip()
 
 
121
 
122
  logger.warning(f"API Key not found for provider '{provider}'. Checked UI override and environment variable '{env_var_name or 'N/A'}'.")
123
  return None
@@ -132,9 +145,11 @@ def get_default_model_for_provider(provider: str) -> str | None:
132
  models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
133
  default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default")
134
  if default_model_id:
 
135
  for display_name, model_id in models_dict.items():
136
  if model_id == default_model_id:
137
  return display_name
 
138
  if models_dict:
139
  return sorted(list(models_dict.keys()))[0]
140
  return None
@@ -164,6 +179,7 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
164
  headers = {}
165
  payload = {}
166
  request_url = base_url
 
167
 
168
  logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...")
169
 
@@ -173,17 +189,27 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
173
  payload = {
174
  "model": model_id,
175
  "messages": messages,
176
- "stream": True
 
 
177
  }
178
  if provider_lower == "openrouter":
179
- headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/your_username/ai-space-builder"
180
- headers["X-Title"] = "AI Space Builder"
181
 
182
- response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
183
  response.raise_for_status()
184
 
185
  byte_buffer = b""
186
  for chunk in response.iter_content(chunk_size=8192):
 
 
 
 
 
 
 
 
187
  byte_buffer += chunk
188
  while b'\n' in byte_buffer:
189
  line, byte_buffer = byte_buffer.split(b'\n', 1)
@@ -191,7 +217,7 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
191
  if decoded_line.startswith('data: '):
192
  data = decoded_line[6:]
193
  if data == '[DONE]':
194
- byte_buffer = b''
195
  break
196
  try:
197
  event_data = json.loads(data)
@@ -200,11 +226,13 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
200
  if delta and delta.get("content"):
201
  yield delta["content"]
202
  except json.JSONDecodeError:
203
- logger.warning(f"Failed to decode JSON from stream line: {decoded_line}")
 
204
  except Exception as e:
205
- logger.error(f"Error processing stream data: {e}, Data: {decoded_line}")
 
206
  if byte_buffer:
207
- remaining_line = byte_buffer.decode('utf-8', errors='ignore')
208
  if remaining_line.startswith('data: '):
209
  data = remaining_line[6:]
210
  if data != '[DONE]':
@@ -225,14 +253,24 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
225
  filtered_messages = []
226
  for msg in messages:
227
  if msg["role"] == "system":
 
 
228
  system_instruction = msg["content"]
229
  else:
 
230
  role = "model" if msg["role"] == "assistant" else msg["role"]
231
  filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]})
232
 
 
 
 
 
 
 
 
233
  payload = {
234
  "contents": filtered_messages,
235
- "safetySettings": [
236
  {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
237
  {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
238
  {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
@@ -240,69 +278,68 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
240
  ],
241
  "generationConfig": {
242
  "temperature": 0.7,
 
243
  }
244
  }
 
245
  if system_instruction:
246
  payload["system_instruction"] = {"parts": [{"text": system_instruction}]}
247
 
 
248
  request_url = f"{base_url}{model_id}:streamGenerateContent"
249
- headers = {"Content-Type": "application/json"}
250
  request_url = f"{request_url}?key={api_key}"
 
251
 
252
- response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
253
  response.raise_for_status()
254
 
255
  byte_buffer = b""
256
  for chunk in response.iter_content(chunk_size=8192):
257
- byte_buffer += chunk
258
- while b'\n' in byte_buffer:
259
- line, byte_buffer = byte_buffer.split(b'\n', 1)
260
- decoded_line = line.decode('utf-8', errors='ignore')
261
-
262
- if decoded_line.startswith('data: '):
263
- decoded_line = decoded_line[6:].strip()
264
-
265
- if not decoded_line: continue
266
-
267
- try:
268
- event_data_list = json.loads(f"[{decoded_line}]")
269
- if not isinstance(event_data_list, list): event_data_list = [event_data_list]
270
-
271
- for event_data in event_data_list:
272
- if not isinstance(event_data, dict): continue
273
-
274
- if event_data.get("candidates") and len(event_data["candidates"]) > 0:
275
- candidate = event_data["candidates"][0]
276
- if candidate.get("content") and candidate["content"].get("parts"):
277
- full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
278
- if full_text_chunk:
279
- yield full_text_chunk
280
-
281
- except json.JSONDecodeError:
282
- logger.warning(f"Failed to decode JSON from Google stream chunk: {decoded_line}. Accumulating buffer.")
283
- pass
284
-
285
- except Exception as e:
286
- logger.error(f"Error processing Google stream data: {e}, Data: {decoded_line}")
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  if byte_buffer:
289
- remaining_line = byte_buffer.decode('utf-8', errors='ignore').strip()
290
- if remaining_line:
291
- try:
292
- event_data_list = json.loads(f"[{remaining_line}]")
293
- if not isinstance(event_data_list, list): event_data_list = [event_data_list]
294
- for event_data in event_data_list:
295
- if not isinstance(event_data, dict): continue
296
- if event_data.get("candidates") and len(event_data["candidates"]) > 0:
297
- candidate = event_data["candidates"][0]
298
- if candidate.get("content") and candidate["content"].get("parts"):
299
- full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
300
- if full_text_chunk:
301
- yield full_text_chunk
302
- except json.JSONDecodeError:
303
- logger.warning(f"Failed to decode final Google stream buffer JSON: {remaining_line}")
304
- except Exception as e:
305
- logger.error(f"Error processing final Google stream buffer data: {e}, Data: {remaining_line}")
306
 
307
 
308
  elif provider_lower == "cohere":
@@ -313,76 +350,93 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
313
  system_prompt_for_cohere = None
314
  current_message_for_cohere = ""
315
 
 
 
316
  temp_history = []
317
  for msg in messages:
318
  if msg["role"] == "system":
319
- system_prompt_for_cohere = msg["content"]
 
 
320
  elif msg["role"] == "user" or msg["role"] == "assistant":
321
  temp_history.append(msg)
322
 
323
- if temp_history:
324
- current_message_for_cohere = temp_history[-1]["content"]
325
- chat_history_for_cohere = [{"role": ("chatbot" if m["role"] == "assistant" else m["role"]), "message": m["content"]} for m in temp_history[:-1]]
326
-
327
- if not current_message_for_cohere:
328
- yield "Error: User message not found for Cohere API call."
329
  return
330
 
 
 
 
 
331
  payload = {
332
  "model": model_id,
333
  "message": current_message_for_cohere,
334
  "stream": True,
335
- "temperature": 0.7
 
336
  }
337
  if chat_history_for_cohere:
338
  payload["chat_history"] = chat_history_for_cohere
339
  if system_prompt_for_cohere:
340
  payload["preamble"] = system_prompt_for_cohere
341
 
342
- response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
343
  response.raise_for_status()
344
 
345
  byte_buffer = b""
346
  for chunk in response.iter_content(chunk_size=8192):
 
 
 
 
 
 
 
347
  byte_buffer += chunk
348
- while b'\n\n' in byte_buffer:
349
  event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1)
350
  lines = event_chunk.strip().split(b'\n')
351
  event_type = None
352
  event_data = None
353
 
354
  for l in lines:
 
355
  if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
356
  elif l.startswith(b"data: "):
357
  try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
358
  except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}")
 
 
 
 
359
 
360
  if event_type == "text-generation" and event_data and "text" in event_data:
361
  yield event_data["text"]
362
  elif event_type == "stream-end":
363
- byte_buffer = b''
364
- break
365
-
 
 
 
 
 
 
 
366
  if byte_buffer:
367
- event_chunk = byte_buffer.strip()
368
- if event_chunk:
369
- lines = event_chunk.split(b'\n')
370
- event_type = None
371
- event_data = None
372
- for l in lines:
373
- if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
374
- elif l.startswith(b"data: "):
375
- try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
376
- except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode final event data JSON: {l[6:].strip()}")
377
-
378
- if event_type == "text-generation" and event_data and "text" in event_data:
379
- yield event_data["text"]
380
- elif event_type == "stream-end":
381
- pass
382
 
383
 
384
  elif provider_lower == "huggingface":
385
- yield f"Error: Direct Hugging Face Inference API streaming for chat models is experimental and model-dependent. Consider using OpenRouter or TogetherAI for HF models with standardized streaming."
 
 
 
 
386
  return
387
 
388
  else:
@@ -394,9 +448,12 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
394
  error_text = e.response.text if e.response is not None else 'No response text'
395
  logger.error(f"HTTP error during streaming for {provider}/{model_id}: {e}")
396
  yield f"API HTTP Error ({status_code}): {error_text}\nDetails: {e}"
 
 
 
397
  except requests.exceptions.RequestException as e:
398
  logger.error(f"Request error during streaming for {provider}/{model_id}: {e}")
399
  yield f"API Request Error: Could not connect or receive response from {provider} ({e})"
400
  except Exception as e:
401
  logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:")
402
- yield f"An unexpected error occurred: {e}"
 
2
  import requests
3
  import json
4
  import logging
5
+ import time # Import time for retries
6
 
7
  logging.basicConfig(
8
  level=logging.INFO,
 
11
  logger = logging.getLogger(__name__)
12
 
13
  API_KEYS = {
14
+ "HUGGINGFACE": 'HF_TOKEN', # Note: HF_TOKEN is also for HF Hub, so maybe rename this in UI label?
15
  "GROQ": 'GROQ_API_KEY',
16
  "OPENROUTER": 'OPENROUTER_API_KEY',
17
  "TOGETHERAI": 'TOGETHERAI_API_KEY',
 
29
  "COHERE": 'https://api.cohere.ai/v1/chat',
30
  "XAI": 'https://api.x.ai/v1/chat/completions',
31
  "OPENAI": 'https://api.openai.com/v1/chat/completions',
32
+ "GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/', # Base URL, model ID added later
33
  }
34
 
35
+ # Load model configuration from JSON
36
+ try:
37
+ with open("models.json", "r") as f:
38
+ MODELS_BY_PROVIDER = json.load(f)
39
+ logger.info("models.json loaded successfully.")
40
+ except FileNotFoundError:
41
+ logger.error("models.json not found. Using hardcoded fallback models.")
42
+ # Keep the hardcoded fallback as a safety measure
43
+ MODELS_BY_PROVIDER = {
44
+ "groq": {
45
+ "default": "llama3-8b-8192",
46
+ "models": {
47
+ "Llama 3 8B (Groq)": "llama3-8b-8192",
48
+ "Llama 3 70B (Groq)": "llama3-70b-8192",
49
+ "Mixtral 8x7B (Groq)": "mixtral-8x7b-32768",
50
+ "Gemma 7B (Groq)": "gemma-7b-it",
51
+ }
52
+ },
53
+ "openrouter": {
54
+ "default": "nousresearch/llama-3-8b-instruct",
55
+ "models": {
56
+ "Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct",
57
+ "Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free",
58
+ "Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free",
59
+ }
60
+ },
61
+ "google": {
62
+ "default": "gemini-1.5-flash-latest",
63
+ "models": {
64
+ "Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest",
65
+ "Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest",
66
+ }
67
+ },
68
+ "openai": {
69
+ "default": "gpt-3.5-turbo",
70
+ "models": {
71
+ "GPT-4o mini (OpenAI)": "gpt-4o-mini",
72
+ "GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
73
+ }
74
+ },
75
+ # Add other providers here if needed for fallback
76
+ }
77
+ except json.JSONDecodeError:
78
+ logger.error("Error decoding models.json. Using hardcoded fallback models.")
79
+ # Keep the hardcoded fallback as a safety measure
80
+ MODELS_BY_PROVIDER = {
81
+ "groq": {
82
+ "default": "llama3-8b-8192",
83
+ "models": {
84
+ "Llama 3 8B (Groq)": "llama3-8b-8192",
85
+ "Llama 3 70B (Groq)": "llama3-70b-8192",
86
+ "Mixtral 8x7B (Groq)": "mixtral-8x7b-32768",
87
+ "Gemma 7B (Groq)": "gemma-7b-it",
88
+ }
89
+ },
90
+ "openrouter": {
91
+ "default": "nousresearch/llama-3-8b-instruct",
92
+ "models": {
93
+ "Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct",
94
+ "Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free",
95
+ "Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free",
96
+ }
97
+ },
98
+ "google": {
99
+ "default": "gemini-1.5-flash-latest",
100
+ "models": {
101
+ "Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest",
102
+ "Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest",
103
+ }
104
+ },
105
+ "openai": {
106
+ "default": "gpt-3.5-turbo",
107
+ "models": {
108
+ "GPT-4o mini (OpenAI)": "gpt-4o-mini",
109
+ "GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
110
+ }
111
+ },
112
+ # Add other providers here if needed for fallback
113
  }
114
+
115
 
116
  def _get_api_key(provider: str, ui_api_key_override: str = None) -> str:
117
  if ui_api_key_override:
118
+ logger.debug(f"Using UI API key override for {provider}")
119
  return ui_api_key_override.strip()
120
 
121
  env_var_name = API_KEYS.get(provider.upper())
122
  if env_var_name:
123
  env_key = os.getenv(env_var_name)
124
  if env_key:
125
+ logger.debug(f"Using env var {env_var_name} for {provider}")
126
  return env_key.strip()
127
 
128
+ # Special case for Hugging Face, HF_TOKEN is common
129
  if provider.lower() == 'huggingface':
130
  hf_token = os.getenv("HF_TOKEN")
131
+ if hf_token:
132
+ logger.debug(f"Using HF_TOKEN env var for {provider}")
133
+ return hf_token.strip()
134
 
135
  logger.warning(f"API Key not found for provider '{provider}'. Checked UI override and environment variable '{env_var_name or 'N/A'}'.")
136
  return None
 
145
  models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
146
  default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default")
147
  if default_model_id:
148
+ # Find the display name corresponding to the default model ID
149
  for display_name, model_id in models_dict.items():
150
  if model_id == default_model_id:
151
  return display_name
152
+ # Fallback: If no default specified or found, return the first model in the sorted list
153
  if models_dict:
154
  return sorted(list(models_dict.keys()))[0]
155
  return None
 
179
  headers = {}
180
  payload = {}
181
  request_url = base_url
182
+ timeout_seconds = 180 # Increased timeout
183
 
184
  logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...")
185
 
 
189
  payload = {
190
  "model": model_id,
191
  "messages": messages,
192
+ "stream": True,
193
+ "temperature": 0.7, # Add temperature
194
+ "max_tokens": 4096 # Add max_tokens
195
  }
196
  if provider_lower == "openrouter":
197
+ headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/your_username/ai-space-commander" # Use space name
198
+ headers["X-Title"] = "Hugging Face Space Commander" # Use project title
199
 
200
+ response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds)
201
  response.raise_for_status()
202
 
203
  byte_buffer = b""
204
  for chunk in response.iter_content(chunk_size=8192):
205
+ # Check for potential HTTP errors during streaming
206
+ if response.status_code != 200:
207
+ # Attempt to read error body if available
208
+ error_body = response.text
209
+ logger.error(f"HTTP Error during stream: {response.status_code}, Body: {error_body}")
210
+ yield f"API HTTP Error ({response.status_code}) during stream: {error_body}"
211
+ return # Stop streaming on error
212
+
213
  byte_buffer += chunk
214
  while b'\n' in byte_buffer:
215
  line, byte_buffer = byte_buffer.split(b'\n', 1)
 
217
  if decoded_line.startswith('data: '):
218
  data = decoded_line[6:]
219
  if data == '[DONE]':
220
+ byte_buffer = b'' # Clear buffer after DONE
221
  break
222
  try:
223
  event_data = json.loads(data)
 
226
  if delta and delta.get("content"):
227
  yield delta["content"]
228
  except json.JSONDecodeError:
229
+ # Log warning but continue, partial data might be okay or next line fixes it
230
+ logger.warning(f"Failed to decode JSON from stream line: {decoded_line.strip()}")
231
  except Exception as e:
232
+ logger.error(f"Error processing stream data: {e}, Data: {decoded_line.strip()}")
233
+ # Process any remaining data in the buffer after the loop
234
  if byte_buffer:
235
+ remaining_line = byte_buffer.decode('utf-8', errors='ignore').strip()
236
  if remaining_line.startswith('data: '):
237
  data = remaining_line[6:]
238
  if data != '[DONE]':
 
253
  filtered_messages = []
254
  for msg in messages:
255
  if msg["role"] == "system":
256
+ # Google's API takes system instruction separately or expects a specific history format
257
+ # Let's extract the system instruction
258
  system_instruction = msg["content"]
259
  else:
260
+ # Map roles: 'user' -> 'user', 'assistant' -> 'model'
261
  role = "model" if msg["role"] == "assistant" else msg["role"]
262
  filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]})
263
 
264
+ # Ensure conversation history alternates roles correctly for Google
265
+ # Simple check: if last two roles are same, it's invalid.
266
+ for i in range(1, len(filtered_messages)):
267
+ if filtered_messages[i]["role"] == filtered_messages[i-1]["role"]:
268
+ yield f"Error: Google API requires alternating user/model roles in chat history. Please check prompt or history format."
269
+ return # Stop if history format is invalid
270
+
271
  payload = {
272
  "contents": filtered_messages,
273
+ "safetySettings": [ # Default safety settings to allow helpful but potentially sensitive code/instructions
274
  {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
275
  {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
276
  {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
 
278
  ],
279
  "generationConfig": {
280
  "temperature": 0.7,
281
+ "maxOutputTokens": 4096 # Google's max_tokens equivalent
282
  }
283
  }
284
+ # System instruction is passed separately
285
  if system_instruction:
286
  payload["system_instruction"] = {"parts": [{"text": system_instruction}]}
287
 
288
+
289
  request_url = f"{base_url}{model_id}:streamGenerateContent"
290
+ # API key is passed as a query parameter for Google
291
  request_url = f"{request_url}?key={api_key}"
292
+ headers = {"Content-Type": "application/json"} # Content-Type is still application/json
293
 
294
+ response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds)
295
  response.raise_for_status()
296
 
297
  byte_buffer = b""
298
  for chunk in response.iter_content(chunk_size=8192):
299
+ # Check for potential HTTP errors during streaming
300
+ if response.status_code != 200:
301
+ error_body = response.text
302
+ logger.error(f"HTTP Error during Google stream: {response.status_code}, Body: {error_body}")
303
+ yield f"API HTTP Error ({response.status_code}) during Google stream: {error_body}"
304
+ return # Stop streaming on error
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
+ byte_buffer += chunk
307
+ # Google's streaming can send multiple JSON objects in one chunk, sometimes split by newlines
308
+ # Or just single JSON objects. They don't strictly follow the Server-Sent Events 'data:' format.
309
+ # We need to find JSON objects in the buffer.
310
+ json_decoder = json.JSONDecoder()
311
+ while byte_buffer:
312
+ try:
313
+ # Attempt to decode a JSON object from the start of the buffer
314
+ obj, idx = json_decoder.raw_decode(byte_buffer.decode('utf-8', errors='ignore').lstrip()) # lstrip to handle leading whitespace/newlines
315
+ # If successful, process the object
316
+ byte_buffer = byte_buffer[len(byte_buffer.decode('utf-8', errors='ignore').lstrip()[:idx]).encode('utf-8'):] # Remove the decoded part from the buffer
317
+
318
+ if obj.get("candidates") and len(obj["candidates"]) > 0:
319
+ candidate = obj["candidates"][0]
320
+ if candidate.get("content") and candidate["content"].get("parts"):
321
+ full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
322
+ if full_text_chunk:
323
+ yield full_text_chunk
324
+ # Check for potential errors in the response object itself
325
+ if obj.get("error"):
326
+ error_details = obj["error"].get("message", str(obj["error"]))
327
+ logger.error(f"Google API returned error in stream data: {error_details}")
328
+ yield f"API Error (Google): {error_details}"
329
+ return # Stop streaming
330
+
331
+ except json.JSONDecodeError:
332
+ # If raw_decode fails, it means the buffer doesn't contain a complete JSON object at the start.
333
+ # Break the inner while loop and wait for more data.
334
+ break
335
+ except Exception as e:
336
+ logger.error(f"Error processing Google stream data object: {e}, Object: {obj}")
337
+ # Decide if this is a fatal error or just a bad chunk
338
+ # For now, log and continue might be okay for processing subsequent chunks.
339
+
340
+ # If loop finishes and buffer still has data, log it (incomplete data)
341
  if byte_buffer:
342
+ logger.warning(f"Remaining data in Google stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
 
345
  elif provider_lower == "cohere":
 
350
  system_prompt_for_cohere = None
351
  current_message_for_cohere = ""
352
 
353
+ # Cohere requires a specific history format and separates system/preamble
354
+ # The last message is the "message", previous are "chat_history"
355
  temp_history = []
356
  for msg in messages:
357
  if msg["role"] == "system":
358
+ # If multiple system prompts, concatenate them for preamble
359
+ if system_prompt_for_cohere: system_prompt_for_cohere += "\n" + msg["content"]
360
+ else: system_prompt_for_cohere = msg["content"]
361
  elif msg["role"] == "user" or msg["role"] == "assistant":
362
  temp_history.append(msg)
363
 
364
+ if not temp_history:
365
+ yield "Error: No user message found for Cohere API call."
366
+ return
367
+ if temp_history[-1]["role"] != "user":
368
+ yield "Error: Last message must be from user for Cohere API call."
 
369
  return
370
 
371
+ current_message_for_cohere = temp_history[-1]["content"]
372
+ # Map roles: 'user' -> 'user', 'assistant' -> 'chatbot'
373
+ chat_history_for_cohere = [{"role": ("chatbot" if m["role"] == "assistant" else m["role"]), "message": m["content"]} for m in temp_history[:-1]]
374
+
375
  payload = {
376
  "model": model_id,
377
  "message": current_message_for_cohere,
378
  "stream": True,
379
+ "temperature": 0.7,
380
+ "max_tokens": 4096 # Add max_tokens
381
  }
382
  if chat_history_for_cohere:
383
  payload["chat_history"] = chat_history_for_cohere
384
  if system_prompt_for_cohere:
385
  payload["preamble"] = system_prompt_for_cohere
386
 
387
+ response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds)
388
  response.raise_for_status()
389
 
390
  byte_buffer = b""
391
  for chunk in response.iter_content(chunk_size=8192):
392
+ # Check for potential HTTP errors during streaming
393
+ if response.status_code != 200:
394
+ error_body = response.text
395
+ logger.error(f"HTTP Error during Cohere stream: {response.status_code}, Body: {error_body}")
396
+ yield f"API HTTP Error ({response.status_code}) during Cohere stream: {error_body}"
397
+ return # Stop streaming on error
398
+
399
  byte_buffer += chunk
400
+ while b'\n\n' in byte_buffer: # Cohere uses \n\n as event separator
401
  event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1)
402
  lines = event_chunk.strip().split(b'\n')
403
  event_type = None
404
  event_data = None
405
 
406
  for l in lines:
407
+ if l.strip() == b"": continue # Skip blank lines within an event
408
  if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
409
  elif l.startswith(b"data: "):
410
  try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
411
  except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}")
412
+ else:
413
+ # Log unexpected lines in event chunk
414
+ logger.warning(f"Cohere: Unexpected line in event chunk: {l.decode('utf-8', errors='ignore').strip()}")
415
+
416
 
417
  if event_type == "text-generation" and event_data and "text" in event_data:
418
  yield event_data["text"]
419
  elif event_type == "stream-end":
420
+ logger.debug("Cohere stream-end event received.")
421
+ byte_buffer = b'' # Clear buffer after stream-end
422
+ break # Exit the while loop
423
+ elif event_type == "error":
424
+ error_msg = event_data.get("message", str(event_data)) if event_data else "Unknown Cohere stream error"
425
+ logger.error(f"Cohere stream error event: {error_msg}")
426
+ yield f"API Error (Cohere stream): {error_msg}"
427
+ return # Stop streaming on error
428
+
429
+ # Process any remaining data in the buffer after the loop
430
  if byte_buffer:
431
+ logger.warning(f"Remaining data in Cohere stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
 
434
  elif provider_lower == "huggingface":
435
+ # Hugging Face Inference API often supports streaming for text-generation,
436
+ # but chat completion streaming format varies greatly model by model, if supported.
437
+ # Standard OpenAI-like streaming is not guaranteed.
438
+ # Let's provide a more informative message.
439
+ yield f"Error: Direct Hugging Face Inference API streaming for chat models is highly experimental and depends heavily on the specific model's implementation. Standard OpenAI-like streaming is NOT guaranteed. For better compatibility with HF models that support the OpenAI format, consider using the OpenRouter or TogetherAI providers and selecting the HF models listed there."
440
  return
441
 
442
  else:
 
448
  error_text = e.response.text if e.response is not None else 'No response text'
449
  logger.error(f"HTTP error during streaming for {provider}/{model_id}: {e}")
450
  yield f"API HTTP Error ({status_code}): {error_text}\nDetails: {e}"
451
+ except requests.exceptions.Timeout:
452
+ logger.error(f"Request Timeout after {timeout_seconds} seconds for {provider}/{model_id}.")
453
+ yield f"API Request Timeout: The request took too long to complete ({timeout_seconds} seconds)."
454
  except requests.exceptions.RequestException as e:
455
  logger.error(f"Request error during streaming for {provider}/{model_id}: {e}")
456
  yield f"API Request Error: Could not connect or receive response from {provider} ({e})"
457
  except Exception as e:
458
  logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:")
459
+ yield f"An unexpected error occurred during streaming: {e}"