File size: 4,725 Bytes
14149fd
60122df
14149fd
60122df
 
14149fd
60122df
14149fd
60122df
14149fd
60122df
b79ff5a
14149fd
 
60122df
14149fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b79ff5a
 
14149fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b79ff5a
14149fd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# model_logic.py
import os
import requests
import json
import logging
from dotenv import load_dotenv

load_dotenv()

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

API_KEYS_ENV_VARS = {"GROQ": 'GROQ_API_KEY', "OPENROUTER": 'OPENROUTER_API_KEY', "TOGETHERAI": 'TOGETHERAI_API_KEY', "COHERE": 'COHERE_API_KEY', "XAI": 'XAI_API_KEY', "OPENAI": 'OPENAI_API_KEY', "GOOGLE": 'GOOGLE_API_KEY', "HUGGINGFACE": 'HF_TOKEN'}
API_URLS = {"GROQ": 'https://api.groq.com/openai/v1/chat/completions', "OPENROUTER": 'https://openrouter.ai/api/v1/chat/completions', "TOGETHERAI": 'https://api.together.ai/v1/chat/completions', "COHERE": 'https://api.cohere.ai/v1/chat', "XAI": 'https://api.x.ai/v1/chat/completions', "OPENAI": 'https://api.openai.com/v1/chat/completions', "GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/', "HUGGINGFACE": 'https://api-inference.huggingface.co/models/'}

try:
    with open("models.json", "r") as f: MODELS_BY_PROVIDER = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
    logger.warning("models.json not found or invalid. Using a fallback model list.")
    MODELS_BY_PROVIDER = {"groq": {"default": "llama3-8b-8192", "models": {"Llama 3 8B (Groq)": "llama3-8b-8192"}}}

def _get_api_key(provider: str, ui_api_key_override: str = None) -> str | None:
    provider_upper = provider.upper()
    if ui_api_key_override and ui_api_key_override.strip(): return ui_api_key_override.strip()
    env_var_name = API_KEYS_ENV_VARS.get(provider_upper)
    if env_var_name:
        env_key = os.getenv(env_var_name)
        if env_key and env_key.strip(): return env_key.strip()
    return None

def get_available_providers() -> list[str]:
    return sorted(list(MODELS_BY_PROVIDER.keys()))

def get_model_display_names_for_provider(provider: str) -> list[str]:
    return sorted(list(MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).keys()))

def get_default_model_display_name_for_provider(provider: str) -> str | None:
    provider_data = MODELS_BY_PROVIDER.get(provider.lower(), {})
    models_dict = provider_data.get("models", {})
    default_model_id = provider_data.get("default")
    if default_model_id and models_dict:
        for display_name, model_id_val in models_dict.items():
            if model_id_val == default_model_id: return display_name
    if models_dict: return sorted(list(models_dict.keys()))[0]
    return None

def get_model_id_from_display_name(provider: str, display_name: str) -> str | None:
    return MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).get(display_name)

def call_model_stream(provider: str, model_display_name: str, messages: list[dict], api_key_override: str = None, temperature: float = 0.7, max_tokens: int = None) -> iter:
    provider_lower = provider.lower()
    api_key = _get_api_key(provider_lower, api_key_override)
    base_url = API_URLS.get(provider.upper())
    model_id = get_model_id_from_display_name(provider_lower, model_display_name)
    if not all([api_key, base_url, model_id]):
        yield f"Error: Configuration missing for {provider}/{model_display_name}."
        return

    headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
    payload = {"model": model_id, "messages": messages, "stream": True, "temperature": temperature}
    if max_tokens: payload["max_tokens"] = max_tokens
    if provider_lower == "openrouter": headers["HTTP-Referer"] = os.getenv("OPENROUTER_REFERRER", "http://localhost")

    try:
        response = requests.post(base_url, headers=headers, json=payload, stream=True, timeout=180)
        response.raise_for_status()
        buffer = ""
        for chunk in response.iter_content(chunk_size=None):
            buffer += chunk.decode('utf-8', errors='replace')
            while '\n\n' in buffer:
                event_str, buffer = buffer.split('\n\n', 1)
                if not event_str.strip() or not event_str.startswith('data: '): continue
                data_json = event_str[len('data: '):].strip()
                if data_json == '[DONE]': return
                try:
                    data = json.loads(data_json)
                    if data.get("choices") and len(data["choices"]) > 0:
                        delta = data["choices"][0].get("delta", {})
                        if delta and delta.get("content"): yield delta["content"]
                except json.JSONDecodeError: continue
    except requests.exceptions.HTTPError as e:
        yield f"Error: API HTTP Error ({e.response.status_code}): {e.response.text[:200]}"
    except Exception as e:
        yield f"Error: An unexpected error occurred: {e}"