diff options
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_polars/namespace.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/narwhals/_polars/namespace.py | 313 |
1 files changed, 313 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_polars/namespace.py b/venv/lib/python3.8/site-packages/narwhals/_polars/namespace.py new file mode 100644 index 0000000..4dec34c --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_polars/namespace.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import operator +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + Literal, + Mapping, + Sequence, + cast, + overload, +) + +import polars as pl + +from narwhals._polars.expr import PolarsExpr +from narwhals._polars.series import PolarsSeries +from narwhals._polars.utils import extract_args_kwargs, narwhals_to_native_dtype +from narwhals._utils import Implementation, requires +from narwhals.dependencies import is_numpy_array_2d +from narwhals.dtypes import DType + +if TYPE_CHECKING: + from datetime import timezone + + from narwhals._compliant import CompliantSelectorNamespace, CompliantWhen + from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame + from narwhals._polars.typing import FrameT + from narwhals._utils import Version, _FullContext + from narwhals.schema import Schema + from narwhals.typing import Into1DArray, IntoDType, TimeUnit, _2DArray + + +class PolarsNamespace: + all: Method[PolarsExpr] + col: Method[PolarsExpr] + exclude: Method[PolarsExpr] + all_horizontal: Method[PolarsExpr] + any_horizontal: Method[PolarsExpr] + sum_horizontal: Method[PolarsExpr] + min_horizontal: Method[PolarsExpr] + max_horizontal: Method[PolarsExpr] + + # NOTE: `pyright` accepts, `mypy` doesn't highlight the issue + # error: Type argument "PolarsExpr" of "CompliantWhen" must be a subtype of "CompliantExpr[Any, Any]" + when: Method[CompliantWhen[PolarsDataFrame, PolarsSeries, PolarsExpr]] # type: ignore[type-var] + + def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> None: + self._backend_version = backend_version + self._implementation = Implementation.POLARS + self._version = version + + def __getattr__(self, attr: str) -> Any: + def func(*args: Any, **kwargs: Any) -> Any: + pos, kwds = extract_args_kwargs(args, kwargs) + return self._expr( + getattr(pl, attr)(*pos, **kwds), + version=self._version, + backend_version=self._backend_version, + ) + + return func + + @property + def _dataframe(self) -> type[PolarsDataFrame]: + from narwhals._polars.dataframe import PolarsDataFrame + + return PolarsDataFrame + + @property + def _lazyframe(self) -> type[PolarsLazyFrame]: + from narwhals._polars.dataframe import PolarsLazyFrame + + return PolarsLazyFrame + + @property + def _expr(self) -> type[PolarsExpr]: + return PolarsExpr + + @property + def _series(self) -> type[PolarsSeries]: + return PolarsSeries + + @overload + def from_native(self, data: pl.DataFrame, /) -> PolarsDataFrame: ... + @overload + def from_native(self, data: pl.LazyFrame, /) -> PolarsLazyFrame: ... + @overload + def from_native(self, data: pl.Series, /) -> PolarsSeries: ... + def from_native( + self, data: pl.DataFrame | pl.LazyFrame | pl.Series | Any, / + ) -> PolarsDataFrame | PolarsLazyFrame | PolarsSeries: + if self._dataframe._is_native(data): + return self._dataframe.from_native(data, context=self) + elif self._series._is_native(data): + return self._series.from_native(data, context=self) + elif self._lazyframe._is_native(data): + return self._lazyframe.from_native(data, context=self) + else: # pragma: no cover + msg = f"Unsupported type: {type(data).__name__!r}" + raise TypeError(msg) + + @overload + def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> PolarsSeries: ... + + @overload + def from_numpy( + self, + data: _2DArray, + /, + schema: Mapping[str, DType] | Schema | Sequence[str] | None, + ) -> PolarsDataFrame: ... + + def from_numpy( + self, + data: Into1DArray | _2DArray, + /, + schema: Mapping[str, DType] | Schema | Sequence[str] | None = None, + ) -> PolarsDataFrame | PolarsSeries: + if is_numpy_array_2d(data): + return self._dataframe.from_numpy(data, schema=schema, context=self) + return self._series.from_numpy(data, context=self) # pragma: no cover + + @requires.backend_version( + (1, 0, 0), "Please use `col` for columns selection instead." + ) + def nth(self, *indices: int) -> PolarsExpr: + return self._expr( + pl.nth(*indices), version=self._version, backend_version=self._backend_version + ) + + def len(self) -> PolarsExpr: + if self._backend_version < (0, 20, 5): + return self._expr( + pl.count().alias("len"), + version=self._version, + backend_version=self._backend_version, + ) + return self._expr( + pl.len(), version=self._version, backend_version=self._backend_version + ) + + def concat( + self, + items: Iterable[FrameT], + *, + how: Literal["vertical", "horizontal", "diagonal"], + ) -> PolarsDataFrame | PolarsLazyFrame: + result = pl.concat((item.native for item in items), how=how) + if isinstance(result, pl.DataFrame): + return self._dataframe( + result, backend_version=self._backend_version, version=self._version + ) + return self._lazyframe.from_native(result, context=self) + + def lit(self, value: Any, dtype: IntoDType | None) -> PolarsExpr: + if dtype is not None: + return self._expr( + pl.lit( + value, + dtype=narwhals_to_native_dtype( + dtype, self._version, self._backend_version + ), + ), + version=self._version, + backend_version=self._backend_version, + ) + return self._expr( + pl.lit(value), version=self._version, backend_version=self._backend_version + ) + + def mean_horizontal(self, *exprs: PolarsExpr) -> PolarsExpr: + if self._backend_version < (0, 20, 8): + return self._expr( + pl.sum_horizontal(e._native_expr for e in exprs) + / pl.sum_horizontal(1 - e.is_null()._native_expr for e in exprs), + version=self._version, + backend_version=self._backend_version, + ) + + return self._expr( + pl.mean_horizontal(e._native_expr for e in exprs), + version=self._version, + backend_version=self._backend_version, + ) + + def concat_str( + self, *exprs: PolarsExpr, separator: str, ignore_nulls: bool + ) -> PolarsExpr: + pl_exprs: list[pl.Expr] = [expr._native_expr for expr in exprs] + + if self._backend_version < (0, 20, 6): + null_mask = [expr.is_null() for expr in pl_exprs] + sep = pl.lit(separator) + + if not ignore_nulls: + null_mask_result = pl.any_horizontal(*null_mask) + output_expr = pl.reduce( + lambda x, y: x.cast(pl.String()) + sep + y.cast(pl.String()), # type: ignore[arg-type,return-value] + pl_exprs, + ) + result = pl.when(~null_mask_result).then(output_expr) + else: + init_value, *values = [ + pl.when(nm).then(pl.lit("")).otherwise(expr.cast(pl.String())) + for expr, nm in zip(pl_exprs, null_mask) + ] + separators = [ + pl.when(~nm).then(sep).otherwise(pl.lit("")) for nm in null_mask[:-1] + ] + + result = pl.fold( # type: ignore[assignment] + acc=init_value, + function=operator.add, + exprs=[s + v for s, v in zip(separators, values)], + ) + + return self._expr( + result, version=self._version, backend_version=self._backend_version + ) + + return self._expr( + pl.concat_str(pl_exprs, separator=separator, ignore_nulls=ignore_nulls), + version=self._version, + backend_version=self._backend_version, + ) + + # NOTE: Implementation is too different to annotate correctly (vs other `*SelectorNamespace`) + # 1. Others have lots of private stuff for code reuse + # i. None of that is useful here + # 2. We don't have a `PolarsSelector` abstraction, and just use `PolarsExpr` + @property + def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]: + return cast( + "CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]", + PolarsSelectorNamespace(self), + ) + + +class PolarsSelectorNamespace: + def __init__(self, context: _FullContext, /) -> None: + self._implementation = context._implementation + self._backend_version = context._backend_version + self._version = context._version + + def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr: + native_dtypes = [ + narwhals_to_native_dtype( + dtype, self._version, self._backend_version + ).__class__ + if isinstance(dtype, type) and issubclass(dtype, DType) + else narwhals_to_native_dtype(dtype, self._version, self._backend_version) + for dtype in dtypes + ] + return PolarsExpr( + pl.selectors.by_dtype(native_dtypes), + version=self._version, + backend_version=self._backend_version, + ) + + def matches(self, pattern: str) -> PolarsExpr: + return PolarsExpr( + pl.selectors.matches(pattern=pattern), + version=self._version, + backend_version=self._backend_version, + ) + + def numeric(self) -> PolarsExpr: + return PolarsExpr( + pl.selectors.numeric(), + version=self._version, + backend_version=self._backend_version, + ) + + def boolean(self) -> PolarsExpr: + return PolarsExpr( + pl.selectors.boolean(), + version=self._version, + backend_version=self._backend_version, + ) + + def string(self) -> PolarsExpr: + return PolarsExpr( + pl.selectors.string(), + version=self._version, + backend_version=self._backend_version, + ) + + def categorical(self) -> PolarsExpr: + return PolarsExpr( + pl.selectors.categorical(), + version=self._version, + backend_version=self._backend_version, + ) + + def all(self) -> PolarsExpr: + return PolarsExpr( + pl.selectors.all(), + version=self._version, + backend_version=self._backend_version, + ) + + def datetime( + self, + time_unit: TimeUnit | Iterable[TimeUnit] | None, + time_zone: str | timezone | Iterable[str | timezone | None] | None, + ) -> PolarsExpr: + return PolarsExpr( + pl.selectors.datetime(time_unit=time_unit, time_zone=time_zone), # type: ignore[arg-type] + version=self._version, + backend_version=self._backend_version, + ) |