Update model_logic.py
Browse files- model_logic.py +36 -115
model_logic.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
import requests
|
3 |
import json
|
4 |
import logging
|
5 |
-
import time
|
6 |
|
7 |
logging.basicConfig(
|
8 |
level=logging.INFO,
|
@@ -11,7 +11,7 @@ logging.basicConfig(
|
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
API_KEYS = {
|
14 |
-
"HUGGINGFACE": 'HF_TOKEN',
|
15 |
"GROQ": 'GROQ_API_KEY',
|
16 |
"OPENROUTER": 'OPENROUTER_API_KEY',
|
17 |
"TOGETHERAI": 'TOGETHERAI_API_KEY',
|
@@ -29,17 +29,15 @@ API_URLS = {
|
|
29 |
"COHERE": 'https://api.cohere.ai/v1/chat',
|
30 |
"XAI": 'https://api.x.ai/v1/chat/completions',
|
31 |
"OPENAI": 'https://api.openai.com/v1/chat/completions',
|
32 |
-
"GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/',
|
33 |
}
|
34 |
|
35 |
-
# Load model configuration from JSON
|
36 |
try:
|
37 |
with open("models.json", "r") as f:
|
38 |
MODELS_BY_PROVIDER = json.load(f)
|
39 |
logger.info("models.json loaded successfully.")
|
40 |
-
except FileNotFoundError:
|
41 |
-
logger.error("models.json
|
42 |
-
# Keep the hardcoded fallback as a safety measure
|
43 |
MODELS_BY_PROVIDER = {
|
44 |
"groq": {
|
45 |
"default": "llama3-8b-8192",
|
@@ -70,47 +68,9 @@ except FileNotFoundError:
|
|
70 |
"models": {
|
71 |
"GPT-4o mini (OpenAI)": "gpt-4o-mini",
|
72 |
"GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
|
73 |
-
}
|
74 |
-
},
|
75 |
-
# Add other providers here if needed for fallback
|
76 |
-
}
|
77 |
-
except json.JSONDecodeError:
|
78 |
-
logger.error("Error decoding models.json. Using hardcoded fallback models.")
|
79 |
-
# Keep the hardcoded fallback as a safety measure
|
80 |
-
MODELS_BY_PROVIDER = {
|
81 |
-
"groq": {
|
82 |
-
"default": "llama3-8b-8192",
|
83 |
-
"models": {
|
84 |
-
"Llama 3 8B (Groq)": "llama3-8b-8192",
|
85 |
-
"Llama 3 70B (Groq)": "llama3-70b-8192",
|
86 |
-
"Mixtral 8x7B (Groq)": "mixtral-8x7b-32768",
|
87 |
-
"Gemma 7B (Groq)": "gemma-7b-it",
|
88 |
-
}
|
89 |
-
},
|
90 |
-
"openrouter": {
|
91 |
-
"default": "nousresearch/llama-3-8b-instruct",
|
92 |
-
"models": {
|
93 |
-
"Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct",
|
94 |
-
"Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free",
|
95 |
-
"Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free",
|
96 |
-
}
|
97 |
-
},
|
98 |
-
"google": {
|
99 |
-
"default": "gemini-1.5-flash-latest",
|
100 |
-
"models": {
|
101 |
-
"Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest",
|
102 |
-
"Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest",
|
103 |
-
}
|
104 |
-
},
|
105 |
-
"openai": {
|
106 |
-
"default": "gpt-3.5-turbo",
|
107 |
-
"models": {
|
108 |
-
"GPT-4o mini (OpenAI)": "gpt-4o-mini",
|
109 |
-
"GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
|
110 |
}
|
111 |
},
|
112 |
-
|
113 |
-
}
|
114 |
|
115 |
|
116 |
def _get_api_key(provider: str, ui_api_key_override: str = None) -> str:
|
@@ -125,7 +85,6 @@ def _get_api_key(provider: str, ui_api_key_override: str = None) -> str:
|
|
125 |
logger.debug(f"Using env var {env_var_name} for {provider}")
|
126 |
return env_key.strip()
|
127 |
|
128 |
-
# Special case for Hugging Face, HF_TOKEN is common
|
129 |
if provider.lower() == 'huggingface':
|
130 |
hf_token = os.getenv("HF_TOKEN")
|
131 |
if hf_token:
|
@@ -145,11 +104,9 @@ def get_default_model_for_provider(provider: str) -> str | None:
|
|
145 |
models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
|
146 |
default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default")
|
147 |
if default_model_id:
|
148 |
-
# Find the display name corresponding to the default model ID
|
149 |
for display_name, model_id in models_dict.items():
|
150 |
if model_id == default_model_id:
|
151 |
return display_name
|
152 |
-
# Fallback: If no default specified or found, return the first model in the sorted list
|
153 |
if models_dict:
|
154 |
return sorted(list(models_dict.keys()))[0]
|
155 |
return None
|
@@ -179,7 +136,7 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
179 |
headers = {}
|
180 |
payload = {}
|
181 |
request_url = base_url
|
182 |
-
timeout_seconds = 180
|
183 |
|
184 |
logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...")
|
185 |
|
@@ -190,25 +147,23 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
190 |
"model": model_id,
|
191 |
"messages": messages,
|
192 |
"stream": True,
|
193 |
-
"temperature": 0.7,
|
194 |
-
"max_tokens": 4096
|
195 |
}
|
196 |
if provider_lower == "openrouter":
|
197 |
-
headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/
|
198 |
-
headers["X-Title"] = "Hugging Face Space Commander"
|
199 |
|
200 |
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds)
|
201 |
response.raise_for_status()
|
202 |
|
203 |
byte_buffer = b""
|
204 |
for chunk in response.iter_content(chunk_size=8192):
|
205 |
-
# Check for potential HTTP errors during streaming
|
206 |
if response.status_code != 200:
|
207 |
-
# Attempt to read error body if available
|
208 |
error_body = response.text
|
209 |
logger.error(f"HTTP Error during stream: {response.status_code}, Body: {error_body}")
|
210 |
yield f"API HTTP Error ({response.status_code}) during stream: {error_body}"
|
211 |
-
return
|
212 |
|
213 |
byte_buffer += chunk
|
214 |
while b'\n' in byte_buffer:
|
@@ -217,7 +172,7 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
217 |
if decoded_line.startswith('data: '):
|
218 |
data = decoded_line[6:]
|
219 |
if data == '[DONE]':
|
220 |
-
byte_buffer = b''
|
221 |
break
|
222 |
try:
|
223 |
event_data = json.loads(data)
|
@@ -226,11 +181,9 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
226 |
if delta and delta.get("content"):
|
227 |
yield delta["content"]
|
228 |
except json.JSONDecodeError:
|
229 |
-
# Log warning but continue, partial data might be okay or next line fixes it
|
230 |
logger.warning(f"Failed to decode JSON from stream line: {decoded_line.strip()}")
|
231 |
except Exception as e:
|
232 |
logger.error(f"Error processing stream data: {e}, Data: {decoded_line.strip()}")
|
233 |
-
# Process any remaining data in the buffer after the loop
|
234 |
if byte_buffer:
|
235 |
remaining_line = byte_buffer.decode('utf-8', errors='ignore').strip()
|
236 |
if remaining_line.startswith('data: '):
|
@@ -253,24 +206,20 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
253 |
filtered_messages = []
|
254 |
for msg in messages:
|
255 |
if msg["role"] == "system":
|
256 |
-
|
257 |
-
|
258 |
-
system_instruction = msg["content"]
|
259 |
else:
|
260 |
-
# Map roles: 'user' -> 'user', 'assistant' -> 'model'
|
261 |
role = "model" if msg["role"] == "assistant" else msg["role"]
|
262 |
filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]})
|
263 |
|
264 |
-
# Ensure conversation history alternates roles correctly for Google
|
265 |
-
# Simple check: if last two roles are same, it's invalid.
|
266 |
for i in range(1, len(filtered_messages)):
|
267 |
if filtered_messages[i]["role"] == filtered_messages[i-1]["role"]:
|
268 |
yield f"Error: Google API requires alternating user/model roles in chat history. Please check prompt or history format."
|
269 |
-
return
|
270 |
|
271 |
payload = {
|
272 |
"contents": filtered_messages,
|
273 |
-
"safetySettings": [
|
274 |
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
275 |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
276 |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
@@ -278,66 +227,51 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
278 |
],
|
279 |
"generationConfig": {
|
280 |
"temperature": 0.7,
|
281 |
-
"maxOutputTokens": 4096
|
282 |
}
|
283 |
}
|
284 |
-
# System instruction is passed separately
|
285 |
if system_instruction:
|
286 |
payload["system_instruction"] = {"parts": [{"text": system_instruction}]}
|
287 |
|
288 |
-
|
289 |
request_url = f"{base_url}{model_id}:streamGenerateContent"
|
290 |
-
# API key is passed as a query parameter for Google
|
291 |
request_url = f"{request_url}?key={api_key}"
|
292 |
-
headers = {"Content-Type": "application/json"}
|
293 |
|
294 |
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds)
|
295 |
response.raise_for_status()
|
296 |
|
297 |
byte_buffer = b""
|
|
|
298 |
for chunk in response.iter_content(chunk_size=8192):
|
299 |
-
# Check for potential HTTP errors during streaming
|
300 |
if response.status_code != 200:
|
301 |
error_body = response.text
|
302 |
logger.error(f"HTTP Error during Google stream: {response.status_code}, Body: {error_body}")
|
303 |
yield f"API HTTP Error ({response.status_code}) during Google stream: {error_body}"
|
304 |
-
return
|
305 |
|
306 |
byte_buffer += chunk
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
json_decoder = json.JSONDecoder()
|
311 |
-
while byte_buffer:
|
312 |
try:
|
313 |
-
|
314 |
-
|
315 |
-
# If successful, process the object
|
316 |
-
byte_buffer = byte_buffer[len(byte_buffer.decode('utf-8', errors='ignore').lstrip()[:idx]).encode('utf-8'):] # Remove the decoded part from the buffer
|
317 |
-
|
318 |
if obj.get("candidates") and len(obj["candidates"]) > 0:
|
319 |
candidate = obj["candidates"][0]
|
320 |
if candidate.get("content") and candidate["content"].get("parts"):
|
321 |
full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
|
322 |
if full_text_chunk:
|
323 |
yield full_text_chunk
|
324 |
-
# Check for potential errors in the response object itself
|
325 |
if obj.get("error"):
|
326 |
error_details = obj["error"].get("message", str(obj["error"]))
|
327 |
logger.error(f"Google API returned error in stream data: {error_details}")
|
328 |
yield f"API Error (Google): {error_details}"
|
329 |
-
return
|
330 |
-
|
331 |
except json.JSONDecodeError:
|
332 |
-
# If raw_decode fails, it means the buffer doesn't contain a complete JSON object at the start.
|
333 |
-
# Break the inner while loop and wait for more data.
|
334 |
break
|
335 |
except Exception as e:
|
336 |
logger.error(f"Error processing Google stream data object: {e}, Object: {obj}")
|
337 |
-
# Decide if this is a fatal error or just a bad chunk
|
338 |
-
# For now, log and continue might be okay for processing subsequent chunks.
|
339 |
|
340 |
-
# If loop finishes and buffer still has data, log it (incomplete data)
|
341 |
if byte_buffer:
|
342 |
logger.warning(f"Remaining data in Google stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}")
|
343 |
|
@@ -350,12 +284,9 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
350 |
system_prompt_for_cohere = None
|
351 |
current_message_for_cohere = ""
|
352 |
|
353 |
-
# Cohere requires a specific history format and separates system/preamble
|
354 |
-
# The last message is the "message", previous are "chat_history"
|
355 |
temp_history = []
|
356 |
for msg in messages:
|
357 |
if msg["role"] == "system":
|
358 |
-
# If multiple system prompts, concatenate them for preamble
|
359 |
if system_prompt_for_cohere: system_prompt_for_cohere += "\n" + msg["content"]
|
360 |
else: system_prompt_for_cohere = msg["content"]
|
361 |
elif msg["role"] == "user" or msg["role"] == "assistant":
|
@@ -369,7 +300,6 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
369 |
return
|
370 |
|
371 |
current_message_for_cohere = temp_history[-1]["content"]
|
372 |
-
# Map roles: 'user' -> 'user', 'assistant' -> 'chatbot'
|
373 |
chat_history_for_cohere = [{"role": ("chatbot" if m["role"] == "assistant" else m["role"]), "message": m["content"]} for m in temp_history[:-1]]
|
374 |
|
375 |
payload = {
|
@@ -377,7 +307,7 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
377 |
"message": current_message_for_cohere,
|
378 |
"stream": True,
|
379 |
"temperature": 0.7,
|
380 |
-
"max_tokens": 4096
|
381 |
}
|
382 |
if chat_history_for_cohere:
|
383 |
payload["chat_history"] = chat_history_for_cohere
|
@@ -389,54 +319,45 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
389 |
|
390 |
byte_buffer = b""
|
391 |
for chunk in response.iter_content(chunk_size=8192):
|
392 |
-
# Check for potential HTTP errors during streaming
|
393 |
if response.status_code != 200:
|
394 |
error_body = response.text
|
395 |
logger.error(f"HTTP Error during Cohere stream: {response.status_code}, Body: {error_body}")
|
396 |
yield f"API HTTP Error ({response.status_code}) during Cohere stream: {error_body}"
|
397 |
-
return
|
398 |
|
399 |
byte_buffer += chunk
|
400 |
-
while b'\n\n' in byte_buffer:
|
401 |
event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1)
|
402 |
lines = event_chunk.strip().split(b'\n')
|
403 |
event_type = None
|
404 |
event_data = None
|
405 |
|
406 |
for l in lines:
|
407 |
-
if l.strip() == b"": continue
|
408 |
if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
|
409 |
elif l.startswith(b"data: "):
|
410 |
try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
|
411 |
except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}")
|
412 |
else:
|
413 |
-
# Log unexpected lines in event chunk
|
414 |
logger.warning(f"Cohere: Unexpected line in event chunk: {l.decode('utf-8', errors='ignore').strip()}")
|
415 |
|
416 |
-
|
417 |
if event_type == "text-generation" and event_data and "text" in event_data:
|
418 |
yield event_data["text"]
|
419 |
elif event_type == "stream-end":
|
420 |
logger.debug("Cohere stream-end event received.")
|
421 |
-
byte_buffer = b''
|
422 |
-
break
|
423 |
elif event_type == "error":
|
424 |
error_msg = event_data.get("message", str(event_data)) if event_data else "Unknown Cohere stream error"
|
425 |
logger.error(f"Cohere stream error event: {error_msg}")
|
426 |
yield f"API Error (Cohere stream): {error_msg}"
|
427 |
-
return
|
428 |
|
429 |
-
# Process any remaining data in the buffer after the loop
|
430 |
if byte_buffer:
|
431 |
logger.warning(f"Remaining data in Cohere stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}")
|
432 |
|
433 |
-
|
434 |
elif provider_lower == "huggingface":
|
435 |
-
|
436 |
-
# but chat completion streaming format varies greatly model by model, if supported.
|
437 |
-
# Standard OpenAI-like streaming is not guaranteed.
|
438 |
-
# Let's provide a more informative message.
|
439 |
-
yield f"Error: Direct Hugging Face Inference API streaming for chat models is highly experimental and depends heavily on the specific model's implementation. Standard OpenAI-like streaming is NOT guaranteed. For better compatibility with HF models that support the OpenAI format, consider using the OpenRouter or TogetherAI providers and selecting the HF models listed there."
|
440 |
return
|
441 |
|
442 |
else:
|
@@ -456,4 +377,4 @@ def generate_stream(provider: str, model_display_name: str, api_key_override: st
|
|
456 |
yield f"API Request Error: Could not connect or receive response from {provider} ({e})"
|
457 |
except Exception as e:
|
458 |
logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:")
|
459 |
-
yield f"An unexpected error occurred during streaming: {e}"
|
|
|
2 |
import requests
|
3 |
import json
|
4 |
import logging
|
5 |
+
import time
|
6 |
|
7 |
logging.basicConfig(
|
8 |
level=logging.INFO,
|
|
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
API_KEYS = {
|
14 |
+
"HUGGINGFACE": 'HF_TOKEN',
|
15 |
"GROQ": 'GROQ_API_KEY',
|
16 |
"OPENROUTER": 'OPENROUTER_API_KEY',
|
17 |
"TOGETHERAI": 'TOGETHERAI_API_KEY',
|
|
|
29 |
"COHERE": 'https://api.cohere.ai/v1/chat',
|
30 |
"XAI": 'https://api.x.ai/v1/chat/completions',
|
31 |
"OPENAI": 'https://api.openai.com/v1/chat/completions',
|
32 |
+
"GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/',
|
33 |
}
|
34 |
|
|
|
35 |
try:
|
36 |
with open("models.json", "r") as f:
|
37 |
MODELS_BY_PROVIDER = json.load(f)
|
38 |
logger.info("models.json loaded successfully.")
|
39 |
+
except (FileNotFoundError, json.JSONDecodeError) as e:
|
40 |
+
logger.error(f"Error loading models.json: {e}. Using hardcoded fallback models.")
|
|
|
41 |
MODELS_BY_PROVIDER = {
|
42 |
"groq": {
|
43 |
"default": "llama3-8b-8192",
|
|
|
68 |
"models": {
|
69 |
"GPT-4o mini (OpenAI)": "gpt-4o-mini",
|
70 |
"GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
}
|
72 |
},
|
73 |
+
}
|
|
|
74 |
|
75 |
|
76 |
def _get_api_key(provider: str, ui_api_key_override: str = None) -> str:
|
|
|
85 |
logger.debug(f"Using env var {env_var_name} for {provider}")
|
86 |
return env_key.strip()
|
87 |
|
|
|
88 |
if provider.lower() == 'huggingface':
|
89 |
hf_token = os.getenv("HF_TOKEN")
|
90 |
if hf_token:
|
|
|
104 |
models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
|
105 |
default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default")
|
106 |
if default_model_id:
|
|
|
107 |
for display_name, model_id in models_dict.items():
|
108 |
if model_id == default_model_id:
|
109 |
return display_name
|
|
|
110 |
if models_dict:
|
111 |
return sorted(list(models_dict.keys()))[0]
|
112 |
return None
|
|
|
136 |
headers = {}
|
137 |
payload = {}
|
138 |
request_url = base_url
|
139 |
+
timeout_seconds = 180
|
140 |
|
141 |
logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...")
|
142 |
|
|
|
147 |
"model": model_id,
|
148 |
"messages": messages,
|
149 |
"stream": True,
|
150 |
+
"temperature": 0.7,
|
151 |
+
"max_tokens": 4096
|
152 |
}
|
153 |
if provider_lower == "openrouter":
|
154 |
+
headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/huggingface/ai-space-commander"
|
155 |
+
headers["X-Title"] = "Hugging Face Space Commander"
|
156 |
|
157 |
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds)
|
158 |
response.raise_for_status()
|
159 |
|
160 |
byte_buffer = b""
|
161 |
for chunk in response.iter_content(chunk_size=8192):
|
|
|
162 |
if response.status_code != 200:
|
|
|
163 |
error_body = response.text
|
164 |
logger.error(f"HTTP Error during stream: {response.status_code}, Body: {error_body}")
|
165 |
yield f"API HTTP Error ({response.status_code}) during stream: {error_body}"
|
166 |
+
return
|
167 |
|
168 |
byte_buffer += chunk
|
169 |
while b'\n' in byte_buffer:
|
|
|
172 |
if decoded_line.startswith('data: '):
|
173 |
data = decoded_line[6:]
|
174 |
if data == '[DONE]':
|
175 |
+
byte_buffer = b''
|
176 |
break
|
177 |
try:
|
178 |
event_data = json.loads(data)
|
|
|
181 |
if delta and delta.get("content"):
|
182 |
yield delta["content"]
|
183 |
except json.JSONDecodeError:
|
|
|
184 |
logger.warning(f"Failed to decode JSON from stream line: {decoded_line.strip()}")
|
185 |
except Exception as e:
|
186 |
logger.error(f"Error processing stream data: {e}, Data: {decoded_line.strip()}")
|
|
|
187 |
if byte_buffer:
|
188 |
remaining_line = byte_buffer.decode('utf-8', errors='ignore').strip()
|
189 |
if remaining_line.startswith('data: '):
|
|
|
206 |
filtered_messages = []
|
207 |
for msg in messages:
|
208 |
if msg["role"] == "system":
|
209 |
+
if system_instruction: system_instruction += "\n" + msg["content"]
|
210 |
+
else: system_instruction = msg["content"]
|
|
|
211 |
else:
|
|
|
212 |
role = "model" if msg["role"] == "assistant" else msg["role"]
|
213 |
filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]})
|
214 |
|
|
|
|
|
215 |
for i in range(1, len(filtered_messages)):
|
216 |
if filtered_messages[i]["role"] == filtered_messages[i-1]["role"]:
|
217 |
yield f"Error: Google API requires alternating user/model roles in chat history. Please check prompt or history format."
|
218 |
+
return
|
219 |
|
220 |
payload = {
|
221 |
"contents": filtered_messages,
|
222 |
+
"safetySettings": [
|
223 |
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
224 |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
225 |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
|
|
227 |
],
|
228 |
"generationConfig": {
|
229 |
"temperature": 0.7,
|
230 |
+
"maxOutputTokens": 4096
|
231 |
}
|
232 |
}
|
|
|
233 |
if system_instruction:
|
234 |
payload["system_instruction"] = {"parts": [{"text": system_instruction}]}
|
235 |
|
|
|
236 |
request_url = f"{base_url}{model_id}:streamGenerateContent"
|
|
|
237 |
request_url = f"{request_url}?key={api_key}"
|
238 |
+
headers = {"Content-Type": "application/json"}
|
239 |
|
240 |
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds)
|
241 |
response.raise_for_status()
|
242 |
|
243 |
byte_buffer = b""
|
244 |
+
json_decoder = json.JSONDecoder()
|
245 |
for chunk in response.iter_content(chunk_size=8192):
|
|
|
246 |
if response.status_code != 200:
|
247 |
error_body = response.text
|
248 |
logger.error(f"HTTP Error during Google stream: {response.status_code}, Body: {error_body}")
|
249 |
yield f"API HTTP Error ({response.status_code}) during Google stream: {error_body}"
|
250 |
+
return
|
251 |
|
252 |
byte_buffer += chunk
|
253 |
+
decoded_buffer = byte_buffer.decode('utf-8', errors='ignore')
|
254 |
+
buffer_index = 0
|
255 |
+
while buffer_index < len(decoded_buffer):
|
|
|
|
|
256 |
try:
|
257 |
+
obj, idx = json_decoder.raw_decode(decoded_buffer[buffer_index:].lstrip())
|
258 |
+
buffer_index += len(decoded_buffer[buffer_index:].lstrip()[:idx])
|
|
|
|
|
|
|
259 |
if obj.get("candidates") and len(obj["candidates"]) > 0:
|
260 |
candidate = obj["candidates"][0]
|
261 |
if candidate.get("content") and candidate["content"].get("parts"):
|
262 |
full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
|
263 |
if full_text_chunk:
|
264 |
yield full_text_chunk
|
|
|
265 |
if obj.get("error"):
|
266 |
error_details = obj["error"].get("message", str(obj["error"]))
|
267 |
logger.error(f"Google API returned error in stream data: {error_details}")
|
268 |
yield f"API Error (Google): {error_details}"
|
269 |
+
return
|
|
|
270 |
except json.JSONDecodeError:
|
|
|
|
|
271 |
break
|
272 |
except Exception as e:
|
273 |
logger.error(f"Error processing Google stream data object: {e}, Object: {obj}")
|
|
|
|
|
274 |
|
|
|
275 |
if byte_buffer:
|
276 |
logger.warning(f"Remaining data in Google stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}")
|
277 |
|
|
|
284 |
system_prompt_for_cohere = None
|
285 |
current_message_for_cohere = ""
|
286 |
|
|
|
|
|
287 |
temp_history = []
|
288 |
for msg in messages:
|
289 |
if msg["role"] == "system":
|
|
|
290 |
if system_prompt_for_cohere: system_prompt_for_cohere += "\n" + msg["content"]
|
291 |
else: system_prompt_for_cohere = msg["content"]
|
292 |
elif msg["role"] == "user" or msg["role"] == "assistant":
|
|
|
300 |
return
|
301 |
|
302 |
current_message_for_cohere = temp_history[-1]["content"]
|
|
|
303 |
chat_history_for_cohere = [{"role": ("chatbot" if m["role"] == "assistant" else m["role"]), "message": m["content"]} for m in temp_history[:-1]]
|
304 |
|
305 |
payload = {
|
|
|
307 |
"message": current_message_for_cohere,
|
308 |
"stream": True,
|
309 |
"temperature": 0.7,
|
310 |
+
"max_tokens": 4096
|
311 |
}
|
312 |
if chat_history_for_cohere:
|
313 |
payload["chat_history"] = chat_history_for_cohere
|
|
|
319 |
|
320 |
byte_buffer = b""
|
321 |
for chunk in response.iter_content(chunk_size=8192):
|
|
|
322 |
if response.status_code != 200:
|
323 |
error_body = response.text
|
324 |
logger.error(f"HTTP Error during Cohere stream: {response.status_code}, Body: {error_body}")
|
325 |
yield f"API HTTP Error ({response.status_code}) during Cohere stream: {error_body}"
|
326 |
+
return
|
327 |
|
328 |
byte_buffer += chunk
|
329 |
+
while b'\n\n' in byte_buffer:
|
330 |
event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1)
|
331 |
lines = event_chunk.strip().split(b'\n')
|
332 |
event_type = None
|
333 |
event_data = None
|
334 |
|
335 |
for l in lines:
|
336 |
+
if l.strip() == b"": continue
|
337 |
if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
|
338 |
elif l.startswith(b"data: "):
|
339 |
try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
|
340 |
except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}")
|
341 |
else:
|
|
|
342 |
logger.warning(f"Cohere: Unexpected line in event chunk: {l.decode('utf-8', errors='ignore').strip()}")
|
343 |
|
|
|
344 |
if event_type == "text-generation" and event_data and "text" in event_data:
|
345 |
yield event_data["text"]
|
346 |
elif event_type == "stream-end":
|
347 |
logger.debug("Cohere stream-end event received.")
|
348 |
+
byte_buffer = b''
|
349 |
+
break
|
350 |
elif event_type == "error":
|
351 |
error_msg = event_data.get("message", str(event_data)) if event_data else "Unknown Cohere stream error"
|
352 |
logger.error(f"Cohere stream error event: {error_msg}")
|
353 |
yield f"API Error (Cohere stream): {error_msg}"
|
354 |
+
return
|
355 |
|
|
|
356 |
if byte_buffer:
|
357 |
logger.warning(f"Remaining data in Cohere stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}")
|
358 |
|
|
|
359 |
elif provider_lower == "huggingface":
|
360 |
+
yield f"Error: Direct Hugging Face Inference API streaming for chat models is highly experimental and depends heavily on the specific model's implementation. For better compatibility with HF models that support the OpenAI format, consider using the OpenRouter or TogetherAI providers and selecting the HF models listed there."
|
|
|
|
|
|
|
|
|
361 |
return
|
362 |
|
363 |
else:
|
|
|
377 |
yield f"API Request Error: Could not connect or receive response from {provider} ({e})"
|
378 |
except Exception as e:
|
379 |
logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:")
|
380 |
+
yield f"An unexpected error occurred during streaming: {e}"
|