aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_compliant/dataframe.py
diff options
context:
space:
mode:
authorsotech117 <michael_foiani@brown.edu>2025-07-31 17:27:24 -0400
committersotech117 <michael_foiani@brown.edu>2025-07-31 17:27:24 -0400
commit5bf22fc7e3c392c8bd44315ca2d06d7dca7d084e (patch)
tree8dacb0f195df1c0788d36dd0064f6bbaa3143ede /venv/lib/python3.8/site-packages/narwhals/_compliant/dataframe.py
parentb832d364da8c2efe09e3f75828caf73c50d01ce3 (diff)
add code for analysis of data
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_compliant/dataframe.py')
-rw-r--r--venv/lib/python3.8/site-packages/narwhals/_compliant/dataframe.py500
1 files changed, 500 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_compliant/dataframe.py b/venv/lib/python3.8/site-packages/narwhals/_compliant/dataframe.py
new file mode 100644
index 0000000..5f21055
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/narwhals/_compliant/dataframe.py
@@ -0,0 +1,500 @@
+from __future__ import annotations
+
+from itertools import chain
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Iterator,
+ Literal,
+ Mapping,
+ Protocol,
+ Sequence,
+ Sized,
+ TypeVar,
+ overload,
+)
+
+from narwhals._compliant.typing import (
+ CompliantDataFrameAny,
+ CompliantExprT_contra,
+ CompliantLazyFrameAny,
+ CompliantSeriesT,
+ EagerExprT,
+ EagerSeriesT,
+ NativeExprT,
+ NativeFrameT,
+)
+from narwhals._translate import (
+ ArrowConvertible,
+ DictConvertible,
+ FromNative,
+ NumpyConvertible,
+ ToNarwhals,
+ ToNarwhalsT_co,
+)
+from narwhals._typing_compat import deprecated
+from narwhals._utils import (
+ Version,
+ _StoresNative,
+ check_columns_exist,
+ is_compliant_series,
+ is_index_selector,
+ is_range,
+ is_sequence_like,
+ is_sized_multi_index_selector,
+ is_slice_index,
+ is_slice_none,
+)
+
+if TYPE_CHECKING:
+ from io import BytesIO
+ from pathlib import Path
+
+ import pandas as pd
+ import polars as pl
+ import pyarrow as pa
+ from typing_extensions import Self, TypeAlias
+
+ from narwhals._compliant.expr import LazyExpr
+ from narwhals._compliant.group_by import CompliantGroupBy, DataFrameGroupBy
+ from narwhals._compliant.namespace import EagerNamespace
+ from narwhals._compliant.window import WindowInputs
+ from narwhals._translate import IntoArrowTable
+ from narwhals._utils import Implementation, _FullContext
+ from narwhals.dataframe import DataFrame
+ from narwhals.dtypes import DType
+ from narwhals.exceptions import ColumnNotFoundError
+ from narwhals.schema import Schema
+ from narwhals.typing import (
+ AsofJoinStrategy,
+ JoinStrategy,
+ LazyUniqueKeepStrategy,
+ MultiColSelector,
+ MultiIndexSelector,
+ PivotAgg,
+ SingleIndexSelector,
+ SizedMultiIndexSelector,
+ SizedMultiNameSelector,
+ SizeUnit,
+ UniqueKeepStrategy,
+ _2DArray,
+ _SliceIndex,
+ _SliceName,
+ )
+
+ Incomplete: TypeAlias = Any
+
+__all__ = ["CompliantDataFrame", "CompliantLazyFrame", "EagerDataFrame"]
+
+T = TypeVar("T")
+
+_ToDict: TypeAlias = "dict[str, CompliantSeriesT] | dict[str, list[Any]]" # noqa: PYI047
+
+
+class CompliantDataFrame(
+ NumpyConvertible["_2DArray", "_2DArray"],
+ DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]],
+ ArrowConvertible["pa.Table", "IntoArrowTable"],
+ _StoresNative[NativeFrameT],
+ FromNative[NativeFrameT],
+ ToNarwhals[ToNarwhalsT_co],
+ Sized,
+ Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co],
+):
+ _native_frame: NativeFrameT
+ _implementation: Implementation
+ _backend_version: tuple[int, ...]
+ _version: Version
+
+ def __narwhals_dataframe__(self) -> Self: ...
+ def __narwhals_namespace__(self) -> Any: ...
+ @classmethod
+ def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self: ...
+ @classmethod
+ def from_dict(
+ cls,
+ data: Mapping[str, Any],
+ /,
+ *,
+ context: _FullContext,
+ schema: Mapping[str, DType] | Schema | None,
+ ) -> Self: ...
+ @classmethod
+ def from_native(cls, data: NativeFrameT, /, *, context: _FullContext) -> Self: ...
+ @classmethod
+ def from_numpy(
+ cls,
+ data: _2DArray,
+ /,
+ *,
+ context: _FullContext,
+ schema: Mapping[str, DType] | Schema | Sequence[str] | None,
+ ) -> Self: ...
+
+ def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ...
+ def __getitem__(
+ self,
+ item: tuple[
+ SingleIndexSelector | MultiIndexSelector[CompliantSeriesT],
+ MultiColSelector[CompliantSeriesT],
+ ],
+ ) -> Self: ...
+ def simple_select(self, *column_names: str) -> Self:
+ """`select` where all args are column names."""
+ ...
+
+ def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
+ """`select` where all args are aggregations or literals.
+
+ (so, no broadcasting is necessary).
+ """
+ # NOTE: Ignore is to avoid an intermittent false positive
+ return self.select(*exprs) # pyright: ignore[reportArgumentType]
+
+ def _with_version(self, version: Version) -> Self: ...
+
+ @property
+ def native(self) -> NativeFrameT:
+ return self._native_frame
+
+ @property
+ def columns(self) -> Sequence[str]: ...
+ @property
+ def schema(self) -> Mapping[str, DType]: ...
+ @property
+ def shape(self) -> tuple[int, int]: ...
+ def clone(self) -> Self: ...
+ def collect(
+ self, backend: Implementation | None, **kwargs: Any
+ ) -> CompliantDataFrameAny: ...
+ def collect_schema(self) -> Mapping[str, DType]: ...
+ def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
+ def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
+ def estimated_size(self, unit: SizeUnit) -> int | float: ...
+ def explode(self, columns: Sequence[str]) -> Self: ...
+ def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
+ def gather_every(self, n: int, offset: int) -> Self: ...
+ def get_column(self, name: str) -> CompliantSeriesT: ...
+ def group_by(
+ self,
+ keys: Sequence[str] | Sequence[CompliantExprT_contra],
+ *,
+ drop_null_keys: bool,
+ ) -> DataFrameGroupBy[Self, Any]: ...
+ def head(self, n: int) -> Self: ...
+ def item(self, row: int | None, column: int | str | None) -> Any: ...
+ def iter_columns(self) -> Iterator[CompliantSeriesT]: ...
+ def iter_rows(
+ self, *, named: bool, buffer_size: int
+ ) -> Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]: ...
+ def is_unique(self) -> CompliantSeriesT: ...
+ def join(
+ self,
+ other: Self,
+ *,
+ how: JoinStrategy,
+ left_on: Sequence[str] | None,
+ right_on: Sequence[str] | None,
+ suffix: str,
+ ) -> Self: ...
+ def join_asof(
+ self,
+ other: Self,
+ *,
+ left_on: str,
+ right_on: str,
+ by_left: Sequence[str] | None,
+ by_right: Sequence[str] | None,
+ strategy: AsofJoinStrategy,
+ suffix: str,
+ ) -> Self: ...
+ def lazy(self, *, backend: Implementation | None) -> CompliantLazyFrameAny: ...
+ def pivot(
+ self,
+ on: Sequence[str],
+ *,
+ index: Sequence[str] | None,
+ values: Sequence[str] | None,
+ aggregate_function: PivotAgg | None,
+ sort_columns: bool,
+ separator: str,
+ ) -> Self: ...
+ def rename(self, mapping: Mapping[str, str]) -> Self: ...
+ def row(self, index: int) -> tuple[Any, ...]: ...
+ def rows(
+ self, *, named: bool
+ ) -> Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]: ...
+ def sample(
+ self,
+ n: int | None,
+ *,
+ fraction: float | None,
+ with_replacement: bool,
+ seed: int | None,
+ ) -> Self: ...
+ def select(self, *exprs: CompliantExprT_contra) -> Self: ...
+ def sort(
+ self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
+ ) -> Self: ...
+ def tail(self, n: int) -> Self: ...
+ def to_arrow(self) -> pa.Table: ...
+ def to_pandas(self) -> pd.DataFrame: ...
+ def to_polars(self) -> pl.DataFrame: ...
+ @overload
+ def to_dict(self, *, as_series: Literal[True]) -> dict[str, CompliantSeriesT]: ...
+ @overload
+ def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
+ def to_dict(
+ self, *, as_series: bool
+ ) -> dict[str, CompliantSeriesT] | dict[str, list[Any]]: ...
+ def unique(
+ self,
+ subset: Sequence[str] | None,
+ *,
+ keep: UniqueKeepStrategy,
+ maintain_order: bool | None = None,
+ ) -> Self: ...
+ def unpivot(
+ self,
+ on: Sequence[str] | None,
+ index: Sequence[str] | None,
+ variable_name: str,
+ value_name: str,
+ ) -> Self: ...
+ def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
+ def with_row_index(self, name: str) -> Self: ...
+ @overload
+ def write_csv(self, file: None) -> str: ...
+ @overload
+ def write_csv(self, file: str | Path | BytesIO) -> None: ...
+ def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ...
+ def write_parquet(self, file: str | Path | BytesIO) -> None: ...
+
+ def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
+ it = (expr._evaluate_aliases(self) for expr in exprs)
+ return list(chain.from_iterable(it))
+
+ def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
+ return check_columns_exist(subset, available=self.columns)
+
+
+class CompliantLazyFrame(
+ _StoresNative[NativeFrameT],
+ FromNative[NativeFrameT],
+ ToNarwhals[ToNarwhalsT_co],
+ Protocol[CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co],
+):
+ _native_frame: NativeFrameT
+ _implementation: Implementation
+ _backend_version: tuple[int, ...]
+ _version: Version
+
+ def __narwhals_lazyframe__(self) -> Self: ...
+ def __narwhals_namespace__(self) -> Any: ...
+
+ @classmethod
+ def from_native(cls, data: NativeFrameT, /, *, context: _FullContext) -> Self: ...
+
+ def simple_select(self, *column_names: str) -> Self:
+ """`select` where all args are column names."""
+ ...
+
+ def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
+ """`select` where all args are aggregations or literals.
+
+ (so, no broadcasting is necessary).
+ """
+ ...
+
+ def _with_version(self, version: Version) -> Self: ...
+
+ @property
+ def native(self) -> NativeFrameT:
+ return self._native_frame
+
+ @property
+ def columns(self) -> Sequence[str]: ...
+ @property
+ def schema(self) -> Mapping[str, DType]: ...
+ def _iter_columns(self) -> Iterator[Any]: ...
+ def collect(
+ self, backend: Implementation | None, **kwargs: Any
+ ) -> CompliantDataFrameAny: ...
+ def collect_schema(self) -> Mapping[str, DType]: ...
+ def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
+ def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
+ def explode(self, columns: Sequence[str]) -> Self: ...
+ def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
+ @deprecated(
+ "`LazyFrame.gather_every` is deprecated and will be removed in a future version."
+ )
+ def gather_every(self, n: int, offset: int) -> Self: ...
+ def group_by(
+ self,
+ keys: Sequence[str] | Sequence[CompliantExprT_contra],
+ *,
+ drop_null_keys: bool,
+ ) -> CompliantGroupBy[Self, CompliantExprT_contra]: ...
+ def head(self, n: int) -> Self: ...
+ def join(
+ self,
+ other: Self,
+ *,
+ how: Literal["left", "inner", "cross", "anti", "semi"],
+ left_on: Sequence[str] | None,
+ right_on: Sequence[str] | None,
+ suffix: str,
+ ) -> Self: ...
+ def join_asof(
+ self,
+ other: Self,
+ *,
+ left_on: str,
+ right_on: str,
+ by_left: Sequence[str] | None,
+ by_right: Sequence[str] | None,
+ strategy: AsofJoinStrategy,
+ suffix: str,
+ ) -> Self: ...
+ def rename(self, mapping: Mapping[str, str]) -> Self: ...
+ def select(self, *exprs: CompliantExprT_contra) -> Self: ...
+ def sort(
+ self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
+ ) -> Self: ...
+ @deprecated("`LazyFrame.tail` is deprecated and will be removed in a future version.")
+ def tail(self, n: int) -> Self: ...
+ def unique(
+ self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
+ ) -> Self: ...
+ def unpivot(
+ self,
+ on: Sequence[str] | None,
+ index: Sequence[str] | None,
+ variable_name: str,
+ value_name: str,
+ ) -> Self: ...
+ def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
+ def with_row_index(self, name: str) -> Self: ...
+ def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any:
+ result = expr(self)
+ assert len(result) == 1 # debug assertion # noqa: S101
+ return result[0]
+
+ def _evaluate_window_expr(
+ self,
+ expr: LazyExpr[Self, NativeExprT],
+ /,
+ window_inputs: WindowInputs[NativeExprT],
+ ) -> NativeExprT:
+ result = expr.window_function(self, window_inputs)
+ assert len(result) == 1 # debug assertion # noqa: S101
+ return result[0]
+
+ def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
+ it = (expr._evaluate_aliases(self) for expr in exprs)
+ return list(chain.from_iterable(it))
+
+ def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
+ return check_columns_exist(subset, available=self.columns)
+
+
+class EagerDataFrame(
+ CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
+ CompliantLazyFrame[EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
+ Protocol[EagerSeriesT, EagerExprT, NativeFrameT],
+):
+ def __narwhals_namespace__(
+ self,
+ ) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT]: ...
+
+ def to_narwhals(self) -> DataFrame[NativeFrameT]:
+ return self._version.dataframe(self, level="full")
+
+ def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT:
+ """Evaluate `expr` and ensure it has a **single** output."""
+ result: Sequence[EagerSeriesT] = expr(self)
+ assert len(result) == 1 # debug assertion # noqa: S101
+ return result[0]
+
+ def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]:
+ # NOTE: Ignore is to avoid an intermittent false positive
+ return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType]
+
+ def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]:
+ """Return list of raw columns.
+
+ For eager backends we alias operations at each step.
+
+ As a safety precaution, here we can check that the expected result names match those
+ we were expecting from the various `evaluate_output_names` / `alias_output_names` calls.
+
+ Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want.
+ """
+ aliases = expr._evaluate_aliases(self)
+ result = expr(self)
+ if list(aliases) != (
+ result_aliases := [s.name for s in result]
+ ): # pragma: no cover
+ msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}"
+ raise AssertionError(msg)
+ return result
+
+ def _extract_comparand(self, other: EagerSeriesT, /) -> Any:
+ """Extract native Series, broadcasting to `len(self)` if necessary."""
+ ...
+
+ @staticmethod
+ def _numpy_column_names(
+ data: _2DArray, columns: Sequence[str] | None, /
+ ) -> list[str]:
+ return list(columns or (f"column_{x}" for x in range(data.shape[1])))
+
+ def _gather(self, rows: SizedMultiIndexSelector[Any]) -> Self: ...
+ def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
+ def _select_multi_index(self, columns: SizedMultiIndexSelector[Any]) -> Self: ...
+ def _select_multi_name(self, columns: SizedMultiNameSelector[Any]) -> Self: ...
+ def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ...
+ def _select_slice_name(self, columns: _SliceName) -> Self: ...
+ def __getitem__( # noqa: C901, PLR0912
+ self,
+ item: tuple[
+ SingleIndexSelector | MultiIndexSelector[EagerSeriesT],
+ MultiColSelector[EagerSeriesT],
+ ],
+ ) -> Self:
+ rows, columns = item
+ compliant = self
+ if not is_slice_none(columns):
+ if isinstance(columns, Sized) and len(columns) == 0:
+ return compliant.select()
+ if is_index_selector(columns):
+ if is_slice_index(columns) or is_range(columns):
+ compliant = compliant._select_slice_index(columns)
+ elif is_compliant_series(columns):
+ compliant = self._select_multi_index(columns.native)
+ else:
+ compliant = compliant._select_multi_index(columns)
+ elif isinstance(columns, slice):
+ compliant = compliant._select_slice_name(columns)
+ elif is_compliant_series(columns):
+ compliant = self._select_multi_name(columns.native)
+ elif is_sequence_like(columns):
+ compliant = self._select_multi_name(columns)
+ else: # pragma: no cover
+ msg = f"Unreachable code, got unexpected type: {type(columns)}"
+ raise AssertionError(msg)
+
+ if not is_slice_none(rows):
+ if isinstance(rows, int):
+ compliant = compliant._gather([rows])
+ elif isinstance(rows, (slice, range)):
+ compliant = compliant._gather_slice(rows)
+ elif is_compliant_series(rows):
+ compliant = compliant._gather(rows.native)
+ elif is_sized_multi_index_selector(rows):
+ compliant = compliant._gather(rows)
+ else: # pragma: no cover
+ msg = f"Unreachable code, got unexpected type: {type(rows)}"
+ raise AssertionError(msg)
+
+ return compliant