Source code for langchain_google_vertexai.model_garden

from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional

from langchain_core.callbacks.manager import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
    BaseChatModel,
    agenerate_from_stream,
    generate_from_stream,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
)
from langchain_core.outputs import (
    ChatGeneration,
    ChatGenerationChunk,
    ChatResult,
    Generation,
    LLMResult,
)
from langchain_core.pydantic_v1 import Field, root_validator

from langchain_google_vertexai._anthropic_utils import _format_messages_anthropic
from langchain_google_vertexai._base import _BaseVertexAIModelGarden, _VertexAICommon


[docs]class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM): """Large language models served from Vertex AI Model Garden.""" class Config: """Configuration for this pydantic object.""" allow_population_by_field_name = True def _generate( self, prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" instances = self._prepare_request(prompts, **kwargs) if self.single_example_per_request and len(instances) > 1: results = [] for instance in instances: response = self.client.predict( endpoint=self.endpoint_path, instances=[instance] ) results.append(self._parse_prediction(response.predictions[0])) return LLMResult( generations=[[Generation(text=result)] for result in results] ) response = self.client.predict(endpoint=self.endpoint_path, instances=instances) return self._parse_response(response) async def _agenerate( self, prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" instances = self._prepare_request(prompts, **kwargs) if self.single_example_per_request and len(instances) > 1: responses = [] for instance in instances: responses.append( self.async_client.predict( endpoint=self.endpoint_path, instances=[instance] ) ) responses = await asyncio.gather(*responses) return LLMResult( generations=[ [Generation(text=self._parse_prediction(response.predictions[0]))] for response in responses ] ) response = await self.async_client.predict( endpoint=self.endpoint_path, instances=instances ) return self._parse_response(response)
[docs]class ChatAnthropicVertex(_VertexAICommon, BaseChatModel): async_client: Any = None #: :meta private: model_name: Optional[str] = Field(default=None, alias="model") # type: ignore[assignment] "Underlying model name." max_output_tokens: int = Field(default=1024, alias="max_tokens") access_token: Optional[str] = None class Config: """Configuration for this pydantic object.""" allow_population_by_field_name = True @root_validator() def validate_environment(cls, values: Dict) -> Dict: from anthropic import ( # type: ignore[import-not-found] AnthropicVertex, AsyncAnthropicVertex, ) values["client"] = AnthropicVertex( project_id=values["project"], region=values["location"], max_retries=values["max_retries"], access_token=values["access_token"], ) values["async_client"] = AsyncAnthropicVertex( project_id=values["project"], region=values["location"], max_retries=values["max_retries"], access_token=values["access_token"], ) return values @property def _default_params(self): return { "model": self.model_name, "max_tokens": self.max_output_tokens, "temperature": self.temperature, "top_k": self.top_k, "top_p": self.top_p, } def _format_params( self, *, messages: List[BaseMessage], stop: Optional[List[str]] = None, **kwargs: Any, ) -> Dict[str, Any]: system_message, formatted_messages = _format_messages_anthropic(messages) params = self._default_params params.update(kwargs) if kwargs.get("model_name"): params["model"] = params["model_name"] if kwargs.get("model"): params["model"] = kwargs["model"] params.pop("model_name", None) params.update( { "system": system_message, "messages": formatted_messages, "stop_sequences": stop, } ) return {k: v for k, v in params.items() if v is not None} def _format_output(self, data: Any, **kwargs: Any) -> ChatResult: data_dict = data.model_dump() content = data_dict["content"] llm_output = { k: v for k, v in data_dict.items() if k not in ("content", "role", "type") } if len(content) == 1 and content[0]["type"] == "text": msg = AIMessage(content=content[0]["text"]) else: msg = AIMessage(content=content) return ChatResult( generations=[ChatGeneration(message=msg)], llm_output=llm_output, ) def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: params = self._format_params(messages=messages, stop=stop, **kwargs) if self.streaming: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ) return generate_from_stream(stream_iter) data = self.client.messages.create(**params) return self._format_output(data, **kwargs) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: params = self._format_params(messages=messages, stop=stop, **kwargs) if self.streaming: stream_iter = self._astream( messages, stop=stop, run_manager=run_manager, **kwargs ) return await agenerate_from_stream(stream_iter) data = await self.async_client.messages.create(**params) return self._format_output(data, **kwargs) @property def _llm_type(self) -> str: """Return type of chat model.""" return "anthropic-chat-vertexai" def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: params = self._format_params(messages=messages, stop=stop, **kwargs) with self.client.messages.stream(**params) as stream: for text in stream.text_stream: chunk = ChatGenerationChunk(message=AIMessageChunk(content=text)) if run_manager: run_manager.on_llm_new_token(text, chunk=chunk) yield chunk async def _astream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: params = self._format_params(messages=messages, stop=stop, **kwargs) async with self.async_client.messages.stream(**params) as stream: async for text in stream.text_stream: chunk = ChatGenerationChunk(message=AIMessageChunk(content=text)) if run_manager: await run_manager.on_llm_new_token(text, chunk=chunk) yield chunk