Source code for langchain_nvidia_ai_endpoints._common

from __future__ import annotations

import json
import logging
import os
import time
from copy import deepcopy
from typing import (
    Any,
    AsyncIterator,
    Callable,
    Dict,
    Generator,
    Iterator,
    List,
    Literal,
    Optional,
    Sequence,
    Tuple,
    Union,
)
from urllib.parse import urlparse

import aiohttp
import requests
from langchain_core._api import deprecated, warn_deprecated
from langchain_core.pydantic_v1 import (
    BaseModel,
    Field,
    PrivateAttr,
    SecretStr,
    root_validator,
    validator,
)
from requests.models import Response

from langchain_nvidia_ai_endpoints._statics import MODEL_SPECS, Model

logger = logging.getLogger(__name__)

_MODE_TYPE = Literal["nvidia", "nim"]


def default_payload_fn(payload: dict) -> dict:
    return payload


class NVEModel(BaseModel):

    """
    Underlying Client for interacting with the AI Foundation Model Function API.
    Leveraged by the NVIDIABaseModel to provide a simple requests-oriented interface.
    Direct abstraction over NGC-recommended streaming/non-streaming Python solutions.

    NOTE: Models in the playground does not currently support raw text continuation.
    """

    # todo: add a validator for requests.Response (last_response attribute) and
    #       remove arbitrary_types_allowed=True
    class Config:
        arbitrary_types_allowed = True

    ## Core defaults. These probably should not be changed
    _api_key_var = "NVIDIA_API_KEY"
    base_url: str = Field(
        "https://api.nvcf.nvidia.com/v2/nvcf",
        description="Base URL for standard inference",
    )
    get_session_fn: Callable = Field(requests.Session)
    get_asession_fn: Callable = Field(aiohttp.ClientSession)
    endpoints: dict = Field(
        {
            "infer": "{base_url}/pexec/functions/{model_id}",
            "status": "{base_url}/pexec/status/{request_id}",
            "models": "{base_url}/functions",
        }
    )

    api_key: SecretStr = Field(..., description="API Key for service of choice")

    ## Generation arguments
    timeout: float = Field(60, ge=0, description="Timeout for waiting on response (s)")
    interval: float = Field(0.02, ge=0, description="Interval for pulling response")
    last_inputs: dict = Field({}, description="Last inputs sent over to the server")
    last_response: Response = Field(
        None, description="Last response sent from the server"
    )
    payload_fn: Callable = Field(
        default_payload_fn, description="Function to process payload"
    )
    headers_tmpl: dict = Field(
        ...,
        description="Headers template for API calls."
        " Should contain `call` and `stream` keys.",
    )
    _available_functions: Optional[List[dict]] = PrivateAttr(default=None)
    _available_models: Optional[dict] = PrivateAttr(default=None)

    @classmethod
    def is_lc_serializable(cls) -> bool:
        return True

    @property
    def lc_secrets(self) -> Dict[str, str]:
        return {"api_key": self._api_key_var}

    @property
    def headers(self) -> dict:
        """Return headers with API key injected"""
        headers_ = self.headers_tmpl.copy()
        for header in headers_.values():
            if "{api_key}" in header["Authorization"]:
                header["Authorization"] = header["Authorization"].format(
                    api_key=self.api_key.get_secret_value(),
                )
        return headers_

    @validator("base_url")
    def validate_base_url(cls, v: str) -> str:
        if v is not None:
            result = urlparse(v)
            # Ensure scheme and netloc (domain name) are present
            if not (result.scheme and result.netloc):
                raise ValueError(
                    f"Invalid base_url, minimally needs scheme and netloc: {v}"
                )
        return v

    @root_validator(pre=True)
    def validate_model(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Validate and update model arguments, including API key and formatting"""
        values["api_key"] = (
            values.get(cls._api_key_var.lower())
            or values.get("api_key")
            or os.getenv(cls._api_key_var)
            or ""
        )
        if "headers_tmpl" not in values:
            call_kvs = {
                "Accept": "application/json",
            }
            stream_kvs = {
                "Accept": "text/event-stream",
                "content-type": "application/json",
            }
            shared_kvs = {
                "Authorization": "Bearer {api_key}",
                "User-Agent": "langchain-nvidia-ai-endpoints",
            }
            values["headers_tmpl"] = {
                "call": {**call_kvs, **shared_kvs},
                "stream": {**stream_kvs, **shared_kvs},
            }
        return values

    @property
    def available_models(self) -> dict:
        """List the available models that can be invoked."""
        if self._available_models is not None:
            return self._available_models
        live_fns = self.available_functions
        if "status" in live_fns[0]:
            live_fns = [v for v in live_fns if v.get("status") == "ACTIVE"]
            self._available_models = {v["name"]: v["id"] for v in live_fns}
        else:
            self._available_models = {v.get("id"): v.get("owned_by") for v in live_fns}
        return self._available_models

    @property
    def available_functions(self) -> list:
        """List the available functions that can be invoked."""
        if self._available_functions and isinstance(self._available_functions, list):
            return self._available_functions
        if not self.endpoints.get("models"):
            raise ValueError("No models endpoint found, so cannot retrieve model list.")
        try:
            invoke_url = self.endpoints.get("models", "").format(base_url=self.base_url)
            query_res = self.query(invoke_url)
        except Exception as e:
            raise ValueError(f"Failed to query model endpoint {invoke_url}.\n{e}")
        output = None
        if isinstance(query_res.get("functions"), list):
            output = query_res.get("functions")
        elif isinstance(query_res.get("data"), list):
            output = query_res.get("data")
        if not isinstance(output, list):
            raise ValueError(
                f"Unexpected response when querying {invoke_url}\n{query_res}"
            )
        # if there's an alias / model name for the function, add it as well
        # this lets users work with ai-gemma-2b and google/gemma-2b
        aliases = []
        for function in output:
            assert (
                "name" in function or "id" in function
            ), f"No name or id in function: {function}"
            name = function["name"] if "name" in function else function["id"]
            if name in MODEL_SPECS and "model_name" in MODEL_SPECS[name]:
                alias = function.copy()
                alias.update(name=MODEL_SPECS[name]["model_name"])
                aliases.append(alias)
        output.extend(aliases)
        self._available_functions = output
        return self._available_functions

    def reset_method_cache(self) -> None:
        """Reset method cache to force re-fetch of available functions"""
        self._available_functions = None
        self._available_models = None

    ####################################################################################
    ## Core utilities for posting and getting from NV Endpoints

    def _post(
        self,
        invoke_url: str,
        payload: Optional[dict] = {},
    ) -> Tuple[Response, Any]:
        """Method for posting to the AI Foundation Model Function API."""
        self.last_inputs = {
            "url": invoke_url,
            "headers": self.headers["call"],
            "json": self.payload_fn(payload),
            "stream": False,
        }
        session = self.get_session_fn()
        self.last_response = response = session.post(**self.last_inputs)
        self._try_raise(response)
        return response, session

    def _get(
        self,
        invoke_url: str,
        payload: Optional[dict] = {},
    ) -> Tuple[Response, Any]:
        """Method for getting from the AI Foundation Model Function API."""
        self.last_inputs = {
            "url": invoke_url,
            "headers": self.headers["call"],
            "stream": False,
        }
        if payload:
            self.last_inputs["json"] = self.payload_fn(payload)
        session = self.get_session_fn()
        self.last_response = response = session.get(**self.last_inputs)
        self._try_raise(response)
        return response, session

    def _wait(self, response: Response, session: Any) -> Response:
        """Wait for a response from API after an initial response is made"""
        start_time = time.time()
        while response.status_code == 202:
            time.sleep(self.interval)
            if (time.time() - start_time) > self.timeout:
                raise TimeoutError(
                    f"Timeout reached without a successful response."
                    f"\nLast response: {str(response)}"
                )
            request_id = response.headers.get("NVCF-REQID", "")
            endpoint_args = {"base_url": self.base_url, "request_id": request_id}
            self.last_response = response = session.get(
                self.endpoints["status"].format(**endpoint_args),
                headers=self.headers["call"],
            )
        self._try_raise(response)
        return response

    def _try_raise(self, response: Response) -> None:
        """Try to raise an error from a response"""
        try:
            response.raise_for_status()
        except requests.HTTPError:
            try:
                rd = response.json()
                if "detail" in rd and "reqId" in rd.get("detail", ""):
                    rd_buf = "- " + str(rd["detail"])
                    rd_buf = rd_buf.replace(": ", ", Error: ").replace(", ", "\n- ")
                    rd["detail"] = rd_buf
            except json.JSONDecodeError:
                rd = response.__dict__
                if "status_code" in rd:
                    if "headers" in rd and "WWW-Authenticate" in rd["headers"]:
                        rd["detail"] = rd.get("headers").get("WWW-Authenticate")
                        rd["detail"] = rd["detail"].replace(", ", "\n")
                else:
                    rd = rd.get("_content", rd)
                    if isinstance(rd, bytes):
                        rd = rd.decode("utf-8")[5:]  ## remove "data:" prefix
                    try:
                        rd = json.loads(rd)
                    except Exception:
                        rd = {"detail": rd}
            status = rd.get("status") or rd.get("status_code") or "###"
            title = (
                rd.get("title")
                or rd.get("error")
                or rd.get("reason")
                or "Unknown Error"
            )
            header = f"[{status}] {title}"
            body = ""
            if "requestId" in rd:
                if "detail" in rd:
                    body += f"{rd['detail']}\n"
                body += "RequestID: " + rd["requestId"]
            else:
                body = rd.get("detail", rd)
            if str(status) == "401":
                body += "\nPlease check or regenerate your API key."
            # todo: raise as an HTTPError
            raise Exception(f"{header}\n{body}") from None

    ####################################################################################
    ## Simple query interface to show the set of model options

    def query(
        self,
        invoke_url: str,
        payload: Optional[dict] = None,
        request: str = "get",
    ) -> dict:
        """Simple method for an end-to-end get query. Returns result dictionary"""
        if request == "get":
            response, session = self._get(invoke_url, payload)
        else:
            response, session = self._post(invoke_url, payload)
        response = self._wait(response, session)
        output = self._process_response(response)[0]
        return output

    def _process_response(self, response: Union[str, Response]) -> List[dict]:
        """General-purpose response processing for single responses and streams"""
        if hasattr(response, "json"):  ## For single response (i.e. non-streaming)
            try:
                return [response.json()]
            except json.JSONDecodeError:
                response = str(response.__dict__)
        if isinstance(response, str):  ## For set of responses (i.e. streaming)
            msg_list = []
            for msg in response.split("\n\n"):
                if "{" not in msg:
                    continue
                msg_list += [json.loads(msg[msg.find("{") :])]
            return msg_list
        raise ValueError(f"Received ill-formed response: {response}")

    def _get_invoke_url(
        self,
        model_name: Optional[str] = None,
        invoke_url: Optional[str] = None,
        endpoint: str = "",
    ) -> str:
        """Helper method to get invoke URL from a model name, URL, or endpoint stub"""
        if not invoke_url:
            endpoint_str = self.endpoints.get(endpoint, "")
            if not endpoint_str:
                raise ValueError(f"Unknown endpoint referenced {endpoint} provided")
            if "{model_id}" in endpoint_str:
                if not model_name:
                    raise ValueError("URL or model name must be specified to invoke")
                if model_name in self.available_models:
                    model_id = self.available_models[model_name]
                elif f"playground_{model_name}" in self.available_models:
                    model_id = self.available_models[f"playground_{model_name}"]
                else:
                    available_models_str = "\n".join(
                        [f"{k} - {v}" for k, v in self.available_models.items()]
                    )
                    raise ValueError(
                        f"Unknown model name {model_name} specified."
                        "\nAvailable models are:\n"
                        f"{available_models_str}"
                    )
            else:
                model_id = ""

            endpoint_args = {"base_url": self.base_url, "model_id": model_id}
            invoke_url = endpoint_str.format(**endpoint_args)

        if not invoke_url:
            raise ValueError("URL or model name must be specified to invoke")

        return invoke_url

    ####################################################################################
    ## Generation interface to allow users to generate new values from endpoints

    def get_req(
        self,
        model_name: Optional[str] = None,
        payload: dict = {},
        invoke_url: Optional[str] = None,
        stop: Optional[Sequence[str]] = None,
        endpoint: str = "",
    ) -> Response:
        """Post to the API."""
        invoke_url = self._get_invoke_url(model_name, invoke_url, endpoint=endpoint)
        if payload.get("stream", False) is True:
            payload = {**payload, "stream": False}
        response, session = self._post(invoke_url, payload)
        return self._wait(response, session)

    def get_req_generation(
        self,
        model_name: Optional[str] = None,
        payload: dict = {},
        invoke_url: Optional[str] = None,
        stop: Optional[Sequence[str]] = None,
        endpoint: str = "infer",
    ) -> dict:
        """Method for an end-to-end post query with NVE post-processing."""
        invoke_url = self._get_invoke_url(model_name, invoke_url, endpoint=endpoint)
        response = self.get_req(model_name, payload, invoke_url)
        output, _ = self.postprocess(response, stop=stop)
        return output

    def postprocess(
        self, response: Union[str, Response], stop: Optional[Sequence[str]] = None
    ) -> Tuple[dict, bool]:
        """Parses a response from the AI Foundation Model Function API.
        Strongly assumes that the API will return a single response.
        """
        msg_list = self._process_response(response)
        msg, is_stopped = self._aggregate_msgs(msg_list)
        msg, is_stopped = self._early_stop_msg(msg, is_stopped, stop=stop)
        return msg, is_stopped

    def _aggregate_msgs(self, msg_list: Sequence[dict]) -> Tuple[dict, bool]:
        """Dig out relevant details of aggregated message"""
        content_buffer: Dict[str, Any] = dict()
        content_holder: Dict[Any, Any] = dict()
        usage_holder: Dict[Any, Any] = dict()  ####
        is_stopped = False
        for msg in msg_list:
            usage_holder = msg.get("usage", {})  ####
            if "choices" in msg:
                ## Tease out ['choices'][0]...['delta'/'message']
                msg = msg.get("choices", [{}])[0]
                is_stopped = msg.get("finish_reason", "") == "stop"
                msg = msg.get("delta", msg.get("message", msg.get("text", "")))
                if not isinstance(msg, dict):
                    msg = {"content": msg}
            elif "data" in msg:
                ## Tease out ['data'][0]...['embedding']
                msg = msg.get("data", [{}])[0]
            content_holder = msg
            for k, v in msg.items():
                if k in ("content",) and k in content_buffer:
                    content_buffer[k] += v
                else:
                    content_buffer[k] = v
            if is_stopped:
                break
        content_holder = {**content_holder, **content_buffer}
        if usage_holder:
            content_holder.update(token_usage=usage_holder)  ####
        return content_holder, is_stopped

    def _early_stop_msg(
        self, msg: dict, is_stopped: bool, stop: Optional[Sequence[str]] = None
    ) -> Tuple[dict, bool]:
        """Try to early-terminate streaming or generation by iterating over stop list"""
        content = msg.get("content", "")
        if content and stop:
            for stop_str in stop:
                if stop_str and stop_str in content:
                    msg["content"] = content[: content.find(stop_str) + 1]
                    is_stopped = True
        return msg, is_stopped

    ####################################################################################
    ## Streaming interface to allow you to iterate through progressive generations

    def get_req_stream(
        self,
        model: Optional[str] = None,
        payload: dict = {},
        invoke_url: Optional[str] = None,
        stop: Optional[Sequence[str]] = None,
        endpoint: str = "infer",
    ) -> Iterator:
        invoke_url = self._get_invoke_url(model, invoke_url, endpoint=endpoint)
        if payload.get("stream", True) is False:
            payload = {**payload, "stream": True}
        self.last_inputs = {
            "url": invoke_url,
            "headers": self.headers["stream"],
            "json": self.payload_fn(payload),
            "stream": True,
        }
        response = self.get_session_fn().post(**self.last_inputs)
        self._try_raise(response)
        call = self.copy()

        def out_gen() -> Generator[dict, Any, Any]:
            ## Good for client, since it allows self.last_inputs
            for line in response.iter_lines():
                if line and line.strip() != b"data: [DONE]":
                    line = line.decode("utf-8")
                    msg, final_line = call.postprocess(line, stop=stop)
                    yield msg
                    if final_line:
                        break
                self._try_raise(response)

        return (r for r in out_gen())

    ####################################################################################
    ## Asynchronous streaming interface to allow multiple generations to happen at once.

    async def get_req_astream(
        self,
        model: Optional[str] = None,
        payload: dict = {},
        invoke_url: Optional[str] = None,
        stop: Optional[Sequence[str]] = None,
        endpoint: str = "infer",
    ) -> AsyncIterator:
        invoke_url = self._get_invoke_url(model, invoke_url, endpoint=endpoint)
        if payload.get("stream", True) is False:
            payload = {**payload, "stream": True}
        self.last_inputs = {
            "url": invoke_url,
            "headers": self.headers["stream"],
            "json": self.payload_fn(payload),
        }
        async with self.get_asession_fn() as session:
            async with session.post(**self.last_inputs) as response:
                self._try_raise(response)
                async for line in response.content.iter_any():
                    if line and line.strip() != b"data: [DONE]":
                        line = line.decode("utf-8")
                        msg, final_line = self.postprocess(line, stop=stop)
                        yield msg
                        if final_line:
                            break


class _NVIDIAClient(BaseModel):
    """
    Higher-Level AI Foundation Model Function API Client with argument defaults.
    Is subclassed by ChatNVIDIA to provide a simple LangChain interface.
    """

    client: NVEModel = Field(NVEModel)

    _default_model: str = ""
    model: str = Field(description="Name of the model to invoke")
    infer_endpoint: str = Field("{base_url}/chat/completions")
    curr_mode: _MODE_TYPE = Field("nvidia")  # todo: remove this in 0.1
    is_hosted: bool = Field(True)

    def __init__(self, **kwargs: Any):
        super().__init__(**kwargs)
        if "base_url" in kwargs:
            self.is_hosted = False
            self.curr_mode = "nim"
            self.client.endpoints["infer"] = self.infer_endpoint
            self.client.endpoints["models"] = "{base_url}/models"

    ####################################################################################

    @root_validator(pre=True)
    def validate_client(cls, values: Any) -> Any:
        """Validate and update client arguments, including API key and formatting"""
        if not values.get("client"):
            values["client"] = NVEModel(**values)
        elif isinstance(values["client"], NVEModel):
            values["client"] = values["client"].__class__(**values["client"].dict())
        if not values.get("model"):
            values["model"] = cls._default_model
            assert values["model"], "No model given, with no default to fall back on."

        # the only model that doesn't support a stream parameter is kosmos_2.
        # to address this, we'll use the payload_fn to remove the stream parameter for
        # kosmos_2. if a user tries to set their own payload_fn, this patch will be
        # overwritten.
        # todo: get kosmos_2 api updated to support stream parameter
        if values["model"] == "kosmos_2":

            def kosmos_patch(payload: dict) -> dict:
                payload.pop("stream", None)
                return payload

            values["client"].payload_fn = kosmos_patch

        return values

    @classmethod
    def is_lc_serializable(cls) -> bool:
        return True

    @property
    def lc_secrets(self) -> Dict[str, str]:
        return {"api_key": self.client._api_key_var}

    @property
    def lc_attributes(self) -> Dict[str, Any]:
        attributes: Dict[str, Any] = {}
        if getattr(self.client, "base_url"):
            attributes["base_url"] = self.client.base_url

        if self.model:
            attributes["model"] = self.model

        if getattr(self.client, "endpoints"):
            attributes["endpoints"] = self.client.endpoints

        return attributes

    @property
    @deprecated(since="0.0.15", removal="0.1.0", alternative="available_models")
    def available_functions(self) -> List[dict]:
        """Map the available functions that can be invoked."""
        return self.__class__.get_available_functions(client=self)

    @property
    def available_models(self) -> List[Model]:
        """Map the available models that can be invoked."""
        if self.curr_mode == "nim" or not self.is_hosted:
            return self.__class__.get_available_models(
                client=self, base_url=self.client.base_url
            )
        else:
            return self.__class__.get_available_models(client=self)

    @classmethod
    @deprecated(since="0.0.15", removal="0.1.0", alternative="get_available_models")
    def get_available_functions(
        cls,
        mode: Optional[_MODE_TYPE] = None,
        client: Optional[_NVIDIAClient] = None,
        **kwargs: Any,
    ) -> List[dict]:
        """Map the available functions that can be invoked. Callable from class"""
        nveclient = (client or cls(**kwargs)).mode(mode, **kwargs).client
        nveclient.reset_method_cache()
        return nveclient.available_functions

    @classmethod
    def get_available_models(
        cls,
        mode: Optional[_MODE_TYPE] = None,
        client: Optional[_NVIDIAClient] = None,
        list_all: bool = False,
        filter: Optional[str] = None,
        **kwargs: Any,
    ) -> List[Model]:
        """Map the available models that can be invoked. Callable from class"""
        if mode is not None:
            warn_deprecated(
                name="mode", since="0.0.17", removal="0.1.0", alternative="`base_url`"
            )
        self = client or cls(**kwargs)
        nveclient = self.client
        nveclient.reset_method_cache()
        out = sorted(
            [
                Model(id=k.replace("playground_", ""), path=v, **MODEL_SPECS.get(k, {}))
                for k, v in nveclient.available_models.items()
            ],
            key=lambda x: f"{x.client or 'Z'}{x.id}{cls}",
        )
        # nim model listing does not provide the type and we cannot know
        # the model name ahead of time to guess the type.
        # so we need to list all models.
        if mode == "nim" or not self.is_hosted:
            list_all = True
        if not filter:
            filter = cls.__name__
        if not list_all:
            out = [m for m in out if m.client == filter]
        return out

    @deprecated(since="0.0.15", removal="0.1.0")
    def get_model_details(self, model: Optional[str] = None) -> dict:
        """Get more meta-details about a model retrieved by a given name"""
        if model is None:
            model = self.model
        model_key = self.client._get_invoke_url(model).split("/")[-1]
        known_fns = self.client.available_functions
        fn_spec = [f for f in known_fns if f.get("id") == model_key][0]
        return fn_spec

    @deprecated(since="0.0.15", removal="0.1.0")
    def get_binding_model(self) -> Optional[str]:
        """Get the model to bind to the client as default payload argument"""
        # if a model is configured with a model_name, always use that
        # todo: move from search of available_models to a Model property
        matches = [model for model in self.available_models if model.id == self.model]
        if matches:
            if matches[0].model_name:
                return matches[0].model_name
        if self.curr_mode == "nvidia":
            return ""
        return self.model

    @deprecated(
        since="0.0.17",
        removal="0.1.0",
        alternative="`base_url` in constructor",
    )
    def mode(
        self,
        mode: Optional[_MODE_TYPE] = "nvidia",
        base_url: Optional[str] = None,
        model: Optional[str] = None,
        api_key: Optional[str] = None,
        infer_path: Optional[str] = None,
        models_path: Optional[str] = "{base_url}/models",
        force_mode: bool = False,
        force_clone: bool = True,
        **kwargs: Any,
    ) -> Any:  # todo: in python 3.11+ this should be typing.Self
        """Deprecated: pass `base_url=...` to constructor instead."""
        if isinstance(self, str):
            raise ValueError("Please construct the model before calling mode()")
        out = self if not force_clone else deepcopy(self)

        if mode is None:
            return out

        out.model = model or out.model

        if base_url and not force_mode:
            ## If a user tries to set base_url, assume custom openapi unless forced
            mode = "nim"

        if mode == "nvidia":
            key_var = "NVIDIA_API_KEY"
            if not api_key or not api_key.startswith("nvapi-"):
                api_key = os.getenv(key_var) or out.client.api_key.get_secret_value()
            if not api_key.startswith("nvapi-"):
                raise ValueError(f"No {key_var} in env/fed as api_key. (nvapi-...)")

        out.curr_mode = mode
        if api_key:
            out.client.api_key = SecretStr(api_key)

        nvcf_base = "https://api.nvcf.nvidia.com/v2/nvcf"  ## NVCF Main URL
        nvcf_infer = "{base_url}/pexec/functions/{model_id}"  ## Inference endpoints
        nvcf_status = "{base_url}/pexec/status/{request_id}"  ## 202 wait handle
        nvcf_models = "{base_url}/functions"  ## Model listing

        if mode == "nvidia":
            ## Classic support for nvcf-backed foundation model endpoints.
            out.client.base_url = base_url or nvcf_base
            out.client.endpoints = {
                "infer": nvcf_infer,  ## Per-model inference
                "status": nvcf_status,  ## 202 wait handle
                "models": nvcf_models,  ## Model listing
            }

        elif mode == "nim":
            ## OpenAPI-style specs to connect to NeMo Inference Microservices etc.
            ## Most generic option, requires specifying base_url
            assert base_url, "Base URL must be specified for nim mode"
            out.client.base_url = base_url
            out.client.endpoints["infer"] = infer_path or out.infer_endpoint
            out.client.endpoints["models"] = models_path or "{base_url}/models"

        else:
            options = ["nvidia", "nim"]
            raise ValueError(f"Unknown mode: `{mode}`. Expected one of {options}.")

        out.client.reset_method_cache()

        return out