ScientificArgumentRecommender / src /langchain_modules /datasets_pie_document_store.py
ArneBinder's picture
update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e verified
import json
import logging
import os
import shutil
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple
from datasets import Dataset as HFDataset
from langchain_core.documents import Document as LCDocument
from pie_datasets import Dataset, DatasetDict, concatenate_datasets
from pytorch_ie.documents import TextBasedDocument
from .pie_document_store import PieDocumentStore
logger = logging.getLogger(__name__)
class DatasetsPieDocumentStore(PieDocumentStore):
"""PIE Document store that uses Huggingface Datasets as the backend."""
def __init__(self) -> None:
self._data: Optional[Dataset] = None
# keys map to indices in the dataset
self._keys: Dict[str, int] = {}
self._metadata: Dict[str, Any] = {}
def __len__(self):
return len(self._keys)
def _get_pie_docs_by_indices(self, indices: Iterable[int]) -> Sequence[TextBasedDocument]:
if self._data is None:
return []
return self._data.apply_hf_func(func=HFDataset.select, indices=indices)
def mget(self, keys: Sequence[str]) -> List[LCDocument]:
if self._data is None or len(keys) == 0:
return []
keys_in_data = [key for key in keys if key in self._keys]
indices = [self._keys[key] for key in keys_in_data]
dataset = self._get_pie_docs_by_indices(indices)
metadatas = [self._metadata.get(key, {}) for key in keys_in_data]
return [self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(dataset, metadatas)]
def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
if len(items) == 0:
return
keys, new_docs = zip(*items)
pie_docs, metadatas = zip(*[self.unwrap_with_metadata(doc) for doc in new_docs])
if self._data is None:
idx_start = 0
self._data = Dataset.from_documents(pie_docs)
else:
# we pass the features to the new dataset to mitigate issues caused by
# slightly different inferred features
dataset = Dataset.from_documents(pie_docs, features=self._data.features)
idx_start = len(self._data)
self._data = concatenate_datasets([self._data, dataset], clear_metadata=False)
keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)}
self._keys.update(keys_dict)
self._metadata.update(
{key: metadata for key, metadata in zip(keys, metadatas) if metadata}
)
def add_pie_dataset(
self,
dataset: Dataset,
keys: Optional[List[str]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
) -> None:
if len(dataset) == 0:
return
if keys is None:
keys = [doc.id for doc in dataset]
if len(keys) != len(set(keys)):
raise ValueError("Keys must be unique.")
if None in keys:
raise ValueError("Keys must not be None.")
if metadatas is None:
metadatas = [{} for _ in range(len(dataset))]
if len(keys) != len(dataset) or len(keys) != len(metadatas):
raise ValueError("Keys, dataset and metadatas must have the same length.")
if self._data is None:
idx_start = 0
self._data = dataset
else:
idx_start = len(self._data)
self._data = concatenate_datasets([self._data, dataset], clear_metadata=False)
keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)}
self._keys.update(keys_dict)
metadatas_dict = {key: metadata for key, metadata in zip(keys, metadatas) if metadata}
self._metadata.update(metadatas_dict)
def mdelete(self, keys: Sequence[str]) -> None:
for key in keys:
idx = self._keys.pop(key, None)
if idx is not None:
self._metadata.pop(key, None)
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
return (key for key in self._keys if prefix is None or key.startswith(prefix))
def _purge_invalid_entries(self):
if self._data is None or len(self._keys) == len(self._data):
return
self._data = self._get_pie_docs_by_indices(self._keys.values())
def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None:
self._purge_invalid_entries()
if len(self) == 0:
logger.warning("No documents to save.")
return
all_doc_ids = list(self._keys)
all_metadatas: List[Dict[str, Any]] = [self._metadata.get(key, {}) for key in all_doc_ids]
pie_documents_path = os.path.join(path, "pie_documents")
if os.path.exists(pie_documents_path):
# remove existing directory
logger.warning(f"Removing existing directory: {pie_documents_path}")
shutil.rmtree(pie_documents_path)
os.makedirs(pie_documents_path, exist_ok=True)
DatasetDict({"train": self._data}).to_json(pie_documents_path, mode="w")
doc_ids_path = os.path.join(path, "doc_ids.json")
with open(doc_ids_path, "w") as f:
json.dump(all_doc_ids, f)
metadata_path = os.path.join(path, "metadata.json")
with open(metadata_path, "w") as f:
json.dump(all_metadatas, f)
def _load_from_directory(self, path: str, **kwargs) -> None:
doc_ids_path = os.path.join(path, "doc_ids.json")
if os.path.exists(doc_ids_path):
with open(doc_ids_path, "r") as f:
all_doc_ids = json.load(f)
else:
logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.")
all_doc_ids = None
metadata_path = os.path.join(path, "metadata.json")
if os.path.exists(metadata_path):
with open(metadata_path, "r") as f:
all_metadata = json.load(f)
else:
logger.warning(f"File {metadata_path} does not exist, don't load any metadata.")
all_metadata = None
pie_documents_path = os.path.join(path, "pie_documents")
if not os.path.exists(pie_documents_path):
logger.warning(
f"Directory {pie_documents_path} does not exist, don't load any documents."
)
return None
# If we have a dataset already loaded, we use its features to load the new dataset
# This is to mitigate issues caused by slightly different inferred features.
features = self._data.features if self._data is not None else None
pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path, features=features)
pie_docs = pie_dataset["train"]
self.add_pie_dataset(pie_docs, keys=all_doc_ids, metadatas=all_metadata)
logger.info(f"Loaded {len(pie_docs)} documents from {path} into docstore")