Source code for langchain.chains.combine_documents.base

"""Base interface for chains combining documents."""

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Type

from langchain_core.callbacks import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
)
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.utils import create_model
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter

from langchain.chains.base import Chain

DEFAULT_DOCUMENT_SEPARATOR = "\n\n"
DOCUMENTS_KEY = "context"
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")


def _validate_prompt(prompt: BasePromptTemplate) -> None:
    if DOCUMENTS_KEY not in prompt.input_variables:
        raise ValueError(
            f"Prompt must accept {DOCUMENTS_KEY} as an input variable. Received prompt "
            f"with input variables: {prompt.input_variables}"
        )


[docs]class BaseCombineDocumentsChain(Chain, ABC): """Base interface for chains combining documents. Subclasses of this chain deal with combining documents in a variety of ways. This base class exists to add some uniformity in the interface these types of chains should expose. Namely, they expect an input key related to the documents to use (default `input_documents`), and then also expose a method to calculate the length of a prompt from documents (useful for outside callers to use to determine whether it's safe to pass a list of documents into this chain or whether that will be longer than the context length). """ input_key: str = "input_documents" #: :meta private: output_key: str = "output_text" #: :meta private:
[docs] def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: return create_model( "CombineDocumentsInput", **{self.input_key: (List[Document], None)}, # type: ignore[call-overload] )
[docs] def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: return create_model( "CombineDocumentsOutput", **{self.output_key: (str, None)}, # type: ignore[call-overload] )
@property def input_keys(self) -> List[str]: """Expect input key. :meta private: """ return [self.input_key] @property def output_keys(self) -> List[str]: """Return output key. :meta private: """ return [self.output_key]
[docs] def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: """Return the prompt length given the documents passed in. This can be used by a caller to determine whether passing in a list of documents would exceed a certain prompt length. This useful when trying to ensure that the size of a prompt remains below a certain context limit. Args: docs: List[Document], a list of documents to use to calculate the total prompt length. Returns: Returns None if the method does not depend on the prompt length, otherwise the length of the prompt in tokens. """ return None
[docs] @abstractmethod def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: """Combine documents into a single string. Args: docs: List[Document], the documents to combine **kwargs: Other parameters to use in combining documents, often other inputs to the prompt. Returns: The first element returned is the single string output. The second element returned is a dictionary of other keys to return. """
[docs] @abstractmethod async def acombine_docs( self, docs: List[Document], **kwargs: Any ) -> Tuple[str, dict]: """Combine documents into a single string. Args: docs: List[Document], the documents to combine **kwargs: Other parameters to use in combining documents, often other inputs to the prompt. Returns: The first element returned is the single string output. The second element returned is a dictionary of other keys to return. """
def _call( self, inputs: Dict[str, List[Document]], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: """Prepare inputs, call combine docs, prepare outputs.""" _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() docs = inputs[self.input_key] # Other keys are assumed to be needed for LLM prediction other_keys = {k: v for k, v in inputs.items() if k != self.input_key} output, extra_return_dict = self.combine_docs( docs, callbacks=_run_manager.get_child(), **other_keys ) extra_return_dict[self.output_key] = output return extra_return_dict async def _acall( self, inputs: Dict[str, List[Document]], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, str]: """Prepare inputs, call combine docs, prepare outputs.""" _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() docs = inputs[self.input_key] # Other keys are assumed to be needed for LLM prediction other_keys = {k: v for k, v in inputs.items() if k != self.input_key} output, extra_return_dict = await self.acombine_docs( docs, callbacks=_run_manager.get_child(), **other_keys ) extra_return_dict[self.output_key] = output return extra_return_dict
[docs]class AnalyzeDocumentChain(Chain): """Chain that splits documents, then analyzes it in pieces. This chain is parameterized by a TextSplitter and a CombineDocumentsChain. This chain takes a single document as input, and then splits it up into chunks and then passes those chucks to the CombineDocumentsChain. """ input_key: str = "input_document" #: :meta private: text_splitter: TextSplitter = Field(default_factory=RecursiveCharacterTextSplitter) combine_docs_chain: BaseCombineDocumentsChain @property def input_keys(self) -> List[str]: """Expect input key. :meta private: """ return [self.input_key] @property def output_keys(self) -> List[str]: """Return output key. :meta private: """ return self.combine_docs_chain.output_keys
[docs] def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: return create_model( "AnalyzeDocumentChain", **{self.input_key: (str, None)}, # type: ignore[call-overload] )
[docs] def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: return self.combine_docs_chain.get_output_schema(config)
def _call( self, inputs: Dict[str, str], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: """Split document into chunks and pass to CombineDocumentsChain.""" _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() document = inputs[self.input_key] docs = self.text_splitter.create_documents([document]) # Other keys are assumed to be needed for LLM prediction other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key} other_keys[self.combine_docs_chain.input_key] = docs return self.combine_docs_chain( other_keys, return_only_outputs=True, callbacks=_run_manager.get_child() )