Source code for langchain_ai21.embeddings

from itertools import islice
from typing import Any, Iterator, List, Optional

from ai21.models import EmbedType
from langchain_core.embeddings import Embeddings

from langchain_ai21.ai21_base import AI21Base

_DEFAULT_BATCH_SIZE = 128


def _split_texts_into_batches(texts: List[str], batch_size: int) -> Iterator[List[str]]:
    texts_itr = iter(texts)
    return iter(lambda: list(islice(texts_itr, batch_size)), [])


[docs]class AI21Embeddings(Embeddings, AI21Base): """AI21 Embeddings embedding model. To use, you should have the 'AI21_API_KEY' environment variable set or pass as a named parameter to the constructor. Example: .. code-block:: python from langchain_ai21 import AI21Embeddings embeddings = AI21Embeddings() query_result = embeddings.embed_query("Hello embeddings world!") """ batch_size: int = _DEFAULT_BATCH_SIZE """Maximum number of texts to embed in each batch"""
[docs] def embed_documents( self, texts: List[str], *, batch_size: Optional[int] = None, **kwargs: Any, ) -> List[List[float]]: """Embed search docs.""" return self._send_embeddings( texts=texts, batch_size=batch_size or self.batch_size, embed_type=EmbedType.SEGMENT, **kwargs, )
[docs] def embed_query( self, text: str, *, batch_size: Optional[int] = None, **kwargs: Any, ) -> List[float]: """Embed query text.""" return self._send_embeddings( texts=[text], batch_size=batch_size or self.batch_size, embed_type=EmbedType.QUERY, **kwargs, )[0]
def _send_embeddings( self, texts: List[str], *, batch_size: int, embed_type: EmbedType, **kwargs: Any ) -> List[List[float]]: chunks = _split_texts_into_batches(texts, batch_size) responses = [ self.client.embed.create( texts=chunk, type=embed_type, **kwargs, ) for chunk in chunks ] return [ result.embedding for response in responses for result in response.results ]