broadfield-dev commited on
Commit
60122df
·
verified ·
1 Parent(s): 081e43b

Create model_logic.py

Browse files
Files changed (1) hide show
  1. model_logic.py +84 -0
model_logic.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_logic.py
2
+ import os
3
+ import requests
4
+ import json
5
+ import logging
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+
10
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
11
+ logger = logging.getLogger(__name__)
12
+
13
+ API_KEYS_ENV_VARS = {"GROQ": 'GROQ_API_KEY', "OPENROUTER": 'OPENROUTER_API_KEY', "TOGETHERAI": 'TOGETHERAI_API_KEY', "COHERE": 'COHERE_API_KEY', "XAI": 'XAI_API_KEY', "OPENAI": 'OPENAI_API_KEY', "GOOGLE": 'GOOGLE_API_KEY', "HUGGINGFACE": 'HF_TOKEN'}
14
+ API_URLS = {"GROQ": 'https://api.groq.com/openai/v1/chat/completions', "OPENROUTER": 'https://openrouter.ai/api/v1/chat/completions', "TOGETHERAI": 'https://api.together.ai/v1/chat/completions', "COHERE": 'https://api.cohere.ai/v1/chat', "XAI": 'https://api.x.ai/v1/chat/completions', "OPENAI": 'https://api.openai.com/v1/chat/completions', "GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/', "HUGGINGFACE": 'https://api-inference.huggingface.co/models/'}
15
+
16
+ try:
17
+ with open("models.json", "r") as f: MODELS_BY_PROVIDER = json.load(f)
18
+ except (FileNotFoundError, json.JSONDecodeError):
19
+ logger.warning("models.json not found or invalid. Using a fallback model list.")
20
+ MODELS_BY_PROVIDER = {"groq": {"default": "llama3-8b-8192", "models": {"Llama 3 8B (Groq)": "llama3-8b-8192"}}}
21
+
22
+ def _get_api_key(provider: str, ui_api_key_override: str = None) -> str | None:
23
+ provider_upper = provider.upper()
24
+ if ui_api_key_override and ui_api_key_override.strip(): return ui_api_key_override.strip()
25
+ env_var_name = API_KEYS_ENV_VARS.get(provider_upper)
26
+ if env_var_name:
27
+ env_key = os.getenv(env_var_name)
28
+ if env_key and env_key.strip(): return env_key.strip()
29
+ return None
30
+
31
+ def get_available_providers() -> list[str]:
32
+ return sorted(list(MODELS_BY_PROVIDER.keys()))
33
+
34
+ def get_model_display_names_for_provider(provider: str) -> list[str]:
35
+ return sorted(list(MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).keys()))
36
+
37
+ def get_default_model_display_name_for_provider(provider: str) -> str | None:
38
+ provider_data = MODELS_BY_PROVIDER.get(provider.lower(), {})
39
+ models_dict = provider_data.get("models", {})
40
+ default_model_id = provider_data.get("default")
41
+ if default_model_id and models_dict:
42
+ for display_name, model_id_val in models_dict.items():
43
+ if model_id_val == default_model_id: return display_name
44
+ if models_dict: return sorted(list(models_dict.keys()))[0]
45
+ return None
46
+
47
+ def get_model_id_from_display_name(provider: str, display_name: str) -> str | None:
48
+ return MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).get(display_name)
49
+
50
+ def call_model_stream(provider: str, model_display_name: str, messages: list[dict], api_key_override: str = None, temperature: float = 0.7, max_tokens: int = None) -> iter:
51
+ provider_lower = provider.lower()
52
+ api_key = _get_api_key(provider_lower, api_key_override)
53
+ base_url = API_URLS.get(provider.upper())
54
+ model_id = get_model_id_from_display_name(provider_lower, model_display_name)
55
+ if not all([api_key, base_url, model_id]):
56
+ yield f"Error: Configuration missing for {provider}/{model_display_name}."
57
+ return
58
+
59
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
60
+ payload = {"model": model_id, "messages": messages, "stream": True, "temperature": temperature}
61
+ if max_tokens: payload["max_tokens"] = max_tokens
62
+ if provider_lower == "openrouter": headers["HTTP-Referer"] = os.getenv("OPENROUTER_REFERRER", "http://localhost")
63
+
64
+ try:
65
+ response = requests.post(base_url, headers=headers, json=payload, stream=True, timeout=180)
66
+ response.raise_for_status()
67
+ buffer = ""
68
+ for chunk in response.iter_content(chunk_size=None):
69
+ buffer += chunk.decode('utf-8', errors='replace')
70
+ while '\n\n' in buffer:
71
+ event_str, buffer = buffer.split('\n\n', 1)
72
+ if not event_str.strip() or not event_str.startswith('data: '): continue
73
+ data_json = event_str[len('data: '):].strip()
74
+ if data_json == '[DONE]': return
75
+ try:
76
+ data = json.loads(data_json)
77
+ if data.get("choices") and len(data["choices"]) > 0:
78
+ delta = data["choices"][0].get("delta", {})
79
+ if delta and delta.get("content"): yield delta["content"]
80
+ except json.JSONDecodeError: continue
81
+ except requests.exceptions.HTTPError as e:
82
+ yield f"Error: API HTTP Error ({e.response.status_code}): {e.response.text[:200]}"
83
+ except Exception as e:
84
+ yield f"Error: An unexpected error occurred: {e}"