diff options
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_compliant/dataframe.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/narwhals/_compliant/dataframe.py | 500 |
1 files changed, 500 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_compliant/dataframe.py b/venv/lib/python3.8/site-packages/narwhals/_compliant/dataframe.py new file mode 100644 index 0000000..5f21055 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_compliant/dataframe.py @@ -0,0 +1,500 @@ +from __future__ import annotations + +from itertools import chain +from typing import ( + TYPE_CHECKING, + Any, + Iterator, + Literal, + Mapping, + Protocol, + Sequence, + Sized, + TypeVar, + overload, +) + +from narwhals._compliant.typing import ( + CompliantDataFrameAny, + CompliantExprT_contra, + CompliantLazyFrameAny, + CompliantSeriesT, + EagerExprT, + EagerSeriesT, + NativeExprT, + NativeFrameT, +) +from narwhals._translate import ( + ArrowConvertible, + DictConvertible, + FromNative, + NumpyConvertible, + ToNarwhals, + ToNarwhalsT_co, +) +from narwhals._typing_compat import deprecated +from narwhals._utils import ( + Version, + _StoresNative, + check_columns_exist, + is_compliant_series, + is_index_selector, + is_range, + is_sequence_like, + is_sized_multi_index_selector, + is_slice_index, + is_slice_none, +) + +if TYPE_CHECKING: + from io import BytesIO + from pathlib import Path + + import pandas as pd + import polars as pl + import pyarrow as pa + from typing_extensions import Self, TypeAlias + + from narwhals._compliant.expr import LazyExpr + from narwhals._compliant.group_by import CompliantGroupBy, DataFrameGroupBy + from narwhals._compliant.namespace import EagerNamespace + from narwhals._compliant.window import WindowInputs + from narwhals._translate import IntoArrowTable + from narwhals._utils import Implementation, _FullContext + from narwhals.dataframe import DataFrame + from narwhals.dtypes import DType + from narwhals.exceptions import ColumnNotFoundError + from narwhals.schema import Schema + from narwhals.typing import ( + AsofJoinStrategy, + JoinStrategy, + LazyUniqueKeepStrategy, + MultiColSelector, + MultiIndexSelector, + PivotAgg, + SingleIndexSelector, + SizedMultiIndexSelector, + SizedMultiNameSelector, + SizeUnit, + UniqueKeepStrategy, + _2DArray, + _SliceIndex, + _SliceName, + ) + + Incomplete: TypeAlias = Any + +__all__ = ["CompliantDataFrame", "CompliantLazyFrame", "EagerDataFrame"] + +T = TypeVar("T") + +_ToDict: TypeAlias = "dict[str, CompliantSeriesT] | dict[str, list[Any]]" # noqa: PYI047 + + +class CompliantDataFrame( + NumpyConvertible["_2DArray", "_2DArray"], + DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]], + ArrowConvertible["pa.Table", "IntoArrowTable"], + _StoresNative[NativeFrameT], + FromNative[NativeFrameT], + ToNarwhals[ToNarwhalsT_co], + Sized, + Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co], +): + _native_frame: NativeFrameT + _implementation: Implementation + _backend_version: tuple[int, ...] + _version: Version + + def __narwhals_dataframe__(self) -> Self: ... + def __narwhals_namespace__(self) -> Any: ... + @classmethod + def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self: ... + @classmethod + def from_dict( + cls, + data: Mapping[str, Any], + /, + *, + context: _FullContext, + schema: Mapping[str, DType] | Schema | None, + ) -> Self: ... + @classmethod + def from_native(cls, data: NativeFrameT, /, *, context: _FullContext) -> Self: ... + @classmethod + def from_numpy( + cls, + data: _2DArray, + /, + *, + context: _FullContext, + schema: Mapping[str, DType] | Schema | Sequence[str] | None, + ) -> Self: ... + + def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ... + def __getitem__( + self, + item: tuple[ + SingleIndexSelector | MultiIndexSelector[CompliantSeriesT], + MultiColSelector[CompliantSeriesT], + ], + ) -> Self: ... + def simple_select(self, *column_names: str) -> Self: + """`select` where all args are column names.""" + ... + + def aggregate(self, *exprs: CompliantExprT_contra) -> Self: + """`select` where all args are aggregations or literals. + + (so, no broadcasting is necessary). + """ + # NOTE: Ignore is to avoid an intermittent false positive + return self.select(*exprs) # pyright: ignore[reportArgumentType] + + def _with_version(self, version: Version) -> Self: ... + + @property + def native(self) -> NativeFrameT: + return self._native_frame + + @property + def columns(self) -> Sequence[str]: ... + @property + def schema(self) -> Mapping[str, DType]: ... + @property + def shape(self) -> tuple[int, int]: ... + def clone(self) -> Self: ... + def collect( + self, backend: Implementation | None, **kwargs: Any + ) -> CompliantDataFrameAny: ... + def collect_schema(self) -> Mapping[str, DType]: ... + def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ... + def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... + def estimated_size(self, unit: SizeUnit) -> int | float: ... + def explode(self, columns: Sequence[str]) -> Self: ... + def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ... + def gather_every(self, n: int, offset: int) -> Self: ... + def get_column(self, name: str) -> CompliantSeriesT: ... + def group_by( + self, + keys: Sequence[str] | Sequence[CompliantExprT_contra], + *, + drop_null_keys: bool, + ) -> DataFrameGroupBy[Self, Any]: ... + def head(self, n: int) -> Self: ... + def item(self, row: int | None, column: int | str | None) -> Any: ... + def iter_columns(self) -> Iterator[CompliantSeriesT]: ... + def iter_rows( + self, *, named: bool, buffer_size: int + ) -> Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]: ... + def is_unique(self) -> CompliantSeriesT: ... + def join( + self, + other: Self, + *, + how: JoinStrategy, + left_on: Sequence[str] | None, + right_on: Sequence[str] | None, + suffix: str, + ) -> Self: ... + def join_asof( + self, + other: Self, + *, + left_on: str, + right_on: str, + by_left: Sequence[str] | None, + by_right: Sequence[str] | None, + strategy: AsofJoinStrategy, + suffix: str, + ) -> Self: ... + def lazy(self, *, backend: Implementation | None) -> CompliantLazyFrameAny: ... + def pivot( + self, + on: Sequence[str], + *, + index: Sequence[str] | None, + values: Sequence[str] | None, + aggregate_function: PivotAgg | None, + sort_columns: bool, + separator: str, + ) -> Self: ... + def rename(self, mapping: Mapping[str, str]) -> Self: ... + def row(self, index: int) -> tuple[Any, ...]: ... + def rows( + self, *, named: bool + ) -> Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]: ... + def sample( + self, + n: int | None, + *, + fraction: float | None, + with_replacement: bool, + seed: int | None, + ) -> Self: ... + def select(self, *exprs: CompliantExprT_contra) -> Self: ... + def sort( + self, *by: str, descending: bool | Sequence[bool], nulls_last: bool + ) -> Self: ... + def tail(self, n: int) -> Self: ... + def to_arrow(self) -> pa.Table: ... + def to_pandas(self) -> pd.DataFrame: ... + def to_polars(self) -> pl.DataFrame: ... + @overload + def to_dict(self, *, as_series: Literal[True]) -> dict[str, CompliantSeriesT]: ... + @overload + def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + def to_dict( + self, *, as_series: bool + ) -> dict[str, CompliantSeriesT] | dict[str, list[Any]]: ... + def unique( + self, + subset: Sequence[str] | None, + *, + keep: UniqueKeepStrategy, + maintain_order: bool | None = None, + ) -> Self: ... + def unpivot( + self, + on: Sequence[str] | None, + index: Sequence[str] | None, + variable_name: str, + value_name: str, + ) -> Self: ... + def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ... + def with_row_index(self, name: str) -> Self: ... + @overload + def write_csv(self, file: None) -> str: ... + @overload + def write_csv(self, file: str | Path | BytesIO) -> None: ... + def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ... + def write_parquet(self, file: str | Path | BytesIO) -> None: ... + + def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]: + it = (expr._evaluate_aliases(self) for expr in exprs) + return list(chain.from_iterable(it)) + + def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None: + return check_columns_exist(subset, available=self.columns) + + +class CompliantLazyFrame( + _StoresNative[NativeFrameT], + FromNative[NativeFrameT], + ToNarwhals[ToNarwhalsT_co], + Protocol[CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co], +): + _native_frame: NativeFrameT + _implementation: Implementation + _backend_version: tuple[int, ...] + _version: Version + + def __narwhals_lazyframe__(self) -> Self: ... + def __narwhals_namespace__(self) -> Any: ... + + @classmethod + def from_native(cls, data: NativeFrameT, /, *, context: _FullContext) -> Self: ... + + def simple_select(self, *column_names: str) -> Self: + """`select` where all args are column names.""" + ... + + def aggregate(self, *exprs: CompliantExprT_contra) -> Self: + """`select` where all args are aggregations or literals. + + (so, no broadcasting is necessary). + """ + ... + + def _with_version(self, version: Version) -> Self: ... + + @property + def native(self) -> NativeFrameT: + return self._native_frame + + @property + def columns(self) -> Sequence[str]: ... + @property + def schema(self) -> Mapping[str, DType]: ... + def _iter_columns(self) -> Iterator[Any]: ... + def collect( + self, backend: Implementation | None, **kwargs: Any + ) -> CompliantDataFrameAny: ... + def collect_schema(self) -> Mapping[str, DType]: ... + def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ... + def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... + def explode(self, columns: Sequence[str]) -> Self: ... + def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ... + @deprecated( + "`LazyFrame.gather_every` is deprecated and will be removed in a future version." + ) + def gather_every(self, n: int, offset: int) -> Self: ... + def group_by( + self, + keys: Sequence[str] | Sequence[CompliantExprT_contra], + *, + drop_null_keys: bool, + ) -> CompliantGroupBy[Self, CompliantExprT_contra]: ... + def head(self, n: int) -> Self: ... + def join( + self, + other: Self, + *, + how: Literal["left", "inner", "cross", "anti", "semi"], + left_on: Sequence[str] | None, + right_on: Sequence[str] | None, + suffix: str, + ) -> Self: ... + def join_asof( + self, + other: Self, + *, + left_on: str, + right_on: str, + by_left: Sequence[str] | None, + by_right: Sequence[str] | None, + strategy: AsofJoinStrategy, + suffix: str, + ) -> Self: ... + def rename(self, mapping: Mapping[str, str]) -> Self: ... + def select(self, *exprs: CompliantExprT_contra) -> Self: ... + def sort( + self, *by: str, descending: bool | Sequence[bool], nulls_last: bool + ) -> Self: ... + @deprecated("`LazyFrame.tail` is deprecated and will be removed in a future version.") + def tail(self, n: int) -> Self: ... + def unique( + self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy + ) -> Self: ... + def unpivot( + self, + on: Sequence[str] | None, + index: Sequence[str] | None, + variable_name: str, + value_name: str, + ) -> Self: ... + def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ... + def with_row_index(self, name: str) -> Self: ... + def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any: + result = expr(self) + assert len(result) == 1 # debug assertion # noqa: S101 + return result[0] + + def _evaluate_window_expr( + self, + expr: LazyExpr[Self, NativeExprT], + /, + window_inputs: WindowInputs[NativeExprT], + ) -> NativeExprT: + result = expr.window_function(self, window_inputs) + assert len(result) == 1 # debug assertion # noqa: S101 + return result[0] + + def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]: + it = (expr._evaluate_aliases(self) for expr in exprs) + return list(chain.from_iterable(it)) + + def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None: + return check_columns_exist(subset, available=self.columns) + + +class EagerDataFrame( + CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"], + CompliantLazyFrame[EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"], + Protocol[EagerSeriesT, EagerExprT, NativeFrameT], +): + def __narwhals_namespace__( + self, + ) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT]: ... + + def to_narwhals(self) -> DataFrame[NativeFrameT]: + return self._version.dataframe(self, level="full") + + def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT: + """Evaluate `expr` and ensure it has a **single** output.""" + result: Sequence[EagerSeriesT] = expr(self) + assert len(result) == 1 # debug assertion # noqa: S101 + return result[0] + + def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]: + # NOTE: Ignore is to avoid an intermittent false positive + return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType] + + def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]: + """Return list of raw columns. + + For eager backends we alias operations at each step. + + As a safety precaution, here we can check that the expected result names match those + we were expecting from the various `evaluate_output_names` / `alias_output_names` calls. + + Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want. + """ + aliases = expr._evaluate_aliases(self) + result = expr(self) + if list(aliases) != ( + result_aliases := [s.name for s in result] + ): # pragma: no cover + msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}" + raise AssertionError(msg) + return result + + def _extract_comparand(self, other: EagerSeriesT, /) -> Any: + """Extract native Series, broadcasting to `len(self)` if necessary.""" + ... + + @staticmethod + def _numpy_column_names( + data: _2DArray, columns: Sequence[str] | None, / + ) -> list[str]: + return list(columns or (f"column_{x}" for x in range(data.shape[1]))) + + def _gather(self, rows: SizedMultiIndexSelector[Any]) -> Self: ... + def _gather_slice(self, rows: _SliceIndex | range) -> Self: ... + def _select_multi_index(self, columns: SizedMultiIndexSelector[Any]) -> Self: ... + def _select_multi_name(self, columns: SizedMultiNameSelector[Any]) -> Self: ... + def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ... + def _select_slice_name(self, columns: _SliceName) -> Self: ... + def __getitem__( # noqa: C901, PLR0912 + self, + item: tuple[ + SingleIndexSelector | MultiIndexSelector[EagerSeriesT], + MultiColSelector[EagerSeriesT], + ], + ) -> Self: + rows, columns = item + compliant = self + if not is_slice_none(columns): + if isinstance(columns, Sized) and len(columns) == 0: + return compliant.select() + if is_index_selector(columns): + if is_slice_index(columns) or is_range(columns): + compliant = compliant._select_slice_index(columns) + elif is_compliant_series(columns): + compliant = self._select_multi_index(columns.native) + else: + compliant = compliant._select_multi_index(columns) + elif isinstance(columns, slice): + compliant = compliant._select_slice_name(columns) + elif is_compliant_series(columns): + compliant = self._select_multi_name(columns.native) + elif is_sequence_like(columns): + compliant = self._select_multi_name(columns) + else: # pragma: no cover + msg = f"Unreachable code, got unexpected type: {type(columns)}" + raise AssertionError(msg) + + if not is_slice_none(rows): + if isinstance(rows, int): + compliant = compliant._gather([rows]) + elif isinstance(rows, (slice, range)): + compliant = compliant._gather_slice(rows) + elif is_compliant_series(rows): + compliant = compliant._gather(rows.native) + elif is_sized_multi_index_selector(rows): + compliant = compliant._gather(rows) + else: # pragma: no cover + msg = f"Unreachable code, got unexpected type: {type(rows)}" + raise AssertionError(msg) + + return compliant |