diff options
author | sotech117 <michael_foiani@brown.edu> | 2025-07-31 17:27:24 -0400 |
---|---|---|
committer | sotech117 <michael_foiani@brown.edu> | 2025-07-31 17:27:24 -0400 |
commit | 5bf22fc7e3c392c8bd44315ca2d06d7dca7d084e (patch) | |
tree | 8dacb0f195df1c0788d36dd0064f6bbaa3143ede /venv/lib/python3.8/site-packages/narwhals/_dask/expr.py | |
parent | b832d364da8c2efe09e3f75828caf73c50d01ce3 (diff) |
add code for analysis of data
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_dask/expr.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/narwhals/_dask/expr.py | 675 |
1 files changed, 675 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_dask/expr.py b/venv/lib/python3.8/site-packages/narwhals/_dask/expr.py new file mode 100644 index 0000000..aa51997 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_dask/expr.py @@ -0,0 +1,675 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence + +from narwhals._compliant import LazyExpr +from narwhals._compliant.expr import DepthTrackingExpr +from narwhals._dask.expr_dt import DaskExprDateTimeNamespace +from narwhals._dask.expr_str import DaskExprStringNamespace +from narwhals._dask.utils import ( + add_row_index, + maybe_evaluate_expr, + narwhals_to_native_dtype, +) +from narwhals._expression_parsing import ExprKind, evaluate_output_names_and_aliases +from narwhals._pandas_like.utils import native_to_narwhals_dtype +from narwhals._utils import ( + Implementation, + generate_temporary_column_name, + not_implemented, +) +from narwhals.exceptions import InvalidOperationError + +if TYPE_CHECKING: + import dask.dataframe.dask_expr as dx + from typing_extensions import Self + + from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs + from narwhals._dask.dataframe import DaskLazyFrame + from narwhals._dask.namespace import DaskNamespace + from narwhals._expression_parsing import ExprKind, ExprMetadata + from narwhals._utils import Version, _FullContext + from narwhals.typing import ( + FillNullStrategy, + IntoDType, + NonNestedLiteral, + NumericLiteral, + RollingInterpolationMethod, + TemporalLiteral, + ) + + +class DaskExpr( + LazyExpr["DaskLazyFrame", "dx.Series"], + DepthTrackingExpr["DaskLazyFrame", "dx.Series"], +): + _implementation: Implementation = Implementation.DASK + + def __init__( + self, + call: EvalSeries[DaskLazyFrame, dx.Series], + *, + depth: int, + function_name: str, + evaluate_output_names: EvalNames[DaskLazyFrame], + alias_output_names: AliasNames | None, + backend_version: tuple[int, ...], + version: Version, + scalar_kwargs: ScalarKwargs | None = None, + ) -> None: + self._call = call + self._depth = depth + self._function_name = function_name + self._evaluate_output_names = evaluate_output_names + self._alias_output_names = alias_output_names + self._backend_version = backend_version + self._version = version + self._scalar_kwargs = scalar_kwargs or {} + self._metadata: ExprMetadata | None = None + + def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: + return self._call(df) + + def __narwhals_expr__(self) -> None: ... + + def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover + # Unused, just for compatibility with PandasLikeExpr + from narwhals._dask.namespace import DaskNamespace + + return DaskNamespace(backend_version=self._backend_version, version=self._version) + + def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: + def func(df: DaskLazyFrame) -> list[dx.Series]: + # result.loc[0][0] is a workaround for dask~<=2024.10.0/dask_expr~<=1.1.16 + # that raised a KeyErrror for result[0] during collection. + return [result.loc[0][0] for result in self(df)] + + return self.__class__( + func, + depth=self._depth, + function_name=self._function_name, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + scalar_kwargs=self._scalar_kwargs, + ) + + @classmethod + def from_column_names( + cls: type[Self], + evaluate_column_names: EvalNames[DaskLazyFrame], + /, + *, + context: _FullContext, + function_name: str = "", + ) -> Self: + def func(df: DaskLazyFrame) -> list[dx.Series]: + try: + return [ + df._native_frame[column_name] + for column_name in evaluate_column_names(df) + ] + except KeyError as e: + if error := df._check_columns_exist(evaluate_column_names(df)): + raise error from e + raise + + return cls( + func, + depth=0, + function_name=function_name, + evaluate_output_names=evaluate_column_names, + alias_output_names=None, + backend_version=context._backend_version, + version=context._version, + ) + + @classmethod + def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self: + def func(df: DaskLazyFrame) -> list[dx.Series]: + return [df.native.iloc[:, i] for i in column_indices] + + return cls( + func, + depth=0, + function_name="nth", + evaluate_output_names=cls._eval_names_indices(column_indices), + alias_output_names=None, + backend_version=context._backend_version, + version=context._version, + ) + + def _with_callable( + self, + # First argument to `call` should be `dx.Series` + call: Callable[..., dx.Series], + /, + expr_name: str = "", + scalar_kwargs: ScalarKwargs | None = None, + **expressifiable_args: Self | Any, + ) -> Self: + def func(df: DaskLazyFrame) -> list[dx.Series]: + native_results: list[dx.Series] = [] + native_series_list = self._call(df) + other_native_series = { + key: maybe_evaluate_expr(df, value) + for key, value in expressifiable_args.items() + } + for native_series in native_series_list: + result_native = call(native_series, **other_native_series) + native_results.append(result_native) + return native_results + + return self.__class__( + func, + depth=self._depth + 1, + function_name=f"{self._function_name}->{expr_name}", + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + scalar_kwargs=scalar_kwargs, + ) + + def _with_alias_output_names(self, func: AliasNames | None, /) -> Self: + return type(self)( + call=self._call, + depth=self._depth, + function_name=self._function_name, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=func, + backend_version=self._backend_version, + version=self._version, + scalar_kwargs=self._scalar_kwargs, + ) + + def __add__(self, other: Any) -> Self: + return self._with_callable( + lambda expr, other: expr.__add__(other), "__add__", other=other + ) + + def __sub__(self, other: Any) -> Self: + return self._with_callable( + lambda expr, other: expr.__sub__(other), "__sub__", other=other + ) + + def __rsub__(self, other: Any) -> Self: + return self._with_callable( + lambda expr, other: other - expr, "__rsub__", other=other + ).alias("literal") + + def __mul__(self, other: Any) -> Self: + return self._with_callable( + lambda expr, other: expr.__mul__(other), "__mul__", other=other + ) + + def __truediv__(self, other: Any) -> Self: + return self._with_callable( + lambda expr, other: expr.__truediv__(other), "__truediv__", other=other + ) + + def __rtruediv__(self, other: Any) -> Self: + return self._with_callable( + lambda expr, other: other / expr, "__rtruediv__", other=other + ).alias("literal") + + def __floordiv__(self, other: Any) -> Self: + return self._with_callable( + lambda expr, other: expr.__floordiv__(other), "__floordiv__", other=other + ) + + def __rfloordiv__(self, other: Any) -> Self: + return self._with_callable( + lambda expr, other: other // expr, "__rfloordiv__", other=other + ).alias("literal") + + def __pow__(self, other: Any) -> Self: + return self._with_callable( + lambda expr, other: expr.__pow__(other), "__pow__", other=other + ) + + def __rpow__(self, other: Any) -> Self: + return self._with_callable( + lambda expr, other: other**expr, "__rpow__", other=other + ).alias("literal") + + def __mod__(self, other: Any) -> Self: + return self._with_callable( + lambda expr, other: expr.__mod__(other), "__mod__", other=other + ) + + def __rmod__(self, other: Any) -> Self: + return self._with_callable( + lambda expr, other: other % expr, "__rmod__", other=other + ).alias("literal") + + def __eq__(self, other: DaskExpr) -> Self: # type: ignore[override] + return self._with_callable( + lambda expr, other: expr.__eq__(other), "__eq__", other=other + ) + + def __ne__(self, other: DaskExpr) -> Self: # type: ignore[override] + return self._with_callable( + lambda expr, other: expr.__ne__(other), "__ne__", other=other + ) + + def __ge__(self, other: DaskExpr | Any) -> Self: + return self._with_callable( + lambda expr, other: expr.__ge__(other), "__ge__", other=other + ) + + def __gt__(self, other: DaskExpr) -> Self: + return self._with_callable( + lambda expr, other: expr.__gt__(other), "__gt__", other=other + ) + + def __le__(self, other: DaskExpr) -> Self: + return self._with_callable( + lambda expr, other: expr.__le__(other), "__le__", other=other + ) + + def __lt__(self, other: DaskExpr) -> Self: + return self._with_callable( + lambda expr, other: expr.__lt__(other), "__lt__", other=other + ) + + def __and__(self, other: DaskExpr | Any) -> Self: + return self._with_callable( + lambda expr, other: expr.__and__(other), "__and__", other=other + ) + + def __or__(self, other: DaskExpr) -> Self: + return self._with_callable( + lambda expr, other: expr.__or__(other), "__or__", other=other + ) + + def __invert__(self) -> Self: + return self._with_callable(lambda expr: expr.__invert__(), "__invert__") + + def mean(self) -> Self: + return self._with_callable(lambda expr: expr.mean().to_series(), "mean") + + def median(self) -> Self: + from narwhals.exceptions import InvalidOperationError + + def func(s: dx.Series) -> dx.Series: + dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK) + if not dtype.is_numeric(): + msg = "`median` operation not supported for non-numeric input type." + raise InvalidOperationError(msg) + return s.median_approximate().to_series() + + return self._with_callable(func, "median") + + def min(self) -> Self: + return self._with_callable(lambda expr: expr.min().to_series(), "min") + + def max(self) -> Self: + return self._with_callable(lambda expr: expr.max().to_series(), "max") + + def std(self, ddof: int) -> Self: + return self._with_callable( + lambda expr: expr.std(ddof=ddof).to_series(), + "std", + scalar_kwargs={"ddof": ddof}, + ) + + def var(self, ddof: int) -> Self: + return self._with_callable( + lambda expr: expr.var(ddof=ddof).to_series(), + "var", + scalar_kwargs={"ddof": ddof}, + ) + + def skew(self) -> Self: + return self._with_callable(lambda expr: expr.skew().to_series(), "skew") + + def shift(self, n: int) -> Self: + return self._with_callable(lambda expr: expr.shift(n), "shift") + + def cum_sum(self, *, reverse: bool) -> Self: + if reverse: # pragma: no cover + # https://github.com/dask/dask/issues/11802 + msg = "`cum_sum(reverse=True)` is not supported with Dask backend" + raise NotImplementedError(msg) + + return self._with_callable(lambda expr: expr.cumsum(), "cum_sum") + + def cum_count(self, *, reverse: bool) -> Self: + if reverse: # pragma: no cover + msg = "`cum_count(reverse=True)` is not supported with Dask backend" + raise NotImplementedError(msg) + + return self._with_callable( + lambda expr: (~expr.isna()).astype(int).cumsum(), "cum_count" + ) + + def cum_min(self, *, reverse: bool) -> Self: + if reverse: # pragma: no cover + msg = "`cum_min(reverse=True)` is not supported with Dask backend" + raise NotImplementedError(msg) + + return self._with_callable(lambda expr: expr.cummin(), "cum_min") + + def cum_max(self, *, reverse: bool) -> Self: + if reverse: # pragma: no cover + msg = "`cum_max(reverse=True)` is not supported with Dask backend" + raise NotImplementedError(msg) + + return self._with_callable(lambda expr: expr.cummax(), "cum_max") + + def cum_prod(self, *, reverse: bool) -> Self: + if reverse: # pragma: no cover + msg = "`cum_prod(reverse=True)` is not supported with Dask backend" + raise NotImplementedError(msg) + + return self._with_callable(lambda expr: expr.cumprod(), "cum_prod") + + def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: + return self._with_callable( + lambda expr: expr.rolling( + window=window_size, min_periods=min_samples, center=center + ).sum(), + "rolling_sum", + ) + + def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: + return self._with_callable( + lambda expr: expr.rolling( + window=window_size, min_periods=min_samples, center=center + ).mean(), + "rolling_mean", + ) + + def rolling_var( + self, window_size: int, *, min_samples: int, center: bool, ddof: int + ) -> Self: + if ddof == 1: + return self._with_callable( + lambda expr: expr.rolling( + window=window_size, min_periods=min_samples, center=center + ).var(), + "rolling_var", + ) + else: + msg = "Dask backend only supports `ddof=1` for `rolling_var`" + raise NotImplementedError(msg) + + def rolling_std( + self, window_size: int, *, min_samples: int, center: bool, ddof: int + ) -> Self: + if ddof == 1: + return self._with_callable( + lambda expr: expr.rolling( + window=window_size, min_periods=min_samples, center=center + ).std(), + "rolling_std", + ) + else: + msg = "Dask backend only supports `ddof=1` for `rolling_std`" + raise NotImplementedError(msg) + + def sum(self) -> Self: + return self._with_callable(lambda expr: expr.sum().to_series(), "sum") + + def count(self) -> Self: + return self._with_callable(lambda expr: expr.count().to_series(), "count") + + def round(self, decimals: int) -> Self: + return self._with_callable(lambda expr: expr.round(decimals), "round") + + def unique(self) -> Self: + return self._with_callable(lambda expr: expr.unique(), "unique") + + def drop_nulls(self) -> Self: + return self._with_callable(lambda expr: expr.dropna(), "drop_nulls") + + def abs(self) -> Self: + return self._with_callable(lambda expr: expr.abs(), "abs") + + def all(self) -> Self: + return self._with_callable( + lambda expr: expr.all( + axis=None, skipna=True, split_every=False, out=None + ).to_series(), + "all", + ) + + def any(self) -> Self: + return self._with_callable( + lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series(), + "any", + ) + + def fill_null( + self, + value: Self | NonNestedLiteral, + strategy: FillNullStrategy | None, + limit: int | None, + ) -> Self: + def func(expr: dx.Series) -> dx.Series: + if value is not None: + res_ser = expr.fillna(value) + else: + res_ser = ( + expr.ffill(limit=limit) + if strategy == "forward" + else expr.bfill(limit=limit) + ) + return res_ser + + return self._with_callable(func, "fillna") + + def clip( + self, + lower_bound: Self | NumericLiteral | TemporalLiteral | None, + upper_bound: Self | NumericLiteral | TemporalLiteral | None, + ) -> Self: + return self._with_callable( + lambda expr, lower_bound, upper_bound: expr.clip( + lower=lower_bound, upper=upper_bound + ), + "clip", + lower_bound=lower_bound, + upper_bound=upper_bound, + ) + + def diff(self) -> Self: + return self._with_callable(lambda expr: expr.diff(), "diff") + + def n_unique(self) -> Self: + return self._with_callable( + lambda expr: expr.nunique(dropna=False).to_series(), "n_unique" + ) + + def is_null(self) -> Self: + return self._with_callable(lambda expr: expr.isna(), "is_null") + + def is_nan(self) -> Self: + def func(expr: dx.Series) -> dx.Series: + dtype = native_to_narwhals_dtype( + expr.dtype, self._version, self._implementation + ) + if dtype.is_numeric(): + return expr != expr # pyright: ignore[reportReturnType] # noqa: PLR0124 + msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?" + raise InvalidOperationError(msg) + + return self._with_callable(func, "is_null") + + def len(self) -> Self: + return self._with_callable(lambda expr: expr.size.to_series(), "len") + + def quantile( + self, quantile: float, interpolation: RollingInterpolationMethod + ) -> Self: + if interpolation == "linear": + + def func(expr: dx.Series, quantile: float) -> dx.Series: + if expr.npartitions > 1: + msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions." + raise NotImplementedError(msg) + return expr.quantile( + q=quantile, method="dask" + ).to_series() # pragma: no cover + + return self._with_callable(func, "quantile", quantile=quantile) + else: + msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead." + raise NotImplementedError(msg) + + def is_first_distinct(self) -> Self: + def func(expr: dx.Series) -> dx.Series: + _name = expr.name + col_token = generate_temporary_column_name(n_bytes=8, columns=[_name]) + frame = add_row_index( + expr.to_frame(), col_token, self._backend_version, self._implementation + ) + first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token] + return frame[col_token].isin(first_distinct_index) + + return self._with_callable(func, "is_first_distinct") + + def is_last_distinct(self) -> Self: + def func(expr: dx.Series) -> dx.Series: + _name = expr.name + col_token = generate_temporary_column_name(n_bytes=8, columns=[_name]) + frame = add_row_index( + expr.to_frame(), col_token, self._backend_version, self._implementation + ) + last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token] + return frame[col_token].isin(last_distinct_index) + + return self._with_callable(func, "is_last_distinct") + + def is_unique(self) -> Self: + def func(expr: dx.Series) -> dx.Series: + _name = expr.name + return ( + expr.to_frame() + .groupby(_name, dropna=False) + .transform("size", meta=(_name, int)) + == 1 + ) + + return self._with_callable(func, "is_unique") + + def is_in(self, other: Any) -> Self: + return self._with_callable(lambda expr: expr.isin(other), "is_in") + + def null_count(self) -> Self: + return self._with_callable( + lambda expr: expr.isna().sum().to_series(), "null_count" + ) + + def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self: + # pandas is a required dependency of dask so it's safe to import this + from narwhals._pandas_like.group_by import PandasLikeGroupBy + + if not partition_by: + assert order_by # noqa: S101 + + # This is something like `nw.col('a').cum_sum().order_by(key)` + # which we can always easily support, as it doesn't require grouping. + def func(df: DaskLazyFrame) -> Sequence[dx.Series]: + return self(df.sort(*order_by, descending=False, nulls_last=False)) + elif not self._is_elementary(): # pragma: no cover + msg = ( + "Only elementary expressions are supported for `.over` in dask.\n\n" + "Please see: " + "https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/" + ) + raise NotImplementedError(msg) + elif order_by: + # Wrong results https://github.com/dask/dask/issues/11806. + msg = "`over` with `order_by` is not yet supported in Dask." + raise NotImplementedError(msg) + else: + function_name = PandasLikeGroupBy._leaf_name(self) + try: + dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name] + except KeyError: + # window functions are unsupported: https://github.com/dask/dask/issues/11806 + msg = ( + f"Unsupported function: {function_name} in `over` context.\n\n" + f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n" + ) + raise NotImplementedError(msg) from None + + def func(df: DaskLazyFrame) -> Sequence[dx.Series]: + output_names, aliases = evaluate_output_names_and_aliases(self, df, []) + + with warnings.catch_warnings(): + # https://github.com/dask/dask/issues/11804 + warnings.filterwarnings( + "ignore", + message=".*`meta` is not specified", + category=UserWarning, + ) + grouped = df.native.groupby(partition_by) + if dask_function_name == "size": + if len(output_names) != 1: # pragma: no cover + msg = "Safety check failed, please report a bug." + raise AssertionError(msg) + res_native = grouped.transform( + dask_function_name, **self._scalar_kwargs + ).to_frame(output_names[0]) + else: + res_native = grouped[list(output_names)].transform( + dask_function_name, **self._scalar_kwargs + ) + result_frame = df._with_native( + res_native.rename(columns=dict(zip(output_names, aliases))) + ).native + return [result_frame[name] for name in aliases] + + return self.__class__( + func, + depth=self._depth + 1, + function_name=self._function_name + "->over", + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + ) + + def cast(self, dtype: IntoDType) -> Self: + def func(expr: dx.Series) -> dx.Series: + native_dtype = narwhals_to_native_dtype(dtype, self._version) + return expr.astype(native_dtype) + + return self._with_callable(func, "cast") + + def is_finite(self) -> Self: + import dask.array as da + + return self._with_callable(da.isfinite, "is_finite") + + def log(self, base: float) -> Self: + import dask.array as da + + def _log(expr: dx.Series) -> dx.Series: + return da.log(expr) / da.log(base) + + return self._with_callable(_log, "log") + + def exp(self) -> Self: + import dask.array as da + + return self._with_callable(da.exp, "exp") + + @property + def str(self) -> DaskExprStringNamespace: + return DaskExprStringNamespace(self) + + @property + def dt(self) -> DaskExprDateTimeNamespace: + return DaskExprDateTimeNamespace(self) + + list = not_implemented() # pyright: ignore[reportAssignmentType] + struct = not_implemented() # pyright: ignore[reportAssignmentType] + rank = not_implemented() # pyright: ignore[reportAssignmentType] + _alias_native = not_implemented() + window_function = not_implemented() # pyright: ignore[reportAssignmentType] |