diff options
author | sotech117 <michael_foiani@brown.edu> | 2025-07-31 17:27:24 -0400 |
---|---|---|
committer | sotech117 <michael_foiani@brown.edu> | 2025-07-31 17:27:24 -0400 |
commit | 5bf22fc7e3c392c8bd44315ca2d06d7dca7d084e (patch) | |
tree | 8dacb0f195df1c0788d36dd0064f6bbaa3143ede /venv/lib/python3.8/site-packages/narwhals/group_by.py | |
parent | b832d364da8c2efe09e3f75828caf73c50d01ce3 (diff) |
add code for analysis of data
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/group_by.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/narwhals/group_by.py | 190 |
1 files changed, 190 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/group_by.py b/venv/lib/python3.8/site-packages/narwhals/group_by.py new file mode 100644 index 0000000..6a06a17 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/group_by.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic, Iterable, Iterator, Sequence, TypeVar + +from narwhals._expression_parsing import all_exprs_are_scalar_like +from narwhals._utils import flatten, tupleify +from narwhals.exceptions import InvalidOperationError +from narwhals.typing import DataFrameT + +if TYPE_CHECKING: + from narwhals._compliant.typing import CompliantExprAny + from narwhals.dataframe import LazyFrame + from narwhals.expr import Expr + +LazyFrameT = TypeVar("LazyFrameT", bound="LazyFrame[Any]") + + +class GroupBy(Generic[DataFrameT]): + def __init__( + self, + df: DataFrameT, + keys: Sequence[str] | Sequence[CompliantExprAny], + /, + *, + drop_null_keys: bool, + ) -> None: + self._df: DataFrameT = df + self._keys = keys + self._grouped = self._df._compliant_frame.group_by( + self._keys, drop_null_keys=drop_null_keys + ) + + def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT: + """Compute aggregations for each group of a group by operation. + + Arguments: + aggs: Aggregations to compute for each group of the group by operation, + specified as positional arguments. + named_aggs: Additional aggregations, specified as keyword arguments. + + Returns: + A new Dataframe. + + Examples: + Group by one column or by multiple columns and call `agg` to compute + the grouped sum of another column. + + >>> import pandas as pd + >>> import narwhals as nw + >>> df_native = pd.DataFrame( + ... { + ... "a": ["a", "b", "a", "b", "c"], + ... "b": [1, 2, 1, 3, 3], + ... "c": [5, 4, 3, 2, 1], + ... } + ... ) + >>> df = nw.from_native(df_native) + >>> + >>> df.group_by("a").agg(nw.col("b").sum()).sort("a") + ┌──────────────────┐ + |Narwhals DataFrame| + |------------------| + | a b | + | 0 a 2 | + | 1 b 5 | + | 2 c 3 | + └──────────────────┘ + >>> + >>> df.group_by("a", "b").agg(nw.col("c").sum()).sort("a", "b").to_native() + a b c + 0 a 1 8 + 1 b 2 4 + 2 b 3 2 + 3 c 3 1 + """ + flat_aggs = tuple(flatten(aggs)) + if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs): + msg = ( + "Found expression which does not aggregate.\n\n" + "All expressions passed to GroupBy.agg must aggregate.\n" + "For example, `df.group_by('a').agg(nw.col('b').sum())` is valid,\n" + "but `df.group_by('a').agg(nw.col('b'))` is not." + ) + raise InvalidOperationError(msg) + plx = self._df.__narwhals_namespace__() + compliant_aggs = ( + *(x._to_compliant_expr(plx) for x in flat_aggs), + *( + value.alias(key)._to_compliant_expr(plx) + for key, value in named_aggs.items() + ), + ) + return self._df._with_compliant(self._grouped.agg(*compliant_aggs)) + + def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: + yield from ( + (tupleify(key), self._df._with_compliant(df)) + for (key, df) in self._grouped.__iter__() + ) + + +class LazyGroupBy(Generic[LazyFrameT]): + def __init__( + self, + df: LazyFrameT, + keys: Sequence[str] | Sequence[CompliantExprAny], + /, + *, + drop_null_keys: bool, + ) -> None: + self._df: LazyFrameT = df + self._keys = keys + self._grouped = self._df._compliant_frame.group_by( + self._keys, drop_null_keys=drop_null_keys + ) + + def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT: + """Compute aggregations for each group of a group by operation. + + Arguments: + aggs: Aggregations to compute for each group of the group by operation, + specified as positional arguments. + named_aggs: Additional aggregations, specified as keyword arguments. + + Returns: + A new LazyFrame. + + Examples: + Group by one column or by multiple columns and call `agg` to compute + the grouped sum of another column. + + >>> import polars as pl + >>> import narwhals as nw + >>> from narwhals.typing import IntoFrameT + >>> lf_native = pl.LazyFrame( + ... { + ... "a": ["a", "b", "a", "b", "c"], + ... "b": [1, 2, 1, 3, 3], + ... "c": [5, 4, 3, 2, 1], + ... } + ... ) + >>> lf = nw.from_native(lf_native) + >>> + >>> nw.to_native(lf.group_by("a").agg(nw.col("b").sum()).sort("a")).collect() + shape: (3, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ str ┆ i64 │ + ╞═════╪═════╡ + │ a ┆ 2 │ + │ b ┆ 5 │ + │ c ┆ 3 │ + └─────┴─────┘ + >>> + >>> lf.group_by("a", "b").agg(nw.sum("c")).sort("a", "b").collect() + ┌───────────────────┐ + |Narwhals DataFrame | + |-------------------| + |shape: (4, 3) | + |┌─────┬─────┬─────┐| + |│ a ┆ b ┆ c │| + |│ --- ┆ --- ┆ --- │| + |│ str ┆ i64 ┆ i64 │| + |╞═════╪═════╪═════╡| + |│ a ┆ 1 ┆ 8 │| + |│ b ┆ 2 ┆ 4 │| + |│ b ┆ 3 ┆ 2 │| + |│ c ┆ 3 ┆ 1 │| + |└─────┴─────┴─────┘| + └───────────────────┘ + """ + flat_aggs = tuple(flatten(aggs)) + if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs): + msg = ( + "Found expression which does not aggregate.\n\n" + "All expressions passed to GroupBy.agg must aggregate.\n" + "For example, `df.group_by('a').agg(nw.col('b').sum())` is valid,\n" + "but `df.group_by('a').agg(nw.col('b'))` is not." + ) + raise InvalidOperationError(msg) + plx = self._df.__narwhals_namespace__() + compliant_aggs = ( + *(x._to_compliant_expr(plx) for x in flat_aggs), + *( + value.alias(key)._to_compliant_expr(plx) + for key, value in named_aggs.items() + ), + ) + return self._df._with_compliant(self._grouped.agg(*compliant_aggs)) |