Source code for langchain.chains.transform

"""Chain that runs an arbitrary python function."""
import functools
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional

from langchain_core.callbacks import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
)
from langchain_core.pydantic_v1 import Field

from langchain.chains.base import Chain

logger = logging.getLogger(__name__)


[docs]class TransformChain(Chain): """Chain that transforms the chain output. Example: .. code-block:: python from langchain.chains import TransformChain transform_chain = TransformChain(input_variables=["text"], output_variables["entities"], transform=func()) """ input_variables: List[str] """The keys expected by the transform's input dictionary.""" output_variables: List[str] """The keys returned by the transform's output dictionary.""" transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform") """The transform function.""" atransform_cb: Optional[ Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]] ] = Field(None, alias="atransform") """The async coroutine transform function.""" @staticmethod @functools.lru_cache def _log_once(msg: str) -> None: """Log a message once. :meta private: """ logger.warning(msg) @property def input_keys(self) -> List[str]: """Expect input keys. :meta private: """ return self.input_variables @property def output_keys(self) -> List[str]: """Return output keys. :meta private: """ return self.output_variables def _call( self, inputs: Dict[str, str], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: return self.transform_cb(inputs) async def _acall( self, inputs: Dict[str, Any], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: if self.atransform_cb is not None: return await self.atransform_cb(inputs) else: self._log_once( "TransformChain's atransform is not provided, falling" " back to synchronous transform" ) return self.transform_cb(inputs)