Source code for langchain_ai21.embeddings
from itertools import islice
from typing import Any, Iterator, List, Optional
from ai21.models import EmbedType
from langchain_core.embeddings import Embeddings
from langchain_ai21.ai21_base import AI21Base
_DEFAULT_BATCH_SIZE = 128
def _split_texts_into_batches(texts: List[str], batch_size: int) -> Iterator[List[str]]:
texts_itr = iter(texts)
return iter(lambda: list(islice(texts_itr, batch_size)), [])
[docs]class AI21Embeddings(Embeddings, AI21Base):
"""AI21 embedding model integration.
Install ``langchain_ai21`` and set environment variable ``AI21_API_KEY``.
.. code-block:: bash
pip install -U langchain_ai21
export AI21_API_KEY="your-api-key"
Key init args β client params:
api_key: Optional[SecretStr]
batch_size: int
The number of texts that will be sent to the API in each batch.
Use larger batch sizes if working with many short texts. This will reduce
the number of API calls made, and can improve the time it takes to embed
a large number of texts.
num_retries: Optional[int]
Maximum number of retries for API requests before giving up.
timeout_sec: Optional[float]
Timeout in seconds for API requests. If not set, it will default to the
value of the environment variable `AI21_TIMEOUT_SEC` or 300 seconds.
See full list of supported init args and their descriptions in the params section.
Instantiate:
.. code-block:: python
from langchain_ai21 import AI21Embeddings
embed = AI21Embeddings(
# api_key="...",
# batch_size=128,
)
Embed single text:
.. code-block:: python
input_text = "The meaning of life is 42"
vector = embed.embed_query(input_text)
print(vector[:3])
.. code-block:: python
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
Embed multiple texts:
.. code-block:: python
input_texts = ["Document 1...", "Document 2..."]
vectors = embed.embed_documents(input_texts)
print(len(vectors))
# The first 3 coordinates for the first vector
print(vectors[0][:3])
.. code-block:: python
2
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
"""
batch_size: int = _DEFAULT_BATCH_SIZE
"""Maximum number of texts to embed in each batch"""
[docs] def embed_documents(
self,
texts: List[str],
*,
batch_size: Optional[int] = None,
**kwargs: Any,
) -> List[List[float]]:
"""Embed search docs."""
return self._send_embeddings(
texts=texts,
batch_size=batch_size or self.batch_size,
embed_type=EmbedType.SEGMENT,
**kwargs,
)
[docs] def embed_query(
self,
text: str,
*,
batch_size: Optional[int] = None,
**kwargs: Any,
) -> List[float]:
"""Embed query text."""
return self._send_embeddings(
texts=[text],
batch_size=batch_size or self.batch_size,
embed_type=EmbedType.QUERY,
**kwargs,
)[0]
def _send_embeddings(
self, texts: List[str], *, batch_size: int, embed_type: EmbedType, **kwargs: Any
) -> List[List[float]]:
chunks = _split_texts_into_batches(texts, batch_size)
responses = [
self.client.embed.create(
texts=chunk,
type=embed_type,
**kwargs,
)
for chunk in chunks
]
return [
result.embedding for response in responses for result in response.results
]