Source code for

"""Internal representation of a structured query language."""
from __future__ import annotations

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, List, Optional, Sequence, Union

from langchain_core.pydantic_v1 import BaseModel

[docs]class Visitor(ABC): """Defines interface for IR translation using visitor pattern.""" allowed_comparators: Optional[Sequence[Comparator]] = None allowed_operators: Optional[Sequence[Operator]] = None def _validate_func(self, func: Union[Operator, Comparator]) -> None: if isinstance(func, Operator) and self.allowed_operators is not None: if func not in self.allowed_operators: raise ValueError( f"Received disallowed operator {func}. Allowed " f"comparators are {self.allowed_operators}" ) if isinstance(func, Comparator) and self.allowed_comparators is not None: if func not in self.allowed_comparators: raise ValueError( f"Received disallowed comparator {func}. Allowed " f"comparators are {self.allowed_comparators}" )
[docs] @abstractmethod def visit_operation(self, operation: Operation) -> Any: """Translate an Operation."""
[docs] @abstractmethod def visit_comparison(self, comparison: Comparison) -> Any: """Translate a Comparison."""
[docs] @abstractmethod def visit_structured_query(self, structured_query: StructuredQuery) -> Any: """Translate a StructuredQuery."""
def _to_snake_case(name: str) -> str: """Convert a name into snake_case.""" snake_case = "" for i, char in enumerate(name): if char.isupper() and i != 0: snake_case += "_" + char.lower() else: snake_case += char.lower() return snake_case
[docs]class Expr(BaseModel): """Base class for all expressions."""
[docs] def accept(self, visitor: Visitor) -> Any: """Accept a visitor. Args: visitor: visitor to accept Returns: result of visiting """ return getattr(visitor, f"visit_{_to_snake_case(self.__class__.__name__)}")( self )
[docs]class Operator(str, Enum): """Enumerator of the operations.""" AND = "and" OR = "or" NOT = "not"
[docs]class Comparator(str, Enum): """Enumerator of the comparison operators.""" EQ = "eq" NE = "ne" GT = "gt" GTE = "gte" LT = "lt" LTE = "lte" CONTAIN = "contain" LIKE = "like" IN = "in" NIN = "nin"
[docs]class FilterDirective(Expr, ABC): """A filtering expression."""
[docs]class Comparison(FilterDirective): """A comparison to a value.""" comparator: Comparator attribute: str value: Any def __init__( self, comparator: Comparator, attribute: str, value: Any, **kwargs: Any ) -> None: super().__init__( comparator=comparator, attribute=attribute, value=value, **kwargs )
[docs]class Operation(FilterDirective): """A logical operation over other directives.""" operator: Operator arguments: List[FilterDirective] def __init__( self, operator: Operator, arguments: List[FilterDirective], **kwargs: Any ): super().__init__(operator=operator, arguments=arguments, **kwargs)
[docs]class StructuredQuery(Expr): """A structured query.""" query: str """Query string.""" filter: Optional[FilterDirective] """Filtering expression.""" limit: Optional[int] """Limit on the number of results.""" def __init__( self, query: str, filter: Optional[FilterDirective], limit: Optional[int] = None, **kwargs: Any, ): super().__init__(query=query, filter=filter, limit=limit, **kwargs)