|
|
|
|
|
|
|
|
|
|
|
import asyncio |
|
import requests |
|
import uuid |
|
import threading |
|
|
|
from src.config import LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS |
|
|
|
class SessionWithID(requests.Session): |
|
""" |
|
Custom session object that holds a unique session ID and async control flags. |
|
Used to track individual user sessions and allow cancellation of ongoing requests. |
|
""" |
|
def __init__(self): |
|
super().__init__() |
|
self.session_id = str(uuid.uuid4()) |
|
self.stop_event = asyncio.Event() |
|
self.cancel_token = {"cancelled": False} |
|
|
|
def create_session(): |
|
""" |
|
Create and return a new SessionWithID object. |
|
Called when a new user session starts or chat is reset. |
|
""" |
|
return SessionWithID() |
|
|
|
def ensure_stop_event(sess): |
|
""" |
|
Ensure that the session object has stop_event and cancel_token attributes. |
|
Useful when restoring or reusing sessions. |
|
""" |
|
if not hasattr(sess, "stop_event"): |
|
sess.stop_event = asyncio.Event() |
|
if not hasattr(sess, "cancel_token"): |
|
sess.cancel_token = {"cancelled": False} |
|
|
|
def marked_item(item, marked, attempts): |
|
""" |
|
Mark a provider key or host as temporarily problematic after repeated failures. |
|
Automatically unmark after 5 minutes to retry. |
|
This helps avoid repeatedly using failing providers. |
|
""" |
|
marked.add(item) |
|
attempts[item] = attempts.get(item, 0) + 1 |
|
if attempts[item] >= 3: |
|
def remove(): |
|
marked.discard(item) |
|
attempts.pop(item, None) |
|
threading.Timer(300, remove).start() |
|
|
|
def get_model_key(display, MODEL_MAPPING, DEFAULT_MODEL_KEY): |
|
""" |
|
Get the internal model key (identifier) from the display name. |
|
Returns default model key if not found. |
|
""" |
|
return next((k for k, v in MODEL_MAPPING.items() if v == display), DEFAULT_MODEL_KEY) |
|
|