broadfield-dev commited on
Commit
6b5f0c3
·
verified ·
1 Parent(s): b80b022

Create model_logic.py

Browse files
Files changed (1) hide show
  1. model_logic.py +402 -0
model_logic.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import json
4
+ import logging
5
+
6
+ logging.basicConfig(
7
+ level=logging.INFO,
8
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
9
+ )
10
+ logger = logging.getLogger(__name__)
11
+
12
+ API_KEYS = {
13
+ "HUGGINGFACE": 'HF_TOKEN',
14
+ "GROQ": 'GROQ_API_KEY',
15
+ "OPENROUTER": 'OPENROUTER_API_KEY',
16
+ "TOGETHERAI": 'TOGETHERAI_API_KEY',
17
+ "COHERE": 'COHERE_API_KEY',
18
+ "XAI": 'XAI_API_KEY',
19
+ "OPENAI": 'OPENAI_API_KEY',
20
+ "GOOGLE": 'GOOGLE_API_KEY',
21
+ }
22
+
23
+ API_URLS = {
24
+ "HUGGINGFACE": 'https://api-inference.huggingface.co/models/',
25
+ "GROQ": 'https://api.groq.com/openai/v1/chat/completions',
26
+ "OPENROUTER": 'https://openrouter.ai/api/v1/chat/completions',
27
+ "TOGETHERAI": 'https://api.together.ai/v1/chat/completions',
28
+ "COHERE": 'https://api.cohere.ai/v1/chat',
29
+ "XAI": 'https://api.x.ai/v1/chat/completions',
30
+ "OPENAI": 'https://api.openai.com/v1/chat/completions',
31
+ "GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/',
32
+ }
33
+
34
+ MODELS_BY_PROVIDER = {
35
+ "groq": {
36
+ "default": "llama3-8b-8192",
37
+ "models": {
38
+ "Llama 3 8B (Groq)": "llama3-8b-8192",
39
+ "Llama 3 70B (Groq)": "llama3-70b-8192",
40
+ "Mixtral 8x7B (Groq)": "mixtral-8x7b-32768",
41
+ "Gemma 7B (Groq)": "gemma-7b-it",
42
+ }
43
+ },
44
+ "openrouter": {
45
+ "default": "nousresearch/llama-3-8b-instruct",
46
+ "models": {
47
+ "Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct",
48
+ "Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free",
49
+ "Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free",
50
+ "Mixtral 8x7B Instruct v0.1 (OpenRouter)": "mistralai/mixtral-8x7b-instruct",
51
+ "Llama 2 70B Chat (OpenRouter)": "meta-llama/llama-2-70b-chat",
52
+ "Neural Chat 7B v3.1 (OpenRouter)": "intel/neural-chat-7b-v3-1",
53
+ "Goliath 120B (OpenRouter)": "twob/goliath-v2-120b",
54
+ }
55
+ },
56
+ "togetherai": {
57
+ "default": "meta-llama/Llama-3-8b-chat-hf",
58
+ "models": {
59
+ "Llama 3 8B Chat (TogetherAI)": "meta-llama/Llama-3-8b-chat-hf",
60
+ "Llama 3 70B Chat (TogetherAI)": "meta-llama/Llama-3-70b-chat-hf",
61
+ "Mixtral 8x7B Instruct (TogetherAI)": "mistralai/Mixtral-8x7B-Instruct-v0.1",
62
+ "Gemma 7B Instruct (TogetherAI)": "google/gemma-7b-it",
63
+ "RedPajama INCITE Chat 3B (TogetherAI)": "togethercomputer/RedPajama-INCITE-Chat-3B-v1",
64
+ }
65
+ },
66
+ "google": {
67
+ "default": "gemini-1.5-flash-latest",
68
+ "models": {
69
+ "Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest",
70
+ "Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest",
71
+ }
72
+ },
73
+ "cohere": {
74
+ "default": "command-light",
75
+ "models": {
76
+ "Command R (Cohere)": "command-r",
77
+ "Command R+ (Cohere)": "command-r-plus",
78
+ "Command Light (Cohere)": "command-light",
79
+ "Command (Cohere)": "command",
80
+ }
81
+ },
82
+ "huggingface": {
83
+ "default": "HuggingFaceH4/zephyr-7b-beta",
84
+ "models": {
85
+ "Zephyr 7B Beta (H4/HF Inf.)": "HuggingFaceH4/zephyr-7b-beta",
86
+ "Mistral 7B Instruct v0.2 (HF Inf.)": "mistralai/Mistral-7B-Instruct-v0.2",
87
+ "Llama 2 13B Chat (Meta/HF Inf.)": "meta-llama/Llama-2-13b-chat-hf",
88
+ "OpenAssistant/oasst-sft-4-pythia-12b (HF Inf.)": "OpenAssistant/oasst-sft-4-pythia-12b",
89
+ }
90
+ },
91
+ "openai": {
92
+ "default": "gpt-3.5-turbo",
93
+ "models": {
94
+ "GPT-4o (OpenAI)": "gpt-4o",
95
+ "GPT-4o mini (OpenAI)": "gpt-4o-mini",
96
+ "GPT-4 Turbo (OpenAI)": "gpt-4-turbo",
97
+ "GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
98
+ }
99
+ },
100
+ "xai": {
101
+ "default": "grok-1",
102
+ "models": {
103
+ "Grok-1 (xAI)": "grok-1",
104
+ }
105
+ }
106
+ }
107
+
108
+ def _get_api_key(provider: str, ui_api_key_override: str = None) -> str:
109
+ if ui_api_key_override:
110
+ return ui_api_key_override.strip()
111
+
112
+ env_var_name = API_KEYS.get(provider.upper())
113
+ if env_var_name:
114
+ env_key = os.getenv(env_var_name)
115
+ if env_key:
116
+ return env_key.strip()
117
+
118
+ if provider.lower() == 'huggingface':
119
+ hf_token = os.getenv("HF_TOKEN")
120
+ if hf_token: return hf_token.strip()
121
+
122
+ logger.warning(f"API Key not found for provider '{provider}'. Checked UI override and environment variable '{env_var_name or 'N/A'}'.")
123
+ return None
124
+
125
+ def get_available_providers() -> list[str]:
126
+ return sorted(list(MODELS_BY_PROVIDER.keys()))
127
+
128
+ def get_models_for_provider(provider: str) -> list[str]:
129
+ return sorted(list(MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).keys()))
130
+
131
+ def get_default_model_for_provider(provider: str) -> str | None:
132
+ models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
133
+ default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default")
134
+ if default_model_id:
135
+ for display_name, model_id in models_dict.items():
136
+ if model_id == default_model_id:
137
+ return display_name
138
+ if models_dict:
139
+ return sorted(list(models_dict.keys()))[0]
140
+ return None
141
+
142
+ def get_model_id_from_display_name(provider: str, display_name: str) -> str | None:
143
+ models = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
144
+ return models.get(display_name)
145
+
146
+ def generate_stream(provider: str, model_display_name: str, api_key_override: str, messages: list[dict]) -> iter:
147
+ provider_lower = provider.lower()
148
+ api_key = _get_api_key(provider_lower, api_key_override)
149
+
150
+ base_url = API_URLS.get(provider.upper())
151
+ model_id = get_model_id_from_display_name(provider_lower, model_display_name)
152
+
153
+ if not api_key:
154
+ env_var_name = API_KEYS.get(provider.upper(), 'N/A')
155
+ yield f"Error: API Key not found for {provider}. Please set it in the UI override or environment variable '{env_var_name}'."
156
+ return
157
+ if not base_url:
158
+ yield f"Error: Unknown provider '{provider}' or missing API URL configuration."
159
+ return
160
+ if not model_id:
161
+ yield f"Error: Unknown model '{model_display_name}' for provider '{provider}'. Please select a valid model."
162
+ return
163
+
164
+ headers = {}
165
+ payload = {}
166
+ request_url = base_url
167
+
168
+ logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...")
169
+
170
+ try:
171
+ if provider_lower in ["groq", "openrouter", "togetherai", "openai", "xai"]:
172
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
173
+ payload = {
174
+ "model": model_id,
175
+ "messages": messages,
176
+ "stream": True
177
+ }
178
+ if provider_lower == "openrouter":
179
+ headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/your_username/ai-space-builder"
180
+ headers["X-Title"] = "AI Space Builder"
181
+
182
+ response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
183
+ response.raise_for_status()
184
+
185
+ byte_buffer = b""
186
+ for chunk in response.iter_content(chunk_size=8192):
187
+ byte_buffer += chunk
188
+ while b'\n' in byte_buffer:
189
+ line, byte_buffer = byte_buffer.split(b'\n', 1)
190
+ decoded_line = line.decode('utf-8', errors='ignore')
191
+ if decoded_line.startswith('data: '):
192
+ data = decoded_line[6:]
193
+ if data == '[DONE]':
194
+ byte_buffer = b''
195
+ break
196
+ try:
197
+ event_data = json.loads(data)
198
+ if event_data.get("choices") and len(event_data["choices"]) > 0:
199
+ delta = event_data["choices"][0].get("delta")
200
+ if delta and delta.get("content"):
201
+ yield delta["content"]
202
+ except json.JSONDecodeError:
203
+ logger.warning(f"Failed to decode JSON from stream line: {decoded_line}")
204
+ except Exception as e:
205
+ logger.error(f"Error processing stream data: {e}, Data: {decoded_line}")
206
+ if byte_buffer:
207
+ remaining_line = byte_buffer.decode('utf-8', errors='ignore')
208
+ if remaining_line.startswith('data: '):
209
+ data = remaining_line[6:]
210
+ if data != '[DONE]':
211
+ try:
212
+ event_data = json.loads(data)
213
+ if event_data.get("choices") and len(event_data["choices"]) > 0:
214
+ delta = event_data["choices"][0].get("delta")
215
+ if delta and delta.get("content"):
216
+ yield delta["content"]
217
+ except json.JSONDecodeError:
218
+ logger.warning(f"Failed to decode final stream buffer JSON: {remaining_line}")
219
+ except Exception as e:
220
+ logger.error(f"Error processing final stream buffer data: {e}, Data: {remaining_line}")
221
+
222
+
223
+ elif provider_lower == "google":
224
+ system_instruction = None
225
+ filtered_messages = []
226
+ for msg in messages:
227
+ if msg["role"] == "system":
228
+ system_instruction = msg["content"]
229
+ else:
230
+ role = "model" if msg["role"] == "assistant" else msg["role"]
231
+ filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]})
232
+
233
+ payload = {
234
+ "contents": filtered_messages,
235
+ "safetySettings": [
236
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
237
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
238
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
239
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
240
+ ],
241
+ "generationConfig": {
242
+ "temperature": 0.7,
243
+ }
244
+ }
245
+ if system_instruction:
246
+ payload["system_instruction"] = {"parts": [{"text": system_instruction}]}
247
+
248
+ request_url = f"{base_url}{model_id}:streamGenerateContent"
249
+ headers = {"Content-Type": "application/json"}
250
+ request_url = f"{request_url}?key={api_key}"
251
+
252
+ response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
253
+ response.raise_for_status()
254
+
255
+ byte_buffer = b""
256
+ for chunk in response.iter_content(chunk_size=8192):
257
+ byte_buffer += chunk
258
+ while b'\n' in byte_buffer:
259
+ line, byte_buffer = byte_buffer.split(b'\n', 1)
260
+ decoded_line = line.decode('utf-8', errors='ignore')
261
+
262
+ if decoded_line.startswith('data: '):
263
+ decoded_line = decoded_line[6:].strip()
264
+
265
+ if not decoded_line: continue
266
+
267
+ try:
268
+ event_data_list = json.loads(f"[{decoded_line}]")
269
+ if not isinstance(event_data_list, list): event_data_list = [event_data_list]
270
+
271
+ for event_data in event_data_list:
272
+ if not isinstance(event_data, dict): continue
273
+
274
+ if event_data.get("candidates") and len(event_data["candidates"]) > 0:
275
+ candidate = event_data["candidates"][0]
276
+ if candidate.get("content") and candidate["content"].get("parts"):
277
+ full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
278
+ if full_text_chunk:
279
+ yield full_text_chunk
280
+
281
+ except json.JSONDecodeError:
282
+ logger.warning(f"Failed to decode JSON from Google stream chunk: {decoded_line}. Accumulating buffer.")
283
+ pass
284
+
285
+ except Exception as e:
286
+ logger.error(f"Error processing Google stream data: {e}, Data: {decoded_line}")
287
+
288
+ if byte_buffer:
289
+ remaining_line = byte_buffer.decode('utf-8', errors='ignore').strip()
290
+ if remaining_line:
291
+ try:
292
+ event_data_list = json.loads(f"[{remaining_line}]")
293
+ if not isinstance(event_data_list, list): event_data_list = [event_data_list]
294
+ for event_data in event_data_list:
295
+ if not isinstance(event_data, dict): continue
296
+ if event_data.get("candidates") and len(event_data["candidates"]) > 0:
297
+ candidate = event_data["candidates"][0]
298
+ if candidate.get("content") and candidate["content"].get("parts"):
299
+ full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
300
+ if full_text_chunk:
301
+ yield full_text_chunk
302
+ except json.JSONDecodeError:
303
+ logger.warning(f"Failed to decode final Google stream buffer JSON: {remaining_line}")
304
+ except Exception as e:
305
+ logger.error(f"Error processing final Google stream buffer data: {e}, Data: {remaining_line}")
306
+
307
+
308
+ elif provider_lower == "cohere":
309
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
310
+ request_url = f"{base_url}"
311
+
312
+ chat_history_for_cohere = []
313
+ system_prompt_for_cohere = None
314
+ current_message_for_cohere = ""
315
+
316
+ temp_history = []
317
+ for msg in messages:
318
+ if msg["role"] == "system":
319
+ system_prompt_for_cohere = msg["content"]
320
+ elif msg["role"] == "user" or msg["role"] == "assistant":
321
+ temp_history.append(msg)
322
+
323
+ if temp_history:
324
+ current_message_for_cohere = temp_history[-1]["content"]
325
+ chat_history_for_cohere = [{"role": ("chatbot" if m["role"] == "assistant" else m["role"]), "message": m["content"]} for m in temp_history[:-1]]
326
+
327
+ if not current_message_for_cohere:
328
+ yield "Error: User message not found for Cohere API call."
329
+ return
330
+
331
+ payload = {
332
+ "model": model_id,
333
+ "message": current_message_for_cohere,
334
+ "stream": True,
335
+ "temperature": 0.7
336
+ }
337
+ if chat_history_for_cohere:
338
+ payload["chat_history"] = chat_history_for_cohere
339
+ if system_prompt_for_cohere:
340
+ payload["preamble"] = system_prompt_for_cohere
341
+
342
+ response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
343
+ response.raise_for_status()
344
+
345
+ byte_buffer = b""
346
+ for chunk in response.iter_content(chunk_size=8192):
347
+ byte_buffer += chunk
348
+ while b'\n\n' in byte_buffer:
349
+ event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1)
350
+ lines = event_chunk.strip().split(b'\n')
351
+ event_type = None
352
+ event_data = None
353
+
354
+ for l in lines:
355
+ if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
356
+ elif l.startswith(b"data: "):
357
+ try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
358
+ except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}")
359
+
360
+ if event_type == "text-generation" and event_data and "text" in event_data:
361
+ yield event_data["text"]
362
+ elif event_type == "stream-end":
363
+ byte_buffer = b''
364
+ break
365
+
366
+ if byte_buffer:
367
+ event_chunk = byte_buffer.strip()
368
+ if event_chunk:
369
+ lines = event_chunk.split(b'\n')
370
+ event_type = None
371
+ event_data = None
372
+ for l in lines:
373
+ if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
374
+ elif l.startswith(b"data: "):
375
+ try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
376
+ except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode final event data JSON: {l[6:].strip()}")
377
+
378
+ if event_type == "text-generation" and event_data and "text" in event_data:
379
+ yield event_data["text"]
380
+ elif event_type == "stream-end":
381
+ pass
382
+
383
+
384
+ elif provider_lower == "huggingface":
385
+ yield f"Error: Direct Hugging Face Inference API streaming for chat models is experimental and model-dependent. Consider using OpenRouter or TogetherAI for HF models with standardized streaming."
386
+ return
387
+
388
+ else:
389
+ yield f"Error: Unsupported provider '{provider}' for streaming chat."
390
+ return
391
+
392
+ except requests.exceptions.HTTPError as e:
393
+ status_code = e.response.status_code if e.response is not None else 'N/A'
394
+ error_text = e.response.text if e.response is not None else 'No response text'
395
+ logger.error(f"HTTP error during streaming for {provider}/{model_id}: {e}")
396
+ yield f"API HTTP Error ({status_code}): {error_text}\nDetails: {e}"
397
+ except requests.exceptions.RequestException as e:
398
+ logger.error(f"Request error during streaming for {provider}/{model_id}: {e}")
399
+ yield f"API Request Error: Could not connect or receive response from {provider} ({e})"
400
+ except Exception as e:
401
+ logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:")
402
+ yield f"An unexpected error occurred: {e}"