Source code for langchain.chains.retrieval_qa.base

"""Chain for question-answering against a vector database."""
from __future__ import annotations

import inspect
import warnings
from abc import abstractmethod
from typing import Any, Dict, List, Optional

from langchain_core._api import deprecated
from langchain_core.callbacks import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
    Callbacks,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore

from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR


[docs]class BaseRetrievalQA(Chain): """Base class for question-answering chains.""" combine_documents_chain: BaseCombineDocumentsChain """Chain to use to combine the documents.""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: return_source_documents: bool = False """Return the source documents or not.""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid arbitrary_types_allowed = True allow_population_by_field_name = True @property def input_keys(self) -> List[str]: """Input keys. :meta private: """ return [self.input_key] @property def output_keys(self) -> List[str]: """Output keys. :meta private: """ _output_keys = [self.output_key] if self.return_source_documents: _output_keys = _output_keys + ["source_documents"] return _output_keys
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, prompt: Optional[PromptTemplate] = None, callbacks: Callbacks = None, llm_chain_kwargs: Optional[dict] = None, **kwargs: Any, ) -> BaseRetrievalQA: """Initialize from LLM.""" _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) llm_chain = LLMChain( llm=llm, prompt=_prompt, callbacks=callbacks, **(llm_chain_kwargs or {}) ) document_prompt = PromptTemplate( input_variables=["page_content"], template="Context:\n{page_content}" ) combine_documents_chain = StuffDocumentsChain( llm_chain=llm_chain, document_variable_name="context", document_prompt=document_prompt, callbacks=callbacks, ) return cls( combine_documents_chain=combine_documents_chain, callbacks=callbacks, **kwargs, )
[docs] @classmethod def from_chain_type( cls, llm: BaseLanguageModel, chain_type: str = "stuff", chain_type_kwargs: Optional[dict] = None, **kwargs: Any, ) -> BaseRetrievalQA: """Load chain from chain type.""" _chain_type_kwargs = chain_type_kwargs or {} combine_documents_chain = load_qa_chain( llm, chain_type=chain_type, **_chain_type_kwargs ) return cls(combine_documents_chain=combine_documents_chain, **kwargs)
@abstractmethod def _get_docs( self, question: str, *, run_manager: CallbackManagerForChainRun, ) -> List[Document]: """Get documents to do question answering over.""" def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """Run get_relevant_text and llm on input query. If chain has 'return_source_documents' as 'True', returns the retrieved documents as well under the key 'source_documents'. Example: .. code-block:: python res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """ _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] accepts_run_manager = ( "run_manager" in inspect.signature(self._get_docs).parameters ) if accepts_run_manager: docs = self._get_docs(question, run_manager=_run_manager) else: docs = self._get_docs(question) # type: ignore[call-arg] answer = self.combine_documents_chain.run( input_documents=docs, question=question, callbacks=_run_manager.get_child() ) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: return {self.output_key: answer} @abstractmethod async def _aget_docs( self, question: str, *, run_manager: AsyncCallbackManagerForChainRun, ) -> List[Document]: """Get documents to do question answering over.""" async def _acall( self, inputs: Dict[str, Any], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """Run get_relevant_text and llm on input query. If chain has 'return_source_documents' as 'True', returns the retrieved documents as well under the key 'source_documents'. Example: .. code-block:: python res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """ _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] accepts_run_manager = ( "run_manager" in inspect.signature(self._aget_docs).parameters ) if accepts_run_manager: docs = await self._aget_docs(question, run_manager=_run_manager) else: docs = await self._aget_docs(question) # type: ignore[call-arg] answer = await self.combine_documents_chain.arun( input_documents=docs, question=question, callbacks=_run_manager.get_child() ) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: return {self.output_key: answer}
[docs]@deprecated(since="0.1.17", alternative="create_retrieval_chain", removal="0.3.0") class RetrievalQA(BaseRetrievalQA): """Chain for question-answering against an index. This class is deprecated. See below for an example implementation using `create_retrieval_chain`: .. code-block:: python from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI retriever = ... # Your retriever llm = ChatOpenAI() system_prompt = ( "Use the given context to answer the question. " "If you don't know the answer, say you don't know. " "Use three sentence maximum and keep the answer concise. " "Context: {context}" ) prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), ("human", "{input}"), ] ) question_answer_chain = create_stuff_documents_chain(llm, prompt) chain = create_retrieval_chain(retriever, question_answer_chain) chain.invoke({"input": query}) Example: .. code-block:: python from langchain_community.llms import OpenAI from langchain.chains import RetrievalQA from langchain_community.vectorstores import FAISS from langchain_core.vectorstores import VectorStoreRetriever retriever = VectorStoreRetriever(vectorstore=FAISS(...)) retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever) """ retriever: BaseRetriever = Field(exclude=True) def _get_docs( self, question: str, *, run_manager: CallbackManagerForChainRun, ) -> List[Document]: """Get docs.""" return self.retriever.invoke( question, config={"callbacks": run_manager.get_child()} ) async def _aget_docs( self, question: str, *, run_manager: AsyncCallbackManagerForChainRun, ) -> List[Document]: """Get docs.""" return await self.retriever.ainvoke( question, config={"callbacks": run_manager.get_child()} ) @property def _chain_type(self) -> str: """Return the chain type.""" return "retrieval_qa"
[docs]class VectorDBQA(BaseRetrievalQA): """Chain for question-answering against a vector database.""" vectorstore: VectorStore = Field(exclude=True, alias="vectorstore") """Vector Database to connect to.""" k: int = 4 """Number of documents to query for.""" search_type: str = "similarity" """Search type to use over vectorstore. `similarity` or `mmr`.""" search_kwargs: Dict[str, Any] = Field(default_factory=dict) """Extra search args.""" @root_validator() def raise_deprecation(cls, values: Dict) -> Dict: warnings.warn( "`VectorDBQA` is deprecated - " "please use `from langchain.chains import RetrievalQA`" ) return values @root_validator() def validate_search_type(cls, values: Dict) -> Dict: """Validate search type.""" if "search_type" in values: search_type = values["search_type"] if search_type not in ("similarity", "mmr"): raise ValueError(f"search_type of {search_type} not allowed.") return values def _get_docs( self, question: str, *, run_manager: CallbackManagerForChainRun, ) -> List[Document]: """Get docs.""" if self.search_type == "similarity": docs = self.vectorstore.similarity_search( question, k=self.k, **self.search_kwargs ) elif self.search_type == "mmr": docs = self.vectorstore.max_marginal_relevance_search( question, k=self.k, **self.search_kwargs ) else: raise ValueError(f"search_type of {self.search_type} not allowed.") return docs async def _aget_docs( self, question: str, *, run_manager: AsyncCallbackManagerForChainRun, ) -> List[Document]: """Get docs.""" raise NotImplementedError("VectorDBQA does not support async") @property def _chain_type(self) -> str: """Return the chain type.""" return "vector_db_qa"