aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_polars/expr.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_polars/expr.py')
-rw-r--r--venv/lib/python3.8/site-packages/narwhals/_polars/expr.py415
1 files changed, 415 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_polars/expr.py b/venv/lib/python3.8/site-packages/narwhals/_polars/expr.py
new file mode 100644
index 0000000..eb5b5f2
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/narwhals/_polars/expr.py
@@ -0,0 +1,415 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, Sequence
+
+import polars as pl
+
+from narwhals._duration import parse_interval_string
+from narwhals._polars.utils import (
+ extract_args_kwargs,
+ extract_native,
+ narwhals_to_native_dtype,
+)
+from narwhals._utils import Implementation, requires
+
+if TYPE_CHECKING:
+ from typing_extensions import Self
+
+ from narwhals._expression_parsing import ExprKind, ExprMetadata
+ from narwhals._polars.dataframe import Method
+ from narwhals._polars.namespace import PolarsNamespace
+ from narwhals._utils import Version
+ from narwhals.typing import IntoDType
+
+
+class PolarsExpr:
+ def __init__(
+ self, expr: pl.Expr, version: Version, backend_version: tuple[int, ...]
+ ) -> None:
+ self._native_expr = expr
+ self._implementation = Implementation.POLARS
+ self._version = version
+ self._backend_version = backend_version
+ self._metadata: ExprMetadata | None = None
+
+ @property
+ def native(self) -> pl.Expr:
+ return self._native_expr
+
+ def __repr__(self) -> str: # pragma: no cover
+ return "PolarsExpr"
+
+ def _with_native(self, expr: pl.Expr) -> Self:
+ return self.__class__(expr, self._version, self._backend_version)
+
+ @classmethod
+ def _from_series(cls, series: Any) -> Self:
+ return cls(series.native, series._version, series._backend_version)
+
+ def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
+ # Let Polars do its thing.
+ return self
+
+ def __getattr__(self, attr: str) -> Any:
+ def func(*args: Any, **kwargs: Any) -> Any:
+ pos, kwds = extract_args_kwargs(args, kwargs)
+ return self._with_native(getattr(self.native, attr)(*pos, **kwds))
+
+ return func
+
+ def _renamed_min_periods(self, min_samples: int, /) -> dict[str, Any]:
+ name = "min_periods" if self._backend_version < (1, 21, 0) else "min_samples"
+ return {name: min_samples}
+
+ def cast(self, dtype: IntoDType) -> Self:
+ dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version)
+ return self._with_native(self.native.cast(dtype_pl))
+
+ def ewm_mean(
+ self,
+ *,
+ com: float | None,
+ span: float | None,
+ half_life: float | None,
+ alpha: float | None,
+ adjust: bool,
+ min_samples: int,
+ ignore_nulls: bool,
+ ) -> Self:
+ native = self.native.ewm_mean(
+ com=com,
+ span=span,
+ half_life=half_life,
+ alpha=alpha,
+ adjust=adjust,
+ ignore_nulls=ignore_nulls,
+ **self._renamed_min_periods(min_samples),
+ )
+ if self._backend_version < (1,): # pragma: no cover
+ native = pl.when(~self.native.is_null()).then(native).otherwise(None)
+ return self._with_native(native)
+
+ def is_nan(self) -> Self:
+ if self._backend_version >= (1, 18):
+ native = self.native.is_nan()
+ else: # pragma: no cover
+ native = pl.when(self.native.is_not_null()).then(self.native.is_nan())
+ return self._with_native(native)
+
+ def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
+ if self._backend_version < (1, 9):
+ if order_by:
+ msg = "`order_by` in Polars requires version 1.10 or greater"
+ raise NotImplementedError(msg)
+ native = self.native.over(partition_by or pl.lit(1))
+ else:
+ native = self.native.over(
+ partition_by or pl.lit(1), order_by=order_by or None
+ )
+ return self._with_native(native)
+
+ @requires.backend_version((1,))
+ def rolling_var(
+ self, window_size: int, *, min_samples: int, center: bool, ddof: int
+ ) -> Self:
+ kwds = self._renamed_min_periods(min_samples)
+ native = self.native.rolling_var(
+ window_size=window_size, center=center, ddof=ddof, **kwds
+ )
+ return self._with_native(native)
+
+ @requires.backend_version((1,))
+ def rolling_std(
+ self, window_size: int, *, min_samples: int, center: bool, ddof: int
+ ) -> Self:
+ kwds = self._renamed_min_periods(min_samples)
+ native = self.native.rolling_std(
+ window_size=window_size, center=center, ddof=ddof, **kwds
+ )
+ return self._with_native(native)
+
+ def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
+ kwds = self._renamed_min_periods(min_samples)
+ native = self.native.rolling_sum(window_size=window_size, center=center, **kwds)
+ return self._with_native(native)
+
+ def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
+ kwds = self._renamed_min_periods(min_samples)
+ native = self.native.rolling_mean(window_size=window_size, center=center, **kwds)
+ return self._with_native(native)
+
+ def map_batches(
+ self, function: Callable[[Any], Any], return_dtype: IntoDType | None
+ ) -> Self:
+ return_dtype_pl = (
+ narwhals_to_native_dtype(return_dtype, self._version, self._backend_version)
+ if return_dtype
+ else None
+ )
+ native = self.native.map_batches(function, return_dtype_pl)
+ return self._with_native(native)
+
+ @requires.backend_version((1,))
+ def replace_strict(
+ self,
+ old: Sequence[Any] | Mapping[Any, Any],
+ new: Sequence[Any],
+ *,
+ return_dtype: IntoDType | None,
+ ) -> Self:
+ return_dtype_pl = (
+ narwhals_to_native_dtype(return_dtype, self._version, self._backend_version)
+ if return_dtype
+ else None
+ )
+ native = self.native.replace_strict(old, new, return_dtype=return_dtype_pl)
+ return self._with_native(native)
+
+ def __eq__(self, other: object) -> Self: # type: ignore[override]
+ return self._with_native(self.native.__eq__(extract_native(other))) # type: ignore[operator]
+
+ def __ne__(self, other: object) -> Self: # type: ignore[override]
+ return self._with_native(self.native.__ne__(extract_native(other))) # type: ignore[operator]
+
+ def __ge__(self, other: Any) -> Self:
+ return self._with_native(self.native.__ge__(extract_native(other)))
+
+ def __gt__(self, other: Any) -> Self:
+ return self._with_native(self.native.__gt__(extract_native(other)))
+
+ def __le__(self, other: Any) -> Self:
+ return self._with_native(self.native.__le__(extract_native(other)))
+
+ def __lt__(self, other: Any) -> Self:
+ return self._with_native(self.native.__lt__(extract_native(other)))
+
+ def __and__(self, other: PolarsExpr | bool | Any) -> Self:
+ return self._with_native(self.native.__and__(extract_native(other))) # type: ignore[operator]
+
+ def __or__(self, other: PolarsExpr | bool | Any) -> Self:
+ return self._with_native(self.native.__or__(extract_native(other))) # type: ignore[operator]
+
+ def __add__(self, other: Any) -> Self:
+ return self._with_native(self.native.__add__(extract_native(other)))
+
+ def __sub__(self, other: Any) -> Self:
+ return self._with_native(self.native.__sub__(extract_native(other)))
+
+ def __mul__(self, other: Any) -> Self:
+ return self._with_native(self.native.__mul__(extract_native(other)))
+
+ def __pow__(self, other: Any) -> Self:
+ return self._with_native(self.native.__pow__(extract_native(other)))
+
+ def __truediv__(self, other: Any) -> Self:
+ return self._with_native(self.native.__truediv__(extract_native(other)))
+
+ def __floordiv__(self, other: Any) -> Self:
+ return self._with_native(self.native.__floordiv__(extract_native(other)))
+
+ def __mod__(self, other: Any) -> Self:
+ return self._with_native(self.native.__mod__(extract_native(other)))
+
+ def __invert__(self) -> Self:
+ return self._with_native(self.native.__invert__())
+
+ def cum_count(self, *, reverse: bool) -> Self:
+ if self._backend_version < (0, 20, 4):
+ result = (~self.native.is_null()).cum_sum(reverse=reverse)
+ else:
+ result = self.native.cum_count(reverse=reverse)
+ return self._with_native(result)
+
+ def __narwhals_expr__(self) -> None: ...
+ def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover
+ from narwhals._polars.namespace import PolarsNamespace
+
+ return PolarsNamespace(
+ backend_version=self._backend_version, version=self._version
+ )
+
+ @property
+ def dt(self) -> PolarsExprDateTimeNamespace:
+ return PolarsExprDateTimeNamespace(self)
+
+ @property
+ def str(self) -> PolarsExprStringNamespace:
+ return PolarsExprStringNamespace(self)
+
+ @property
+ def cat(self) -> PolarsExprCatNamespace:
+ return PolarsExprCatNamespace(self)
+
+ @property
+ def name(self) -> PolarsExprNameNamespace:
+ return PolarsExprNameNamespace(self)
+
+ @property
+ def list(self) -> PolarsExprListNamespace:
+ return PolarsExprListNamespace(self)
+
+ @property
+ def struct(self) -> PolarsExprStructNamespace:
+ return PolarsExprStructNamespace(self)
+
+ # CompliantExpr
+ _alias_output_names: Any
+ _evaluate_aliases: Any
+ _evaluate_output_names: Any
+ _is_multi_output_unnamed: Any
+ __call__: Any
+ from_column_names: Any
+ from_column_indices: Any
+ _eval_names_indices: Any
+
+ # Polars
+ abs: Method[Self]
+ all: Method[Self]
+ any: Method[Self]
+ alias: Method[Self]
+ arg_max: Method[Self]
+ arg_min: Method[Self]
+ arg_true: Method[Self]
+ clip: Method[Self]
+ count: Method[Self]
+ cum_max: Method[Self]
+ cum_min: Method[Self]
+ cum_prod: Method[Self]
+ cum_sum: Method[Self]
+ diff: Method[Self]
+ drop_nulls: Method[Self]
+ exp: Method[Self]
+ fill_null: Method[Self]
+ gather_every: Method[Self]
+ head: Method[Self]
+ is_finite: Method[Self]
+ is_first_distinct: Method[Self]
+ is_in: Method[Self]
+ is_last_distinct: Method[Self]
+ is_null: Method[Self]
+ is_unique: Method[Self]
+ len: Method[Self]
+ log: Method[Self]
+ max: Method[Self]
+ mean: Method[Self]
+ median: Method[Self]
+ min: Method[Self]
+ mode: Method[Self]
+ n_unique: Method[Self]
+ null_count: Method[Self]
+ quantile: Method[Self]
+ rank: Method[Self]
+ round: Method[Self]
+ sample: Method[Self]
+ shift: Method[Self]
+ skew: Method[Self]
+ std: Method[Self]
+ sum: Method[Self]
+ sort: Method[Self]
+ tail: Method[Self]
+ unique: Method[Self]
+ var: Method[Self]
+
+
+class PolarsExprDateTimeNamespace:
+ def __init__(self, expr: PolarsExpr) -> None:
+ self._compliant_expr = expr
+
+ def truncate(self, every: str) -> PolarsExpr:
+ parse_interval_string(every) # Ensure consistent error message is raised.
+ return self._compliant_expr._with_native(
+ self._compliant_expr.native.dt.truncate(every)
+ )
+
+ def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]:
+ def func(*args: Any, **kwargs: Any) -> PolarsExpr:
+ pos, kwds = extract_args_kwargs(args, kwargs)
+ return self._compliant_expr._with_native(
+ getattr(self._compliant_expr.native.dt, attr)(*pos, **kwds)
+ )
+
+ return func
+
+
+class PolarsExprStringNamespace:
+ def __init__(self, expr: PolarsExpr) -> None:
+ self._compliant_expr = expr
+
+ def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]:
+ def func(*args: Any, **kwargs: Any) -> PolarsExpr:
+ pos, kwds = extract_args_kwargs(args, kwargs)
+ return self._compliant_expr._with_native(
+ getattr(self._compliant_expr.native.str, attr)(*pos, **kwds)
+ )
+
+ return func
+
+
+class PolarsExprCatNamespace:
+ def __init__(self, expr: PolarsExpr) -> None:
+ self._compliant_expr = expr
+
+ def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]:
+ def func(*args: Any, **kwargs: Any) -> PolarsExpr:
+ pos, kwds = extract_args_kwargs(args, kwargs)
+ return self._compliant_expr._with_native(
+ getattr(self._compliant_expr.native.cat, attr)(*pos, **kwds)
+ )
+
+ return func
+
+
+class PolarsExprNameNamespace:
+ def __init__(self, expr: PolarsExpr) -> None:
+ self._compliant_expr = expr
+
+ def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]:
+ def func(*args: Any, **kwargs: Any) -> PolarsExpr:
+ pos, kwds = extract_args_kwargs(args, kwargs)
+ return self._compliant_expr._with_native(
+ getattr(self._compliant_expr.native.name, attr)(*pos, **kwds)
+ )
+
+ return func
+
+
+class PolarsExprListNamespace:
+ def __init__(self, expr: PolarsExpr) -> None:
+ self._expr = expr
+
+ def len(self) -> PolarsExpr:
+ native_expr = self._expr._native_expr
+ native_result = native_expr.list.len()
+
+ if self._expr._backend_version < (1, 16): # pragma: no cover
+ native_result = (
+ pl.when(~native_expr.is_null()).then(native_result).cast(pl.UInt32())
+ )
+ elif self._expr._backend_version < (1, 17): # pragma: no cover
+ native_result = native_result.cast(pl.UInt32())
+
+ return self._expr._with_native(native_result)
+
+ # TODO(FBruzzesi): Remove `pragma: no cover` once other namespace methods are added
+ def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: # pragma: no cover
+ def func(*args: Any, **kwargs: Any) -> PolarsExpr:
+ pos, kwds = extract_args_kwargs(args, kwargs)
+ return self._expr._with_native(
+ getattr(self._expr.native.list, attr)(*pos, **kwds)
+ )
+
+ return func
+
+
+class PolarsExprStructNamespace:
+ def __init__(self, expr: PolarsExpr) -> None:
+ self._expr = expr
+
+ def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: # pragma: no cover
+ def func(*args: Any, **kwargs: Any) -> PolarsExpr:
+ pos, kwds = extract_args_kwargs(args, kwargs)
+ return self._expr._with_native(
+ getattr(self._expr.native.struct, attr)(*pos, **kwds)
+ )
+
+ return func