Source code for langchain.retrievers.self_query.base

"""Retriever that generates and executes structured queries over its own data source."""

import logging
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union

from langchain_core.callbacks.manager import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable
from langchain_core.structured_query import StructuredQuery, Visitor
from langchain_core.vectorstores import VectorStore

from langchain.chains.query_constructor.base import load_query_constructor_runnable
from langchain.chains.query_constructor.schema import AttributeInfo

logger = logging.getLogger(__name__)
QUERY_CONSTRUCTOR_RUN_NAME = "query_constructor"


def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
    """Get the translator class corresponding to the vector store class."""
    try:
        import langchain_community  # noqa: F401
    except ImportError:
        raise ImportError(
            "The langchain-community package must be installed to use this feature."
            " Please install it using `pip install langchain-community`."
        )

    from langchain_community.query_constructors.astradb import AstraDBTranslator
    from langchain_community.query_constructors.chroma import ChromaTranslator
    from langchain_community.query_constructors.dashvector import DashvectorTranslator
    from langchain_community.query_constructors.databricks_vector_search import (
        DatabricksVectorSearchTranslator,
    )
    from langchain_community.query_constructors.deeplake import DeepLakeTranslator
    from langchain_community.query_constructors.dingo import DingoDBTranslator
    from langchain_community.query_constructors.elasticsearch import (
        ElasticsearchTranslator,
    )
    from langchain_community.query_constructors.milvus import MilvusTranslator
    from langchain_community.query_constructors.mongodb_atlas import (
        MongoDBAtlasTranslator,
    )
    from langchain_community.query_constructors.myscale import MyScaleTranslator
    from langchain_community.query_constructors.opensearch import OpenSearchTranslator
    from langchain_community.query_constructors.pgvector import PGVectorTranslator
    from langchain_community.query_constructors.pinecone import PineconeTranslator
    from langchain_community.query_constructors.qdrant import QdrantTranslator
    from langchain_community.query_constructors.redis import RedisTranslator
    from langchain_community.query_constructors.supabase import SupabaseVectorTranslator
    from langchain_community.query_constructors.tencentvectordb import (
        TencentVectorDBTranslator,
    )
    from langchain_community.query_constructors.timescalevector import (
        TimescaleVectorTranslator,
    )
    from langchain_community.query_constructors.vectara import VectaraTranslator
    from langchain_community.query_constructors.weaviate import WeaviateTranslator
    from langchain_community.vectorstores import (
        AstraDB,
        Chroma,
        DashVector,
        DatabricksVectorSearch,
        DeepLake,
        Dingo,
        Milvus,
        MongoDBAtlasVectorSearch,
        MyScale,
        OpenSearchVectorSearch,
        PGVector,
        Qdrant,
        Redis,
        SupabaseVectorStore,
        TencentVectorDB,
        TimescaleVector,
        Vectara,
        Weaviate,
    )
    from langchain_community.vectorstores import (
        ElasticsearchStore as ElasticsearchStoreCommunity,
    )
    from langchain_community.vectorstores import (
        Pinecone as CommunityPinecone,
    )

    BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = {
        AstraDB: AstraDBTranslator,
        PGVector: PGVectorTranslator,
        CommunityPinecone: PineconeTranslator,
        Chroma: ChromaTranslator,
        DashVector: DashvectorTranslator,
        Dingo: DingoDBTranslator,
        Weaviate: WeaviateTranslator,
        Vectara: VectaraTranslator,
        Qdrant: QdrantTranslator,
        MyScale: MyScaleTranslator,
        DeepLake: DeepLakeTranslator,
        ElasticsearchStoreCommunity: ElasticsearchTranslator,
        Milvus: MilvusTranslator,
        SupabaseVectorStore: SupabaseVectorTranslator,
        TimescaleVector: TimescaleVectorTranslator,
        OpenSearchVectorSearch: OpenSearchTranslator,
        MongoDBAtlasVectorSearch: MongoDBAtlasTranslator,
    }
    if isinstance(vectorstore, DatabricksVectorSearch):
        return DatabricksVectorSearchTranslator()
    if isinstance(vectorstore, Qdrant):
        return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key)
    elif isinstance(vectorstore, MyScale):
        return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
    elif isinstance(vectorstore, Redis):
        return RedisTranslator.from_vectorstore(vectorstore)
    elif isinstance(vectorstore, TencentVectorDB):
        fields = [
            field.name for field in (vectorstore.meta_fields or []) if field.index
        ]
        return TencentVectorDBTranslator(fields)
    elif vectorstore.__class__ in BUILTIN_TRANSLATORS:
        return BUILTIN_TRANSLATORS[vectorstore.__class__]()
    else:
        try:
            from langchain_astradb.vectorstores import AstraDBVectorStore
        except ImportError:
            pass
        else:
            if isinstance(vectorstore, AstraDBVectorStore):
                return AstraDBTranslator()

        try:
            from langchain_elasticsearch.vectorstores import ElasticsearchStore
        except ImportError:
            pass
        else:
            if isinstance(vectorstore, ElasticsearchStore):
                return ElasticsearchTranslator()

        try:
            from langchain_pinecone import PineconeVectorStore
        except ImportError:
            pass
        else:
            if isinstance(vectorstore, PineconeVectorStore):
                return PineconeTranslator()

        raise ValueError(
            f"Self query retriever with Vector Store type {vectorstore.__class__}"
            f" not supported."
        )


[docs]class SelfQueryRetriever(BaseRetriever): """Retriever that uses a vector store and an LLM to generate the vector store queries.""" vectorstore: VectorStore """The underlying vector store from which documents will be retrieved.""" query_constructor: Runnable[dict, StructuredQuery] = Field(alias="llm_chain") """The query constructor chain for generating the vector store queries. llm_chain is legacy name kept for backwards compatibility.""" search_type: str = "similarity" """The search type to perform on the vector store.""" search_kwargs: dict = Field(default_factory=dict) """Keyword arguments to pass in to the vector store search.""" structured_query_translator: Visitor """Translator for turning internal query language into vectorstore search params.""" verbose: bool = False use_original_query: bool = False """Use original query instead of the revised new query from LLM""" class Config: """Configuration for this pydantic object.""" arbitrary_types_allowed = True allow_population_by_field_name = True @root_validator(pre=True) def validate_translator(cls, values: Dict) -> Dict: """Validate translator.""" if "structured_query_translator" not in values: values["structured_query_translator"] = _get_builtin_translator( values["vectorstore"] ) return values @property def llm_chain(self) -> Runnable: """llm_chain is legacy name kept for backwards compatibility.""" return self.query_constructor def _prepare_query( self, query: str, structured_query: StructuredQuery ) -> Tuple[str, Dict[str, Any]]: new_query, new_kwargs = self.structured_query_translator.visit_structured_query( structured_query ) if structured_query.limit is not None: new_kwargs["k"] = structured_query.limit if self.use_original_query: new_query = query search_kwargs = {**self.search_kwargs, **new_kwargs} return new_query, search_kwargs def _get_docs_with_query( self, query: str, search_kwargs: Dict[str, Any] ) -> List[Document]: docs = self.vectorstore.search(query, self.search_type, **search_kwargs) return docs async def _aget_docs_with_query( self, query: str, search_kwargs: Dict[str, Any] ) -> List[Document]: docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs) return docs def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Get documents relevant for a query. Args: query: string to find relevant documents for Returns: List of relevant documents """ structured_query = self.query_constructor.invoke( {"query": query}, config={"callbacks": run_manager.get_child()} ) if self.verbose: logger.info(f"Generated Query: {structured_query}") new_query, search_kwargs = self._prepare_query(query, structured_query) docs = self._get_docs_with_query(new_query, search_kwargs) return docs async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: """Get documents relevant for a query. Args: query: string to find relevant documents for Returns: List of relevant documents """ structured_query = await self.query_constructor.ainvoke( {"query": query}, config={"callbacks": run_manager.get_child()} ) if self.verbose: logger.info(f"Generated Query: {structured_query}") new_query, search_kwargs = self._prepare_query(query, structured_query) docs = await self._aget_docs_with_query(new_query, search_kwargs) return docs
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, vectorstore: VectorStore, document_contents: str, metadata_field_info: Sequence[Union[AttributeInfo, dict]], structured_query_translator: Optional[Visitor] = None, chain_kwargs: Optional[Dict] = None, enable_limit: bool = False, use_original_query: bool = False, **kwargs: Any, ) -> "SelfQueryRetriever": if structured_query_translator is None: structured_query_translator = _get_builtin_translator(vectorstore) chain_kwargs = chain_kwargs or {} if ( "allowed_comparators" not in chain_kwargs and structured_query_translator.allowed_comparators is not None ): chain_kwargs[ "allowed_comparators" ] = structured_query_translator.allowed_comparators if ( "allowed_operators" not in chain_kwargs and structured_query_translator.allowed_operators is not None ): chain_kwargs[ "allowed_operators" ] = structured_query_translator.allowed_operators query_constructor = load_query_constructor_runnable( llm, document_contents, metadata_field_info, enable_limit=enable_limit, **chain_kwargs, ) query_constructor = query_constructor.with_config( run_name=QUERY_CONSTRUCTOR_RUN_NAME ) return cls( # type: ignore[call-arg] query_constructor=query_constructor, vectorstore=vectorstore, use_original_query=use_original_query, structured_query_translator=structured_query_translator, **kwargs, )