diff options
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_polars/expr.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/narwhals/_polars/expr.py | 415 |
1 files changed, 415 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_polars/expr.py b/venv/lib/python3.8/site-packages/narwhals/_polars/expr.py new file mode 100644 index 0000000..eb5b5f2 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_polars/expr.py @@ -0,0 +1,415 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, Sequence + +import polars as pl + +from narwhals._duration import parse_interval_string +from narwhals._polars.utils import ( + extract_args_kwargs, + extract_native, + narwhals_to_native_dtype, +) +from narwhals._utils import Implementation, requires + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._expression_parsing import ExprKind, ExprMetadata + from narwhals._polars.dataframe import Method + from narwhals._polars.namespace import PolarsNamespace + from narwhals._utils import Version + from narwhals.typing import IntoDType + + +class PolarsExpr: + def __init__( + self, expr: pl.Expr, version: Version, backend_version: tuple[int, ...] + ) -> None: + self._native_expr = expr + self._implementation = Implementation.POLARS + self._version = version + self._backend_version = backend_version + self._metadata: ExprMetadata | None = None + + @property + def native(self) -> pl.Expr: + return self._native_expr + + def __repr__(self) -> str: # pragma: no cover + return "PolarsExpr" + + def _with_native(self, expr: pl.Expr) -> Self: + return self.__class__(expr, self._version, self._backend_version) + + @classmethod + def _from_series(cls, series: Any) -> Self: + return cls(series.native, series._version, series._backend_version) + + def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + # Let Polars do its thing. + return self + + def __getattr__(self, attr: str) -> Any: + def func(*args: Any, **kwargs: Any) -> Any: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._with_native(getattr(self.native, attr)(*pos, **kwds)) + + return func + + def _renamed_min_periods(self, min_samples: int, /) -> dict[str, Any]: + name = "min_periods" if self._backend_version < (1, 21, 0) else "min_samples" + return {name: min_samples} + + def cast(self, dtype: IntoDType) -> Self: + dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version) + return self._with_native(self.native.cast(dtype_pl)) + + def ewm_mean( + self, + *, + com: float | None, + span: float | None, + half_life: float | None, + alpha: float | None, + adjust: bool, + min_samples: int, + ignore_nulls: bool, + ) -> Self: + native = self.native.ewm_mean( + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + ignore_nulls=ignore_nulls, + **self._renamed_min_periods(min_samples), + ) + if self._backend_version < (1,): # pragma: no cover + native = pl.when(~self.native.is_null()).then(native).otherwise(None) + return self._with_native(native) + + def is_nan(self) -> Self: + if self._backend_version >= (1, 18): + native = self.native.is_nan() + else: # pragma: no cover + native = pl.when(self.native.is_not_null()).then(self.native.is_nan()) + return self._with_native(native) + + def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self: + if self._backend_version < (1, 9): + if order_by: + msg = "`order_by` in Polars requires version 1.10 or greater" + raise NotImplementedError(msg) + native = self.native.over(partition_by or pl.lit(1)) + else: + native = self.native.over( + partition_by or pl.lit(1), order_by=order_by or None + ) + return self._with_native(native) + + @requires.backend_version((1,)) + def rolling_var( + self, window_size: int, *, min_samples: int, center: bool, ddof: int + ) -> Self: + kwds = self._renamed_min_periods(min_samples) + native = self.native.rolling_var( + window_size=window_size, center=center, ddof=ddof, **kwds + ) + return self._with_native(native) + + @requires.backend_version((1,)) + def rolling_std( + self, window_size: int, *, min_samples: int, center: bool, ddof: int + ) -> Self: + kwds = self._renamed_min_periods(min_samples) + native = self.native.rolling_std( + window_size=window_size, center=center, ddof=ddof, **kwds + ) + return self._with_native(native) + + def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: + kwds = self._renamed_min_periods(min_samples) + native = self.native.rolling_sum(window_size=window_size, center=center, **kwds) + return self._with_native(native) + + def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: + kwds = self._renamed_min_periods(min_samples) + native = self.native.rolling_mean(window_size=window_size, center=center, **kwds) + return self._with_native(native) + + def map_batches( + self, function: Callable[[Any], Any], return_dtype: IntoDType | None + ) -> Self: + return_dtype_pl = ( + narwhals_to_native_dtype(return_dtype, self._version, self._backend_version) + if return_dtype + else None + ) + native = self.native.map_batches(function, return_dtype_pl) + return self._with_native(native) + + @requires.backend_version((1,)) + def replace_strict( + self, + old: Sequence[Any] | Mapping[Any, Any], + new: Sequence[Any], + *, + return_dtype: IntoDType | None, + ) -> Self: + return_dtype_pl = ( + narwhals_to_native_dtype(return_dtype, self._version, self._backend_version) + if return_dtype + else None + ) + native = self.native.replace_strict(old, new, return_dtype=return_dtype_pl) + return self._with_native(native) + + def __eq__(self, other: object) -> Self: # type: ignore[override] + return self._with_native(self.native.__eq__(extract_native(other))) # type: ignore[operator] + + def __ne__(self, other: object) -> Self: # type: ignore[override] + return self._with_native(self.native.__ne__(extract_native(other))) # type: ignore[operator] + + def __ge__(self, other: Any) -> Self: + return self._with_native(self.native.__ge__(extract_native(other))) + + def __gt__(self, other: Any) -> Self: + return self._with_native(self.native.__gt__(extract_native(other))) + + def __le__(self, other: Any) -> Self: + return self._with_native(self.native.__le__(extract_native(other))) + + def __lt__(self, other: Any) -> Self: + return self._with_native(self.native.__lt__(extract_native(other))) + + def __and__(self, other: PolarsExpr | bool | Any) -> Self: + return self._with_native(self.native.__and__(extract_native(other))) # type: ignore[operator] + + def __or__(self, other: PolarsExpr | bool | Any) -> Self: + return self._with_native(self.native.__or__(extract_native(other))) # type: ignore[operator] + + def __add__(self, other: Any) -> Self: + return self._with_native(self.native.__add__(extract_native(other))) + + def __sub__(self, other: Any) -> Self: + return self._with_native(self.native.__sub__(extract_native(other))) + + def __mul__(self, other: Any) -> Self: + return self._with_native(self.native.__mul__(extract_native(other))) + + def __pow__(self, other: Any) -> Self: + return self._with_native(self.native.__pow__(extract_native(other))) + + def __truediv__(self, other: Any) -> Self: + return self._with_native(self.native.__truediv__(extract_native(other))) + + def __floordiv__(self, other: Any) -> Self: + return self._with_native(self.native.__floordiv__(extract_native(other))) + + def __mod__(self, other: Any) -> Self: + return self._with_native(self.native.__mod__(extract_native(other))) + + def __invert__(self) -> Self: + return self._with_native(self.native.__invert__()) + + def cum_count(self, *, reverse: bool) -> Self: + if self._backend_version < (0, 20, 4): + result = (~self.native.is_null()).cum_sum(reverse=reverse) + else: + result = self.native.cum_count(reverse=reverse) + return self._with_native(result) + + def __narwhals_expr__(self) -> None: ... + def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover + from narwhals._polars.namespace import PolarsNamespace + + return PolarsNamespace( + backend_version=self._backend_version, version=self._version + ) + + @property + def dt(self) -> PolarsExprDateTimeNamespace: + return PolarsExprDateTimeNamespace(self) + + @property + def str(self) -> PolarsExprStringNamespace: + return PolarsExprStringNamespace(self) + + @property + def cat(self) -> PolarsExprCatNamespace: + return PolarsExprCatNamespace(self) + + @property + def name(self) -> PolarsExprNameNamespace: + return PolarsExprNameNamespace(self) + + @property + def list(self) -> PolarsExprListNamespace: + return PolarsExprListNamespace(self) + + @property + def struct(self) -> PolarsExprStructNamespace: + return PolarsExprStructNamespace(self) + + # CompliantExpr + _alias_output_names: Any + _evaluate_aliases: Any + _evaluate_output_names: Any + _is_multi_output_unnamed: Any + __call__: Any + from_column_names: Any + from_column_indices: Any + _eval_names_indices: Any + + # Polars + abs: Method[Self] + all: Method[Self] + any: Method[Self] + alias: Method[Self] + arg_max: Method[Self] + arg_min: Method[Self] + arg_true: Method[Self] + clip: Method[Self] + count: Method[Self] + cum_max: Method[Self] + cum_min: Method[Self] + cum_prod: Method[Self] + cum_sum: Method[Self] + diff: Method[Self] + drop_nulls: Method[Self] + exp: Method[Self] + fill_null: Method[Self] + gather_every: Method[Self] + head: Method[Self] + is_finite: Method[Self] + is_first_distinct: Method[Self] + is_in: Method[Self] + is_last_distinct: Method[Self] + is_null: Method[Self] + is_unique: Method[Self] + len: Method[Self] + log: Method[Self] + max: Method[Self] + mean: Method[Self] + median: Method[Self] + min: Method[Self] + mode: Method[Self] + n_unique: Method[Self] + null_count: Method[Self] + quantile: Method[Self] + rank: Method[Self] + round: Method[Self] + sample: Method[Self] + shift: Method[Self] + skew: Method[Self] + std: Method[Self] + sum: Method[Self] + sort: Method[Self] + tail: Method[Self] + unique: Method[Self] + var: Method[Self] + + +class PolarsExprDateTimeNamespace: + def __init__(self, expr: PolarsExpr) -> None: + self._compliant_expr = expr + + def truncate(self, every: str) -> PolarsExpr: + parse_interval_string(every) # Ensure consistent error message is raised. + return self._compliant_expr._with_native( + self._compliant_expr.native.dt.truncate(every) + ) + + def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: + def func(*args: Any, **kwargs: Any) -> PolarsExpr: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._compliant_expr._with_native( + getattr(self._compliant_expr.native.dt, attr)(*pos, **kwds) + ) + + return func + + +class PolarsExprStringNamespace: + def __init__(self, expr: PolarsExpr) -> None: + self._compliant_expr = expr + + def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: + def func(*args: Any, **kwargs: Any) -> PolarsExpr: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._compliant_expr._with_native( + getattr(self._compliant_expr.native.str, attr)(*pos, **kwds) + ) + + return func + + +class PolarsExprCatNamespace: + def __init__(self, expr: PolarsExpr) -> None: + self._compliant_expr = expr + + def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: + def func(*args: Any, **kwargs: Any) -> PolarsExpr: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._compliant_expr._with_native( + getattr(self._compliant_expr.native.cat, attr)(*pos, **kwds) + ) + + return func + + +class PolarsExprNameNamespace: + def __init__(self, expr: PolarsExpr) -> None: + self._compliant_expr = expr + + def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: + def func(*args: Any, **kwargs: Any) -> PolarsExpr: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._compliant_expr._with_native( + getattr(self._compliant_expr.native.name, attr)(*pos, **kwds) + ) + + return func + + +class PolarsExprListNamespace: + def __init__(self, expr: PolarsExpr) -> None: + self._expr = expr + + def len(self) -> PolarsExpr: + native_expr = self._expr._native_expr + native_result = native_expr.list.len() + + if self._expr._backend_version < (1, 16): # pragma: no cover + native_result = ( + pl.when(~native_expr.is_null()).then(native_result).cast(pl.UInt32()) + ) + elif self._expr._backend_version < (1, 17): # pragma: no cover + native_result = native_result.cast(pl.UInt32()) + + return self._expr._with_native(native_result) + + # TODO(FBruzzesi): Remove `pragma: no cover` once other namespace methods are added + def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: # pragma: no cover + def func(*args: Any, **kwargs: Any) -> PolarsExpr: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._expr._with_native( + getattr(self._expr.native.list, attr)(*pos, **kwds) + ) + + return func + + +class PolarsExprStructNamespace: + def __init__(self, expr: PolarsExpr) -> None: + self._expr = expr + + def __getattr__(self, attr: str) -> Callable[[Any], PolarsExpr]: # pragma: no cover + def func(*args: Any, **kwargs: Any) -> PolarsExpr: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._expr._with_native( + getattr(self._expr.native.struct, attr)(*pos, **kwds) + ) + + return func |