Node-Brain-BU3 / model_logic.py
broadfield-dev's picture
Update model_logic.py
14149fd verified
# model_logic.py
import os
import requests
import json
import logging
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
API_KEYS_ENV_VARS = {"GROQ": 'GROQ_API_KEY', "OPENROUTER": 'OPENROUTER_API_KEY', "TOGETHERAI": 'TOGETHERAI_API_KEY', "COHERE": 'COHERE_API_KEY', "XAI": 'XAI_API_KEY', "OPENAI": 'OPENAI_API_KEY', "GOOGLE": 'GOOGLE_API_KEY', "HUGGINGFACE": 'HF_TOKEN'}
API_URLS = {"GROQ": 'https://api.groq.com/openai/v1/chat/completions', "OPENROUTER": 'https://openrouter.ai/api/v1/chat/completions', "TOGETHERAI": 'https://api.together.ai/v1/chat/completions', "COHERE": 'https://api.cohere.ai/v1/chat', "XAI": 'https://api.x.ai/v1/chat/completions', "OPENAI": 'https://api.openai.com/v1/chat/completions', "GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/', "HUGGINGFACE": 'https://api-inference.huggingface.co/models/'}
try:
with open("models.json", "r") as f: MODELS_BY_PROVIDER = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
logger.warning("models.json not found or invalid. Using a fallback model list.")
MODELS_BY_PROVIDER = {"groq": {"default": "llama3-8b-8192", "models": {"Llama 3 8B (Groq)": "llama3-8b-8192"}}}
def _get_api_key(provider: str, ui_api_key_override: str = None) -> str | None:
provider_upper = provider.upper()
if ui_api_key_override and ui_api_key_override.strip(): return ui_api_key_override.strip()
env_var_name = API_KEYS_ENV_VARS.get(provider_upper)
if env_var_name:
env_key = os.getenv(env_var_name)
if env_key and env_key.strip(): return env_key.strip()
return None
def get_available_providers() -> list[str]:
return sorted(list(MODELS_BY_PROVIDER.keys()))
def get_model_display_names_for_provider(provider: str) -> list[str]:
return sorted(list(MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).keys()))
def get_default_model_display_name_for_provider(provider: str) -> str | None:
provider_data = MODELS_BY_PROVIDER.get(provider.lower(), {})
models_dict = provider_data.get("models", {})
default_model_id = provider_data.get("default")
if default_model_id and models_dict:
for display_name, model_id_val in models_dict.items():
if model_id_val == default_model_id: return display_name
if models_dict: return sorted(list(models_dict.keys()))[0]
return None
def get_model_id_from_display_name(provider: str, display_name: str) -> str | None:
return MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).get(display_name)
def call_model_stream(provider: str, model_display_name: str, messages: list[dict], api_key_override: str = None, temperature: float = 0.7, max_tokens: int = None) -> iter:
provider_lower = provider.lower()
api_key = _get_api_key(provider_lower, api_key_override)
base_url = API_URLS.get(provider.upper())
model_id = get_model_id_from_display_name(provider_lower, model_display_name)
if not all([api_key, base_url, model_id]):
yield f"Error: Configuration missing for {provider}/{model_display_name}."
return
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = {"model": model_id, "messages": messages, "stream": True, "temperature": temperature}
if max_tokens: payload["max_tokens"] = max_tokens
if provider_lower == "openrouter": headers["HTTP-Referer"] = os.getenv("OPENROUTER_REFERRER", "http://localhost")
try:
response = requests.post(base_url, headers=headers, json=payload, stream=True, timeout=180)
response.raise_for_status()
buffer = ""
for chunk in response.iter_content(chunk_size=None):
buffer += chunk.decode('utf-8', errors='replace')
while '\n\n' in buffer:
event_str, buffer = buffer.split('\n\n', 1)
if not event_str.strip() or not event_str.startswith('data: '): continue
data_json = event_str[len('data: '):].strip()
if data_json == '[DONE]': return
try:
data = json.loads(data_json)
if data.get("choices") and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
if delta and delta.get("content"): yield delta["content"]
except json.JSONDecodeError: continue
except requests.exceptions.HTTPError as e:
yield f"Error: API HTTP Error ({e.response.status_code}): {e.response.text[:200]}"
except Exception as e:
yield f"Error: An unexpected error occurred: {e}"