Source code for langchain.callbacks.streaming_aiter

from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast

from langchain_core.callbacks import AsyncCallbackHandler
from langchain_core.outputs import LLMResult

# TODO If used by two LLM runs in parallel this won't work as expected

[docs]class AsyncIteratorCallbackHandler(AsyncCallbackHandler): """Callback handler that returns an async iterator.""" queue: asyncio.Queue[str] done: asyncio.Event @property def always_verbose(self) -> bool: return True
[docs] def __init__(self) -> None: self.queue = asyncio.Queue() self.done = asyncio.Event()
[docs] async def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: # If two calls are made in a row, this resets the state self.done.clear()
[docs] async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: if token is not None and token != "": self.queue.put_nowait(token)
[docs] async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.done.set()
[docs] async def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: self.done.set()
# TODO implement the other methods
[docs] async def aiter(self) -> AsyncIterator[str]: while not self.queue.empty() or not self.done.is_set(): # Wait for the next token in the queue, # but stop waiting if the done event is set done, other = await asyncio.wait( [ # NOTE: If you add other tasks here, update the code below, # which assumes each set has exactly one task each asyncio.ensure_future(self.queue.get()), asyncio.ensure_future(self.done.wait()), ], return_when=asyncio.FIRST_COMPLETED, ) # Cancel the other task if other: other.pop().cancel() # Extract the value of the first completed task token_or_done = cast(Union[str, Literal[True]], done.pop().result()) # If the extracted value is the boolean True, the done event was set if token_or_done is True: break # Otherwise, the extracted value is a token, which we yield yield token_or_done