Source code for langchain_google_vertexai.vectorstores.document_storage

from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Sequence, Tuple

from google.cloud import storage  # type: ignore[attr-defined, unused-ignore]
from langchain_core.documents import Document
from langchain_core.stores import BaseStore

if TYPE_CHECKING:
    from google.cloud import datastore  # type: ignore[attr-defined, unused-ignore]


[docs]class DocumentStorage(BaseStore[str, Document]): """Abstract interface of a key, text storage for retrieving documents."""
[docs]class GCSDocumentStorage(DocumentStorage): """Stores documents in Google Cloud Storage. For each pair id, document_text the name of the blob will be {prefix}/{id} stored in plain text format. """
[docs] def __init__( self, bucket: storage.Bucket, prefix: Optional[str] = "documents" ) -> None: """Constructor. Args: bucket: Bucket where the documents will be stored. prefix: Prefix that is prepended to all document names. """ super().__init__() self._bucket = bucket self._prefix = prefix
[docs] def mset(self, key_value_pairs: Sequence[Tuple[str, Document]]) -> None: """Stores a series of documents using each keys Args: key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs. """ for key, value in key_value_pairs: self._set_one(key, value)
[docs] def mget(self, keys: Sequence[str]) -> List[Optional[Document]]: """Gets a batch of documents by id. The default implementation only loops `get_by_id`. Subclasses that have faster ways to retrieve data by batch should implement this method. Args: ids: List of ids for the text. Returns: List of documents. If the key id is not found for any id record returns a None instead. """ return [self._get_one(key) for key in keys]
[docs] def mdelete(self, keys: Sequence[str]) -> None: """Deletes a batch of documents by id. Args: keys: List of ids for the text. """ for key in keys: self._delete_one(key)
[docs] def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]: """Yields the keys present in the storage. Args: prefix: Ignored. Uses the prefix provided in the constructor. """ for blob in self._bucket.list_blobs(prefix=self._prefix): yield blob.name.split("/")[-1]
def _get_one(self, key: str) -> Document | None: """Gets the text of a document by its id. If not found, returns None. Args: key: Id of the document to get from the storage. Returns: Document if found, otherwise None. """ blob_name = self._get_blob_name(key) existing_blob = self._bucket.get_blob(blob_name) if existing_blob is None: return None document_str = existing_blob.download_as_text() document_json: Dict[str, Any] = json.loads(document_str) return Document(**document_json) def _set_one(self, key: str, value: Document) -> None: """Stores a document text associated to a document_id. Args: key: Id of the document to be stored. document: Document to be stored. """ blob_name = self._get_blob_name(key) new_blow = self._bucket.blob(blob_name) document_json = value.dict() document_text = json.dumps(document_json) new_blow.upload_from_string(document_text) def _delete_one(self, key: str) -> None: """Deletes one document by its key. Args: key (str): Id of the document to delete. """ blob_name = self._get_blob_name(key) blob = self._bucket.blob(blob_name) blob.delete() def _get_blob_name(self, document_id: str) -> str: """Builds a blob name using the prefix and the document_id. Args: document_id: Id of the document. Returns: Name of the blob that the document will be/is stored in """ return f"{self._prefix}/{document_id}"
[docs]class DataStoreDocumentStorage(DocumentStorage): """Stores documents in Google Cloud DataStore."""
[docs] def __init__( self, datastore_client: datastore.Client, kind: str = "document_id", text_property_name: str = "text", metadata_property_name: str = "metadata", ) -> None: """Constructor. Args: bucket: Bucket where the documents will be stored. prefix: Prefix that is prepended to all document names. """ super().__init__() self._client = datastore_client self._text_property_name = text_property_name self._metadata_property_name = metadata_property_name self._kind = kind
[docs] def mget(self, keys: Sequence[str]) -> List[Optional[Document]]: """Gets a batch of documents by id. Args: ids: List of ids for the text. Returns: List of texts. If the key id is not found for any id record returns a None instead. """ ds_keys = [self._client.key(self._kind, id_) for id_ in keys] entities = self._client.get_multi(ds_keys) # Entities are not sorted by key by default, the order is unclear. This orders # the list by the id retrieved. entity_id_lookup = {entity.key.id_or_name: entity for entity in entities} entities = [entity_id_lookup.get(id_) for id_ in keys] return [ Document( page_content=entity[self._text_property_name], metadata=self._convert_entity_to_dict( entity[self._metadata_property_name] ), ) if entity is not None else None for entity in entities ]
[docs] def mset(self, key_value_pairs: Sequence[Tuple[str, Document]]) -> None: """Stores a series of documents using each keys Args: key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs. """ ids = [key for key, _ in key_value_pairs] documents = [document for _, document in key_value_pairs] with self._client.transaction(): keys = [self._client.key(self._kind, id_) for id_ in ids] entities = [] for key, document in zip(keys, documents): entity = self._client.entity(key=key) entity[self._text_property_name] = document.page_content entity[self._metadata_property_name] = document.metadata entities.append(entity) self._client.put_multi(entities)
[docs] def mdelete(self, keys: Sequence[str]) -> None: """Deletes a sequence of documents by key. Args: keys (Sequence[str]): A sequence of keys to delete. """ with self._client.transaction(): keys = [self._client.key(self._kind, id_) for id_ in keys] self._client.delete_multi(keys)
[docs] def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]: """Yields the keys of all documents in the storage. Args: prefix: Ignored """ query = self._client.query(kind=self._kind) query.keys_only() for entity in query.fetch(): yield str(entity.key.id_or_name)
def _convert_entity_to_dict(self, entity: datastore.Entity) -> Dict[str, Any]: """Recursively transform an entity into a plain dictionary.""" from google.cloud import datastore # type: ignore[attr-defined, unused-ignore] dict_entity = dict(entity) for key in dict_entity: value = dict_entity[key] if isinstance(value, datastore.Entity): dict_entity[key] = self._convert_entity_to_dict(value) return dict_entity