Update model_logic.py
Browse files- model_logic.py +220 -163
model_logic.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import requests
|
3 |
import json
|
4 |
import logging
|
|
|
5 |
|
6 |
logging.basicConfig(
|
7 |
level=logging.INFO,
|
@@ -10,7 +11,7 @@ logging.basicConfig(
|
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
12 |
API_KEYS = {
|
13 |
-
"HUGGINGFACE": 'HF_TOKEN',
|
14 |
"GROQ": 'GROQ_API_KEY',
|
15 |
"OPENROUTER": 'OPENROUTER_API_KEY',
|
16 |
"TOGETHERAI": 'TOGETHERAI_API_KEY',
|
@@ -28,96 +29,108 @@ API_URLS = {
|
|
28 |
"COHERE": 'https://api.cohere.ai/v1/chat',
|
29 |
"XAI": 'https://api.x.ai/v1/chat/completions',
|
30 |
"OPENAI": 'https://api.openai.com/v1/chat/completions',
|
31 |
-
"GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/',
|
32 |
}
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
"
|
62 |
-
"
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
"
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
}
|
106 |
-
|
107 |
|
108 |
def _get_api_key(provider: str, ui_api_key_override: str = None) -> str:
|
109 |
if ui_api_key_override:
|
|
|
110 |
return ui_api_key_override.strip()
|
111 |
|
112 |
env_var_name = API_KEYS.get(provider.upper())
|
113 |
if env_var_name:
|
114 |
env_key = os.getenv(env_var_name)
|
115 |
if env_key:
|
|
|
116 |
return env_key.strip()
|
117 |
|
|
|
118 |
if provider.lower() == 'huggingface':
|
119 |
hf_token = os.getenv("HF_TOKEN")
|
120 |
-
if hf_token:
|
|
|
|
|
121 |
|
122 |
logger.warning(f"API Key not found for provider '{provider}'. Checked UI override and environment variable '{env_var_name or 'N/A'}'.")
|
123 |
return None
|
@@ -132,9 +145,11 @@ def get_default_model_for_provider(provider: str) -> str | None:
|
|
132 |
models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
|
133 |
default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default")
|
134 |
if default_model_id:
|
|
|
135 |
for display_name, model_id in models_dict.items():
|
136 |
if model_id == default_model_id:
|
137 |
return display_name
|
|
|
138 |
if models_dict:
|
139 |
return sorted(list(models_dict.keys()))[0]
|
140 |
return None
|
@@ -164,6 +179,7 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
164 |
headers = {}
|
165 |
payload = {}
|
166 |
request_url = base_url
|
|
|
167 |
|
168 |
logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...")
|
169 |
|
@@ -173,17 +189,27 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
173 |
payload = {
|
174 |
"model": model_id,
|
175 |
"messages": messages,
|
176 |
-
"stream": True
|
|
|
|
|
177 |
}
|
178 |
if provider_lower == "openrouter":
|
179 |
-
headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/your_username/ai-space-
|
180 |
-
headers["X-Title"] = "
|
181 |
|
182 |
-
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=
|
183 |
response.raise_for_status()
|
184 |
|
185 |
byte_buffer = b""
|
186 |
for chunk in response.iter_content(chunk_size=8192):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
byte_buffer += chunk
|
188 |
while b'\n' in byte_buffer:
|
189 |
line, byte_buffer = byte_buffer.split(b'\n', 1)
|
@@ -191,7 +217,7 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
191 |
if decoded_line.startswith('data: '):
|
192 |
data = decoded_line[6:]
|
193 |
if data == '[DONE]':
|
194 |
-
byte_buffer = b''
|
195 |
break
|
196 |
try:
|
197 |
event_data = json.loads(data)
|
@@ -200,11 +226,13 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
200 |
if delta and delta.get("content"):
|
201 |
yield delta["content"]
|
202 |
except json.JSONDecodeError:
|
203 |
-
|
|
|
204 |
except Exception as e:
|
205 |
-
logger.error(f"Error processing stream data: {e}, Data: {decoded_line}")
|
|
|
206 |
if byte_buffer:
|
207 |
-
remaining_line = byte_buffer.decode('utf-8', errors='ignore')
|
208 |
if remaining_line.startswith('data: '):
|
209 |
data = remaining_line[6:]
|
210 |
if data != '[DONE]':
|
@@ -225,14 +253,24 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
225 |
filtered_messages = []
|
226 |
for msg in messages:
|
227 |
if msg["role"] == "system":
|
|
|
|
|
228 |
system_instruction = msg["content"]
|
229 |
else:
|
|
|
230 |
role = "model" if msg["role"] == "assistant" else msg["role"]
|
231 |
filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]})
|
232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
payload = {
|
234 |
"contents": filtered_messages,
|
235 |
-
"safetySettings": [
|
236 |
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
237 |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
238 |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
@@ -240,69 +278,68 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
240 |
],
|
241 |
"generationConfig": {
|
242 |
"temperature": 0.7,
|
|
|
243 |
}
|
244 |
}
|
|
|
245 |
if system_instruction:
|
246 |
payload["system_instruction"] = {"parts": [{"text": system_instruction}]}
|
247 |
|
|
|
248 |
request_url = f"{base_url}{model_id}:streamGenerateContent"
|
249 |
-
|
250 |
request_url = f"{request_url}?key={api_key}"
|
|
|
251 |
|
252 |
-
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=
|
253 |
response.raise_for_status()
|
254 |
|
255 |
byte_buffer = b""
|
256 |
for chunk in response.iter_content(chunk_size=8192):
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
decoded_line = decoded_line[6:].strip()
|
264 |
-
|
265 |
-
if not decoded_line: continue
|
266 |
-
|
267 |
-
try:
|
268 |
-
event_data_list = json.loads(f"[{decoded_line}]")
|
269 |
-
if not isinstance(event_data_list, list): event_data_list = [event_data_list]
|
270 |
-
|
271 |
-
for event_data in event_data_list:
|
272 |
-
if not isinstance(event_data, dict): continue
|
273 |
-
|
274 |
-
if event_data.get("candidates") and len(event_data["candidates"]) > 0:
|
275 |
-
candidate = event_data["candidates"][0]
|
276 |
-
if candidate.get("content") and candidate["content"].get("parts"):
|
277 |
-
full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
|
278 |
-
if full_text_chunk:
|
279 |
-
yield full_text_chunk
|
280 |
-
|
281 |
-
except json.JSONDecodeError:
|
282 |
-
logger.warning(f"Failed to decode JSON from Google stream chunk: {decoded_line}. Accumulating buffer.")
|
283 |
-
pass
|
284 |
-
|
285 |
-
except Exception as e:
|
286 |
-
logger.error(f"Error processing Google stream data: {e}, Data: {decoded_line}")
|
287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
if byte_buffer:
|
289 |
-
|
290 |
-
if remaining_line:
|
291 |
-
try:
|
292 |
-
event_data_list = json.loads(f"[{remaining_line}]")
|
293 |
-
if not isinstance(event_data_list, list): event_data_list = [event_data_list]
|
294 |
-
for event_data in event_data_list:
|
295 |
-
if not isinstance(event_data, dict): continue
|
296 |
-
if event_data.get("candidates") and len(event_data["candidates"]) > 0:
|
297 |
-
candidate = event_data["candidates"][0]
|
298 |
-
if candidate.get("content") and candidate["content"].get("parts"):
|
299 |
-
full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
|
300 |
-
if full_text_chunk:
|
301 |
-
yield full_text_chunk
|
302 |
-
except json.JSONDecodeError:
|
303 |
-
logger.warning(f"Failed to decode final Google stream buffer JSON: {remaining_line}")
|
304 |
-
except Exception as e:
|
305 |
-
logger.error(f"Error processing final Google stream buffer data: {e}, Data: {remaining_line}")
|
306 |
|
307 |
|
308 |
elif provider_lower == "cohere":
|
@@ -313,76 +350,93 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
313 |
system_prompt_for_cohere = None
|
314 |
current_message_for_cohere = ""
|
315 |
|
|
|
|
|
316 |
temp_history = []
|
317 |
for msg in messages:
|
318 |
if msg["role"] == "system":
|
319 |
-
|
|
|
|
|
320 |
elif msg["role"] == "user" or msg["role"] == "assistant":
|
321 |
temp_history.append(msg)
|
322 |
|
323 |
-
if temp_history:
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
yield "Error: User message not found for Cohere API call."
|
329 |
return
|
330 |
|
|
|
|
|
|
|
|
|
331 |
payload = {
|
332 |
"model": model_id,
|
333 |
"message": current_message_for_cohere,
|
334 |
"stream": True,
|
335 |
-
"temperature": 0.7
|
|
|
336 |
}
|
337 |
if chat_history_for_cohere:
|
338 |
payload["chat_history"] = chat_history_for_cohere
|
339 |
if system_prompt_for_cohere:
|
340 |
payload["preamble"] = system_prompt_for_cohere
|
341 |
|
342 |
-
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=
|
343 |
response.raise_for_status()
|
344 |
|
345 |
byte_buffer = b""
|
346 |
for chunk in response.iter_content(chunk_size=8192):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
byte_buffer += chunk
|
348 |
-
while b'\n\n' in byte_buffer:
|
349 |
event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1)
|
350 |
lines = event_chunk.strip().split(b'\n')
|
351 |
event_type = None
|
352 |
event_data = None
|
353 |
|
354 |
for l in lines:
|
|
|
355 |
if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
|
356 |
elif l.startswith(b"data: "):
|
357 |
try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
|
358 |
except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}")
|
|
|
|
|
|
|
|
|
359 |
|
360 |
if event_type == "text-generation" and event_data and "text" in event_data:
|
361 |
yield event_data["text"]
|
362 |
elif event_type == "stream-end":
|
363 |
-
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
if byte_buffer:
|
367 |
-
|
368 |
-
if event_chunk:
|
369 |
-
lines = event_chunk.split(b'\n')
|
370 |
-
event_type = None
|
371 |
-
event_data = None
|
372 |
-
for l in lines:
|
373 |
-
if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
|
374 |
-
elif l.startswith(b"data: "):
|
375 |
-
try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
|
376 |
-
except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode final event data JSON: {l[6:].strip()}")
|
377 |
-
|
378 |
-
if event_type == "text-generation" and event_data and "text" in event_data:
|
379 |
-
yield event_data["text"]
|
380 |
-
elif event_type == "stream-end":
|
381 |
-
pass
|
382 |
|
383 |
|
384 |
elif provider_lower == "huggingface":
|
385 |
-
|
|
|
|
|
|
|
|
|
386 |
return
|
387 |
|
388 |
else:
|
@@ -394,9 +448,12 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
394 |
error_text = e.response.text if e.response is not None else 'No response text'
|
395 |
logger.error(f"HTTP error during streaming for {provider}/{model_id}: {e}")
|
396 |
yield f"API HTTP Error ({status_code}): {error_text}\nDetails: {e}"
|
|
|
|
|
|
|
397 |
except requests.exceptions.RequestException as e:
|
398 |
logger.error(f"Request error during streaming for {provider}/{model_id}: {e}")
|
399 |
yield f"API Request Error: Could not connect or receive response from {provider} ({e})"
|
400 |
except Exception as e:
|
401 |
logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:")
|
402 |
-
yield f"An unexpected error occurred: {e}"
|
|
|
2 |
import requests
|
3 |
import json
|
4 |
import logging
|
5 |
+
import time # Import time for retries
|
6 |
|
7 |
logging.basicConfig(
|
8 |
level=logging.INFO,
|
|
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
API_KEYS = {
|
14 |
+
"HUGGINGFACE": 'HF_TOKEN', # Note: HF_TOKEN is also for HF Hub, so maybe rename this in UI label?
|
15 |
"GROQ": 'GROQ_API_KEY',
|
16 |
"OPENROUTER": 'OPENROUTER_API_KEY',
|
17 |
"TOGETHERAI": 'TOGETHERAI_API_KEY',
|
|
|
29 |
"COHERE": 'https://api.cohere.ai/v1/chat',
|
30 |
"XAI": 'https://api.x.ai/v1/chat/completions',
|
31 |
"OPENAI": 'https://api.openai.com/v1/chat/completions',
|
32 |
+
"GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/', # Base URL, model ID added later
|
33 |
}
|
34 |
|
35 |
+
# Load model configuration from JSON
|
36 |
+
try:
|
37 |
+
with open("models.json", "r") as f:
|
38 |
+
MODELS_BY_PROVIDER = json.load(f)
|
39 |
+
logger.info("models.json loaded successfully.")
|
40 |
+
except FileNotFoundError:
|
41 |
+
logger.error("models.json not found. Using hardcoded fallback models.")
|
42 |
+
# Keep the hardcoded fallback as a safety measure
|
43 |
+
MODELS_BY_PROVIDER = {
|
44 |
+
"groq": {
|
45 |
+
"default": "llama3-8b-8192",
|
46 |
+
"models": {
|
47 |
+
"Llama 3 8B (Groq)": "llama3-8b-8192",
|
48 |
+
"Llama 3 70B (Groq)": "llama3-70b-8192",
|
49 |
+
"Mixtral 8x7B (Groq)": "mixtral-8x7b-32768",
|
50 |
+
"Gemma 7B (Groq)": "gemma-7b-it",
|
51 |
+
}
|
52 |
+
},
|
53 |
+
"openrouter": {
|
54 |
+
"default": "nousresearch/llama-3-8b-instruct",
|
55 |
+
"models": {
|
56 |
+
"Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct",
|
57 |
+
"Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free",
|
58 |
+
"Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free",
|
59 |
+
}
|
60 |
+
},
|
61 |
+
"google": {
|
62 |
+
"default": "gemini-1.5-flash-latest",
|
63 |
+
"models": {
|
64 |
+
"Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest",
|
65 |
+
"Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest",
|
66 |
+
}
|
67 |
+
},
|
68 |
+
"openai": {
|
69 |
+
"default": "gpt-3.5-turbo",
|
70 |
+
"models": {
|
71 |
+
"GPT-4o mini (OpenAI)": "gpt-4o-mini",
|
72 |
+
"GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
|
73 |
+
}
|
74 |
+
},
|
75 |
+
# Add other providers here if needed for fallback
|
76 |
+
}
|
77 |
+
except json.JSONDecodeError:
|
78 |
+
logger.error("Error decoding models.json. Using hardcoded fallback models.")
|
79 |
+
# Keep the hardcoded fallback as a safety measure
|
80 |
+
MODELS_BY_PROVIDER = {
|
81 |
+
"groq": {
|
82 |
+
"default": "llama3-8b-8192",
|
83 |
+
"models": {
|
84 |
+
"Llama 3 8B (Groq)": "llama3-8b-8192",
|
85 |
+
"Llama 3 70B (Groq)": "llama3-70b-8192",
|
86 |
+
"Mixtral 8x7B (Groq)": "mixtral-8x7b-32768",
|
87 |
+
"Gemma 7B (Groq)": "gemma-7b-it",
|
88 |
+
}
|
89 |
+
},
|
90 |
+
"openrouter": {
|
91 |
+
"default": "nousresearch/llama-3-8b-instruct",
|
92 |
+
"models": {
|
93 |
+
"Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct",
|
94 |
+
"Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free",
|
95 |
+
"Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free",
|
96 |
+
}
|
97 |
+
},
|
98 |
+
"google": {
|
99 |
+
"default": "gemini-1.5-flash-latest",
|
100 |
+
"models": {
|
101 |
+
"Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest",
|
102 |
+
"Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest",
|
103 |
+
}
|
104 |
+
},
|
105 |
+
"openai": {
|
106 |
+
"default": "gpt-3.5-turbo",
|
107 |
+
"models": {
|
108 |
+
"GPT-4o mini (OpenAI)": "gpt-4o-mini",
|
109 |
+
"GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
|
110 |
+
}
|
111 |
+
},
|
112 |
+
# Add other providers here if needed for fallback
|
113 |
}
|
114 |
+
|
115 |
|
116 |
def _get_api_key(provider: str, ui_api_key_override: str = None) -> str:
|
117 |
if ui_api_key_override:
|
118 |
+
logger.debug(f"Using UI API key override for {provider}")
|
119 |
return ui_api_key_override.strip()
|
120 |
|
121 |
env_var_name = API_KEYS.get(provider.upper())
|
122 |
if env_var_name:
|
123 |
env_key = os.getenv(env_var_name)
|
124 |
if env_key:
|
125 |
+
logger.debug(f"Using env var {env_var_name} for {provider}")
|
126 |
return env_key.strip()
|
127 |
|
128 |
+
# Special case for Hugging Face, HF_TOKEN is common
|
129 |
if provider.lower() == 'huggingface':
|
130 |
hf_token = os.getenv("HF_TOKEN")
|
131 |
+
if hf_token:
|
132 |
+
logger.debug(f"Using HF_TOKEN env var for {provider}")
|
133 |
+
return hf_token.strip()
|
134 |
|
135 |
logger.warning(f"API Key not found for provider '{provider}'. Checked UI override and environment variable '{env_var_name or 'N/A'}'.")
|
136 |
return None
|
|
|
145 |
models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
|
146 |
default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default")
|
147 |
if default_model_id:
|
148 |
+
# Find the display name corresponding to the default model ID
|
149 |
for display_name, model_id in models_dict.items():
|
150 |
if model_id == default_model_id:
|
151 |
return display_name
|
152 |
+
# Fallback: If no default specified or found, return the first model in the sorted list
|
153 |
if models_dict:
|
154 |
return sorted(list(models_dict.keys()))[0]
|
155 |
return None
|
|
|
179 |
headers = {}
|
180 |
payload = {}
|
181 |
request_url = base_url
|
182 |
+
timeout_seconds = 180 # Increased timeout
|
183 |
|
184 |
logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...")
|
185 |
|
|
|
189 |
payload = {
|
190 |
"model": model_id,
|
191 |
"messages": messages,
|
192 |
+
"stream": True,
|
193 |
+
"temperature": 0.7, # Add temperature
|
194 |
+
"max_tokens": 4096 # Add max_tokens
|
195 |
}
|
196 |
if provider_lower == "openrouter":
|
197 |
+
headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/your_username/ai-space-commander" # Use space name
|
198 |
+
headers["X-Title"] = "Hugging Face Space Commander" # Use project title
|
199 |
|
200 |
+
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds)
|
201 |
response.raise_for_status()
|
202 |
|
203 |
byte_buffer = b""
|
204 |
for chunk in response.iter_content(chunk_size=8192):
|
205 |
+
# Check for potential HTTP errors during streaming
|
206 |
+
if response.status_code != 200:
|
207 |
+
# Attempt to read error body if available
|
208 |
+
error_body = response.text
|
209 |
+
logger.error(f"HTTP Error during stream: {response.status_code}, Body: {error_body}")
|
210 |
+
yield f"API HTTP Error ({response.status_code}) during stream: {error_body}"
|
211 |
+
return # Stop streaming on error
|
212 |
+
|
213 |
byte_buffer += chunk
|
214 |
while b'\n' in byte_buffer:
|
215 |
line, byte_buffer = byte_buffer.split(b'\n', 1)
|
|
|
217 |
if decoded_line.startswith('data: '):
|
218 |
data = decoded_line[6:]
|
219 |
if data == '[DONE]':
|
220 |
+
byte_buffer = b'' # Clear buffer after DONE
|
221 |
break
|
222 |
try:
|
223 |
event_data = json.loads(data)
|
|
|
226 |
if delta and delta.get("content"):
|
227 |
yield delta["content"]
|
228 |
except json.JSONDecodeError:
|
229 |
+
# Log warning but continue, partial data might be okay or next line fixes it
|
230 |
+
logger.warning(f"Failed to decode JSON from stream line: {decoded_line.strip()}")
|
231 |
except Exception as e:
|
232 |
+
logger.error(f"Error processing stream data: {e}, Data: {decoded_line.strip()}")
|
233 |
+
# Process any remaining data in the buffer after the loop
|
234 |
if byte_buffer:
|
235 |
+
remaining_line = byte_buffer.decode('utf-8', errors='ignore').strip()
|
236 |
if remaining_line.startswith('data: '):
|
237 |
data = remaining_line[6:]
|
238 |
if data != '[DONE]':
|
|
|
253 |
filtered_messages = []
|
254 |
for msg in messages:
|
255 |
if msg["role"] == "system":
|
256 |
+
# Google's API takes system instruction separately or expects a specific history format
|
257 |
+
# Let's extract the system instruction
|
258 |
system_instruction = msg["content"]
|
259 |
else:
|
260 |
+
# Map roles: 'user' -> 'user', 'assistant' -> 'model'
|
261 |
role = "model" if msg["role"] == "assistant" else msg["role"]
|
262 |
filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]})
|
263 |
|
264 |
+
# Ensure conversation history alternates roles correctly for Google
|
265 |
+
# Simple check: if last two roles are same, it's invalid.
|
266 |
+
for i in range(1, len(filtered_messages)):
|
267 |
+
if filtered_messages[i]["role"] == filtered_messages[i-1]["role"]:
|
268 |
+
yield f"Error: Google API requires alternating user/model roles in chat history. Please check prompt or history format."
|
269 |
+
return # Stop if history format is invalid
|
270 |
+
|
271 |
payload = {
|
272 |
"contents": filtered_messages,
|
273 |
+
"safetySettings": [ # Default safety settings to allow helpful but potentially sensitive code/instructions
|
274 |
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
275 |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
276 |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
|
|
278 |
],
|
279 |
"generationConfig": {
|
280 |
"temperature": 0.7,
|
281 |
+
"maxOutputTokens": 4096 # Google's max_tokens equivalent
|
282 |
}
|
283 |
}
|
284 |
+
# System instruction is passed separately
|
285 |
if system_instruction:
|
286 |
payload["system_instruction"] = {"parts": [{"text": system_instruction}]}
|
287 |
|
288 |
+
|
289 |
request_url = f"{base_url}{model_id}:streamGenerateContent"
|
290 |
+
# API key is passed as a query parameter for Google
|
291 |
request_url = f"{request_url}?key={api_key}"
|
292 |
+
headers = {"Content-Type": "application/json"} # Content-Type is still application/json
|
293 |
|
294 |
+
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds)
|
295 |
response.raise_for_status()
|
296 |
|
297 |
byte_buffer = b""
|
298 |
for chunk in response.iter_content(chunk_size=8192):
|
299 |
+
# Check for potential HTTP errors during streaming
|
300 |
+
if response.status_code != 200:
|
301 |
+
error_body = response.text
|
302 |
+
logger.error(f"HTTP Error during Google stream: {response.status_code}, Body: {error_body}")
|
303 |
+
yield f"API HTTP Error ({response.status_code}) during Google stream: {error_body}"
|
304 |
+
return # Stop streaming on error
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
|
306 |
+
byte_buffer += chunk
|
307 |
+
# Google's streaming can send multiple JSON objects in one chunk, sometimes split by newlines
|
308 |
+
# Or just single JSON objects. They don't strictly follow the Server-Sent Events 'data:' format.
|
309 |
+
# We need to find JSON objects in the buffer.
|
310 |
+
json_decoder = json.JSONDecoder()
|
311 |
+
while byte_buffer:
|
312 |
+
try:
|
313 |
+
# Attempt to decode a JSON object from the start of the buffer
|
314 |
+
obj, idx = json_decoder.raw_decode(byte_buffer.decode('utf-8', errors='ignore').lstrip()) # lstrip to handle leading whitespace/newlines
|
315 |
+
# If successful, process the object
|
316 |
+
byte_buffer = byte_buffer[len(byte_buffer.decode('utf-8', errors='ignore').lstrip()[:idx]).encode('utf-8'):] # Remove the decoded part from the buffer
|
317 |
+
|
318 |
+
if obj.get("candidates") and len(obj["candidates"]) > 0:
|
319 |
+
candidate = obj["candidates"][0]
|
320 |
+
if candidate.get("content") and candidate["content"].get("parts"):
|
321 |
+
full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
|
322 |
+
if full_text_chunk:
|
323 |
+
yield full_text_chunk
|
324 |
+
# Check for potential errors in the response object itself
|
325 |
+
if obj.get("error"):
|
326 |
+
error_details = obj["error"].get("message", str(obj["error"]))
|
327 |
+
logger.error(f"Google API returned error in stream data: {error_details}")
|
328 |
+
yield f"API Error (Google): {error_details}"
|
329 |
+
return # Stop streaming
|
330 |
+
|
331 |
+
except json.JSONDecodeError:
|
332 |
+
# If raw_decode fails, it means the buffer doesn't contain a complete JSON object at the start.
|
333 |
+
# Break the inner while loop and wait for more data.
|
334 |
+
break
|
335 |
+
except Exception as e:
|
336 |
+
logger.error(f"Error processing Google stream data object: {e}, Object: {obj}")
|
337 |
+
# Decide if this is a fatal error or just a bad chunk
|
338 |
+
# For now, log and continue might be okay for processing subsequent chunks.
|
339 |
+
|
340 |
+
# If loop finishes and buffer still has data, log it (incomplete data)
|
341 |
if byte_buffer:
|
342 |
+
logger.warning(f"Remaining data in Google stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
|
345 |
elif provider_lower == "cohere":
|
|
|
350 |
system_prompt_for_cohere = None
|
351 |
current_message_for_cohere = ""
|
352 |
|
353 |
+
# Cohere requires a specific history format and separates system/preamble
|
354 |
+
# The last message is the "message", previous are "chat_history"
|
355 |
temp_history = []
|
356 |
for msg in messages:
|
357 |
if msg["role"] == "system":
|
358 |
+
# If multiple system prompts, concatenate them for preamble
|
359 |
+
if system_prompt_for_cohere: system_prompt_for_cohere += "\n" + msg["content"]
|
360 |
+
else: system_prompt_for_cohere = msg["content"]
|
361 |
elif msg["role"] == "user" or msg["role"] == "assistant":
|
362 |
temp_history.append(msg)
|
363 |
|
364 |
+
if not temp_history:
|
365 |
+
yield "Error: No user message found for Cohere API call."
|
366 |
+
return
|
367 |
+
if temp_history[-1]["role"] != "user":
|
368 |
+
yield "Error: Last message must be from user for Cohere API call."
|
|
|
369 |
return
|
370 |
|
371 |
+
current_message_for_cohere = temp_history[-1]["content"]
|
372 |
+
# Map roles: 'user' -> 'user', 'assistant' -> 'chatbot'
|
373 |
+
chat_history_for_cohere = [{"role": ("chatbot" if m["role"] == "assistant" else m["role"]), "message": m["content"]} for m in temp_history[:-1]]
|
374 |
+
|
375 |
payload = {
|
376 |
"model": model_id,
|
377 |
"message": current_message_for_cohere,
|
378 |
"stream": True,
|
379 |
+
"temperature": 0.7,
|
380 |
+
"max_tokens": 4096 # Add max_tokens
|
381 |
}
|
382 |
if chat_history_for_cohere:
|
383 |
payload["chat_history"] = chat_history_for_cohere
|
384 |
if system_prompt_for_cohere:
|
385 |
payload["preamble"] = system_prompt_for_cohere
|
386 |
|
387 |
+
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds)
|
388 |
response.raise_for_status()
|
389 |
|
390 |
byte_buffer = b""
|
391 |
for chunk in response.iter_content(chunk_size=8192):
|
392 |
+
# Check for potential HTTP errors during streaming
|
393 |
+
if response.status_code != 200:
|
394 |
+
error_body = response.text
|
395 |
+
logger.error(f"HTTP Error during Cohere stream: {response.status_code}, Body: {error_body}")
|
396 |
+
yield f"API HTTP Error ({response.status_code}) during Cohere stream: {error_body}"
|
397 |
+
return # Stop streaming on error
|
398 |
+
|
399 |
byte_buffer += chunk
|
400 |
+
while b'\n\n' in byte_buffer: # Cohere uses \n\n as event separator
|
401 |
event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1)
|
402 |
lines = event_chunk.strip().split(b'\n')
|
403 |
event_type = None
|
404 |
event_data = None
|
405 |
|
406 |
for l in lines:
|
407 |
+
if l.strip() == b"": continue # Skip blank lines within an event
|
408 |
if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
|
409 |
elif l.startswith(b"data: "):
|
410 |
try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
|
411 |
except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}")
|
412 |
+
else:
|
413 |
+
# Log unexpected lines in event chunk
|
414 |
+
logger.warning(f"Cohere: Unexpected line in event chunk: {l.decode('utf-8', errors='ignore').strip()}")
|
415 |
+
|
416 |
|
417 |
if event_type == "text-generation" and event_data and "text" in event_data:
|
418 |
yield event_data["text"]
|
419 |
elif event_type == "stream-end":
|
420 |
+
logger.debug("Cohere stream-end event received.")
|
421 |
+
byte_buffer = b'' # Clear buffer after stream-end
|
422 |
+
break # Exit the while loop
|
423 |
+
elif event_type == "error":
|
424 |
+
error_msg = event_data.get("message", str(event_data)) if event_data else "Unknown Cohere stream error"
|
425 |
+
logger.error(f"Cohere stream error event: {error_msg}")
|
426 |
+
yield f"API Error (Cohere stream): {error_msg}"
|
427 |
+
return # Stop streaming on error
|
428 |
+
|
429 |
+
# Process any remaining data in the buffer after the loop
|
430 |
if byte_buffer:
|
431 |
+
logger.warning(f"Remaining data in Cohere stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
|
433 |
|
434 |
elif provider_lower == "huggingface":
|
435 |
+
# Hugging Face Inference API often supports streaming for text-generation,
|
436 |
+
# but chat completion streaming format varies greatly model by model, if supported.
|
437 |
+
# Standard OpenAI-like streaming is not guaranteed.
|
438 |
+
# Let's provide a more informative message.
|
439 |
+
yield f"Error: Direct Hugging Face Inference API streaming for chat models is highly experimental and depends heavily on the specific model's implementation. Standard OpenAI-like streaming is NOT guaranteed. For better compatibility with HF models that support the OpenAI format, consider using the OpenRouter or TogetherAI providers and selecting the HF models listed there."
|
440 |
return
|
441 |
|
442 |
else:
|
|
|
448 |
error_text = e.response.text if e.response is not None else 'No response text'
|
449 |
logger.error(f"HTTP error during streaming for {provider}/{model_id}: {e}")
|
450 |
yield f"API HTTP Error ({status_code}): {error_text}\nDetails: {e}"
|
451 |
+
except requests.exceptions.Timeout:
|
452 |
+
logger.error(f"Request Timeout after {timeout_seconds} seconds for {provider}/{model_id}.")
|
453 |
+
yield f"API Request Timeout: The request took too long to complete ({timeout_seconds} seconds)."
|
454 |
except requests.exceptions.RequestException as e:
|
455 |
logger.error(f"Request error during streaming for {provider}/{model_id}: {e}")
|
456 |
yield f"API Request Error: Could not connect or receive response from {provider} ({e})"
|
457 |
except Exception as e:
|
458 |
logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:")
|
459 |
+
yield f"An unexpected error occurred during streaming: {e}"
|