Duibonduil commited on
Commit
d29a129
·
verified ·
1 Parent(s): bbbef25

Upload 2 files

Browse files
aworld/session/base_session_service.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import List, Optional
3
+
4
+ from aworld.cmd import SessionModel
5
+ from aworld.cmd import ChatCompletionMessage
6
+
7
+
8
+ class BaseSessionService(abc.ABC):
9
+ @abc.abstractmethod
10
+ async def get_session(
11
+ self, user_id: str, session_id: str
12
+ ) -> Optional[SessionModel]:
13
+ pass
14
+
15
+ @abc.abstractmethod
16
+ async def list_sessions(self, user_id: str) -> List[SessionModel]:
17
+ pass
18
+
19
+ @abc.abstractmethod
20
+ async def delete_session(self, user_id: str, session_id: str) -> None:
21
+ pass
22
+
23
+ @abc.abstractmethod
24
+ async def append_messages(
25
+ self, user_id: str, session_id: str, messages: List[ChatCompletionMessage]
26
+ ) -> None:
27
+ pass
aworld/session/simple_session_service.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import threading
4
+
5
+ from typing import List, Optional
6
+ from typing_extensions import override
7
+ from datetime import datetime
8
+ from aworld.cmd import SessionModel, ChatCompletionMessage
9
+ from aworld.logs.util import logger
10
+ from .base_session_service import BaseSessionService
11
+
12
+
13
+ class SimpleSessionService(BaseSessionService):
14
+ def __init__(self):
15
+ self.data_file = os.path.join(os.curdir, "data", "session.bin")
16
+ os.makedirs(os.path.dirname(self.data_file), exist_ok=True)
17
+ self._lock = threading.Lock()
18
+
19
+ def _load_sessions(self) -> dict:
20
+ if not os.path.exists(self.data_file):
21
+ return {}
22
+ try:
23
+ with open(self.data_file, "rb") as f:
24
+ return pickle.load(f)
25
+ except Exception as e:
26
+ logger.error(f"Error loading sessions: {e}")
27
+ return {}
28
+
29
+ def _save_sessions(self, sessions: dict):
30
+ try:
31
+ temp_file = self.data_file + ".tmp"
32
+ with open(temp_file, "wb") as f:
33
+ pickle.dump(sessions, f)
34
+ f.flush()
35
+ os.fsync(f.fileno())
36
+ os.replace(temp_file, self.data_file)
37
+ except Exception as e:
38
+ logger.error(f"Error saving sessions: {e}")
39
+ if os.path.exists(temp_file):
40
+ os.remove(temp_file)
41
+ raise
42
+
43
+ @override
44
+ async def get_session(
45
+ self, user_id: str, session_id: str
46
+ ) -> Optional[SessionModel]:
47
+ session_key = f"{user_id}:{session_id}"
48
+ with self._lock:
49
+ sessions = self._load_sessions()
50
+ return sessions.get(session_key)
51
+
52
+ @override
53
+ async def list_sessions(self, user_id: str) -> List[SessionModel]:
54
+ with self._lock:
55
+ sessions = self._load_sessions()
56
+ user_sessions = [
57
+ session for key, session in sessions.items() if key.startswith(user_id)
58
+ ]
59
+ # Sort by created_at in descending order (newest first)
60
+ return sorted(user_sessions, key=lambda x: x.created_at, reverse=True)
61
+
62
+ @override
63
+ async def delete_session(self, user_id: str, session_id: str) -> None:
64
+ session_key = f"{user_id}:{session_id}"
65
+ with self._lock:
66
+ sessions = self._load_sessions()
67
+ if session_key not in sessions:
68
+ logger.warning(f"Session {session_key} not found")
69
+ return
70
+ del sessions[session_key]
71
+ self._save_sessions(sessions)
72
+
73
+ @override
74
+ async def append_messages(
75
+ self, user_id: str, session_id: str, messages: List[ChatCompletionMessage]
76
+ ) -> None:
77
+ session_key = f"{user_id}:{session_id}"
78
+ with self._lock:
79
+ sessions = self._load_sessions()
80
+
81
+ if session_key not in sessions:
82
+ logger.info(f"Session {session_key} not found, creating new session")
83
+ sessions[session_key] = SessionModel(
84
+ user_id=user_id,
85
+ session_id=session_id,
86
+ name=messages[0].content,
87
+ description=messages[0].content,
88
+ created_at=datetime.now(),
89
+ updated_at=datetime.now(),
90
+ messages=[],
91
+ )
92
+
93
+ sessions[session_key].messages.extend(messages)
94
+ sessions[session_key].updated_at = datetime.now()
95
+ self._save_sessions(sessions)