Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| import threading | |
| from typing import List, Optional | |
| from typing_extensions import override | |
| from datetime import datetime | |
| from aworld.cmd import SessionModel, ChatCompletionMessage | |
| from aworld.logs.util import logger | |
| from .base_session_service import BaseSessionService | |
| class SimpleSessionService(BaseSessionService): | |
| def __init__(self): | |
| self.data_file = os.path.join(os.curdir, "data", "session.bin") | |
| os.makedirs(os.path.dirname(self.data_file), exist_ok=True) | |
| self._lock = threading.Lock() | |
| def _load_sessions(self) -> dict: | |
| if not os.path.exists(self.data_file): | |
| return {} | |
| try: | |
| with open(self.data_file, "rb") as f: | |
| return pickle.load(f) | |
| except Exception as e: | |
| logger.error(f"Error loading sessions: {e}") | |
| return {} | |
| def _save_sessions(self, sessions: dict): | |
| try: | |
| temp_file = self.data_file + ".tmp" | |
| with open(temp_file, "wb") as f: | |
| pickle.dump(sessions, f) | |
| f.flush() | |
| os.fsync(f.fileno()) | |
| os.replace(temp_file, self.data_file) | |
| except Exception as e: | |
| logger.error(f"Error saving sessions: {e}") | |
| if os.path.exists(temp_file): | |
| os.remove(temp_file) | |
| raise | |
| async def get_session( | |
| self, user_id: str, session_id: str | |
| ) -> Optional[SessionModel]: | |
| session_key = f"{user_id}:{session_id}" | |
| with self._lock: | |
| sessions = self._load_sessions() | |
| return sessions.get(session_key) | |
| async def list_sessions(self, user_id: str) -> List[SessionModel]: | |
| with self._lock: | |
| sessions = self._load_sessions() | |
| user_sessions = [ | |
| session for key, session in sessions.items() if key.startswith(user_id) | |
| ] | |
| # Sort by created_at in descending order (newest first) | |
| return sorted(user_sessions, key=lambda x: x.created_at, reverse=True) | |
| async def delete_session(self, user_id: str, session_id: str) -> None: | |
| session_key = f"{user_id}:{session_id}" | |
| with self._lock: | |
| sessions = self._load_sessions() | |
| if session_key not in sessions: | |
| logger.warning(f"Session {session_key} not found") | |
| return | |
| del sessions[session_key] | |
| self._save_sessions(sessions) | |
| async def append_messages( | |
| self, user_id: str, session_id: str, messages: List[ChatCompletionMessage] | |
| ) -> None: | |
| session_key = f"{user_id}:{session_id}" | |
| with self._lock: | |
| sessions = self._load_sessions() | |
| if session_key not in sessions: | |
| logger.info(f"Session {session_key} not found, creating new session") | |
| sessions[session_key] = SessionModel( | |
| user_id=user_id, | |
| session_id=session_id, | |
| name=messages[0].content, | |
| description=messages[0].content, | |
| created_at=datetime.now(), | |
| updated_at=datetime.now(), | |
| messages=[], | |
| ) | |
| sessions[session_key].messages.extend(messages) | |
| sessions[session_key].updated_at = datetime.now() | |
| self._save_sessions(sessions) | |