File size: 2,989 Bytes
f99ad65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
#
# SPDX-FileCopyrightText: Hadad <[email protected]>
# SPDX-License-Identifier: Apache-2.0
#
import codecs # Reasoning
import httpx
import json
from src.cores.session import marked_item
from src.config import LINUX_SERVER_ERRORS, LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS, RESPONSES
async def fetch_response_stream_async(host, key, model, msgs, cfg, sid, stop_event, cancel_token):
"""
Async generator that streams AI responses from a backend server.
Implements retry logic and marks failing keys to avoid repeated failures.
Streams reasoning and content separately for richer UI updates.
"""
for timeout in [5, 10]:
try:
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream(
"POST",
host,
json={**{"model": model, "messages": msgs, "session_id": sid, "stream": True}, **cfg},
headers={"Authorization": f"Bearer {key}"}
) as response:
if response.status_code in LINUX_SERVER_ERRORS:
marked_item(key, LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS)
return
async for line in response.aiter_lines():
if stop_event.is_set() or cancel_token["cancelled"]:
return
if not line:
continue
if line.startswith("data: "):
data = line[6:]
if data.strip() == RESPONSES["RESPONSE_10"]:
return
try:
j = json.loads(data)
if isinstance(j, dict) and j.get("choices"):
for ch in j["choices"]:
delta = ch.get("delta", {})
# Stream reasoning text separately for UI
if "reasoning" in delta and delta["reasoning"]:
decoded = delta["reasoning"].encode('utf-8').decode('unicode_escape')
yield ("reasoning", decoded)
# Stream main content text
if "content" in delta and delta["content"]:
yield ("content", delta["content"])
except Exception:
# Ignore malformed JSON or unexpected data
continue
except Exception:
# Network or other errors, try next timeout or mark key
continue
marked_item(key, LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS)
return
|