File size: 3,442 Bytes
d29a129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import pickle
import threading

from typing import List, Optional
from typing_extensions import override
from datetime import datetime
from aworld.cmd import SessionModel, ChatCompletionMessage
from aworld.logs.util import logger
from .base_session_service import BaseSessionService


class SimpleSessionService(BaseSessionService):
    def __init__(self):
        self.data_file = os.path.join(os.curdir, "data", "session.bin")
        os.makedirs(os.path.dirname(self.data_file), exist_ok=True)
        self._lock = threading.Lock()

    def _load_sessions(self) -> dict:
        if not os.path.exists(self.data_file):
            return {}
        try:
            with open(self.data_file, "rb") as f:
                return pickle.load(f)
        except Exception as e:
            logger.error(f"Error loading sessions: {e}")
            return {}

    def _save_sessions(self, sessions: dict):
        try:
            temp_file = self.data_file + ".tmp"
            with open(temp_file, "wb") as f:
                pickle.dump(sessions, f)
                f.flush()
                os.fsync(f.fileno())
            os.replace(temp_file, self.data_file)
        except Exception as e:
            logger.error(f"Error saving sessions: {e}")
            if os.path.exists(temp_file):
                os.remove(temp_file)
            raise

    @override
    async def get_session(
            self, user_id: str, session_id: str
    ) -> Optional[SessionModel]:
        session_key = f"{user_id}:{session_id}"
        with self._lock:
            sessions = self._load_sessions()
            return sessions.get(session_key)

    @override
    async def list_sessions(self, user_id: str) -> List[SessionModel]:
        with self._lock:
            sessions = self._load_sessions()
            user_sessions = [
                session for key, session in sessions.items() if key.startswith(user_id)
            ]
            # Sort by created_at in descending order (newest first)
            return sorted(user_sessions, key=lambda x: x.created_at, reverse=True)

    @override
    async def delete_session(self, user_id: str, session_id: str) -> None:
        session_key = f"{user_id}:{session_id}"
        with self._lock:
            sessions = self._load_sessions()
            if session_key not in sessions:
                logger.warning(f"Session {session_key} not found")
                return
            del sessions[session_key]
            self._save_sessions(sessions)

    @override
    async def append_messages(
            self, user_id: str, session_id: str, messages: List[ChatCompletionMessage]
    ) -> None:
        session_key = f"{user_id}:{session_id}"
        with self._lock:
            sessions = self._load_sessions()

            if session_key not in sessions:
                logger.info(f"Session {session_key} not found, creating new session")
                sessions[session_key] = SessionModel(
                    user_id=user_id,
                    session_id=session_id,
                    name=messages[0].content,
                    description=messages[0].content,
                    created_at=datetime.now(),
                    updated_at=datetime.now(),
                    messages=[],
                )

            sessions[session_key].messages.extend(messages)
            sessions[session_key].updated_at = datetime.now()
            self._save_sessions(sessions)