1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
|
from __future__ import annotations
import collections
from typing import TYPE_CHECKING, Any, ClassVar, Iterator, Mapping, Sequence
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.utils import cast_to_comparable_string_types, extract_py_scalar
from narwhals._compliant import EagerGroupBy
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._utils import generate_temporary_column_name
if TYPE_CHECKING:
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
AggregateOptions,
Aggregation,
Incomplete,
)
from narwhals._compliant.group_by import NarwhalsAggregation
from narwhals.typing import UniqueKeepStrategy
class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
"sum": "sum",
"mean": "mean",
"median": "approximate_median",
"max": "max",
"min": "min",
"std": "stddev",
"var": "variance",
"len": "count",
"n_unique": "count_distinct",
"count": "count",
}
_REMAP_UNIQUE: ClassVar[Mapping[UniqueKeepStrategy, Aggregation]] = {
"any": "min",
"first": "min",
"last": "max",
}
def __init__(
self,
df: ArrowDataFrame,
keys: Sequence[ArrowExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
self._df = df
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
self._grouped = pa.TableGroupBy(self.compliant.native, self._keys)
self._drop_null_keys = drop_null_keys
def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
self._ensure_all_simple(exprs)
aggs: list[tuple[str, Aggregation, AggregateOptions | None]] = []
expected_pyarrow_column_names: list[str] = self._keys.copy()
new_column_names: list[str] = self._keys.copy()
exclude = (*self._keys, *self._output_key_names)
for expr in exprs:
output_names, aliases = evaluate_output_names_and_aliases(
expr, self.compliant, exclude
)
if expr._depth == 0:
# e.g. `agg(nw.len())`
if expr._function_name != "len": # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)
new_column_names.append(aliases[0])
expected_pyarrow_column_names.append(f"{self._keys[0]}_count")
aggs.append((self._keys[0], "count", pc.CountOptions(mode="all")))
continue
function_name = self._leaf_name(expr)
if function_name in {"std", "var"}:
assert "ddof" in expr._scalar_kwargs # noqa: S101
option: Any = pc.VarianceOptions(ddof=expr._scalar_kwargs["ddof"])
elif function_name in {"len", "n_unique"}:
option = pc.CountOptions(mode="all")
elif function_name == "count":
option = pc.CountOptions(mode="only_valid")
else:
option = None
function_name = self._remap_expr_name(function_name)
new_column_names.extend(aliases)
expected_pyarrow_column_names.extend(
[f"{output_name}_{function_name}" for output_name in output_names]
)
aggs.extend(
[(output_name, function_name, option) for output_name in output_names]
)
result_simple = self._grouped.aggregate(aggs)
# Rename columns, being very careful
expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list)
for idx, item in enumerate(expected_pyarrow_column_names):
expected_old_names_indices[item].append(idx)
if not (
set(result_simple.column_names) == set(expected_pyarrow_column_names)
and len(result_simple.column_names) == len(expected_pyarrow_column_names)
): # pragma: no cover
msg = (
f"Safety assertion failed, expected {expected_pyarrow_column_names} "
f"got {result_simple.column_names}, "
"please report a bug at https://github.com/narwhals-dev/narwhals/issues"
)
raise AssertionError(msg)
index_map: list[int] = [
expected_old_names_indices[item].pop(0) for item in result_simple.column_names
]
new_column_names = [new_column_names[i] for i in index_map]
result_simple = result_simple.rename_columns(new_column_names)
if self.compliant._backend_version < (12, 0, 0):
columns = result_simple.column_names
result_simple = result_simple.select(
[*self._keys, *[col for col in columns if col not in self._keys]]
)
return self.compliant._with_native(result_simple).rename(
dict(zip(self._keys, self._output_key_names))
)
def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
col_token = generate_temporary_column_name(
n_bytes=8, columns=self.compliant.columns
)
null_token: str = "__null_token_value__" # noqa: S105
table = self.compliant.native
it, separator_scalar = cast_to_comparable_string_types(
*(table[key] for key in self._keys), separator=""
)
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
# Reality: `str` is fine
concat_str: Incomplete = pc.binary_join_element_wise
key_values = concat_str(
*it, separator_scalar, null_handling="replace", null_replacement=null_token
)
table = table.add_column(i=0, field_=col_token, column=key_values)
for v in pc.unique(key_values):
t = self.compliant._with_native(
table.filter(pc.equal(table[col_token], v)).drop([col_token])
)
row = t.simple_select(*self._keys).row(0)
yield (
tuple(extract_py_scalar(el) for el in row),
t.simple_select(*self._df.columns),
)
|