File size: 4,396 Bytes
2cc87ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d868d2e
2cc87ec
 
 
 
 
 
 
 
 
 
 
d868d2e
 
2cc87ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import json
import logging
import os
import shutil
from itertools import islice
from typing import Iterator, List, Optional, Sequence, Tuple

from langchain.storage import create_kv_docstore
from langchain_core.documents import Document as LCDocument
from langchain_core.stores import BaseStore, ByteStore
from pie_datasets import Dataset, DatasetDict

from .pie_document_store import PieDocumentStore

logger = logging.getLogger(__name__)


class BasicPieDocumentStore(PieDocumentStore):
    """PIE Document store that uses a client to store and retrieve documents."""

    def __init__(
        self,
        client: Optional[BaseStore[str, LCDocument]] = None,
        byte_store: Optional[ByteStore] = None,
    ):
        if byte_store is not None:
            client = create_kv_docstore(byte_store)
        elif client is None:
            raise Exception("You must pass a `byte_store` parameter.")

        self.client = client

    def mget(self, keys: Sequence[str]) -> List[LCDocument]:
        return self.client.mget(keys)

    def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
        self.client.mset(items)

    def mdelete(self, keys: Sequence[str]) -> None:
        self.client.mdelete(keys)

    def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
        return self.client.yield_keys(prefix=prefix)

    def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None:
        all_doc_ids = []
        all_metadata = []
        pie_documents_path = os.path.join(path, "pie_documents")
        if os.path.exists(pie_documents_path):
            # remove existing directory
            logger.warning(f"Removing existing directory: {pie_documents_path}")
            shutil.rmtree(pie_documents_path)
        os.makedirs(pie_documents_path, exist_ok=True)
        doc_ids_iter = iter(self.client.yield_keys())
        mode = "w"
        while batch_doc_ids := list(islice(doc_ids_iter, batch_size or 1000)):
            all_doc_ids.extend(batch_doc_ids)
            docs = self.client.mget(batch_doc_ids)
            pie_docs = []
            for doc in docs:
                pie_doc = doc.metadata[self.METADATA_KEY_PIE_DOCUMENT]
                pie_docs.append(pie_doc)
                all_metadata.append(
                    {k: v for k, v in doc.metadata.items() if k != self.METADATA_KEY_PIE_DOCUMENT}
                )
            pie_dataset = Dataset.from_documents(pie_docs)
            DatasetDict({"train": pie_dataset}).to_json(path=pie_documents_path, mode=mode)
            mode = "a"  # append after the first batch
        if len(all_doc_ids) > 0:
            doc_ids_path = os.path.join(path, "doc_ids.json")
            with open(doc_ids_path, "w") as f:
                json.dump(all_doc_ids, f)
        if len(all_metadata) > 0:
            metadata_path = os.path.join(path, "metadata.json")
            with open(metadata_path, "w") as f:
                json.dump(all_metadata, f)

    def _load_from_directory(self, path: str, **kwargs) -> None:
        pie_documents_path = os.path.join(path, "pie_documents")
        if not os.path.exists(pie_documents_path):
            logger.warning(
                f"Directory {pie_documents_path} does not exist, don't load any documents."
            )
            return None
        pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path)
        pie_docs = pie_dataset["train"]
        metadata_path = os.path.join(path, "metadata.json")
        if os.path.exists(metadata_path):
            with open(metadata_path, "r") as f:
                all_metadata = json.load(f)
        else:
            logger.warning(f"File {metadata_path} does not exist, don't load any metadata.")
            all_metadata = [{} for _ in pie_docs]
        docs = [
            self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(pie_docs, all_metadata)
        ]
        doc_ids_path = os.path.join(path, "doc_ids.json")
        if os.path.exists(doc_ids_path):
            with open(doc_ids_path, "r") as f:
                all_doc_ids = json.load(f)
        else:
            logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.")
            all_doc_ids = [doc.id for doc in pie_docs]
        self.client.mset(zip(all_doc_ids, docs))
        logger.info(f"Loaded {len(docs)} documents from {path} into docstore")