aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_arrow/group_by.py
blob: d61906a609a978897f2b90f1584c8e940f6ae783 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from __future__ import annotations

import collections
from typing import TYPE_CHECKING, Any, ClassVar, Iterator, Mapping, Sequence

import pyarrow as pa
import pyarrow.compute as pc

from narwhals._arrow.utils import cast_to_comparable_string_types, extract_py_scalar
from narwhals._compliant import EagerGroupBy
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._utils import generate_temporary_column_name

if TYPE_CHECKING:
    from narwhals._arrow.dataframe import ArrowDataFrame
    from narwhals._arrow.expr import ArrowExpr
    from narwhals._arrow.typing import (  # type: ignore[attr-defined]
        AggregateOptions,
        Aggregation,
        Incomplete,
    )
    from narwhals._compliant.group_by import NarwhalsAggregation
    from narwhals.typing import UniqueKeepStrategy


class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
    _REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
        "sum": "sum",
        "mean": "mean",
        "median": "approximate_median",
        "max": "max",
        "min": "min",
        "std": "stddev",
        "var": "variance",
        "len": "count",
        "n_unique": "count_distinct",
        "count": "count",
    }
    _REMAP_UNIQUE: ClassVar[Mapping[UniqueKeepStrategy, Aggregation]] = {
        "any": "min",
        "first": "min",
        "last": "max",
    }

    def __init__(
        self,
        df: ArrowDataFrame,
        keys: Sequence[ArrowExpr] | Sequence[str],
        /,
        *,
        drop_null_keys: bool,
    ) -> None:
        self._df = df
        frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
        self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
        self._grouped = pa.TableGroupBy(self.compliant.native, self._keys)
        self._drop_null_keys = drop_null_keys

    def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
        self._ensure_all_simple(exprs)
        aggs: list[tuple[str, Aggregation, AggregateOptions | None]] = []
        expected_pyarrow_column_names: list[str] = self._keys.copy()
        new_column_names: list[str] = self._keys.copy()
        exclude = (*self._keys, *self._output_key_names)

        for expr in exprs:
            output_names, aliases = evaluate_output_names_and_aliases(
                expr, self.compliant, exclude
            )

            if expr._depth == 0:
                # e.g. `agg(nw.len())`
                if expr._function_name != "len":  # pragma: no cover
                    msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
                    raise AssertionError(msg)

                new_column_names.append(aliases[0])
                expected_pyarrow_column_names.append(f"{self._keys[0]}_count")
                aggs.append((self._keys[0], "count", pc.CountOptions(mode="all")))
                continue

            function_name = self._leaf_name(expr)
            if function_name in {"std", "var"}:
                assert "ddof" in expr._scalar_kwargs  # noqa: S101
                option: Any = pc.VarianceOptions(ddof=expr._scalar_kwargs["ddof"])
            elif function_name in {"len", "n_unique"}:
                option = pc.CountOptions(mode="all")
            elif function_name == "count":
                option = pc.CountOptions(mode="only_valid")
            else:
                option = None

            function_name = self._remap_expr_name(function_name)
            new_column_names.extend(aliases)
            expected_pyarrow_column_names.extend(
                [f"{output_name}_{function_name}" for output_name in output_names]
            )
            aggs.extend(
                [(output_name, function_name, option) for output_name in output_names]
            )

        result_simple = self._grouped.aggregate(aggs)

        # Rename columns, being very careful
        expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list)
        for idx, item in enumerate(expected_pyarrow_column_names):
            expected_old_names_indices[item].append(idx)
        if not (
            set(result_simple.column_names) == set(expected_pyarrow_column_names)
            and len(result_simple.column_names) == len(expected_pyarrow_column_names)
        ):  # pragma: no cover
            msg = (
                f"Safety assertion failed, expected {expected_pyarrow_column_names} "
                f"got {result_simple.column_names}, "
                "please report a bug at https://github.com/narwhals-dev/narwhals/issues"
            )
            raise AssertionError(msg)
        index_map: list[int] = [
            expected_old_names_indices[item].pop(0) for item in result_simple.column_names
        ]
        new_column_names = [new_column_names[i] for i in index_map]
        result_simple = result_simple.rename_columns(new_column_names)
        if self.compliant._backend_version < (12, 0, 0):
            columns = result_simple.column_names
            result_simple = result_simple.select(
                [*self._keys, *[col for col in columns if col not in self._keys]]
            )

        return self.compliant._with_native(result_simple).rename(
            dict(zip(self._keys, self._output_key_names))
        )

    def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
        col_token = generate_temporary_column_name(
            n_bytes=8, columns=self.compliant.columns
        )
        null_token: str = "__null_token_value__"  # noqa: S105

        table = self.compliant.native
        it, separator_scalar = cast_to_comparable_string_types(
            *(table[key] for key in self._keys), separator=""
        )
        # NOTE: stubs indicate `separator` must also be a `ChunkedArray`
        # Reality: `str` is fine
        concat_str: Incomplete = pc.binary_join_element_wise
        key_values = concat_str(
            *it, separator_scalar, null_handling="replace", null_replacement=null_token
        )
        table = table.add_column(i=0, field_=col_token, column=key_values)

        for v in pc.unique(key_values):
            t = self.compliant._with_native(
                table.filter(pc.equal(table[col_token], v)).drop([col_token])
            )
            row = t.simple_select(*self._keys).row(0)
            yield (
                tuple(extract_py_scalar(el) for el in row),
                t.simple_select(*self._df.columns),
            )