File size: 3,731 Bytes
8af6ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from typing import Any, Dict, List, Optional
from . import Checkpoint, BaseCheckpointRepository, VersionUtils

class InMemoryCheckpointRepository(BaseCheckpointRepository):
    """
    In-memory implementation of BaseCheckpointRepository.
    Stores checkpoints in a simple in-memory dictionary.
    Thread safety is not guaranteed.
    """
    def __init__(self) -> None:
        """
        Initialize the in-memory checkpoint repository.
        """
        self._checkpoints: Dict[str, Checkpoint] = {}
        self._session_index: Dict[str, List[str]] = {}

    def get(self, checkpoint_id: str) -> Optional[Checkpoint]:
        """
        Retrieve a checkpoint by its unique identifier.
        Args:
            checkpoint_id (str): The unique identifier of the checkpoint.
        Returns:
            Optional[Checkpoint]: The checkpoint if found, otherwise None.
        """
        return self._checkpoints.get(checkpoint_id)

    def list(self, params: Dict[str, Any]) -> List[Checkpoint]:
        """
        List checkpoints matching the given parameters.
        Args:
            params (dict): Parameters to filter checkpoints.
        Returns:
            List[Checkpoint]: List of matching checkpoints.
        """
        result = []
        for cp in self._checkpoints.values():
            match = True
            for k, v in params.items():
                if k == 'session_id':
                    if cp.metadata.session_id != v:
                        match = False
                        break
                elif k == 'task_id':
                    if cp.metadata.task_id != v:
                        match = False
                        break
                elif cp.get(k) != v:
                    match = False
                    break
            if match:
                result.append(cp)
        return result

    def put(self, checkpoint: Checkpoint) -> None:
        """
        Store a checkpoint.
        Args:
            checkpoint (Checkpoint): The checkpoint to store.
        """
        # Find last version checkpoint by session_id
        last_checkpoint = self.get_by_session(checkpoint.metadata.session_id)
        
        if last_checkpoint:
            # Compare versions to ensure optimistic locking
            if VersionUtils.is_version_less(checkpoint, last_checkpoint.version):
                raise ValueError(f"New checkpoint version {checkpoint.version} must be greater than last version {last_checkpoint.version}")
            
        # Store the new checkpoint
        self._checkpoints[checkpoint.id] = checkpoint
            
        # Update session index
        session_id = checkpoint.metadata.session_id
        if session_id:
            if session_id not in self._session_index:
                self._session_index[session_id] = []
            self._session_index[session_id].append(checkpoint.id)

    def get_by_session(self, session_id: str) -> Optional[Checkpoint]:
        """
        Get the latest checkpoint for a session.
        Args:
            session_id (str): The session identifier.
        Returns:
            Optional[Checkpoint]: The latest checkpoint if found, otherwise None.
        """
        ids = self._session_index.get(session_id, [])
        if not ids:
            return None
        # Assume the last one is the latest
        last_id = ids[-1]
        return self._checkpoints.get(last_id)

    def delete_by_session(self, session_id: str) -> None:
        """
        Delete all checkpoints related to a session.
        Args:
            session_id (str): The session identifier.
        """
        ids = self._session_index.pop(session_id, [])
        for cid in ids:
            self._checkpoints.pop(cid, None)