aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_polars/group_by.py
blob: e29c3e24f414b4cce74dfd560dfc52f50a4a9226 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from __future__ import annotations

from typing import TYPE_CHECKING, Iterator, Sequence, cast

from narwhals._utils import is_sequence_of

if TYPE_CHECKING:
    from polars.dataframe.group_by import GroupBy as NativeGroupBy
    from polars.lazyframe.group_by import LazyGroupBy as NativeLazyGroupBy

    from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
    from narwhals._polars.expr import PolarsExpr


class PolarsGroupBy:
    _compliant_frame: PolarsDataFrame
    _grouped: NativeGroupBy
    _drop_null_keys: bool
    _output_names: Sequence[str]

    @property
    def compliant(self) -> PolarsDataFrame:
        return self._compliant_frame

    def __init__(
        self,
        df: PolarsDataFrame,
        keys: Sequence[PolarsExpr] | Sequence[str],
        /,
        *,
        drop_null_keys: bool,
    ) -> None:
        self._keys = list(keys)
        self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df
        self._grouped = (
            self.compliant.native.group_by(keys)
            if is_sequence_of(keys, str)
            else self.compliant.native.group_by(arg.native for arg in keys)
        )

    def agg(self, *aggs: PolarsExpr) -> PolarsDataFrame:
        agg_result = self._grouped.agg(arg.native for arg in aggs)
        return self.compliant._with_native(agg_result)

    def __iter__(self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]:
        for key, df in self._grouped:
            yield tuple(cast("str", key)), self.compliant._with_native(df)


class PolarsLazyGroupBy:
    _compliant_frame: PolarsLazyFrame
    _grouped: NativeLazyGroupBy
    _drop_null_keys: bool
    _output_names: Sequence[str]

    @property
    def compliant(self) -> PolarsLazyFrame:
        return self._compliant_frame

    def __init__(
        self,
        df: PolarsLazyFrame,
        keys: Sequence[PolarsExpr] | Sequence[str],
        /,
        *,
        drop_null_keys: bool,
    ) -> None:
        self._keys = list(keys)
        self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df
        self._grouped = (
            self.compliant.native.group_by(keys)
            if is_sequence_of(keys, str)
            else self.compliant.native.group_by(arg.native for arg in keys)
        )

    def agg(self, *aggs: PolarsExpr) -> PolarsLazyFrame:
        agg_result = self._grouped.agg(arg.native for arg in aggs)
        return self.compliant._with_native(agg_result)