Source code for langchain.retrievers.document_compressors.embeddings_filter

from typing import Callable, Dict, Optional, Sequence

import numpy as np
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field, root_validator

from langchain.retrievers.document_compressors.base import (
    BaseDocumentCompressor,
)


def _get_similarity_function() -> Callable:
    try:
        from langchain_community.utils.math import cosine_similarity
    except ImportError:
        raise ImportError(
            "To use please install langchain-community "
            "with `pip install langchain-community`."
        )
    return cosine_similarity


[docs]class EmbeddingsFilter(BaseDocumentCompressor): """Document compressor that uses embeddings to drop documents unrelated to the query.""" embeddings: Embeddings """Embeddings to use for embedding document contents and queries.""" similarity_fn: Callable = Field(default_factory=_get_similarity_function) """Similarity function for comparing documents. Function expected to take as input two matrices (List[List[float]]) and return a matrix of scores where higher values indicate greater similarity.""" k: Optional[int] = 20 """The number of relevant documents to return. Can be set to None, in which case `similarity_threshold` must be specified. Defaults to 20.""" similarity_threshold: Optional[float] """Threshold for determining when two documents are similar enough to be considered redundant. Defaults to None, must be specified if `k` is set to None.""" class Config: """Configuration for this pydantic object.""" arbitrary_types_allowed = True @root_validator() def validate_params(cls, values: Dict) -> Dict: """Validate similarity parameters.""" if values["k"] is None and values["similarity_threshold"] is None: raise ValueError("Must specify one of `k` or `similarity_threshold`.") return values
[docs] def compress_documents( self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """Filter documents based on similarity of their embeddings to the query.""" try: from langchain_community.document_transformers.embeddings_redundant_filter import ( # noqa: E501 _get_embeddings_from_stateful_docs, get_stateful_documents, ) except ImportError: raise ImportError( "To use please install langchain-community " "with `pip install langchain-community`." ) stateful_documents = get_stateful_documents(documents) embedded_documents = _get_embeddings_from_stateful_docs( self.embeddings, stateful_documents ) embedded_query = self.embeddings.embed_query(query) similarity = self.similarity_fn([embedded_query], embedded_documents)[0] included_idxs = np.arange(len(embedded_documents)) if self.k is not None: included_idxs = np.argsort(similarity)[::-1][: self.k] if self.similarity_threshold is not None: similar_enough = np.where( similarity[included_idxs] > self.similarity_threshold ) included_idxs = included_idxs[similar_enough] for i in included_idxs: stateful_documents[i].state["query_similarity_score"] = similarity[i] return [stateful_documents[i] for i in included_idxs]