Spaces:
Runtime error
Runtime error
Create model_logic.py
Browse files- model_logic.py +84 -0
model_logic.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# model_logic.py
|
2 |
+
import os
|
3 |
+
import requests
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
load_dotenv()
|
9 |
+
|
10 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
API_KEYS_ENV_VARS = {"GROQ": 'GROQ_API_KEY', "OPENROUTER": 'OPENROUTER_API_KEY', "TOGETHERAI": 'TOGETHERAI_API_KEY', "COHERE": 'COHERE_API_KEY', "XAI": 'XAI_API_KEY', "OPENAI": 'OPENAI_API_KEY', "GOOGLE": 'GOOGLE_API_KEY', "HUGGINGFACE": 'HF_TOKEN'}
|
14 |
+
API_URLS = {"GROQ": 'https://api.groq.com/openai/v1/chat/completions', "OPENROUTER": 'https://openrouter.ai/api/v1/chat/completions', "TOGETHERAI": 'https://api.together.ai/v1/chat/completions', "COHERE": 'https://api.cohere.ai/v1/chat', "XAI": 'https://api.x.ai/v1/chat/completions', "OPENAI": 'https://api.openai.com/v1/chat/completions', "GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/', "HUGGINGFACE": 'https://api-inference.huggingface.co/models/'}
|
15 |
+
|
16 |
+
try:
|
17 |
+
with open("models.json", "r") as f: MODELS_BY_PROVIDER = json.load(f)
|
18 |
+
except (FileNotFoundError, json.JSONDecodeError):
|
19 |
+
logger.warning("models.json not found or invalid. Using a fallback model list.")
|
20 |
+
MODELS_BY_PROVIDER = {"groq": {"default": "llama3-8b-8192", "models": {"Llama 3 8B (Groq)": "llama3-8b-8192"}}}
|
21 |
+
|
22 |
+
def _get_api_key(provider: str, ui_api_key_override: str = None) -> str | None:
|
23 |
+
provider_upper = provider.upper()
|
24 |
+
if ui_api_key_override and ui_api_key_override.strip(): return ui_api_key_override.strip()
|
25 |
+
env_var_name = API_KEYS_ENV_VARS.get(provider_upper)
|
26 |
+
if env_var_name:
|
27 |
+
env_key = os.getenv(env_var_name)
|
28 |
+
if env_key and env_key.strip(): return env_key.strip()
|
29 |
+
return None
|
30 |
+
|
31 |
+
def get_available_providers() -> list[str]:
|
32 |
+
return sorted(list(MODELS_BY_PROVIDER.keys()))
|
33 |
+
|
34 |
+
def get_model_display_names_for_provider(provider: str) -> list[str]:
|
35 |
+
return sorted(list(MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).keys()))
|
36 |
+
|
37 |
+
def get_default_model_display_name_for_provider(provider: str) -> str | None:
|
38 |
+
provider_data = MODELS_BY_PROVIDER.get(provider.lower(), {})
|
39 |
+
models_dict = provider_data.get("models", {})
|
40 |
+
default_model_id = provider_data.get("default")
|
41 |
+
if default_model_id and models_dict:
|
42 |
+
for display_name, model_id_val in models_dict.items():
|
43 |
+
if model_id_val == default_model_id: return display_name
|
44 |
+
if models_dict: return sorted(list(models_dict.keys()))[0]
|
45 |
+
return None
|
46 |
+
|
47 |
+
def get_model_id_from_display_name(provider: str, display_name: str) -> str | None:
|
48 |
+
return MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).get(display_name)
|
49 |
+
|
50 |
+
def call_model_stream(provider: str, model_display_name: str, messages: list[dict], api_key_override: str = None, temperature: float = 0.7, max_tokens: int = None) -> iter:
|
51 |
+
provider_lower = provider.lower()
|
52 |
+
api_key = _get_api_key(provider_lower, api_key_override)
|
53 |
+
base_url = API_URLS.get(provider.upper())
|
54 |
+
model_id = get_model_id_from_display_name(provider_lower, model_display_name)
|
55 |
+
if not all([api_key, base_url, model_id]):
|
56 |
+
yield f"Error: Configuration missing for {provider}/{model_display_name}."
|
57 |
+
return
|
58 |
+
|
59 |
+
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
60 |
+
payload = {"model": model_id, "messages": messages, "stream": True, "temperature": temperature}
|
61 |
+
if max_tokens: payload["max_tokens"] = max_tokens
|
62 |
+
if provider_lower == "openrouter": headers["HTTP-Referer"] = os.getenv("OPENROUTER_REFERRER", "http://localhost")
|
63 |
+
|
64 |
+
try:
|
65 |
+
response = requests.post(base_url, headers=headers, json=payload, stream=True, timeout=180)
|
66 |
+
response.raise_for_status()
|
67 |
+
buffer = ""
|
68 |
+
for chunk in response.iter_content(chunk_size=None):
|
69 |
+
buffer += chunk.decode('utf-8', errors='replace')
|
70 |
+
while '\n\n' in buffer:
|
71 |
+
event_str, buffer = buffer.split('\n\n', 1)
|
72 |
+
if not event_str.strip() or not event_str.startswith('data: '): continue
|
73 |
+
data_json = event_str[len('data: '):].strip()
|
74 |
+
if data_json == '[DONE]': return
|
75 |
+
try:
|
76 |
+
data = json.loads(data_json)
|
77 |
+
if data.get("choices") and len(data["choices"]) > 0:
|
78 |
+
delta = data["choices"][0].get("delta", {})
|
79 |
+
if delta and delta.get("content"): yield delta["content"]
|
80 |
+
except json.JSONDecodeError: continue
|
81 |
+
except requests.exceptions.HTTPError as e:
|
82 |
+
yield f"Error: API HTTP Error ({e.response.status_code}): {e.response.text[:200]}"
|
83 |
+
except Exception as e:
|
84 |
+
yield f"Error: An unexpected error occurred: {e}"
|