diff options
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_compliant/namespace.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/narwhals/_compliant/namespace.py | 194 |
1 files changed, 194 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_compliant/namespace.py b/venv/lib/python3.8/site-packages/narwhals/_compliant/namespace.py new file mode 100644 index 0000000..e73ccc2 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_compliant/namespace.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +from functools import partial +from typing import ( + TYPE_CHECKING, + Any, + Container, + Iterable, + Mapping, + Protocol, + Sequence, + overload, +) + +from narwhals._compliant.typing import ( + CompliantExprT, + CompliantFrameT, + CompliantLazyFrameT, + DepthTrackingExprT, + EagerDataFrameT, + EagerExprT, + EagerSeriesT, + LazyExprT, + NativeFrameT, + NativeFrameT_co, +) +from narwhals._utils import ( + exclude_column_names, + get_column_names, + passthrough_column_names, +) +from narwhals.dependencies import is_numpy_array_2d + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from narwhals._compliant.selectors import CompliantSelectorNamespace + from narwhals._compliant.when_then import CompliantWhen, EagerWhen + from narwhals._utils import Implementation, Version + from narwhals.dtypes import DType + from narwhals.schema import Schema + from narwhals.typing import ( + ConcatMethod, + Into1DArray, + IntoDType, + NonNestedLiteral, + _2DArray, + ) + + Incomplete: TypeAlias = Any + +__all__ = ["CompliantNamespace", "EagerNamespace"] + + +class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]): + _implementation: Implementation + _backend_version: tuple[int, ...] + _version: Version + + def all(self) -> CompliantExprT: + return self._expr.from_column_names(get_column_names, context=self) + + def col(self, *column_names: str) -> CompliantExprT: + return self._expr.from_column_names( + passthrough_column_names(column_names), context=self + ) + + def exclude(self, excluded_names: Container[str]) -> CompliantExprT: + return self._expr.from_column_names( + partial(exclude_column_names, names=excluded_names), context=self + ) + + def nth(self, *column_indices: int) -> CompliantExprT: + return self._expr.from_column_indices(*column_indices, context=self) + + def len(self) -> CompliantExprT: ... + def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> CompliantExprT: ... + def all_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def any_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def concat( + self, items: Iterable[CompliantFrameT], *, how: ConcatMethod + ) -> CompliantFrameT: ... + def when( + self, predicate: CompliantExprT + ) -> CompliantWhen[CompliantFrameT, Incomplete, CompliantExprT]: ... + def concat_str( + self, *exprs: CompliantExprT, separator: str, ignore_nulls: bool + ) -> CompliantExprT: ... + @property + def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ... + @property + def _expr(self) -> type[CompliantExprT]: ... + + +class DepthTrackingNamespace( + CompliantNamespace[CompliantFrameT, DepthTrackingExprT], + Protocol[CompliantFrameT, DepthTrackingExprT], +): + def all(self) -> DepthTrackingExprT: + return self._expr.from_column_names( + get_column_names, function_name="all", context=self + ) + + def col(self, *column_names: str) -> DepthTrackingExprT: + return self._expr.from_column_names( + passthrough_column_names(column_names), function_name="col", context=self + ) + + def exclude(self, excluded_names: Container[str]) -> DepthTrackingExprT: + return self._expr.from_column_names( + partial(exclude_column_names, names=excluded_names), + function_name="exclude", + context=self, + ) + + +class LazyNamespace( + CompliantNamespace[CompliantLazyFrameT, LazyExprT], + Protocol[CompliantLazyFrameT, LazyExprT, NativeFrameT_co], +): + @property + def _lazyframe(self) -> type[CompliantLazyFrameT]: ... + + def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT: + if self._lazyframe._is_native(data): + return self._lazyframe.from_native(data, context=self) + else: # pragma: no cover + msg = f"Unsupported type: {type(data).__name__!r}" + raise TypeError(msg) + + +class EagerNamespace( + DepthTrackingNamespace[EagerDataFrameT, EagerExprT], + Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT], +): + @property + def _dataframe(self) -> type[EagerDataFrameT]: ... + @property + def _series(self) -> type[EagerSeriesT]: ... + def when( + self, predicate: EagerExprT + ) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT]: ... + + def from_native(self, data: Any, /) -> EagerDataFrameT | EagerSeriesT: + if self._dataframe._is_native(data): + return self._dataframe.from_native(data, context=self) + elif self._series._is_native(data): + return self._series.from_native(data, context=self) + msg = f"Unsupported type: {type(data).__name__!r}" + raise TypeError(msg) + + @overload + def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT: ... + + @overload + def from_numpy( + self, + data: _2DArray, + /, + schema: Mapping[str, DType] | Schema | Sequence[str] | None, + ) -> EagerDataFrameT: ... + + def from_numpy( + self, + data: Into1DArray | _2DArray, + /, + schema: Mapping[str, DType] | Schema | Sequence[str] | None = None, + ) -> EagerDataFrameT | EagerSeriesT: + if is_numpy_array_2d(data): + return self._dataframe.from_numpy(data, schema=schema, context=self) + return self._series.from_numpy(data, context=self) + + def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ... + def _concat_horizontal( + self, dfs: Sequence[NativeFrameT | Any], / + ) -> NativeFrameT: ... + def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ... + def concat( + self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod + ) -> EagerDataFrameT: + dfs = [item.native for item in items] + if how == "horizontal": + native = self._concat_horizontal(dfs) + elif how == "vertical": + native = self._concat_vertical(dfs) + elif how == "diagonal": + native = self._concat_diagonal(dfs) + else: # pragma: no cover + raise NotImplementedError + return self._dataframe.from_native(native, context=self) |