aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_arrow/expr.py
blob: af7993c54984aec5147b075db122148b41e24af1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Sequence

import pyarrow.compute as pc

from narwhals._arrow.series import ArrowSeries
from narwhals._compliant import EagerExpr
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._utils import (
    Implementation,
    generate_temporary_column_name,
    not_implemented,
)

if TYPE_CHECKING:
    from typing_extensions import Self

    from narwhals._arrow.dataframe import ArrowDataFrame
    from narwhals._arrow.namespace import ArrowNamespace
    from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
    from narwhals._expression_parsing import ExprMetadata
    from narwhals._utils import Version, _FullContext
    from narwhals.typing import RankMethod


class ArrowExpr(EagerExpr["ArrowDataFrame", ArrowSeries]):
    _implementation: Implementation = Implementation.PYARROW

    def __init__(
        self,
        call: EvalSeries[ArrowDataFrame, ArrowSeries],
        *,
        depth: int,
        function_name: str,
        evaluate_output_names: EvalNames[ArrowDataFrame],
        alias_output_names: AliasNames | None,
        backend_version: tuple[int, ...],
        version: Version,
        scalar_kwargs: ScalarKwargs | None = None,
        implementation: Implementation | None = None,
    ) -> None:
        self._call = call
        self._depth = depth
        self._function_name = function_name
        self._depth = depth
        self._evaluate_output_names = evaluate_output_names
        self._alias_output_names = alias_output_names
        self._backend_version = backend_version
        self._version = version
        self._scalar_kwargs = scalar_kwargs or {}
        self._metadata: ExprMetadata | None = None

    @classmethod
    def from_column_names(
        cls: type[Self],
        evaluate_column_names: EvalNames[ArrowDataFrame],
        /,
        *,
        context: _FullContext,
        function_name: str = "",
    ) -> Self:
        def func(df: ArrowDataFrame) -> list[ArrowSeries]:
            try:
                return [
                    ArrowSeries(
                        df.native[column_name],
                        name=column_name,
                        backend_version=df._backend_version,
                        version=df._version,
                    )
                    for column_name in evaluate_column_names(df)
                ]
            except KeyError as e:
                if error := df._check_columns_exist(evaluate_column_names(df)):
                    raise error from e
                raise

        return cls(
            func,
            depth=0,
            function_name=function_name,
            evaluate_output_names=evaluate_column_names,
            alias_output_names=None,
            backend_version=context._backend_version,
            version=context._version,
        )

    @classmethod
    def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self:
        def func(df: ArrowDataFrame) -> list[ArrowSeries]:
            tbl = df.native
            cols = df.columns
            return [
                ArrowSeries.from_native(tbl[i], name=cols[i], context=df)
                for i in column_indices
            ]

        return cls(
            func,
            depth=0,
            function_name="nth",
            evaluate_output_names=cls._eval_names_indices(column_indices),
            alias_output_names=None,
            backend_version=context._backend_version,
            version=context._version,
        )

    def __narwhals_namespace__(self) -> ArrowNamespace:
        from narwhals._arrow.namespace import ArrowNamespace

        return ArrowNamespace(
            backend_version=self._backend_version, version=self._version
        )

    def __narwhals_expr__(self) -> None: ...

    def _reuse_series_extra_kwargs(
        self, *, returns_scalar: bool = False
    ) -> dict[str, Any]:
        return {"_return_py_scalar": False} if returns_scalar else {}

    def cum_sum(self, *, reverse: bool) -> Self:
        return self._reuse_series("cum_sum", reverse=reverse)

    def shift(self, n: int) -> Self:
        return self._reuse_series("shift", n=n)

    def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
        assert self._metadata is not None  # noqa: S101
        if partition_by and not self._metadata.is_scalar_like:
            msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow."
            raise NotImplementedError(msg)

        if not partition_by:
            # e.g. `nw.col('a').cum_sum().order_by(key)`
            # which we can always easily support, as it doesn't require grouping.
            assert order_by  # noqa: S101

            def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
                token = generate_temporary_column_name(8, df.columns)
                df = df.with_row_index(token).sort(
                    *order_by, descending=False, nulls_last=False
                )
                result = self(df.drop([token], strict=True))
                # TODO(marco): is there a way to do this efficiently without
                # doing 2 sorts? Here we're sorting the dataframe and then
                # again calling `sort_indices`. `ArrowSeries.scatter` would also sort.
                sorting_indices = pc.sort_indices(df.get_column(token).native)
                return [s._with_native(s.native.take(sorting_indices)) for s in result]
        else:

            def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
                output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
                if overlap := set(output_names).intersection(partition_by):
                    # E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined,
                    # we just don't support it yet.
                    msg = (
                        f"Column names {overlap} appear in both expression output names and in `over` keys.\n"
                        "This is not yet supported."
                    )
                    raise NotImplementedError(msg)

                tmp = df.group_by(partition_by, drop_null_keys=False).agg(self)
                tmp = df.simple_select(*partition_by).join(
                    tmp,
                    how="left",
                    left_on=partition_by,
                    right_on=partition_by,
                    suffix="_right",
                )
                return [tmp.get_column(alias) for alias in aliases]

        return self.__class__(
            func,
            depth=self._depth + 1,
            function_name=self._function_name + "->over",
            evaluate_output_names=self._evaluate_output_names,
            alias_output_names=self._alias_output_names,
            backend_version=self._backend_version,
            version=self._version,
        )

    def cum_count(self, *, reverse: bool) -> Self:
        return self._reuse_series("cum_count", reverse=reverse)

    def cum_min(self, *, reverse: bool) -> Self:
        return self._reuse_series("cum_min", reverse=reverse)

    def cum_max(self, *, reverse: bool) -> Self:
        return self._reuse_series("cum_max", reverse=reverse)

    def cum_prod(self, *, reverse: bool) -> Self:
        return self._reuse_series("cum_prod", reverse=reverse)

    def rank(self, method: RankMethod, *, descending: bool) -> Self:
        return self._reuse_series("rank", method=method, descending=descending)

    def log(self, base: float) -> Self:
        return self._reuse_series("log", base=base)

    def exp(self) -> Self:
        return self._reuse_series("exp")

    ewm_mean = not_implemented()