|
import os |
|
import requests |
|
import json |
|
import logging |
|
import time |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
API_KEYS = { |
|
"HUGGINGFACE": 'HF_TOKEN', |
|
"GROQ": 'GROQ_API_KEY', |
|
"OPENROUTER": 'OPENROUTER_API_KEY', |
|
"TOGETHERAI": 'TOGETHERAI_API_KEY', |
|
"COHERE": 'COHERE_API_KEY', |
|
"XAI": 'XAI_API_KEY', |
|
"OPENAI": 'OPENAI_API_KEY', |
|
"GOOGLE": 'GOOGLE_API_KEY', |
|
} |
|
|
|
API_URLS = { |
|
"HUGGINGFACE": 'https://api-inference.huggingface.co/models/', |
|
"GROQ": 'https://api.groq.com/openai/v1/chat/completions', |
|
"OPENROUTER": 'https://openrouter.ai/api/v1/chat/completions', |
|
"TOGETHERAI": 'https://api.together.ai/v1/chat/completions', |
|
"COHERE": 'https://api.cohere.ai/v1/chat', |
|
"XAI": 'https://api.x.ai/v1/chat/completions', |
|
"OPENAI": 'https://api.openai.com/v1/chat/completions', |
|
"GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/', |
|
} |
|
|
|
|
|
try: |
|
with open("models.json", "r") as f: |
|
MODELS_BY_PROVIDER = json.load(f) |
|
logger.info("models.json loaded successfully.") |
|
except FileNotFoundError: |
|
logger.error("models.json not found. Using hardcoded fallback models.") |
|
|
|
MODELS_BY_PROVIDER = { |
|
"groq": { |
|
"default": "llama3-8b-8192", |
|
"models": { |
|
"Llama 3 8B (Groq)": "llama3-8b-8192", |
|
"Llama 3 70B (Groq)": "llama3-70b-8192", |
|
"Mixtral 8x7B (Groq)": "mixtral-8x7b-32768", |
|
"Gemma 7B (Groq)": "gemma-7b-it", |
|
} |
|
}, |
|
"openrouter": { |
|
"default": "nousresearch/llama-3-8b-instruct", |
|
"models": { |
|
"Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct", |
|
"Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free", |
|
"Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free", |
|
} |
|
}, |
|
"google": { |
|
"default": "gemini-1.5-flash-latest", |
|
"models": { |
|
"Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest", |
|
"Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest", |
|
} |
|
}, |
|
"openai": { |
|
"default": "gpt-3.5-turbo", |
|
"models": { |
|
"GPT-4o mini (OpenAI)": "gpt-4o-mini", |
|
"GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo", |
|
} |
|
}, |
|
|
|
} |
|
except json.JSONDecodeError: |
|
logger.error("Error decoding models.json. Using hardcoded fallback models.") |
|
|
|
MODELS_BY_PROVIDER = { |
|
"groq": { |
|
"default": "llama3-8b-8192", |
|
"models": { |
|
"Llama 3 8B (Groq)": "llama3-8b-8192", |
|
"Llama 3 70B (Groq)": "llama3-70b-8192", |
|
"Mixtral 8x7B (Groq)": "mixtral-8x7b-32768", |
|
"Gemma 7B (Groq)": "gemma-7b-it", |
|
} |
|
}, |
|
"openrouter": { |
|
"default": "nousresearch/llama-3-8b-instruct", |
|
"models": { |
|
"Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct", |
|
"Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free", |
|
"Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free", |
|
} |
|
}, |
|
"google": { |
|
"default": "gemini-1.5-flash-latest", |
|
"models": { |
|
"Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest", |
|
"Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest", |
|
} |
|
}, |
|
"openai": { |
|
"default": "gpt-3.5-turbo", |
|
"models": { |
|
"GPT-4o mini (OpenAI)": "gpt-4o-mini", |
|
"GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo", |
|
} |
|
}, |
|
|
|
} |
|
|
|
|
|
def _get_api_key(provider: str, ui_api_key_override: str = None) -> str: |
|
if ui_api_key_override: |
|
logger.debug(f"Using UI API key override for {provider}") |
|
return ui_api_key_override.strip() |
|
|
|
env_var_name = API_KEYS.get(provider.upper()) |
|
if env_var_name: |
|
env_key = os.getenv(env_var_name) |
|
if env_key: |
|
logger.debug(f"Using env var {env_var_name} for {provider}") |
|
return env_key.strip() |
|
|
|
|
|
if provider.lower() == 'huggingface': |
|
hf_token = os.getenv("HF_TOKEN") |
|
if hf_token: |
|
logger.debug(f"Using HF_TOKEN env var for {provider}") |
|
return hf_token.strip() |
|
|
|
logger.warning(f"API Key not found for provider '{provider}'. Checked UI override and environment variable '{env_var_name or 'N/A'}'.") |
|
return None |
|
|
|
def get_available_providers() -> list[str]: |
|
return sorted(list(MODELS_BY_PROVIDER.keys())) |
|
|
|
def get_models_for_provider(provider: str) -> list[str]: |
|
return sorted(list(MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).keys())) |
|
|
|
def get_default_model_for_provider(provider: str) -> str | None: |
|
models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}) |
|
default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default") |
|
if default_model_id: |
|
|
|
for display_name, model_id in models_dict.items(): |
|
if model_id == default_model_id: |
|
return display_name |
|
|
|
if models_dict: |
|
return sorted(list(models_dict.keys()))[0] |
|
return None |
|
|
|
def get_model_id_from_display_name(provider: str, display_name: str) -> str | None: |
|
models = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}) |
|
return models.get(display_name) |
|
|
|
def generate_stream(provider: str, model_display_name: str, api_key_override: str, messages: list[dict]) -> iter: |
|
provider_lower = provider.lower() |
|
api_key = _get_api_key(provider_lower, api_key_override) |
|
|
|
base_url = API_URLS.get(provider.upper()) |
|
model_id = get_model_id_from_display_name(provider_lower, model_display_name) |
|
|
|
if not api_key: |
|
env_var_name = API_KEYS.get(provider.upper(), 'N/A') |
|
yield f"Error: API Key not found for {provider}. Please set it in the UI override or environment variable '{env_var_name}'." |
|
return |
|
if not base_url: |
|
yield f"Error: Unknown provider '{provider}' or missing API URL configuration." |
|
return |
|
if not model_id: |
|
yield f"Error: Unknown model '{model_display_name}' for provider '{provider}'. Please select a valid model." |
|
return |
|
|
|
headers = {} |
|
payload = {} |
|
request_url = base_url |
|
timeout_seconds = 180 |
|
|
|
logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...") |
|
|
|
try: |
|
if provider_lower in ["groq", "openrouter", "togetherai", "openai", "xai"]: |
|
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} |
|
payload = { |
|
"model": model_id, |
|
"messages": messages, |
|
"stream": True, |
|
"temperature": 0.7, |
|
"max_tokens": 4096 |
|
} |
|
if provider_lower == "openrouter": |
|
headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/your_username/ai-space-commander" |
|
headers["X-Title"] = "Hugging Face Space Commander" |
|
|
|
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds) |
|
response.raise_for_status() |
|
|
|
byte_buffer = b"" |
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
|
if response.status_code != 200: |
|
|
|
error_body = response.text |
|
logger.error(f"HTTP Error during stream: {response.status_code}, Body: {error_body}") |
|
yield f"API HTTP Error ({response.status_code}) during stream: {error_body}" |
|
return |
|
|
|
byte_buffer += chunk |
|
while b'\n' in byte_buffer: |
|
line, byte_buffer = byte_buffer.split(b'\n', 1) |
|
decoded_line = line.decode('utf-8', errors='ignore') |
|
if decoded_line.startswith('data: '): |
|
data = decoded_line[6:] |
|
if data == '[DONE]': |
|
byte_buffer = b'' |
|
break |
|
try: |
|
event_data = json.loads(data) |
|
if event_data.get("choices") and len(event_data["choices"]) > 0: |
|
delta = event_data["choices"][0].get("delta") |
|
if delta and delta.get("content"): |
|
yield delta["content"] |
|
except json.JSONDecodeError: |
|
|
|
logger.warning(f"Failed to decode JSON from stream line: {decoded_line.strip()}") |
|
except Exception as e: |
|
logger.error(f"Error processing stream data: {e}, Data: {decoded_line.strip()}") |
|
|
|
if byte_buffer: |
|
remaining_line = byte_buffer.decode('utf-8', errors='ignore').strip() |
|
if remaining_line.startswith('data: '): |
|
data = remaining_line[6:] |
|
if data != '[DONE]': |
|
try: |
|
event_data = json.loads(data) |
|
if event_data.get("choices") and len(event_data["choices"]) > 0: |
|
delta = event_data["choices"][0].get("delta") |
|
if delta and delta.get("content"): |
|
yield delta["content"] |
|
except json.JSONDecodeError: |
|
logger.warning(f"Failed to decode final stream buffer JSON: {remaining_line}") |
|
except Exception as e: |
|
logger.error(f"Error processing final stream buffer data: {e}, Data: {remaining_line}") |
|
|
|
|
|
elif provider_lower == "google": |
|
system_instruction = None |
|
filtered_messages = [] |
|
for msg in messages: |
|
if msg["role"] == "system": |
|
|
|
|
|
system_instruction = msg["content"] |
|
else: |
|
|
|
role = "model" if msg["role"] == "assistant" else msg["role"] |
|
filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]}) |
|
|
|
|
|
|
|
for i in range(1, len(filtered_messages)): |
|
if filtered_messages[i]["role"] == filtered_messages[i-1]["role"]: |
|
yield f"Error: Google API requires alternating user/model roles in chat history. Please check prompt or history format." |
|
return |
|
|
|
payload = { |
|
"contents": filtered_messages, |
|
"safetySettings": [ |
|
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, |
|
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, |
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, |
|
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, |
|
], |
|
"generationConfig": { |
|
"temperature": 0.7, |
|
"maxOutputTokens": 4096 |
|
} |
|
} |
|
|
|
if system_instruction: |
|
payload["system_instruction"] = {"parts": [{"text": system_instruction}]} |
|
|
|
|
|
request_url = f"{base_url}{model_id}:streamGenerateContent" |
|
|
|
request_url = f"{request_url}?key={api_key}" |
|
headers = {"Content-Type": "application/json"} |
|
|
|
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds) |
|
response.raise_for_status() |
|
|
|
byte_buffer = b"" |
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
|
if response.status_code != 200: |
|
error_body = response.text |
|
logger.error(f"HTTP Error during Google stream: {response.status_code}, Body: {error_body}") |
|
yield f"API HTTP Error ({response.status_code}) during Google stream: {error_body}" |
|
return |
|
|
|
byte_buffer += chunk |
|
|
|
|
|
|
|
json_decoder = json.JSONDecoder() |
|
while byte_buffer: |
|
try: |
|
|
|
obj, idx = json_decoder.raw_decode(byte_buffer.decode('utf-8', errors='ignore').lstrip()) |
|
|
|
byte_buffer = byte_buffer[len(byte_buffer.decode('utf-8', errors='ignore').lstrip()[:idx]).encode('utf-8'):] |
|
|
|
if obj.get("candidates") and len(obj["candidates"]) > 0: |
|
candidate = obj["candidates"][0] |
|
if candidate.get("content") and candidate["content"].get("parts"): |
|
full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"]) |
|
if full_text_chunk: |
|
yield full_text_chunk |
|
|
|
if obj.get("error"): |
|
error_details = obj["error"].get("message", str(obj["error"])) |
|
logger.error(f"Google API returned error in stream data: {error_details}") |
|
yield f"API Error (Google): {error_details}" |
|
return |
|
|
|
except json.JSONDecodeError: |
|
|
|
|
|
break |
|
except Exception as e: |
|
logger.error(f"Error processing Google stream data object: {e}, Object: {obj}") |
|
|
|
|
|
|
|
|
|
if byte_buffer: |
|
logger.warning(f"Remaining data in Google stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}") |
|
|
|
|
|
elif provider_lower == "cohere": |
|
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} |
|
request_url = f"{base_url}" |
|
|
|
chat_history_for_cohere = [] |
|
system_prompt_for_cohere = None |
|
current_message_for_cohere = "" |
|
|
|
|
|
|
|
temp_history = [] |
|
for msg in messages: |
|
if msg["role"] == "system": |
|
|
|
if system_prompt_for_cohere: system_prompt_for_cohere += "\n" + msg["content"] |
|
else: system_prompt_for_cohere = msg["content"] |
|
elif msg["role"] == "user" or msg["role"] == "assistant": |
|
temp_history.append(msg) |
|
|
|
if not temp_history: |
|
yield "Error: No user message found for Cohere API call." |
|
return |
|
if temp_history[-1]["role"] != "user": |
|
yield "Error: Last message must be from user for Cohere API call." |
|
return |
|
|
|
current_message_for_cohere = temp_history[-1]["content"] |
|
|
|
chat_history_for_cohere = [{"role": ("chatbot" if m["role"] == "assistant" else m["role"]), "message": m["content"]} for m in temp_history[:-1]] |
|
|
|
payload = { |
|
"model": model_id, |
|
"message": current_message_for_cohere, |
|
"stream": True, |
|
"temperature": 0.7, |
|
"max_tokens": 4096 |
|
} |
|
if chat_history_for_cohere: |
|
payload["chat_history"] = chat_history_for_cohere |
|
if system_prompt_for_cohere: |
|
payload["preamble"] = system_prompt_for_cohere |
|
|
|
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=timeout_seconds) |
|
response.raise_for_status() |
|
|
|
byte_buffer = b"" |
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
|
if response.status_code != 200: |
|
error_body = response.text |
|
logger.error(f"HTTP Error during Cohere stream: {response.status_code}, Body: {error_body}") |
|
yield f"API HTTP Error ({response.status_code}) during Cohere stream: {error_body}" |
|
return |
|
|
|
byte_buffer += chunk |
|
while b'\n\n' in byte_buffer: |
|
event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1) |
|
lines = event_chunk.strip().split(b'\n') |
|
event_type = None |
|
event_data = None |
|
|
|
for l in lines: |
|
if l.strip() == b"": continue |
|
if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore') |
|
elif l.startswith(b"data: "): |
|
try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore')) |
|
except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}") |
|
else: |
|
|
|
logger.warning(f"Cohere: Unexpected line in event chunk: {l.decode('utf-8', errors='ignore').strip()}") |
|
|
|
|
|
if event_type == "text-generation" and event_data and "text" in event_data: |
|
yield event_data["text"] |
|
elif event_type == "stream-end": |
|
logger.debug("Cohere stream-end event received.") |
|
byte_buffer = b'' |
|
break |
|
elif event_type == "error": |
|
error_msg = event_data.get("message", str(event_data)) if event_data else "Unknown Cohere stream error" |
|
logger.error(f"Cohere stream error event: {error_msg}") |
|
yield f"API Error (Cohere stream): {error_msg}" |
|
return |
|
|
|
|
|
if byte_buffer: |
|
logger.warning(f"Remaining data in Cohere stream buffer after processing: {byte_buffer.decode('utf-8', errors='ignore')}") |
|
|
|
|
|
elif provider_lower == "huggingface": |
|
|
|
|
|
|
|
|
|
yield f"Error: Direct Hugging Face Inference API streaming for chat models is highly experimental and depends heavily on the specific model's implementation. Standard OpenAI-like streaming is NOT guaranteed. For better compatibility with HF models that support the OpenAI format, consider using the OpenRouter or TogetherAI providers and selecting the HF models listed there." |
|
return |
|
|
|
else: |
|
yield f"Error: Unsupported provider '{provider}' for streaming chat." |
|
return |
|
|
|
except requests.exceptions.HTTPError as e: |
|
status_code = e.response.status_code if e.response is not None else 'N/A' |
|
error_text = e.response.text if e.response is not None else 'No response text' |
|
logger.error(f"HTTP error during streaming for {provider}/{model_id}: {e}") |
|
yield f"API HTTP Error ({status_code}): {error_text}\nDetails: {e}" |
|
except requests.exceptions.Timeout: |
|
logger.error(f"Request Timeout after {timeout_seconds} seconds for {provider}/{model_id}.") |
|
yield f"API Request Timeout: The request took too long to complete ({timeout_seconds} seconds)." |
|
except requests.exceptions.RequestException as e: |
|
logger.error(f"Request error during streaming for {provider}/{model_id}: {e}") |
|
yield f"API Request Error: Could not connect or receive response from {provider} ({e})" |
|
except Exception as e: |
|
logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:") |
|
yield f"An unexpected error occurred during streaming: {e}" |