broadfield-dev commited on
Commit
0e640ed
·
verified ·
1 Parent(s): 56badb0

Update model_logic.py

Browse files
Files changed (1) hide show
  1. model_logic.py +368 -263
model_logic.py CHANGED
@@ -1,7 +1,12 @@
 
1
  import os
2
  import requests
3
  import json
4
  import logging
 
 
 
 
5
 
6
  logging.basicConfig(
7
  level=logging.INFO,
@@ -9,15 +14,16 @@ logging.basicConfig(
9
  )
10
  logger = logging.getLogger(__name__)
11
 
12
- API_KEYS = {
13
- "HUGGINGFACE": 'HF_TOKEN',
 
14
  "GROQ": 'GROQ_API_KEY',
15
  "OPENROUTER": 'OPENROUTER_API_KEY',
16
  "TOGETHERAI": 'TOGETHERAI_API_KEY',
17
  "COHERE": 'COHERE_API_KEY',
18
  "XAI": 'XAI_API_KEY',
19
  "OPENAI": 'OPENAI_API_KEY',
20
- "GOOGLE": 'GOOGLE_API_KEY',
21
  }
22
 
23
  API_URLS = {
@@ -25,12 +31,13 @@ API_URLS = {
25
  "GROQ": 'https://api.groq.com/openai/v1/chat/completions',
26
  "OPENROUTER": 'https://openrouter.ai/api/v1/chat/completions',
27
  "TOGETHERAI": 'https://api.together.ai/v1/chat/completions',
28
- "COHERE": 'https://api.cohere.ai/v1/chat',
29
  "XAI": 'https://api.x.ai/v1/chat/completions',
30
  "OPENAI": 'https://api.openai.com/v1/chat/completions',
31
  "GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/',
32
  }
33
 
 
34
  MODELS_BY_PROVIDER = {
35
  "groq": {
36
  "default": "llama3-8b-8192",
@@ -42,15 +49,20 @@ MODELS_BY_PROVIDER = {
42
  }
43
  },
44
  "openrouter": {
45
- "default": "nousresearch/llama-3-8b-instruct",
46
  "models": {
47
  "Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct",
48
- "Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free",
49
- "Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free",
50
- "Mixtral 8x7B Instruct v0.1 (OpenRouter)": "mistralai/mixtral-8x7b-instruct",
51
- "Llama 2 70B Chat (OpenRouter)": "meta-llama/llama-2-70b-chat",
52
- "Neural Chat 7B v3.1 (OpenRouter)": "intel/neural-chat-7b-v3-1",
53
- "Goliath 120B (OpenRouter)": "twob/goliath-v2-120b",
 
 
 
 
 
54
  }
55
  },
56
  "togetherai": {
@@ -60,7 +72,7 @@ MODELS_BY_PROVIDER = {
60
  "Llama 3 70B Chat (TogetherAI)": "meta-llama/Llama-3-70b-chat-hf",
61
  "Mixtral 8x7B Instruct (TogetherAI)": "mistralai/Mixtral-8x7B-Instruct-v0.1",
62
  "Gemma 7B Instruct (TogetherAI)": "google/gemma-7b-it",
63
- "RedPajama INCITE Chat 3B (TogetherAI)": "togethercomputer/RedPajama-INCITE-Chat-3B-v1",
64
  }
65
  },
66
  "google": {
@@ -68,335 +80,428 @@ MODELS_BY_PROVIDER = {
68
  "models": {
69
  "Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest",
70
  "Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest",
 
71
  }
72
  },
73
  "cohere": {
74
- "default": "command-light",
75
  "models": {
76
  "Command R (Cohere)": "command-r",
77
  "Command R+ (Cohere)": "command-r-plus",
78
  "Command Light (Cohere)": "command-light",
79
- "Command (Cohere)": "command",
80
  }
81
  },
82
- "huggingface": {
83
- "default": "HuggingFaceH4/zephyr-7b-beta",
84
  "models": {
85
- "Zephyr 7B Beta (H4/HF Inf.)": "HuggingFaceH4/zephyr-7b-beta",
86
  "Mistral 7B Instruct v0.2 (HF Inf.)": "mistralai/Mistral-7B-Instruct-v0.2",
87
- "Llama 2 13B Chat (Meta/HF Inf.)": "meta-llama/Llama-2-13b-chat-hf",
88
- "OpenAssistant/oasst-sft-4-pythia-12b (HF Inf.)": "OpenAssistant/oasst-sft-4-pythia-12b",
89
  }
90
  },
91
  "openai": {
92
- "default": "gpt-3.5-turbo",
93
  "models": {
94
  "GPT-4o (OpenAI)": "gpt-4o",
95
  "GPT-4o mini (OpenAI)": "gpt-4o-mini",
96
- "GPT-4 Turbo (OpenAI)": "gpt-4-turbo",
97
- "GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
98
  }
99
  },
100
- "xai": {
101
- "default": "grok-1",
102
  "models": {
103
- "Grok-1 (xAI)": "grok-1",
 
104
  }
105
  }
106
  }
107
 
108
- def _get_api_key(provider: str, ui_api_key_override: str = None) -> str:
109
- if ui_api_key_override:
 
 
 
 
 
 
110
  return ui_api_key_override.strip()
111
 
112
- env_var_name = API_KEYS.get(provider.upper())
113
  if env_var_name:
114
  env_key = os.getenv(env_var_name)
115
- if env_key:
 
116
  return env_key.strip()
117
 
118
- if provider.lower() == 'huggingface':
119
- hf_token = os.getenv("HF_TOKEN")
120
- if hf_token: return hf_token.strip()
 
 
 
121
 
122
- logger.warning(f"API Key not found for provider '{provider}'. Checked UI override and environment variable '{env_var_name or 'N/A'}'.")
123
  return None
124
 
125
  def get_available_providers() -> list[str]:
 
126
  return sorted(list(MODELS_BY_PROVIDER.keys()))
127
 
128
- def get_models_for_provider(provider: str) -> list[str]:
 
129
  return sorted(list(MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).keys()))
130
 
131
- def get_default_model_for_provider(provider: str) -> str | None:
132
- models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
133
- default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default")
134
- if default_model_id:
135
- for display_name, model_id in models_dict.items():
136
- if model_id == default_model_id:
 
 
 
137
  return display_name
 
 
138
  if models_dict:
139
- return sorted(list(models_dict.keys()))[0]
 
 
140
  return None
141
 
142
  def get_model_id_from_display_name(provider: str, display_name: str) -> str | None:
 
143
  models = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
144
  return models.get(display_name)
145
 
146
- def generate_stream(provider: str, model_display_name: str, api_key_override: str, messages: list[dict]) -> iter:
 
 
 
 
 
 
147
  provider_lower = provider.lower()
148
  api_key = _get_api_key(provider_lower, api_key_override)
149
-
150
  base_url = API_URLS.get(provider.upper())
151
  model_id = get_model_id_from_display_name(provider_lower, model_display_name)
152
 
153
  if not api_key:
154
- env_var_name = API_KEYS.get(provider.upper(), 'N/A')
155
- yield f"Error: API Key not found for {provider}. Please set it in the UI override or environment variable '{env_var_name}'."
156
  return
157
  if not base_url:
158
  yield f"Error: Unknown provider '{provider}' or missing API URL configuration."
159
  return
160
  if not model_id:
161
- yield f"Error: Unknown model '{model_display_name}' for provider '{provider}'. Please select a valid model."
162
  return
163
 
164
  headers = {}
165
  payload = {}
166
  request_url = base_url
167
 
168
- logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...")
 
 
 
 
 
 
169
 
170
- try:
171
- if provider_lower in ["groq", "openrouter", "togetherai", "openai", "xai"]:
172
- headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
173
- payload = {
174
- "model": model_id,
175
- "messages": messages,
176
- "stream": True
177
- }
178
- if provider_lower == "openrouter":
179
- headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/your_username/ai-space-builder"
180
- headers["X-Title"] = "AI Space Builder"
181
 
 
182
  response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
183
  response.raise_for_status()
184
 
185
- byte_buffer = b""
186
- for chunk in response.iter_content(chunk_size=8192):
187
- byte_buffer += chunk
188
- while b'\n' in byte_buffer:
189
- line, byte_buffer = byte_buffer.split(b'\n', 1)
190
- decoded_line = line.decode('utf-8', errors='ignore')
191
- if decoded_line.startswith('data: '):
192
- data = decoded_line[6:]
193
- if data == '[DONE]':
194
- byte_buffer = b''
195
- break
196
- try:
197
- event_data = json.loads(data)
198
- if event_data.get("choices") and len(event_data["choices"]) > 0:
199
- delta = event_data["choices"][0].get("delta")
200
- if delta and delta.get("content"):
201
- yield delta["content"]
202
- except json.JSONDecodeError:
203
- logger.warning(f"Failed to decode JSON from stream line: {decoded_line}")
204
- except Exception as e:
205
- logger.error(f"Error processing stream data: {e}, Data: {decoded_line}")
206
- if byte_buffer:
207
- remaining_line = byte_buffer.decode('utf-8', errors='ignore')
208
- if remaining_line.startswith('data: '):
209
- data = remaining_line[6:]
210
- if data != '[DONE]':
211
- try:
212
- event_data = json.loads(data)
213
- if event_data.get("choices") and len(event_data["choices"]) > 0:
214
- delta = event_data["choices"][0].get("delta")
215
- if delta and delta.get("content"):
216
- yield delta["content"]
217
- except json.JSONDecodeError:
218
- logger.warning(f"Failed to decode final stream buffer JSON: {remaining_line}")
219
- except Exception as e:
220
- logger.error(f"Error processing final stream buffer data: {e}, Data: {remaining_line}")
221
-
222
-
223
- elif provider_lower == "google":
224
- system_instruction = None
225
- filtered_messages = []
226
- for msg in messages:
227
- if msg["role"] == "system":
228
- system_instruction = msg["content"]
229
- else:
230
- role = "model" if msg["role"] == "assistant" else msg["role"]
231
- filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]})
232
-
233
- payload = {
234
- "contents": filtered_messages,
235
- "safetySettings": [
236
- {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
237
- {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
238
- {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
239
- {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
240
- ],
241
- "generationConfig": {
242
- "temperature": 0.7,
243
- }
244
- }
245
- if system_instruction:
246
- payload["system_instruction"] = {"parts": [{"text": system_instruction}]}
247
-
248
- request_url = f"{base_url}{model_id}:streamGenerateContent"
249
- headers = {"Content-Type": "application/json"}
250
- request_url = f"{request_url}?key={api_key}"
251
 
 
252
  response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
253
  response.raise_for_status()
254
-
255
- byte_buffer = b""
256
- for chunk in response.iter_content(chunk_size=8192):
257
- byte_buffer += chunk
258
- while b'\n' in byte_buffer:
259
- line, byte_buffer = byte_buffer.split(b'\n', 1)
260
- decoded_line = line.decode('utf-8', errors='ignore')
261
-
262
- if decoded_line.startswith('data: '):
263
- decoded_line = decoded_line[6:].strip()
264
-
265
- if not decoded_line: continue
 
 
 
 
266
 
267
  try:
268
- event_data_list = json.loads(f"[{decoded_line}]")
269
- if not isinstance(event_data_list, list): event_data_list = [event_data_list]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
- for event_data in event_data_list:
272
- if not isinstance(event_data, dict): continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
- if event_data.get("candidates") and len(event_data["candidates"]) > 0:
275
- candidate = event_data["candidates"][0]
276
- if candidate.get("content") and candidate["content"].get("parts"):
277
- full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
278
- if full_text_chunk:
279
- yield full_text_chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
- except json.JSONDecodeError:
282
- logger.warning(f"Failed to decode JSON from Google stream chunk: {decoded_line}. Accumulating buffer.")
283
- pass
284
-
285
- except Exception as e:
286
- logger.error(f"Error processing Google stream data: {e}, Data: {decoded_line}")
287
-
288
- if byte_buffer:
289
- remaining_line = byte_buffer.decode('utf-8', errors='ignore').strip()
290
- if remaining_line:
291
- try:
292
- event_data_list = json.loads(f"[{remaining_line}]")
293
- if not isinstance(event_data_list, list): event_data_list = [event_data_list]
294
- for event_data in event_data_list:
295
- if not isinstance(event_data, dict): continue
296
- if event_data.get("candidates") and len(event_data["candidates"]) > 0:
297
- candidate = event_data["candidates"][0]
298
- if candidate.get("content") and candidate["content"].get("parts"):
299
- full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
300
- if full_text_chunk:
301
- yield full_text_chunk
302
- except json.JSONDecodeError:
303
- logger.warning(f"Failed to decode final Google stream buffer JSON: {remaining_line}")
304
- except Exception as e:
305
- logger.error(f"Error processing final Google stream buffer data: {e}, Data: {remaining_line}")
306
-
307
-
308
- elif provider_lower == "cohere":
309
- headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
310
- request_url = f"{base_url}"
311
-
312
- chat_history_for_cohere = []
313
- system_prompt_for_cohere = None
314
- current_message_for_cohere = ""
315
-
316
- temp_history = []
317
- for msg in messages:
318
- if msg["role"] == "system":
319
- system_prompt_for_cohere = msg["content"]
320
- elif msg["role"] == "user" or msg["role"] == "assistant":
321
- temp_history.append(msg)
322
-
323
- if temp_history:
324
- current_message_for_cohere = temp_history[-1]["content"]
325
- chat_history_for_cohere = [{"role": ("chatbot" if m["role"] == "assistant" else m["role"]), "message": m["content"]} for m in temp_history[:-1]]
326
-
327
- if not current_message_for_cohere:
328
- yield "Error: User message not found for Cohere API call."
329
- return
330
-
331
- payload = {
332
- "model": model_id,
333
- "message": current_message_for_cohere,
334
- "stream": True,
335
- "temperature": 0.7
336
- }
337
- if chat_history_for_cohere:
338
- payload["chat_history"] = chat_history_for_cohere
339
- if system_prompt_for_cohere:
340
- payload["preamble"] = system_prompt_for_cohere
341
 
 
342
  response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
343
  response.raise_for_status()
344
 
345
- byte_buffer = b""
346
- for chunk in response.iter_content(chunk_size=8192):
347
- byte_buffer += chunk
348
- while b'\n\n' in byte_buffer:
349
- event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1)
350
- lines = event_chunk.strip().split(b'\n')
351
- event_type = None
352
- event_data = None
353
-
354
- for l in lines:
355
- if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
356
- elif l.startswith(b"data: "):
357
- try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
358
- except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}")
359
-
360
- if event_type == "text-generation" and event_data and "text" in event_data:
361
- yield event_data["text"]
362
- elif event_type == "stream-end":
363
- byte_buffer = b''
364
- break
365
-
366
- if byte_buffer:
367
- event_chunk = byte_buffer.strip()
368
- if event_chunk:
369
- lines = event_chunk.split(b'\n')
370
- event_type = None
371
- event_data = None
372
- for l in lines:
373
- if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
374
- elif l.startswith(b"data: "):
375
- try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
376
- except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode final event data JSON: {l[6:].strip()}")
377
-
378
- if event_type == "text-generation" and event_data and "text" in event_data:
379
- yield event_data["text"]
380
- elif event_type == "stream-end":
381
- pass
382
-
383
-
384
- elif provider_lower == "huggingface":
385
- yield f"Error: Direct Hugging Face Inference API streaming for chat models is experimental and model-dependent. Consider using OpenRouter or TogetherAI for HF models with standardized streaming."
386
- return
387
-
388
- else:
389
- yield f"Error: Unsupported provider '{provider}' for streaming chat."
390
- return
391
 
392
- except requests.exceptions.HTTPError as e:
393
- status_code = e.response.status_code if e.response is not None else 'N/A'
394
- error_text = e.response.text if e.response is not None else 'No response text'
395
- logger.error(f"HTTP error during streaming for {provider}/{model_id}: {e}")
396
- yield f"API HTTP Error ({status_code}): {error_text}\nDetails: {e}"
397
- except requests.exceptions.RequestException as e:
398
- logger.error(f"Request error during streaming for {provider}/{model_id}: {e}")
399
- yield f"API Request Error: Could not connect or receive response from {provider} ({e})"
400
- except Exception as e:
401
- logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:")
402
- yield f"An unexpected error occurred: {e}"
 
1
+ # model_handler.py
2
  import os
3
  import requests
4
  import json
5
  import logging
6
+ from dotenv import load_dotenv
7
+
8
+ # Load environment variables from .env file
9
+ load_dotenv()
10
 
11
  logging.basicConfig(
12
  level=logging.INFO,
 
14
  )
15
  logger = logging.getLogger(__name__)
16
 
17
+ # Maps provider name (uppercase) to environment variable name for API key
18
+ API_KEYS_ENV_VARS = {
19
+ "HUGGINGFACE": 'HF_TOKEN', # Note: HF_TOKEN is often used for general HF auth
20
  "GROQ": 'GROQ_API_KEY',
21
  "OPENROUTER": 'OPENROUTER_API_KEY',
22
  "TOGETHERAI": 'TOGETHERAI_API_KEY',
23
  "COHERE": 'COHERE_API_KEY',
24
  "XAI": 'XAI_API_KEY',
25
  "OPENAI": 'OPENAI_API_KEY',
26
+ "GOOGLE": 'GOOGLE_API_KEY', # Or GOOGLE_GEMINI_API_KEY etc.
27
  }
28
 
29
  API_URLS = {
 
31
  "GROQ": 'https://api.groq.com/openai/v1/chat/completions',
32
  "OPENROUTER": 'https://openrouter.ai/api/v1/chat/completions',
33
  "TOGETHERAI": 'https://api.together.ai/v1/chat/completions',
34
+ "COHERE": 'https://api.cohere.ai/v1/chat', # v1 is common for chat, was v2 in ai-learn
35
  "XAI": 'https://api.x.ai/v1/chat/completions',
36
  "OPENAI": 'https://api.openai.com/v1/chat/completions',
37
  "GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/',
38
  }
39
 
40
+ # Structure: provider_key: { "default": "model_id", "models": {"Display Name": "model_id", ...} }
41
  MODELS_BY_PROVIDER = {
42
  "groq": {
43
  "default": "llama3-8b-8192",
 
49
  }
50
  },
51
  "openrouter": {
52
+ "default": "nousresearch/llama-3-8b-instruct", # Updated default
53
  "models": {
54
  "Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct",
55
+ "Mistral 7B Instruct v0.3 (OpenRouter)": "mistralai/mistral-7b-instruct-v0.3", # v0.3 is newer
56
+ "Mistral 7B Instruct (Free/OpenRouter)": "mistralai/mistral-7b-instruct:free", # Keep free tier if distinct
57
+ "Gemma 2 9B Instruct (OpenRouter)": "google/gemma-2-9b-it", # Gemma 2
58
+ "Gemma 7B Instruct (Free/OpenRouter)": "google/gemma-7b-it:free",
59
+ "Llama 3.1 8B Instruct (OpenRouter)": "meta-llama/llama-3.1-8b-instruct", # Llama 3.1
60
+ "Llama 3.1 70B Instruct (OpenRouter)": "meta-llama/llama-3.1-70b-instruct",
61
+ "OpenAI GPT-4o mini (OpenRouter)": "openai/gpt-4o-mini",
62
+ "OpenAI GPT-4o (OpenRouter)": "openai/gpt-4o",
63
+ "Claude 3.5 Sonnet (OpenRouter)": "anthropic/claude-3.5-sonnet",
64
+ "Mixtral 8x7B Instruct v0.1 (OpenRouter)": "mistralai/mixtral-8x7b-instruct", # Older Mixtral
65
+ "Qwen 2 72B Instruct (OpenRouter)": "qwen/qwen-2-72b-instruct",
66
  }
67
  },
68
  "togetherai": {
 
72
  "Llama 3 70B Chat (TogetherAI)": "meta-llama/Llama-3-70b-chat-hf",
73
  "Mixtral 8x7B Instruct (TogetherAI)": "mistralai/Mixtral-8x7B-Instruct-v0.1",
74
  "Gemma 7B Instruct (TogetherAI)": "google/gemma-7b-it",
75
+ "Qwen1.5-72B-Chat (TogetherAI)": "qwen/Qwen1.5-72B-Chat",
76
  }
77
  },
78
  "google": {
 
80
  "models": {
81
  "Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest",
82
  "Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest",
83
+ # "Gemini 1.0 Pro": "gemini-pro" # Older model example
84
  }
85
  },
86
  "cohere": {
87
+ "default": "command-r", # command-r is generally better than light
88
  "models": {
89
  "Command R (Cohere)": "command-r",
90
  "Command R+ (Cohere)": "command-r-plus",
91
  "Command Light (Cohere)": "command-light",
 
92
  }
93
  },
94
+ "huggingface": { # Direct HF Inference API is tricky for chat, often better via OpenRouter/TogetherAI
95
+ "default": "mistralai/Mistral-7B-Instruct-v0.2", # A common TGI compatible model
96
  "models": {
 
97
  "Mistral 7B Instruct v0.2 (HF Inf.)": "mistralai/Mistral-7B-Instruct-v0.2",
98
+ "Llama 3 8B Instruct (HF Inf.)": "meta-llama/Meta-Llama-3-8B-Instruct", # Ensure this specific ID is for TGI
99
+ # "Zephyr 7B Beta (H4/HF Inf.)": "HuggingFaceH4/zephyr-7b-beta", # Older model
100
  }
101
  },
102
  "openai": {
103
+ "default": "gpt-4o-mini", # New default
104
  "models": {
105
  "GPT-4o (OpenAI)": "gpt-4o",
106
  "GPT-4o mini (OpenAI)": "gpt-4o-mini",
107
+ "GPT-4 Turbo (OpenAI)": "gpt-4-turbo", # Refers to latest gpt-4-turbo variant
108
+ "GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo", # Refers to latest gpt-3.5-turbo variant
109
  }
110
  },
111
+ "xai": { # Assuming xAI might expand model list
112
+ "default": "grok-1.5-flash", # Assuming Grok 1.5 flash is available
113
  "models": {
114
+ "Grok 1.5 Flash (xAI)": "grok-1.5-flash",
115
+ # "Grok-1 (xAI)": "grok-1", # Older model
116
  }
117
  }
118
  }
119
 
120
+ def _get_api_key(provider: str, ui_api_key_override: str = None) -> str | None:
121
+ """
122
+ Retrieves API key for a given provider.
123
+ Priority: UI Override > Environment Variable from API_KEYS_ENV_VARS > Specific (e.g. HF_TOKEN for HuggingFace).
124
+ """
125
+ provider_upper = provider.upper()
126
+ if ui_api_key_override and ui_api_key_override.strip():
127
+ logger.debug(f"Using UI-provided API key for {provider_upper}.")
128
  return ui_api_key_override.strip()
129
 
130
+ env_var_name = API_KEYS_ENV_VARS.get(provider_upper)
131
  if env_var_name:
132
  env_key = os.getenv(env_var_name)
133
+ if env_key and env_key.strip():
134
+ logger.debug(f"Using API key from env var '{env_var_name}' for {provider_upper}.")
135
  return env_key.strip()
136
 
137
+ # Specific fallback for HuggingFace if HF_TOKEN is set and API_KEYS_ENV_VARS['HUGGINGFACE'] wasn't specific enough
138
+ if provider_upper == 'HUGGINGFACE':
139
+ hf_token_fallback = os.getenv("HF_TOKEN")
140
+ if hf_token_fallback and hf_token_fallback.strip():
141
+ logger.debug("Using HF_TOKEN as fallback for HuggingFace provider.")
142
+ return hf_token_fallback.strip()
143
 
144
+ logger.warning(f"API Key not found for provider '{provider_upper}'. Checked UI override and environment variable '{env_var_name or 'N/A'}'.")
145
  return None
146
 
147
  def get_available_providers() -> list[str]:
148
+ """Returns a sorted list of available provider names (e.g., 'groq', 'openai')."""
149
  return sorted(list(MODELS_BY_PROVIDER.keys()))
150
 
151
+ def get_model_display_names_for_provider(provider: str) -> list[str]:
152
+ """Returns a sorted list of model display names for a given provider."""
153
  return sorted(list(MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).keys()))
154
 
155
+ def get_default_model_display_name_for_provider(provider: str) -> str | None:
156
+ """Gets the default model's display name for a provider."""
157
+ provider_data = MODELS_BY_PROVIDER.get(provider.lower(), {})
158
+ models_dict = provider_data.get("models", {})
159
+ default_model_id = provider_data.get("default")
160
+
161
+ if default_model_id and models_dict:
162
+ for display_name, model_id_val in models_dict.items():
163
+ if model_id_val == default_model_id:
164
  return display_name
165
+
166
+ # Fallback to the first model in the sorted list if default not found or not set
167
  if models_dict:
168
+ sorted_display_names = sorted(list(models_dict.keys()))
169
+ if sorted_display_names:
170
+ return sorted_display_names[0]
171
  return None
172
 
173
  def get_model_id_from_display_name(provider: str, display_name: str) -> str | None:
174
+ """Gets the actual model ID from its display name for a given provider."""
175
  models = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
176
  return models.get(display_name)
177
 
178
+
179
+ def call_model_stream(provider: str, model_display_name: str, messages: list[dict], api_key_override: str = None, temperature: float = 0.7, max_tokens: int = None) -> iter:
180
+ """
181
+ Calls the specified model via its provider and streams the response.
182
+ Handles provider-specific request formatting and error handling.
183
+ Yields chunks of the response text or an error string.
184
+ """
185
  provider_lower = provider.lower()
186
  api_key = _get_api_key(provider_lower, api_key_override)
 
187
  base_url = API_URLS.get(provider.upper())
188
  model_id = get_model_id_from_display_name(provider_lower, model_display_name)
189
 
190
  if not api_key:
191
+ env_var_name = API_KEYS_ENV_VARS.get(provider.upper(), 'N/A')
192
+ yield f"Error: API Key not found for {provider}. Please set it in the UI or env var '{env_var_name}'."
193
  return
194
  if not base_url:
195
  yield f"Error: Unknown provider '{provider}' or missing API URL configuration."
196
  return
197
  if not model_id:
198
+ yield f"Error: Model ID not found for '{model_display_name}' under provider '{provider}'. Check configuration."
199
  return
200
 
201
  headers = {}
202
  payload = {}
203
  request_url = base_url
204
 
205
+ logger.info(f"Streaming from {provider}/{model_display_name} (ID: {model_id})...")
206
+
207
+ # --- Standard OpenAI-compatible providers ---
208
+ if provider_lower in ["groq", "openrouter", "togetherai", "openai", "xai"]:
209
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
210
+ payload = {"model": model_id, "messages": messages, "stream": True, "temperature": temperature}
211
+ if max_tokens: payload["max_tokens"] = max_tokens
212
 
213
+ if provider_lower == "openrouter":
214
+ headers["HTTP-Referer"] = os.getenv("OPENROUTER_REFERRER") or "http://localhost/gradio" # Example Referer
215
+ headers["X-Title"] = os.getenv("OPENROUTER_X_TITLE") or "Gradio AI Researcher" # Example Title
 
 
 
 
 
 
 
 
216
 
217
+ try:
218
  response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
219
  response.raise_for_status()
220
 
221
+ # More robust SSE parsing
222
+ buffer = ""
223
+ for chunk in response.iter_content(chunk_size=None): # Process raw bytes
224
+ buffer += chunk.decode('utf-8', errors='replace')
225
+ while '\n\n' in buffer:
226
+ event_str, buffer = buffer.split('\n\n', 1)
227
+ if not event_str.strip(): continue
228
+
229
+ content_chunk = ""
230
+ for line in event_str.splitlines():
231
+ if line.startswith('data: '):
232
+ data_json = line[len('data: '):].strip()
233
+ if data_json == '[DONE]':
234
+ return # Stream finished
235
+ try:
236
+ data = json.loads(data_json)
237
+ if data.get("choices") and len(data["choices"]) > 0:
238
+ delta = data["choices"][0].get("delta", {})
239
+ if delta and delta.get("content"):
240
+ content_chunk += delta["content"]
241
+ except json.JSONDecodeError:
242
+ logger.warning(f"Failed to decode JSON from stream line: {data_json}")
243
+ if content_chunk:
244
+ yield content_chunk
245
+ # Process any remaining buffer content (less common with '\n\n' delimiter)
246
+ if buffer.strip():
247
+ logger.debug(f"Remaining buffer after OpenAI-like stream: {buffer}")
248
+
249
+
250
+ except requests.exceptions.HTTPError as e:
251
+ err_msg = f"API HTTP Error ({e.response.status_code}): {e.response.text[:500]}"
252
+ logger.error(f"{err_msg} for {provider}/{model_id}", exc_info=False)
253
+ yield f"Error: {err_msg}"
254
+ except requests.exceptions.RequestException as e:
255
+ logger.error(f"API Request Error for {provider}/{model_id}: {e}", exc_info=False)
256
+ yield f"Error: Could not connect to {provider} ({e})"
257
+ except Exception as e:
258
+ logger.exception(f"Unexpected error during {provider} stream:")
259
+ yield f"Error: An unexpected error occurred: {e}"
260
+ return
261
+
262
+ # --- Google Gemini ---
263
+ elif provider_lower == "google":
264
+ system_instruction = None
265
+ filtered_messages = []
266
+ for msg in messages:
267
+ if msg["role"] == "system": system_instruction = {"parts": [{"text": msg["content"]}]}
268
+ else:
269
+ role = "model" if msg["role"] == "assistant" else msg["role"]
270
+ filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]})
271
+
272
+ payload = {
273
+ "contents": filtered_messages,
274
+ "safetySettings": [ # Example: more permissive settings
275
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
276
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
277
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
278
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
279
+ ],
280
+ "generationConfig": {"temperature": temperature}
281
+ }
282
+ if max_tokens: payload["generationConfig"]["maxOutputTokens"] = max_tokens
283
+ if system_instruction: payload["system_instruction"] = system_instruction
284
+
285
+ request_url = f"{base_url}{model_id}:streamGenerateContent?key={api_key}" # API key in query param
286
+ headers = {"Content-Type": "application/json"}
287
 
288
+ try:
289
  response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
290
  response.raise_for_status()
291
+
292
+ # Google's stream is a bit different, often newline-delimited JSON arrays/objects
293
+ buffer = ""
294
+ for chunk in response.iter_content(chunk_size=None):
295
+ buffer += chunk.decode('utf-8', errors='replace')
296
+ # Google might send chunks that are not complete JSON objects, or multiple objects
297
+ # A common pattern is [ {obj1} , {obj2} ] where chunks split mid-array or mid-object.
298
+ # This parsing needs to be robust. A simple split by '\n' might not always work if JSON is pretty-printed.
299
+ # The previous code's `json.loads(f"[{decoded_line}]")` was an attempt to handle this.
300
+ # For now, let's assume newline delimited for simplicity, but this is a known tricky part.
301
+
302
+ while '\n' in buffer:
303
+ line, buffer = buffer.split('\n', 1)
304
+ line = line.strip()
305
+ if not line: continue
306
+ if line.startswith(','): line = line[1:] # Handle leading commas if splitting an array
307
 
308
  try:
309
+ # Remove "data: " prefix if present (less common for Gemini direct API but good practice)
310
+ if line.startswith('data: '): line = line[len('data: '):]
311
+
312
+ # Gemini often streams an array of objects, or just one object.
313
+ # Try to parse as a single object first. If fails, try as array.
314
+ parsed_data = None
315
+ try:
316
+ parsed_data = json.loads(line)
317
+ except json.JSONDecodeError:
318
+ # If it's part of an array, it might be missing brackets.
319
+ # This heuristic is fragile. A proper SSE parser or stateful JSON parser is better.
320
+ if line.startswith('{') and line.endswith('}'): # Looks like a complete object
321
+ pass # already tried json.loads
322
+ # Try to wrap with [] if it seems like a list content without brackets
323
+ elif line.startswith('{') or line.endswith('}'):
324
+ try:
325
+ temp_parsed_list = json.loads(f"[{line}]")
326
+ if temp_parsed_list and isinstance(temp_parsed_list, list):
327
+ parsed_data = temp_parsed_list[0] # take first if it becomes a list
328
+ except json.JSONDecodeError:
329
+ logger.warning(f"Google: Still can't parse line even with array wrap: {line}")
330
+
331
+ if parsed_data:
332
+ data_to_process = [parsed_data] if isinstance(parsed_data, dict) else parsed_data # Ensure list
333
+ for event_data in data_to_process:
334
+ if not isinstance(event_data, dict): continue
335
+ if event_data.get("candidates"):
336
+ for candidate in event_data["candidates"]:
337
+ if candidate.get("content", {}).get("parts"):
338
+ for part in candidate["content"]["parts"]:
339
+ if part.get("text"):
340
+ yield part["text"]
341
+ except json.JSONDecodeError:
342
+ logger.warning(f"Google: JSONDecodeError for line: {line}")
343
+ except Exception as e_google_proc:
344
+ logger.error(f"Google: Error processing stream data: {e_google_proc}, Line: {line}")
345
+
346
+ except requests.exceptions.HTTPError as e:
347
+ err_msg = f"Google API HTTP Error ({e.response.status_code}): {e.response.text[:500]}"
348
+ logger.error(err_msg, exc_info=False)
349
+ yield f"Error: {err_msg}"
350
+ except Exception as e:
351
+ logger.exception(f"Unexpected error during Google stream:")
352
+ yield f"Error: An unexpected error occurred with Google API: {e}"
353
+ return
354
 
355
+ # --- Cohere ---
356
+ elif provider_lower == "cohere":
357
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "Accept": "application/json"}
358
+
359
+ # Cohere message format
360
+ chat_history_cohere = []
361
+ preamble_cohere = None
362
+ user_message_cohere = ""
363
+
364
+ temp_messages = list(messages) # Work with a copy
365
+ if temp_messages and temp_messages[0]["role"] == "system":
366
+ preamble_cohere = temp_messages.pop(0)["content"]
367
+
368
+ if temp_messages:
369
+ user_message_cohere = temp_messages.pop()["content"] # Last message is the current user query
370
+ for msg in temp_messages: # Remaining are history
371
+ role = "USER" if msg["role"] == "user" else "CHATBOT"
372
+ chat_history_cohere.append({"role": role, "message": msg["content"]})
373
+
374
+ if not user_message_cohere:
375
+ yield "Error: User message is empty for Cohere."
376
+ return
377
 
378
+ payload = {
379
+ "model": model_id,
380
+ "message": user_message_cohere,
381
+ "stream": True,
382
+ "temperature": temperature
383
+ }
384
+ if max_tokens: payload["max_tokens"] = max_tokens # Cohere uses max_tokens
385
+ if chat_history_cohere: payload["chat_history"] = chat_history_cohere
386
+ if preamble_cohere: payload["preamble"] = preamble_cohere
387
+
388
+ try:
389
+ response = requests.post(base_url, headers=headers, json=payload, stream=True, timeout=180)
390
+ response.raise_for_status()
391
+
392
+ # Cohere SSE format is event: type\ndata: {json}\n\n
393
+ buffer = ""
394
+ for chunk_bytes in response.iter_content(chunk_size=None):
395
+ buffer += chunk_bytes.decode('utf-8', errors='replace')
396
+ while '\n\n' in buffer:
397
+ event_str, buffer = buffer.split('\n\n', 1)
398
+ if not event_str.strip(): continue
399
+
400
+ event_type = None
401
+ data_json_str = None
402
+ for line in event_str.splitlines():
403
+ if line.startswith("event:"): event_type = line[len("event:"):].strip()
404
+ elif line.startswith("data:"): data_json_str = line[len("data:"):].strip()
405
+
406
+ if data_json_str:
407
+ try:
408
+ data = json.loads(data_json_str)
409
+ if event_type == "text-generation" and "text" in data:
410
+ yield data["text"]
411
+ elif event_type == "stream-end":
412
+ logger.debug(f"Cohere stream ended. Finish reason: {data.get('finish_reason')}")
413
+ return
414
+ except json.JSONDecodeError:
415
+ logger.warning(f"Cohere: Failed to decode JSON: {data_json_str}")
416
+ if buffer.strip():
417
+ logger.debug(f"Cohere: Remaining buffer: {buffer.strip()}")
418
+
419
+
420
+ except requests.exceptions.HTTPError as e:
421
+ err_msg = f"Cohere API HTTP Error ({e.response.status_code}): {e.response.text[:500]}"
422
+ logger.error(err_msg, exc_info=False)
423
+ yield f"Error: {err_msg}"
424
+ except Exception as e:
425
+ logger.exception(f"Unexpected error during Cohere stream:")
426
+ yield f"Error: An unexpected error occurred with Cohere API: {e}"
427
+ return
428
 
429
+ # --- HuggingFace Inference API (Basic TGI support) ---
430
+ # This is very basic and might not work for all models or complex scenarios.
431
+ # Assumes model is deployed with Text Generation Inference (TGI) and supports streaming.
432
+ elif provider_lower == "huggingface":
433
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
434
+ # Construct prompt string for TGI (often needs specific formatting)
435
+ # This is a generic attempt, specific models might need <|user|>, <|assistant|> etc.
436
+ prompt_parts = []
437
+ for msg in messages:
438
+ role_prefix = ""
439
+ if msg['role'] == 'system': role_prefix = "System: " # Or might be ignored/handled differently
440
+ elif msg['role'] == 'user': role_prefix = "User: "
441
+ elif msg['role'] == 'assistant': role_prefix = "Assistant: "
442
+ prompt_parts.append(f"{role_prefix}{msg['content']}")
443
+
444
+ # TGI typically expects a final "Assistant: " to start generating from
445
+ tgi_prompt = "\n".join(prompt_parts) + "\nAssistant: "
446
+
447
+ payload = {
448
+ "inputs": tgi_prompt,
449
+ "parameters": {
450
+ "temperature": temperature if temperature > 0 else 0.01, # TGI needs temp > 0 for sampling
451
+ "max_new_tokens": max_tokens or 1024, # Default TGI max_new_tokens
452
+ "return_full_text": False, # We only want generated part
453
+ "do_sample": True if temperature > 0 else False,
454
+ },
455
+ "stream": True
456
+ }
457
+ request_url = f"{base_url}{model_id}" # Model ID is part of URL path for HF
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
 
459
+ try:
460
  response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
461
  response.raise_for_status()
462
 
463
+ # TGI SSE stream: data: {"token": {"id": ..., "text": "...", "logprob": ..., "special": ...}}
464
+ # Or sometimes just data: "text_chunk" for simpler models/configs
465
+ buffer = ""
466
+ for chunk_bytes in response.iter_content(chunk_size=None):
467
+ buffer += chunk_bytes.decode('utf-8', errors='replace')
468
+ while '\n' in buffer: # TGI often uses single newline
469
+ line, buffer = buffer.split('\n', 1)
470
+ line = line.strip()
471
+ if not line: continue
472
+
473
+ if line.startswith('data:'):
474
+ data_json_str = line[len('data:'):].strip()
475
+ try:
476
+ data = json.loads(data_json_str)
477
+ if "token" in data and "text" in data["token"]:
478
+ yield data["token"]["text"]
479
+ elif "generated_text" in data and data.get("details") is None: # Sometimes a final non-streaming like object might appear
480
+ # This case is tricky, if it's the *only* thing then it's not really streaming
481
+ pass # For now, ignore if it's not a token object
482
+ # Some TGI might send raw text if not fully SSE compliant for stream
483
+ # elif isinstance(data, str): yield data
484
+
485
+ except json.JSONDecodeError:
486
+ # If it's not JSON, it might be a raw string (less common for TGI stream=True)
487
+ # For safety, only yield if it's a clear text string
488
+ if not data_json_str.startswith('{') and not data_json_str.startswith('['):
489
+ yield data_json_str
490
+ else:
491
+ logger.warning(f"HF: Failed to decode JSON and not raw string: {data_json_str}")
492
+ if buffer.strip():
493
+ logger.debug(f"HF: Remaining buffer: {buffer.strip()}")
494
+
495
+
496
+ except requests.exceptions.HTTPError as e:
497
+ err_msg = f"HF API HTTP Error ({e.response.status_code}): {e.response.text[:500]}"
498
+ logger.error(err_msg, exc_info=False)
499
+ yield f"Error: {err_msg}"
500
+ except Exception as e:
501
+ logger.exception(f"Unexpected error during HF stream:")
502
+ yield f"Error: An unexpected error occurred with HF API: {e}"
503
+ return
 
 
 
 
 
504
 
505
+ else:
506
+ yield f"Error: Provider '{provider}' is not configured for streaming in this handler."
507
+ return