aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_dask/namespace.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_dask/namespace.py')
-rw-r--r--venv/lib/python3.8/site-packages/narwhals/_dask/namespace.py320
1 files changed, 320 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_dask/namespace.py b/venv/lib/python3.8/site-packages/narwhals/_dask/namespace.py
new file mode 100644
index 0000000..3e0506d
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/narwhals/_dask/namespace.py
@@ -0,0 +1,320 @@
+from __future__ import annotations
+
+import operator
+from functools import reduce
+from typing import TYPE_CHECKING, Iterable, Sequence, cast
+
+import dask.dataframe as dd
+import pandas as pd
+
+from narwhals._compliant import CompliantThen, CompliantWhen, LazyNamespace
+from narwhals._compliant.namespace import DepthTrackingNamespace
+from narwhals._dask.dataframe import DaskLazyFrame
+from narwhals._dask.expr import DaskExpr
+from narwhals._dask.selectors import DaskSelectorNamespace
+from narwhals._dask.utils import (
+ align_series_full_broadcast,
+ narwhals_to_native_dtype,
+ validate_comparand,
+)
+from narwhals._expression_parsing import (
+ ExprKind,
+ combine_alias_output_names,
+ combine_evaluate_output_names,
+)
+from narwhals._utils import Implementation
+
+if TYPE_CHECKING:
+ import dask.dataframe.dask_expr as dx
+
+ from narwhals._utils import Version
+ from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral
+
+
+class DaskNamespace(
+ LazyNamespace[DaskLazyFrame, DaskExpr, dd.DataFrame],
+ DepthTrackingNamespace[DaskLazyFrame, DaskExpr],
+):
+ _implementation: Implementation = Implementation.DASK
+
+ @property
+ def selectors(self) -> DaskSelectorNamespace:
+ return DaskSelectorNamespace.from_namespace(self)
+
+ @property
+ def _expr(self) -> type[DaskExpr]:
+ return DaskExpr
+
+ @property
+ def _lazyframe(self) -> type[DaskLazyFrame]:
+ return DaskLazyFrame
+
+ def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> None:
+ self._backend_version = backend_version
+ self._version = version
+
+ def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DaskExpr:
+ def func(df: DaskLazyFrame) -> list[dx.Series]:
+ if dtype is not None:
+ native_dtype = narwhals_to_native_dtype(dtype, self._version)
+ native_pd_series = pd.Series([value], dtype=native_dtype, name="literal")
+ else:
+ native_pd_series = pd.Series([value], name="literal")
+ npartitions = df._native_frame.npartitions
+ dask_series = dd.from_pandas(native_pd_series, npartitions=npartitions)
+ return [dask_series[0].to_series()]
+
+ return self._expr(
+ func,
+ depth=0,
+ function_name="lit",
+ evaluate_output_names=lambda _df: ["literal"],
+ alias_output_names=None,
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+
+ def len(self) -> DaskExpr:
+ def func(df: DaskLazyFrame) -> list[dx.Series]:
+ # We don't allow dataframes with 0 columns, so `[0]` is safe.
+ return [df._native_frame[df.columns[0]].size.to_series()]
+
+ return self._expr(
+ func,
+ depth=0,
+ function_name="len",
+ evaluate_output_names=lambda _df: ["len"],
+ alias_output_names=None,
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+
+ def all_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
+ def func(df: DaskLazyFrame) -> list[dx.Series]:
+ series = align_series_full_broadcast(
+ df, *(s for _expr in exprs for s in _expr(df))
+ )
+ return [reduce(operator.and_, series)]
+
+ return self._expr(
+ call=func,
+ depth=max(x._depth for x in exprs) + 1,
+ function_name="all_horizontal",
+ evaluate_output_names=combine_evaluate_output_names(*exprs),
+ alias_output_names=combine_alias_output_names(*exprs),
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+
+ def any_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
+ def func(df: DaskLazyFrame) -> list[dx.Series]:
+ series = align_series_full_broadcast(
+ df, *(s for _expr in exprs for s in _expr(df))
+ )
+ return [reduce(operator.or_, series)]
+
+ return self._expr(
+ call=func,
+ depth=max(x._depth for x in exprs) + 1,
+ function_name="any_horizontal",
+ evaluate_output_names=combine_evaluate_output_names(*exprs),
+ alias_output_names=combine_alias_output_names(*exprs),
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+
+ def sum_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
+ def func(df: DaskLazyFrame) -> list[dx.Series]:
+ series = align_series_full_broadcast(
+ df, *(s for _expr in exprs for s in _expr(df))
+ )
+ return [dd.concat(series, axis=1).sum(axis=1)]
+
+ return self._expr(
+ call=func,
+ depth=max(x._depth for x in exprs) + 1,
+ function_name="sum_horizontal",
+ evaluate_output_names=combine_evaluate_output_names(*exprs),
+ alias_output_names=combine_alias_output_names(*exprs),
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+
+ def concat(
+ self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod
+ ) -> DaskLazyFrame:
+ if not items:
+ msg = "No items to concatenate" # pragma: no cover
+ raise AssertionError(msg)
+ dfs = [i._native_frame for i in items]
+ cols_0 = dfs[0].columns
+ if how == "vertical":
+ for i, df in enumerate(dfs[1:], start=1):
+ cols_current = df.columns
+ if not (
+ (len(cols_current) == len(cols_0)) and (cols_current == cols_0).all()
+ ):
+ msg = (
+ "unable to vstack, column names don't match:\n"
+ f" - dataframe 0: {cols_0.to_list()}\n"
+ f" - dataframe {i}: {cols_current.to_list()}\n"
+ )
+ raise TypeError(msg)
+ return DaskLazyFrame(
+ dd.concat(dfs, axis=0, join="inner"),
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+ if how == "diagonal":
+ return DaskLazyFrame(
+ dd.concat(dfs, axis=0, join="outer"),
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+
+ raise NotImplementedError
+
+ def mean_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
+ def func(df: DaskLazyFrame) -> list[dx.Series]:
+ expr_results = [s for _expr in exprs for s in _expr(df)]
+ series = align_series_full_broadcast(df, *(s.fillna(0) for s in expr_results))
+ non_na = align_series_full_broadcast(
+ df, *(1 - s.isna() for s in expr_results)
+ )
+ num = reduce(lambda x, y: x + y, series) # pyright: ignore[reportOperatorIssue]
+ den = reduce(lambda x, y: x + y, non_na) # pyright: ignore[reportOperatorIssue]
+ return [cast("dx.Series", num / den)] # pyright: ignore[reportOperatorIssue]
+
+ return self._expr(
+ call=func,
+ depth=max(x._depth for x in exprs) + 1,
+ function_name="mean_horizontal",
+ evaluate_output_names=combine_evaluate_output_names(*exprs),
+ alias_output_names=combine_alias_output_names(*exprs),
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+
+ def min_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
+ def func(df: DaskLazyFrame) -> list[dx.Series]:
+ series = align_series_full_broadcast(
+ df, *(s for _expr in exprs for s in _expr(df))
+ )
+
+ return [dd.concat(series, axis=1).min(axis=1)]
+
+ return self._expr(
+ call=func,
+ depth=max(x._depth for x in exprs) + 1,
+ function_name="min_horizontal",
+ evaluate_output_names=combine_evaluate_output_names(*exprs),
+ alias_output_names=combine_alias_output_names(*exprs),
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+
+ def max_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
+ def func(df: DaskLazyFrame) -> list[dx.Series]:
+ series = align_series_full_broadcast(
+ df, *(s for _expr in exprs for s in _expr(df))
+ )
+
+ return [dd.concat(series, axis=1).max(axis=1)]
+
+ return self._expr(
+ call=func,
+ depth=max(x._depth for x in exprs) + 1,
+ function_name="max_horizontal",
+ evaluate_output_names=combine_evaluate_output_names(*exprs),
+ alias_output_names=combine_alias_output_names(*exprs),
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+
+ def when(self, predicate: DaskExpr) -> DaskWhen:
+ return DaskWhen.from_expr(predicate, context=self)
+
+ def concat_str(
+ self, *exprs: DaskExpr, separator: str, ignore_nulls: bool
+ ) -> DaskExpr:
+ def func(df: DaskLazyFrame) -> list[dx.Series]:
+ expr_results = [s for _expr in exprs for s in _expr(df)]
+ series = (
+ s.astype(str) for s in align_series_full_broadcast(df, *expr_results)
+ )
+ null_mask = [s.isna() for s in align_series_full_broadcast(df, *expr_results)]
+
+ if not ignore_nulls:
+ null_mask_result = reduce(operator.or_, null_mask)
+ result = reduce(lambda x, y: x + separator + y, series).where(
+ ~null_mask_result, None
+ )
+ else:
+ init_value, *values = [
+ s.where(~nm, "") for s, nm in zip(series, null_mask)
+ ]
+
+ separators = (
+ nm.map({True: "", False: separator}, meta=str)
+ for nm in null_mask[:-1]
+ )
+ result = reduce(
+ operator.add, (s + v for s, v in zip(separators, values)), init_value
+ )
+
+ return [result]
+
+ return self._expr(
+ call=func,
+ depth=max(x._depth for x in exprs) + 1,
+ function_name="concat_str",
+ evaluate_output_names=getattr(
+ exprs[0], "_evaluate_output_names", lambda _df: ["literal"]
+ ),
+ alias_output_names=getattr(exprs[0], "_alias_output_names", None),
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+
+
+class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]):
+ @property
+ def _then(self) -> type[DaskThen]:
+ return DaskThen
+
+ def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
+ then_value = (
+ self._then_value(df)[0]
+ if isinstance(self._then_value, DaskExpr)
+ else self._then_value
+ )
+ otherwise_value = (
+ self._otherwise_value(df)[0]
+ if isinstance(self._otherwise_value, DaskExpr)
+ else self._otherwise_value
+ )
+
+ condition = self._condition(df)[0]
+ # re-evaluate DataFrame if the condition aggregates to force
+ # then/otherwise to be evaluated against the aggregated frame
+ assert self._condition._metadata is not None # noqa: S101
+ if self._condition._metadata.is_scalar_like:
+ new_df = df._with_native(condition.to_frame())
+ condition = self._condition.broadcast(ExprKind.AGGREGATION)(df)[0]
+ df = new_df
+
+ if self._otherwise_value is None:
+ (condition, then_series) = align_series_full_broadcast(
+ df, condition, then_value
+ )
+ validate_comparand(condition, then_series)
+ return [then_series.where(condition)] # pyright: ignore[reportArgumentType]
+ (condition, then_series, otherwise_series) = align_series_full_broadcast(
+ df, condition, then_value, otherwise_value
+ )
+ validate_comparand(condition, then_series)
+ validate_comparand(condition, otherwise_series)
+ return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType]
+
+
+class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr], DaskExpr): ...