File size: 2,046 Bytes
f99ad65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#
# SPDX-FileCopyrightText: Hadad <[email protected]>
# SPDX-License-Identifier: Apache-2.0
#

import asyncio
import requests
import uuid
import threading

from src.config import LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS

class SessionWithID(requests.Session):
    """
    Custom session object that holds a unique session ID and async control flags.
    Used to track individual user sessions and allow cancellation of ongoing requests.
    """
    def __init__(self):
        super().__init__()
        self.session_id = str(uuid.uuid4())  # Unique ID per session
        self.stop_event = asyncio.Event()    # Async event to signal stop requests
        self.cancel_token = {"cancelled": False}  # Flag to indicate cancellation

def create_session():
    """
    Create and return a new SessionWithID object.
    Called when a new user session starts or chat is reset.
    """
    return SessionWithID()

def ensure_stop_event(sess):
    """
    Ensure that the session object has stop_event and cancel_token attributes.
    Useful when restoring or reusing sessions.
    """
    if not hasattr(sess, "stop_event"):
        sess.stop_event = asyncio.Event()
    if not hasattr(sess, "cancel_token"):
        sess.cancel_token = {"cancelled": False}

def marked_item(item, marked, attempts):
    """
    Mark a provider key or host as temporarily problematic after repeated failures.
    Automatically unmark after 5 minutes to retry.
    This helps avoid repeatedly using failing providers.
    """
    marked.add(item)
    attempts[item] = attempts.get(item, 0) + 1
    if attempts[item] >= 3:
        def remove():
            marked.discard(item)
            attempts.pop(item, None)
        threading.Timer(300, remove).start()

def get_model_key(display, MODEL_MAPPING, DEFAULT_MODEL_KEY):
    """
    Get the internal model key (identifier) from the display name.
    Returns default model key if not found.
    """
    return next((k for k, v in MODEL_MAPPING.items() if v == display), DEFAULT_MODEL_KEY)