aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_spark_like/group_by.py
blob: 4c63d77595debb32b6d093bac813e8ad12811502 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence

from narwhals._compliant import LazyGroupBy

if TYPE_CHECKING:
    from sqlframe.base.column import Column  # noqa: F401

    from narwhals._spark_like.dataframe import SparkLikeLazyFrame
    from narwhals._spark_like.expr import SparkLikeExpr


class SparkLikeLazyGroupBy(LazyGroupBy["SparkLikeLazyFrame", "SparkLikeExpr", "Column"]):
    def __init__(
        self,
        df: SparkLikeLazyFrame,
        keys: Sequence[SparkLikeExpr] | Sequence[str],
        /,
        *,
        drop_null_keys: bool,
    ) -> None:
        frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
        self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame

    def agg(self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame:
        result = (
            self.compliant.native.groupBy(*self._keys).agg(*agg_columns)
            if (agg_columns := list(self._evaluate_exprs(exprs)))
            else self.compliant.native.select(*self._keys).dropDuplicates()
        )

        return self.compliant._with_native(result).rename(
            dict(zip(self._keys, self._output_key_names))
        )