ai / src /cores /session.py
hadadrjt's picture
ai: Restructured repo for production.
f99ad65
#
# SPDX-FileCopyrightText: Hadad <[email protected]>
# SPDX-License-Identifier: Apache-2.0
#
import asyncio
import requests
import uuid
import threading
from src.config import LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS
class SessionWithID(requests.Session):
"""
Custom session object that holds a unique session ID and async control flags.
Used to track individual user sessions and allow cancellation of ongoing requests.
"""
def __init__(self):
super().__init__()
self.session_id = str(uuid.uuid4()) # Unique ID per session
self.stop_event = asyncio.Event() # Async event to signal stop requests
self.cancel_token = {"cancelled": False} # Flag to indicate cancellation
def create_session():
"""
Create and return a new SessionWithID object.
Called when a new user session starts or chat is reset.
"""
return SessionWithID()
def ensure_stop_event(sess):
"""
Ensure that the session object has stop_event and cancel_token attributes.
Useful when restoring or reusing sessions.
"""
if not hasattr(sess, "stop_event"):
sess.stop_event = asyncio.Event()
if not hasattr(sess, "cancel_token"):
sess.cancel_token = {"cancelled": False}
def marked_item(item, marked, attempts):
"""
Mark a provider key or host as temporarily problematic after repeated failures.
Automatically unmark after 5 minutes to retry.
This helps avoid repeatedly using failing providers.
"""
marked.add(item)
attempts[item] = attempts.get(item, 0) + 1
if attempts[item] >= 3:
def remove():
marked.discard(item)
attempts.pop(item, None)
threading.Timer(300, remove).start()
def get_model_key(display, MODEL_MAPPING, DEFAULT_MODEL_KEY):
"""
Get the internal model key (identifier) from the display name.
Returns default model key if not found.
"""
return next((k for k, v in MODEL_MAPPING.items() if v == display), DEFAULT_MODEL_KEY)