aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py')
-rw-r--r--venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py205
1 files changed, 205 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py
new file mode 100644
index 0000000..af7993c
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py
@@ -0,0 +1,205 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Sequence
+
+import pyarrow.compute as pc
+
+from narwhals._arrow.series import ArrowSeries
+from narwhals._compliant import EagerExpr
+from narwhals._expression_parsing import evaluate_output_names_and_aliases
+from narwhals._utils import (
+ Implementation,
+ generate_temporary_column_name,
+ not_implemented,
+)
+
+if TYPE_CHECKING:
+ from typing_extensions import Self
+
+ from narwhals._arrow.dataframe import ArrowDataFrame
+ from narwhals._arrow.namespace import ArrowNamespace
+ from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
+ from narwhals._expression_parsing import ExprMetadata
+ from narwhals._utils import Version, _FullContext
+ from narwhals.typing import RankMethod
+
+
+class ArrowExpr(EagerExpr["ArrowDataFrame", ArrowSeries]):
+ _implementation: Implementation = Implementation.PYARROW
+
+ def __init__(
+ self,
+ call: EvalSeries[ArrowDataFrame, ArrowSeries],
+ *,
+ depth: int,
+ function_name: str,
+ evaluate_output_names: EvalNames[ArrowDataFrame],
+ alias_output_names: AliasNames | None,
+ backend_version: tuple[int, ...],
+ version: Version,
+ scalar_kwargs: ScalarKwargs | None = None,
+ implementation: Implementation | None = None,
+ ) -> None:
+ self._call = call
+ self._depth = depth
+ self._function_name = function_name
+ self._depth = depth
+ self._evaluate_output_names = evaluate_output_names
+ self._alias_output_names = alias_output_names
+ self._backend_version = backend_version
+ self._version = version
+ self._scalar_kwargs = scalar_kwargs or {}
+ self._metadata: ExprMetadata | None = None
+
+ @classmethod
+ def from_column_names(
+ cls: type[Self],
+ evaluate_column_names: EvalNames[ArrowDataFrame],
+ /,
+ *,
+ context: _FullContext,
+ function_name: str = "",
+ ) -> Self:
+ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
+ try:
+ return [
+ ArrowSeries(
+ df.native[column_name],
+ name=column_name,
+ backend_version=df._backend_version,
+ version=df._version,
+ )
+ for column_name in evaluate_column_names(df)
+ ]
+ except KeyError as e:
+ if error := df._check_columns_exist(evaluate_column_names(df)):
+ raise error from e
+ raise
+
+ return cls(
+ func,
+ depth=0,
+ function_name=function_name,
+ evaluate_output_names=evaluate_column_names,
+ alias_output_names=None,
+ backend_version=context._backend_version,
+ version=context._version,
+ )
+
+ @classmethod
+ def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self:
+ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
+ tbl = df.native
+ cols = df.columns
+ return [
+ ArrowSeries.from_native(tbl[i], name=cols[i], context=df)
+ for i in column_indices
+ ]
+
+ return cls(
+ func,
+ depth=0,
+ function_name="nth",
+ evaluate_output_names=cls._eval_names_indices(column_indices),
+ alias_output_names=None,
+ backend_version=context._backend_version,
+ version=context._version,
+ )
+
+ def __narwhals_namespace__(self) -> ArrowNamespace:
+ from narwhals._arrow.namespace import ArrowNamespace
+
+ return ArrowNamespace(
+ backend_version=self._backend_version, version=self._version
+ )
+
+ def __narwhals_expr__(self) -> None: ...
+
+ def _reuse_series_extra_kwargs(
+ self, *, returns_scalar: bool = False
+ ) -> dict[str, Any]:
+ return {"_return_py_scalar": False} if returns_scalar else {}
+
+ def cum_sum(self, *, reverse: bool) -> Self:
+ return self._reuse_series("cum_sum", reverse=reverse)
+
+ def shift(self, n: int) -> Self:
+ return self._reuse_series("shift", n=n)
+
+ def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
+ assert self._metadata is not None # noqa: S101
+ if partition_by and not self._metadata.is_scalar_like:
+ msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow."
+ raise NotImplementedError(msg)
+
+ if not partition_by:
+ # e.g. `nw.col('a').cum_sum().order_by(key)`
+ # which we can always easily support, as it doesn't require grouping.
+ assert order_by # noqa: S101
+
+ def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
+ token = generate_temporary_column_name(8, df.columns)
+ df = df.with_row_index(token).sort(
+ *order_by, descending=False, nulls_last=False
+ )
+ result = self(df.drop([token], strict=True))
+ # TODO(marco): is there a way to do this efficiently without
+ # doing 2 sorts? Here we're sorting the dataframe and then
+ # again calling `sort_indices`. `ArrowSeries.scatter` would also sort.
+ sorting_indices = pc.sort_indices(df.get_column(token).native)
+ return [s._with_native(s.native.take(sorting_indices)) for s in result]
+ else:
+
+ def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
+ output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
+ if overlap := set(output_names).intersection(partition_by):
+ # E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined,
+ # we just don't support it yet.
+ msg = (
+ f"Column names {overlap} appear in both expression output names and in `over` keys.\n"
+ "This is not yet supported."
+ )
+ raise NotImplementedError(msg)
+
+ tmp = df.group_by(partition_by, drop_null_keys=False).agg(self)
+ tmp = df.simple_select(*partition_by).join(
+ tmp,
+ how="left",
+ left_on=partition_by,
+ right_on=partition_by,
+ suffix="_right",
+ )
+ return [tmp.get_column(alias) for alias in aliases]
+
+ return self.__class__(
+ func,
+ depth=self._depth + 1,
+ function_name=self._function_name + "->over",
+ evaluate_output_names=self._evaluate_output_names,
+ alias_output_names=self._alias_output_names,
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+
+ def cum_count(self, *, reverse: bool) -> Self:
+ return self._reuse_series("cum_count", reverse=reverse)
+
+ def cum_min(self, *, reverse: bool) -> Self:
+ return self._reuse_series("cum_min", reverse=reverse)
+
+ def cum_max(self, *, reverse: bool) -> Self:
+ return self._reuse_series("cum_max", reverse=reverse)
+
+ def cum_prod(self, *, reverse: bool) -> Self:
+ return self._reuse_series("cum_prod", reverse=reverse)
+
+ def rank(self, method: RankMethod, *, descending: bool) -> Self:
+ return self._reuse_series("rank", method=method, descending=descending)
+
+ def log(self, base: float) -> Self:
+ return self._reuse_series("log", base=base)
+
+ def exp(self) -> Self:
+ return self._reuse_series("exp")
+
+ ewm_mean = not_implemented()