aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_dask/dataframe.py
diff options
context:
space:
mode:
authorsotech117 <michael_foiani@brown.edu>2025-07-31 17:27:24 -0400
committersotech117 <michael_foiani@brown.edu>2025-07-31 17:27:24 -0400
commit5bf22fc7e3c392c8bd44315ca2d06d7dca7d084e (patch)
tree8dacb0f195df1c0788d36dd0064f6bbaa3143ede /venv/lib/python3.8/site-packages/narwhals/_dask/dataframe.py
parentb832d364da8c2efe09e3f75828caf73c50d01ce3 (diff)
add code for analysis of data
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_dask/dataframe.py')
-rw-r--r--venv/lib/python3.8/site-packages/narwhals/_dask/dataframe.py443
1 files changed, 443 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_dask/dataframe.py b/venv/lib/python3.8/site-packages/narwhals/_dask/dataframe.py
new file mode 100644
index 0000000..f03c763
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/narwhals/_dask/dataframe.py
@@ -0,0 +1,443 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Iterator, Mapping, Sequence
+
+import dask.dataframe as dd
+import pandas as pd
+
+from narwhals._dask.utils import add_row_index, evaluate_exprs
+from narwhals._pandas_like.utils import native_to_narwhals_dtype, select_columns_by_name
+from narwhals._utils import (
+ Implementation,
+ _remap_full_join_keys,
+ check_column_names_are_unique,
+ generate_temporary_column_name,
+ not_implemented,
+ parse_columns_to_drop,
+ parse_version,
+ validate_backend_version,
+)
+from narwhals.typing import CompliantLazyFrame
+
+if TYPE_CHECKING:
+ from types import ModuleType
+
+ import dask.dataframe.dask_expr as dx
+ from typing_extensions import Self, TypeIs
+
+ from narwhals._compliant.typing import CompliantDataFrameAny
+ from narwhals._dask.expr import DaskExpr
+ from narwhals._dask.group_by import DaskLazyGroupBy
+ from narwhals._dask.namespace import DaskNamespace
+ from narwhals._utils import Version, _FullContext
+ from narwhals.dataframe import LazyFrame
+ from narwhals.dtypes import DType
+ from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
+
+
+class DaskLazyFrame(
+ CompliantLazyFrame["DaskExpr", "dd.DataFrame", "LazyFrame[dd.DataFrame]"]
+):
+ def __init__(
+ self,
+ native_dataframe: dd.DataFrame,
+ *,
+ backend_version: tuple[int, ...],
+ version: Version,
+ ) -> None:
+ self._native_frame: dd.DataFrame = native_dataframe
+ self._backend_version = backend_version
+ self._implementation = Implementation.DASK
+ self._version = version
+ self._cached_schema: dict[str, DType] | None = None
+ self._cached_columns: list[str] | None = None
+ validate_backend_version(self._implementation, self._backend_version)
+
+ @staticmethod
+ def _is_native(obj: dd.DataFrame | Any) -> TypeIs[dd.DataFrame]:
+ return isinstance(obj, dd.DataFrame)
+
+ @classmethod
+ def from_native(cls, data: dd.DataFrame, /, *, context: _FullContext) -> Self:
+ return cls(
+ data, backend_version=context._backend_version, version=context._version
+ )
+
+ def to_narwhals(self) -> LazyFrame[dd.DataFrame]:
+ return self._version.lazyframe(self, level="lazy")
+
+ def __native_namespace__(self) -> ModuleType:
+ if self._implementation is Implementation.DASK:
+ return self._implementation.to_native_namespace()
+
+ msg = f"Expected dask, got: {type(self._implementation)}" # pragma: no cover
+ raise AssertionError(msg)
+
+ def __narwhals_namespace__(self) -> DaskNamespace:
+ from narwhals._dask.namespace import DaskNamespace
+
+ return DaskNamespace(backend_version=self._backend_version, version=self._version)
+
+ def __narwhals_lazyframe__(self) -> Self:
+ return self
+
+ def _with_version(self, version: Version) -> Self:
+ return self.__class__(
+ self.native, backend_version=self._backend_version, version=version
+ )
+
+ def _with_native(self, df: Any) -> Self:
+ return self.__class__(
+ df, backend_version=self._backend_version, version=self._version
+ )
+
+ def _iter_columns(self) -> Iterator[dx.Series]:
+ for _col, ser in self.native.items(): # noqa: PERF102
+ yield ser
+
+ def with_columns(self, *exprs: DaskExpr) -> Self:
+ new_series = evaluate_exprs(self, *exprs)
+ return self._with_native(self.native.assign(**dict(new_series)))
+
+ def collect(
+ self, backend: Implementation | None, **kwargs: Any
+ ) -> CompliantDataFrameAny:
+ result = self.native.compute(**kwargs)
+
+ if backend is None or backend is Implementation.PANDAS:
+ from narwhals._pandas_like.dataframe import PandasLikeDataFrame
+
+ return PandasLikeDataFrame(
+ result,
+ implementation=Implementation.PANDAS,
+ backend_version=parse_version(pd),
+ version=self._version,
+ validate_column_names=True,
+ )
+
+ if backend is Implementation.POLARS:
+ import polars as pl # ignore-banned-import
+
+ from narwhals._polars.dataframe import PolarsDataFrame
+
+ return PolarsDataFrame(
+ pl.from_pandas(result),
+ backend_version=parse_version(pl),
+ version=self._version,
+ )
+
+ if backend is Implementation.PYARROW:
+ import pyarrow as pa # ignore-banned-import
+
+ from narwhals._arrow.dataframe import ArrowDataFrame
+
+ return ArrowDataFrame(
+ pa.Table.from_pandas(result),
+ backend_version=parse_version(pa),
+ version=self._version,
+ validate_column_names=True,
+ )
+
+ msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
+ raise ValueError(msg) # pragma: no cover
+
+ @property
+ def columns(self) -> list[str]:
+ if self._cached_columns is None:
+ self._cached_columns = (
+ list(self.schema)
+ if self._cached_schema is not None
+ else self.native.columns.tolist()
+ )
+ return self._cached_columns
+
+ def filter(self, predicate: DaskExpr) -> Self:
+ # `[0]` is safe as the predicate's expression only returns a single column
+ mask = predicate(self)[0]
+ return self._with_native(self.native.loc[mask])
+
+ def simple_select(self, *column_names: str) -> Self:
+ native = select_columns_by_name(
+ self.native, list(column_names), self._backend_version, self._implementation
+ )
+ return self._with_native(native)
+
+ def aggregate(self, *exprs: DaskExpr) -> Self:
+ new_series = evaluate_exprs(self, *exprs)
+ df = dd.concat([val.rename(name) for name, val in new_series], axis=1)
+ return self._with_native(df)
+
+ def select(self, *exprs: DaskExpr) -> Self:
+ new_series = evaluate_exprs(self, *exprs)
+ df = select_columns_by_name(
+ self.native.assign(**dict(new_series)),
+ [s[0] for s in new_series],
+ self._backend_version,
+ self._implementation,
+ )
+ return self._with_native(df)
+
+ def drop_nulls(self, subset: Sequence[str] | None) -> Self:
+ if subset is None:
+ return self._with_native(self.native.dropna())
+ plx = self.__narwhals_namespace__()
+ return self.filter(~plx.any_horizontal(plx.col(*subset).is_null()))
+
+ @property
+ def schema(self) -> dict[str, DType]:
+ if self._cached_schema is None:
+ native_dtypes = self.native.dtypes
+ self._cached_schema = {
+ col: native_to_narwhals_dtype(
+ native_dtypes[col], self._version, self._implementation
+ )
+ for col in self.native.columns
+ }
+ return self._cached_schema
+
+ def collect_schema(self) -> dict[str, DType]:
+ return self.schema
+
+ def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
+ to_drop = parse_columns_to_drop(self, columns, strict=strict)
+
+ return self._with_native(self.native.drop(columns=to_drop))
+
+ def with_row_index(self, name: str) -> Self:
+ # Implementation is based on the following StackOverflow reply:
+ # https://stackoverflow.com/questions/60831518/in-dask-how-does-one-add-a-range-of-integersauto-increment-to-a-new-column/60852409#60852409
+ return self._with_native(
+ add_row_index(self.native, name, self._backend_version, self._implementation)
+ )
+
+ def rename(self, mapping: Mapping[str, str]) -> Self:
+ return self._with_native(self.native.rename(columns=mapping))
+
+ def head(self, n: int) -> Self:
+ return self._with_native(self.native.head(n=n, compute=False, npartitions=-1))
+
+ def unique(
+ self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
+ ) -> Self:
+ if subset and (error := self._check_columns_exist(subset)):
+ raise error
+ if keep == "none":
+ subset = subset or self.columns
+ token = generate_temporary_column_name(n_bytes=8, columns=subset)
+ ser = self.native.groupby(subset).size().rename(token)
+ ser = ser[ser == 1]
+ unique = ser.reset_index().drop(columns=token)
+ result = self.native.merge(unique, on=subset, how="inner")
+ else:
+ mapped_keep = {"any": "first"}.get(keep, keep)
+ result = self.native.drop_duplicates(subset=subset, keep=mapped_keep)
+ return self._with_native(result)
+
+ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
+ if isinstance(descending, bool):
+ ascending: bool | list[bool] = not descending
+ else:
+ ascending = [not d for d in descending]
+ position = "last" if nulls_last else "first"
+ return self._with_native(
+ self.native.sort_values(list(by), ascending=ascending, na_position=position)
+ )
+
+ def join( # noqa: C901
+ self,
+ other: Self,
+ *,
+ how: JoinStrategy,
+ left_on: Sequence[str] | None,
+ right_on: Sequence[str] | None,
+ suffix: str,
+ ) -> Self:
+ if how == "cross":
+ key_token = generate_temporary_column_name(
+ n_bytes=8, columns=[*self.columns, *other.columns]
+ )
+
+ return self._with_native(
+ self.native.assign(**{key_token: 0})
+ .merge(
+ other.native.assign(**{key_token: 0}),
+ how="inner",
+ left_on=key_token,
+ right_on=key_token,
+ suffixes=("", suffix),
+ )
+ .drop(columns=key_token)
+ )
+
+ if how == "anti":
+ indicator_token = generate_temporary_column_name(
+ n_bytes=8, columns=[*self.columns, *other.columns]
+ )
+
+ if right_on is None: # pragma: no cover
+ msg = "`right_on` cannot be `None` in anti-join"
+ raise TypeError(msg)
+ other_native = (
+ select_columns_by_name(
+ other.native,
+ list(right_on),
+ self._backend_version,
+ self._implementation,
+ )
+ .rename( # rename to avoid creating extra columns in join
+ columns=dict(zip(right_on, left_on)) # type: ignore[arg-type]
+ )
+ .drop_duplicates()
+ )
+ df = self.native.merge(
+ other_native,
+ how="outer",
+ indicator=indicator_token, # pyright: ignore[reportArgumentType]
+ left_on=left_on,
+ right_on=left_on,
+ )
+ return self._with_native(
+ df[df[indicator_token] == "left_only"].drop(columns=[indicator_token])
+ )
+
+ if how == "semi":
+ if right_on is None: # pragma: no cover
+ msg = "`right_on` cannot be `None` in semi-join"
+ raise TypeError(msg)
+ other_native = (
+ select_columns_by_name(
+ other.native,
+ list(right_on),
+ self._backend_version,
+ self._implementation,
+ )
+ .rename( # rename to avoid creating extra columns in join
+ columns=dict(zip(right_on, left_on)) # type: ignore[arg-type]
+ )
+ .drop_duplicates() # avoids potential rows duplication from inner join
+ )
+ return self._with_native(
+ self.native.merge(
+ other_native, how="inner", left_on=left_on, right_on=left_on
+ )
+ )
+
+ if how == "left":
+ result_native = self.native.merge(
+ other.native,
+ how="left",
+ left_on=left_on,
+ right_on=right_on,
+ suffixes=("", suffix),
+ )
+ extra = []
+ for left_key, right_key in zip(left_on, right_on): # type: ignore[arg-type]
+ if right_key != left_key and right_key not in self.columns:
+ extra.append(right_key)
+ elif right_key != left_key:
+ extra.append(f"{right_key}_right")
+ return self._with_native(result_native.drop(columns=extra))
+
+ if how == "full":
+ # dask does not retain keys post-join
+ # we must append the suffix to each key before-hand
+
+ # help mypy
+ assert left_on is not None # noqa: S101
+ assert right_on is not None # noqa: S101
+
+ right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix)
+ other_native = other.native.rename(columns=right_on_mapper)
+ check_column_names_are_unique(other_native.columns)
+ right_on = list(right_on_mapper.values()) # we now have the suffixed keys
+ return self._with_native(
+ self.native.merge(
+ other_native,
+ left_on=left_on,
+ right_on=right_on,
+ how="outer",
+ suffixes=("", suffix),
+ )
+ )
+
+ return self._with_native(
+ self.native.merge(
+ other.native,
+ left_on=left_on,
+ right_on=right_on,
+ how=how,
+ suffixes=("", suffix),
+ )
+ )
+
+ def join_asof(
+ self,
+ other: Self,
+ *,
+ left_on: str,
+ right_on: str,
+ by_left: Sequence[str] | None,
+ by_right: Sequence[str] | None,
+ strategy: AsofJoinStrategy,
+ suffix: str,
+ ) -> Self:
+ plx = self.__native_namespace__()
+ return self._with_native(
+ plx.merge_asof(
+ self.native,
+ other.native,
+ left_on=left_on,
+ right_on=right_on,
+ left_by=by_left,
+ right_by=by_right,
+ direction=strategy,
+ suffixes=("", suffix),
+ )
+ )
+
+ def group_by(
+ self, keys: Sequence[str] | Sequence[DaskExpr], *, drop_null_keys: bool
+ ) -> DaskLazyGroupBy:
+ from narwhals._dask.group_by import DaskLazyGroupBy
+
+ return DaskLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
+
+ def tail(self, n: int) -> Self: # pragma: no cover
+ native_frame = self.native
+ n_partitions = native_frame.npartitions
+
+ if n_partitions == 1:
+ return self._with_native(self.native.tail(n=n, compute=False))
+ else:
+ msg = "`LazyFrame.tail` is not supported for Dask backend with multiple partitions."
+ raise NotImplementedError(msg)
+
+ def gather_every(self, n: int, offset: int) -> Self:
+ row_index_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
+ plx = self.__narwhals_namespace__()
+ return (
+ self.with_row_index(row_index_token)
+ .filter(
+ (plx.col(row_index_token) >= offset)
+ & ((plx.col(row_index_token) - offset) % n == 0)
+ )
+ .drop([row_index_token], strict=False)
+ )
+
+ def unpivot(
+ self,
+ on: Sequence[str] | None,
+ index: Sequence[str] | None,
+ variable_name: str,
+ value_name: str,
+ ) -> Self:
+ return self._with_native(
+ self.native.melt(
+ id_vars=index,
+ value_vars=on,
+ var_name=variable_name,
+ value_name=value_name,
+ )
+ )
+
+ explode = not_implemented()