diff options
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_dask/namespace.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/narwhals/_dask/namespace.py | 320 |
1 files changed, 320 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_dask/namespace.py b/venv/lib/python3.8/site-packages/narwhals/_dask/namespace.py new file mode 100644 index 0000000..3e0506d --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_dask/namespace.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +import operator +from functools import reduce +from typing import TYPE_CHECKING, Iterable, Sequence, cast + +import dask.dataframe as dd +import pandas as pd + +from narwhals._compliant import CompliantThen, CompliantWhen, LazyNamespace +from narwhals._compliant.namespace import DepthTrackingNamespace +from narwhals._dask.dataframe import DaskLazyFrame +from narwhals._dask.expr import DaskExpr +from narwhals._dask.selectors import DaskSelectorNamespace +from narwhals._dask.utils import ( + align_series_full_broadcast, + narwhals_to_native_dtype, + validate_comparand, +) +from narwhals._expression_parsing import ( + ExprKind, + combine_alias_output_names, + combine_evaluate_output_names, +) +from narwhals._utils import Implementation + +if TYPE_CHECKING: + import dask.dataframe.dask_expr as dx + + from narwhals._utils import Version + from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral + + +class DaskNamespace( + LazyNamespace[DaskLazyFrame, DaskExpr, dd.DataFrame], + DepthTrackingNamespace[DaskLazyFrame, DaskExpr], +): + _implementation: Implementation = Implementation.DASK + + @property + def selectors(self) -> DaskSelectorNamespace: + return DaskSelectorNamespace.from_namespace(self) + + @property + def _expr(self) -> type[DaskExpr]: + return DaskExpr + + @property + def _lazyframe(self) -> type[DaskLazyFrame]: + return DaskLazyFrame + + def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> None: + self._backend_version = backend_version + self._version = version + + def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DaskExpr: + def func(df: DaskLazyFrame) -> list[dx.Series]: + if dtype is not None: + native_dtype = narwhals_to_native_dtype(dtype, self._version) + native_pd_series = pd.Series([value], dtype=native_dtype, name="literal") + else: + native_pd_series = pd.Series([value], name="literal") + npartitions = df._native_frame.npartitions + dask_series = dd.from_pandas(native_pd_series, npartitions=npartitions) + return [dask_series[0].to_series()] + + return self._expr( + func, + depth=0, + function_name="lit", + evaluate_output_names=lambda _df: ["literal"], + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + ) + + def len(self) -> DaskExpr: + def func(df: DaskLazyFrame) -> list[dx.Series]: + # We don't allow dataframes with 0 columns, so `[0]` is safe. + return [df._native_frame[df.columns[0]].size.to_series()] + + return self._expr( + func, + depth=0, + function_name="len", + evaluate_output_names=lambda _df: ["len"], + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + ) + + def all_horizontal(self, *exprs: DaskExpr) -> DaskExpr: + def func(df: DaskLazyFrame) -> list[dx.Series]: + series = align_series_full_broadcast( + df, *(s for _expr in exprs for s in _expr(df)) + ) + return [reduce(operator.and_, series)] + + return self._expr( + call=func, + depth=max(x._depth for x in exprs) + 1, + function_name="all_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + backend_version=self._backend_version, + version=self._version, + ) + + def any_horizontal(self, *exprs: DaskExpr) -> DaskExpr: + def func(df: DaskLazyFrame) -> list[dx.Series]: + series = align_series_full_broadcast( + df, *(s for _expr in exprs for s in _expr(df)) + ) + return [reduce(operator.or_, series)] + + return self._expr( + call=func, + depth=max(x._depth for x in exprs) + 1, + function_name="any_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + backend_version=self._backend_version, + version=self._version, + ) + + def sum_horizontal(self, *exprs: DaskExpr) -> DaskExpr: + def func(df: DaskLazyFrame) -> list[dx.Series]: + series = align_series_full_broadcast( + df, *(s for _expr in exprs for s in _expr(df)) + ) + return [dd.concat(series, axis=1).sum(axis=1)] + + return self._expr( + call=func, + depth=max(x._depth for x in exprs) + 1, + function_name="sum_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + backend_version=self._backend_version, + version=self._version, + ) + + def concat( + self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod + ) -> DaskLazyFrame: + if not items: + msg = "No items to concatenate" # pragma: no cover + raise AssertionError(msg) + dfs = [i._native_frame for i in items] + cols_0 = dfs[0].columns + if how == "vertical": + for i, df in enumerate(dfs[1:], start=1): + cols_current = df.columns + if not ( + (len(cols_current) == len(cols_0)) and (cols_current == cols_0).all() + ): + msg = ( + "unable to vstack, column names don't match:\n" + f" - dataframe 0: {cols_0.to_list()}\n" + f" - dataframe {i}: {cols_current.to_list()}\n" + ) + raise TypeError(msg) + return DaskLazyFrame( + dd.concat(dfs, axis=0, join="inner"), + backend_version=self._backend_version, + version=self._version, + ) + if how == "diagonal": + return DaskLazyFrame( + dd.concat(dfs, axis=0, join="outer"), + backend_version=self._backend_version, + version=self._version, + ) + + raise NotImplementedError + + def mean_horizontal(self, *exprs: DaskExpr) -> DaskExpr: + def func(df: DaskLazyFrame) -> list[dx.Series]: + expr_results = [s for _expr in exprs for s in _expr(df)] + series = align_series_full_broadcast(df, *(s.fillna(0) for s in expr_results)) + non_na = align_series_full_broadcast( + df, *(1 - s.isna() for s in expr_results) + ) + num = reduce(lambda x, y: x + y, series) # pyright: ignore[reportOperatorIssue] + den = reduce(lambda x, y: x + y, non_na) # pyright: ignore[reportOperatorIssue] + return [cast("dx.Series", num / den)] # pyright: ignore[reportOperatorIssue] + + return self._expr( + call=func, + depth=max(x._depth for x in exprs) + 1, + function_name="mean_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + backend_version=self._backend_version, + version=self._version, + ) + + def min_horizontal(self, *exprs: DaskExpr) -> DaskExpr: + def func(df: DaskLazyFrame) -> list[dx.Series]: + series = align_series_full_broadcast( + df, *(s for _expr in exprs for s in _expr(df)) + ) + + return [dd.concat(series, axis=1).min(axis=1)] + + return self._expr( + call=func, + depth=max(x._depth for x in exprs) + 1, + function_name="min_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + backend_version=self._backend_version, + version=self._version, + ) + + def max_horizontal(self, *exprs: DaskExpr) -> DaskExpr: + def func(df: DaskLazyFrame) -> list[dx.Series]: + series = align_series_full_broadcast( + df, *(s for _expr in exprs for s in _expr(df)) + ) + + return [dd.concat(series, axis=1).max(axis=1)] + + return self._expr( + call=func, + depth=max(x._depth for x in exprs) + 1, + function_name="max_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + backend_version=self._backend_version, + version=self._version, + ) + + def when(self, predicate: DaskExpr) -> DaskWhen: + return DaskWhen.from_expr(predicate, context=self) + + def concat_str( + self, *exprs: DaskExpr, separator: str, ignore_nulls: bool + ) -> DaskExpr: + def func(df: DaskLazyFrame) -> list[dx.Series]: + expr_results = [s for _expr in exprs for s in _expr(df)] + series = ( + s.astype(str) for s in align_series_full_broadcast(df, *expr_results) + ) + null_mask = [s.isna() for s in align_series_full_broadcast(df, *expr_results)] + + if not ignore_nulls: + null_mask_result = reduce(operator.or_, null_mask) + result = reduce(lambda x, y: x + separator + y, series).where( + ~null_mask_result, None + ) + else: + init_value, *values = [ + s.where(~nm, "") for s, nm in zip(series, null_mask) + ] + + separators = ( + nm.map({True: "", False: separator}, meta=str) + for nm in null_mask[:-1] + ) + result = reduce( + operator.add, (s + v for s, v in zip(separators, values)), init_value + ) + + return [result] + + return self._expr( + call=func, + depth=max(x._depth for x in exprs) + 1, + function_name="concat_str", + evaluate_output_names=getattr( + exprs[0], "_evaluate_output_names", lambda _df: ["literal"] + ), + alias_output_names=getattr(exprs[0], "_alias_output_names", None), + backend_version=self._backend_version, + version=self._version, + ) + + +class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]): + @property + def _then(self) -> type[DaskThen]: + return DaskThen + + def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: + then_value = ( + self._then_value(df)[0] + if isinstance(self._then_value, DaskExpr) + else self._then_value + ) + otherwise_value = ( + self._otherwise_value(df)[0] + if isinstance(self._otherwise_value, DaskExpr) + else self._otherwise_value + ) + + condition = self._condition(df)[0] + # re-evaluate DataFrame if the condition aggregates to force + # then/otherwise to be evaluated against the aggregated frame + assert self._condition._metadata is not None # noqa: S101 + if self._condition._metadata.is_scalar_like: + new_df = df._with_native(condition.to_frame()) + condition = self._condition.broadcast(ExprKind.AGGREGATION)(df)[0] + df = new_df + + if self._otherwise_value is None: + (condition, then_series) = align_series_full_broadcast( + df, condition, then_value + ) + validate_comparand(condition, then_series) + return [then_series.where(condition)] # pyright: ignore[reportArgumentType] + (condition, then_series, otherwise_series) = align_series_full_broadcast( + df, condition, then_value, otherwise_value + ) + validate_comparand(condition, then_series) + validate_comparand(condition, otherwise_series) + return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType] + + +class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr], DaskExpr): ... |