diff options
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py | 205 |
1 files changed, 205 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py new file mode 100644 index 0000000..af7993c --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence + +import pyarrow.compute as pc + +from narwhals._arrow.series import ArrowSeries +from narwhals._compliant import EagerExpr +from narwhals._expression_parsing import evaluate_output_names_and_aliases +from narwhals._utils import ( + Implementation, + generate_temporary_column_name, + not_implemented, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._arrow.dataframe import ArrowDataFrame + from narwhals._arrow.namespace import ArrowNamespace + from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs + from narwhals._expression_parsing import ExprMetadata + from narwhals._utils import Version, _FullContext + from narwhals.typing import RankMethod + + +class ArrowExpr(EagerExpr["ArrowDataFrame", ArrowSeries]): + _implementation: Implementation = Implementation.PYARROW + + def __init__( + self, + call: EvalSeries[ArrowDataFrame, ArrowSeries], + *, + depth: int, + function_name: str, + evaluate_output_names: EvalNames[ArrowDataFrame], + alias_output_names: AliasNames | None, + backend_version: tuple[int, ...], + version: Version, + scalar_kwargs: ScalarKwargs | None = None, + implementation: Implementation | None = None, + ) -> None: + self._call = call + self._depth = depth + self._function_name = function_name + self._depth = depth + self._evaluate_output_names = evaluate_output_names + self._alias_output_names = alias_output_names + self._backend_version = backend_version + self._version = version + self._scalar_kwargs = scalar_kwargs or {} + self._metadata: ExprMetadata | None = None + + @classmethod + def from_column_names( + cls: type[Self], + evaluate_column_names: EvalNames[ArrowDataFrame], + /, + *, + context: _FullContext, + function_name: str = "", + ) -> Self: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + try: + return [ + ArrowSeries( + df.native[column_name], + name=column_name, + backend_version=df._backend_version, + version=df._version, + ) + for column_name in evaluate_column_names(df) + ] + except KeyError as e: + if error := df._check_columns_exist(evaluate_column_names(df)): + raise error from e + raise + + return cls( + func, + depth=0, + function_name=function_name, + evaluate_output_names=evaluate_column_names, + alias_output_names=None, + backend_version=context._backend_version, + version=context._version, + ) + + @classmethod + def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + tbl = df.native + cols = df.columns + return [ + ArrowSeries.from_native(tbl[i], name=cols[i], context=df) + for i in column_indices + ] + + return cls( + func, + depth=0, + function_name="nth", + evaluate_output_names=cls._eval_names_indices(column_indices), + alias_output_names=None, + backend_version=context._backend_version, + version=context._version, + ) + + def __narwhals_namespace__(self) -> ArrowNamespace: + from narwhals._arrow.namespace import ArrowNamespace + + return ArrowNamespace( + backend_version=self._backend_version, version=self._version + ) + + def __narwhals_expr__(self) -> None: ... + + def _reuse_series_extra_kwargs( + self, *, returns_scalar: bool = False + ) -> dict[str, Any]: + return {"_return_py_scalar": False} if returns_scalar else {} + + def cum_sum(self, *, reverse: bool) -> Self: + return self._reuse_series("cum_sum", reverse=reverse) + + def shift(self, n: int) -> Self: + return self._reuse_series("shift", n=n) + + def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self: + assert self._metadata is not None # noqa: S101 + if partition_by and not self._metadata.is_scalar_like: + msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow." + raise NotImplementedError(msg) + + if not partition_by: + # e.g. `nw.col('a').cum_sum().order_by(key)` + # which we can always easily support, as it doesn't require grouping. + assert order_by # noqa: S101 + + def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]: + token = generate_temporary_column_name(8, df.columns) + df = df.with_row_index(token).sort( + *order_by, descending=False, nulls_last=False + ) + result = self(df.drop([token], strict=True)) + # TODO(marco): is there a way to do this efficiently without + # doing 2 sorts? Here we're sorting the dataframe and then + # again calling `sort_indices`. `ArrowSeries.scatter` would also sort. + sorting_indices = pc.sort_indices(df.get_column(token).native) + return [s._with_native(s.native.take(sorting_indices)) for s in result] + else: + + def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]: + output_names, aliases = evaluate_output_names_and_aliases(self, df, []) + if overlap := set(output_names).intersection(partition_by): + # E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined, + # we just don't support it yet. + msg = ( + f"Column names {overlap} appear in both expression output names and in `over` keys.\n" + "This is not yet supported." + ) + raise NotImplementedError(msg) + + tmp = df.group_by(partition_by, drop_null_keys=False).agg(self) + tmp = df.simple_select(*partition_by).join( + tmp, + how="left", + left_on=partition_by, + right_on=partition_by, + suffix="_right", + ) + return [tmp.get_column(alias) for alias in aliases] + + return self.__class__( + func, + depth=self._depth + 1, + function_name=self._function_name + "->over", + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + backend_version=self._backend_version, + version=self._version, + ) + + def cum_count(self, *, reverse: bool) -> Self: + return self._reuse_series("cum_count", reverse=reverse) + + def cum_min(self, *, reverse: bool) -> Self: + return self._reuse_series("cum_min", reverse=reverse) + + def cum_max(self, *, reverse: bool) -> Self: + return self._reuse_series("cum_max", reverse=reverse) + + def cum_prod(self, *, reverse: bool) -> Self: + return self._reuse_series("cum_prod", reverse=reverse) + + def rank(self, method: RankMethod, *, descending: bool) -> Self: + return self._reuse_series("rank", method=method, descending=descending) + + def log(self, base: float) -> Self: + return self._reuse_series("log", base=base) + + def exp(self) -> Self: + return self._reuse_series("exp") + + ewm_mean = not_implemented() |