from __future__ import annotations

from typing import Any, Dict, Iterator, List, Optional, Tuple

from langchain_core.documents import Document

from langchain_community.document_loaders.base import BaseLoader

[docs]class SnowflakeLoader(BaseLoader): """Load from `Snowflake` API. Each document represents one row of the result. The `page_content_columns` are written into the `page_content` of the document. The `metadata_columns` are written into the `metadata` of the document. By default, all columns are written into the `page_content` and none into the `metadata`. """
[docs] def __init__( self, query: str, user: str, password: str, account: str, warehouse: str, role: str, database: str, schema: str, parameters: Optional[Dict[str, Any]] = None, page_content_columns: Optional[List[str]] = None, metadata_columns: Optional[List[str]] = None, ): """Initialize Snowflake document loader. Args: query: The query to run in Snowflake. user: Snowflake user. password: Snowflake password. account: Snowflake account. warehouse: Snowflake warehouse. role: Snowflake role. database: Snowflake database schema: Snowflake schema parameters: Optional. Parameters to pass to the query. page_content_columns: Optional. Columns written to Document `page_content`. metadata_columns: Optional. Columns written to Document `metadata`. """ self.query = query self.user = user self.password = password self.account = account self.warehouse = warehouse self.role = role self.database = database self.schema = schema self.parameters = parameters self.page_content_columns = ( page_content_columns if page_content_columns is not None else ["*"] ) self.metadata_columns = metadata_columns if metadata_columns is not None else []
def _execute_query(self) -> List[Dict[str, Any]]: try: import snowflake.connector except ImportError as ex: raise ImportError( "Could not import snowflake-connector-python package. " "Please install it with `pip install snowflake-connector-python`." ) from ex conn = snowflake.connector.connect( user=self.user, password=self.password, account=self.account, warehouse=self.warehouse, role=self.role, database=self.database, schema=self.schema, parameters=self.parameters, ) try: cur = conn.cursor() cur.execute("USE DATABASE " + self.database) cur.execute("USE SCHEMA " + self.schema) cur.execute(self.query, self.parameters) query_result = cur.fetchall() column_names = [column[0] for column in cur.description] query_result = [dict(zip(column_names, row)) for row in query_result] except Exception as e: print(f"An error occurred: {e}") # noqa: T201 query_result = [] finally: cur.close() return query_result def _get_columns( self, query_result: List[Dict[str, Any]] ) -> Tuple[List[str], List[str]]: page_content_columns = ( self.page_content_columns if self.page_content_columns else [] ) metadata_columns = self.metadata_columns if self.metadata_columns else [] if page_content_columns is None and query_result: page_content_columns = list(query_result[0].keys()) if metadata_columns is None: metadata_columns = [] return page_content_columns or [], metadata_columns
[docs] def lazy_load(self) -> Iterator[Document]: query_result = self._execute_query() if isinstance(query_result, Exception): print(f"An error occurred during the query: {query_result}") # noqa: T201 return [] # type: ignore[return-value] page_content_columns, metadata_columns = self._get_columns(query_result) if "*" in page_content_columns: page_content_columns = list(query_result[0].keys()) for row in query_result: page_content = "\n".join( f"{k}: {v}" for k, v in row.items() if k in page_content_columns ) metadata = {k: v for k, v in row.items() if k in metadata_columns} doc = Document(page_content=page_content, metadata=metadata) yield doc