Source code for langchain.chains.llm_checker.base

"""Chain for question-answering with self-verification."""
from __future__ import annotations

import warnings
from typing import Any, Dict, List, Optional

from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import Extra, root_validator

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_checker.prompt import (
    CHECK_ASSERTIONS_PROMPT,
    CREATE_DRAFT_ANSWER_PROMPT,
    LIST_ASSERTIONS_PROMPT,
    REVISED_ANSWER_PROMPT,
)
from langchain.chains.sequential import SequentialChain


def _load_question_to_checked_assertions_chain(
    llm: BaseLanguageModel,
    create_draft_answer_prompt: PromptTemplate,
    list_assertions_prompt: PromptTemplate,
    check_assertions_prompt: PromptTemplate,
    revised_answer_prompt: PromptTemplate,
) -> SequentialChain:
    create_draft_answer_chain = LLMChain(
        llm=llm,
        prompt=create_draft_answer_prompt,
        output_key="statement",
    )
    list_assertions_chain = LLMChain(
        llm=llm,
        prompt=list_assertions_prompt,
        output_key="assertions",
    )
    check_assertions_chain = LLMChain(
        llm=llm,
        prompt=check_assertions_prompt,
        output_key="checked_assertions",
    )
    revised_answer_chain = LLMChain(
        llm=llm,
        prompt=revised_answer_prompt,
        output_key="revised_statement",
    )
    chains = [
        create_draft_answer_chain,
        list_assertions_chain,
        check_assertions_chain,
        revised_answer_chain,
    ]
    question_to_checked_assertions_chain = SequentialChain(
        chains=chains,  # type: ignore[arg-type]
        input_variables=["question"],
        output_variables=["revised_statement"],
        verbose=True,
    )
    return question_to_checked_assertions_chain


[docs]class LLMCheckerChain(Chain): """Chain for question-answering with self-verification. Example: .. code-block:: python from langchain_community.llms import OpenAI from langchain.chains import LLMCheckerChain llm = OpenAI(temperature=0.7) checker_chain = LLMCheckerChain.from_llm(llm) """ question_to_checked_assertions_chain: SequentialChain llm: Optional[BaseLanguageModel] = None """[Deprecated] LLM wrapper to use.""" create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT """[Deprecated]""" list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT """[Deprecated]""" check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT """[Deprecated]""" revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT """[Deprecated] Prompt to use when questioning the documents.""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: class Config: """Configuration for this pydantic object.""" extra = Extra.forbid arbitrary_types_allowed = True @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: if "llm" in values: warnings.warn( "Directly instantiating an LLMCheckerChain with an llm is deprecated. " "Please instantiate with question_to_checked_assertions_chain " "or using the from_llm class method." ) if ( "question_to_checked_assertions_chain" not in values and values["llm"] is not None ): question_to_checked_assertions_chain = ( _load_question_to_checked_assertions_chain( values["llm"], values.get( "create_draft_answer_prompt", CREATE_DRAFT_ANSWER_PROMPT ), values.get("list_assertions_prompt", LIST_ASSERTIONS_PROMPT), values.get("check_assertions_prompt", CHECK_ASSERTIONS_PROMPT), values.get("revised_answer_prompt", REVISED_ANSWER_PROMPT), ) ) values[ "question_to_checked_assertions_chain" ] = question_to_checked_assertions_chain return values @property def input_keys(self) -> List[str]: """Return the singular input key. :meta private: """ return [self.input_key] @property def output_keys(self) -> List[str]: """Return the singular output key. :meta private: """ return [self.output_key] def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] output = self.question_to_checked_assertions_chain( {"question": question}, callbacks=_run_manager.get_child() ) return {self.output_key: output["revised_statement"]} @property def _chain_type(self) -> str: return "llm_checker_chain"
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT, list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT, check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT, revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT, **kwargs: Any, ) -> LLMCheckerChain: question_to_checked_assertions_chain = ( _load_question_to_checked_assertions_chain( llm, create_draft_answer_prompt, list_assertions_prompt, check_assertions_prompt, revised_answer_prompt, ) ) return cls( question_to_checked_assertions_chain=question_to_checked_assertions_chain, **kwargs, )