aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/group_by.py
diff options
context:
space:
mode:
authorsotech117 <michael_foiani@brown.edu>2025-07-31 17:27:24 -0400
committersotech117 <michael_foiani@brown.edu>2025-07-31 17:27:24 -0400
commit5bf22fc7e3c392c8bd44315ca2d06d7dca7d084e (patch)
tree8dacb0f195df1c0788d36dd0064f6bbaa3143ede /venv/lib/python3.8/site-packages/narwhals/group_by.py
parentb832d364da8c2efe09e3f75828caf73c50d01ce3 (diff)
add code for analysis of data
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/group_by.py')
-rw-r--r--venv/lib/python3.8/site-packages/narwhals/group_by.py190
1 files changed, 190 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/group_by.py b/venv/lib/python3.8/site-packages/narwhals/group_by.py
new file mode 100644
index 0000000..6a06a17
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/narwhals/group_by.py
@@ -0,0 +1,190 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Generic, Iterable, Iterator, Sequence, TypeVar
+
+from narwhals._expression_parsing import all_exprs_are_scalar_like
+from narwhals._utils import flatten, tupleify
+from narwhals.exceptions import InvalidOperationError
+from narwhals.typing import DataFrameT
+
+if TYPE_CHECKING:
+ from narwhals._compliant.typing import CompliantExprAny
+ from narwhals.dataframe import LazyFrame
+ from narwhals.expr import Expr
+
+LazyFrameT = TypeVar("LazyFrameT", bound="LazyFrame[Any]")
+
+
+class GroupBy(Generic[DataFrameT]):
+ def __init__(
+ self,
+ df: DataFrameT,
+ keys: Sequence[str] | Sequence[CompliantExprAny],
+ /,
+ *,
+ drop_null_keys: bool,
+ ) -> None:
+ self._df: DataFrameT = df
+ self._keys = keys
+ self._grouped = self._df._compliant_frame.group_by(
+ self._keys, drop_null_keys=drop_null_keys
+ )
+
+ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT:
+ """Compute aggregations for each group of a group by operation.
+
+ Arguments:
+ aggs: Aggregations to compute for each group of the group by operation,
+ specified as positional arguments.
+ named_aggs: Additional aggregations, specified as keyword arguments.
+
+ Returns:
+ A new Dataframe.
+
+ Examples:
+ Group by one column or by multiple columns and call `agg` to compute
+ the grouped sum of another column.
+
+ >>> import pandas as pd
+ >>> import narwhals as nw
+ >>> df_native = pd.DataFrame(
+ ... {
+ ... "a": ["a", "b", "a", "b", "c"],
+ ... "b": [1, 2, 1, 3, 3],
+ ... "c": [5, 4, 3, 2, 1],
+ ... }
+ ... )
+ >>> df = nw.from_native(df_native)
+ >>>
+ >>> df.group_by("a").agg(nw.col("b").sum()).sort("a")
+ ┌──────────────────┐
+ |Narwhals DataFrame|
+ |------------------|
+ | a b |
+ | 0 a 2 |
+ | 1 b 5 |
+ | 2 c 3 |
+ └──────────────────┘
+ >>>
+ >>> df.group_by("a", "b").agg(nw.col("c").sum()).sort("a", "b").to_native()
+ a b c
+ 0 a 1 8
+ 1 b 2 4
+ 2 b 3 2
+ 3 c 3 1
+ """
+ flat_aggs = tuple(flatten(aggs))
+ if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs):
+ msg = (
+ "Found expression which does not aggregate.\n\n"
+ "All expressions passed to GroupBy.agg must aggregate.\n"
+ "For example, `df.group_by('a').agg(nw.col('b').sum())` is valid,\n"
+ "but `df.group_by('a').agg(nw.col('b'))` is not."
+ )
+ raise InvalidOperationError(msg)
+ plx = self._df.__narwhals_namespace__()
+ compliant_aggs = (
+ *(x._to_compliant_expr(plx) for x in flat_aggs),
+ *(
+ value.alias(key)._to_compliant_expr(plx)
+ for key, value in named_aggs.items()
+ ),
+ )
+ return self._df._with_compliant(self._grouped.agg(*compliant_aggs))
+
+ def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]:
+ yield from (
+ (tupleify(key), self._df._with_compliant(df))
+ for (key, df) in self._grouped.__iter__()
+ )
+
+
+class LazyGroupBy(Generic[LazyFrameT]):
+ def __init__(
+ self,
+ df: LazyFrameT,
+ keys: Sequence[str] | Sequence[CompliantExprAny],
+ /,
+ *,
+ drop_null_keys: bool,
+ ) -> None:
+ self._df: LazyFrameT = df
+ self._keys = keys
+ self._grouped = self._df._compliant_frame.group_by(
+ self._keys, drop_null_keys=drop_null_keys
+ )
+
+ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT:
+ """Compute aggregations for each group of a group by operation.
+
+ Arguments:
+ aggs: Aggregations to compute for each group of the group by operation,
+ specified as positional arguments.
+ named_aggs: Additional aggregations, specified as keyword arguments.
+
+ Returns:
+ A new LazyFrame.
+
+ Examples:
+ Group by one column or by multiple columns and call `agg` to compute
+ the grouped sum of another column.
+
+ >>> import polars as pl
+ >>> import narwhals as nw
+ >>> from narwhals.typing import IntoFrameT
+ >>> lf_native = pl.LazyFrame(
+ ... {
+ ... "a": ["a", "b", "a", "b", "c"],
+ ... "b": [1, 2, 1, 3, 3],
+ ... "c": [5, 4, 3, 2, 1],
+ ... }
+ ... )
+ >>> lf = nw.from_native(lf_native)
+ >>>
+ >>> nw.to_native(lf.group_by("a").agg(nw.col("b").sum()).sort("a")).collect()
+ shape: (3, 2)
+ ┌─────┬─────┐
+ │ a ┆ b │
+ │ --- ┆ --- │
+ │ str ┆ i64 │
+ ╞═════╪═════╡
+ │ a ┆ 2 │
+ │ b ┆ 5 │
+ │ c ┆ 3 │
+ └─────┴─────┘
+ >>>
+ >>> lf.group_by("a", "b").agg(nw.sum("c")).sort("a", "b").collect()
+ ┌───────────────────┐
+ |Narwhals DataFrame |
+ |-------------------|
+ |shape: (4, 3) |
+ |┌─────┬─────┬─────┐|
+ |│ a ┆ b ┆ c │|
+ |│ --- ┆ --- ┆ --- │|
+ |│ str ┆ i64 ┆ i64 │|
+ |╞═════╪═════╪═════╡|
+ |│ a ┆ 1 ┆ 8 │|
+ |│ b ┆ 2 ┆ 4 │|
+ |│ b ┆ 3 ┆ 2 │|
+ |│ c ┆ 3 ┆ 1 │|
+ |└─────┴─────┴─────┘|
+ └───────────────────┘
+ """
+ flat_aggs = tuple(flatten(aggs))
+ if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs):
+ msg = (
+ "Found expression which does not aggregate.\n\n"
+ "All expressions passed to GroupBy.agg must aggregate.\n"
+ "For example, `df.group_by('a').agg(nw.col('b').sum())` is valid,\n"
+ "but `df.group_by('a').agg(nw.col('b'))` is not."
+ )
+ raise InvalidOperationError(msg)
+ plx = self._df.__narwhals_namespace__()
+ compliant_aggs = (
+ *(x._to_compliant_expr(plx) for x in flat_aggs),
+ *(
+ value.alias(key)._to_compliant_expr(plx)
+ for key, value in named_aggs.items()
+ ),
+ )
+ return self._df._with_compliant(self._grouped.agg(*compliant_aggs))