diff options
author | sotech117 <michael_foiani@brown.edu> | 2025-07-31 17:27:24 -0400 |
---|---|---|
committer | sotech117 <michael_foiani@brown.edu> | 2025-07-31 17:27:24 -0400 |
commit | 5bf22fc7e3c392c8bd44315ca2d06d7dca7d084e (patch) | |
tree | 8dacb0f195df1c0788d36dd0064f6bbaa3143ede /venv/lib/python3.8/site-packages/narwhals/_polars/utils.py | |
parent | b832d364da8c2efe09e3f75828caf73c50d01ce3 (diff) |
add code for analysis of data
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_polars/utils.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/narwhals/_polars/utils.py | 249 |
1 files changed, 249 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_polars/utils.py b/venv/lib/python3.8/site-packages/narwhals/_polars/utils.py new file mode 100644 index 0000000..bb15dfb --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_polars/utils.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + Iterator, + Mapping, + TypeVar, + cast, + overload, +) + +import polars as pl + +from narwhals._utils import Version, _DeferredIterable, isinstance_or_issubclass +from narwhals.exceptions import ( + ColumnNotFoundError, + ComputeError, + DuplicateError, + InvalidOperationError, + NarwhalsError, + ShapeError, +) + +if TYPE_CHECKING: + from typing_extensions import TypeIs + + from narwhals._utils import _StoresNative + from narwhals.dtypes import DType + from narwhals.typing import IntoDType + + T = TypeVar("T") + NativeT = TypeVar( + "NativeT", bound="pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr" + ) + + +@overload +def extract_native(obj: _StoresNative[NativeT]) -> NativeT: ... +@overload +def extract_native(obj: T) -> T: ... +def extract_native(obj: _StoresNative[NativeT] | T) -> NativeT | T: + return obj.native if _is_compliant_polars(obj) else obj + + +def _is_compliant_polars( + obj: _StoresNative[NativeT] | Any, +) -> TypeIs[_StoresNative[NativeT]]: + from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame + from narwhals._polars.expr import PolarsExpr + from narwhals._polars.series import PolarsSeries + + return isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr)) + + +def extract_args_kwargs( + args: Iterable[Any], kwds: Mapping[str, Any], / +) -> tuple[Iterator[Any], dict[str, Any]]: + it_args = (extract_native(arg) for arg in args) + return it_args, {k: extract_native(v) for k, v in kwds.items()} + + +@lru_cache(maxsize=16) +def native_to_narwhals_dtype( # noqa: C901, PLR0912 + dtype: pl.DataType, version: Version, backend_version: tuple[int, ...] +) -> DType: + dtypes = version.dtypes + if dtype == pl.Float64: + return dtypes.Float64() + if dtype == pl.Float32: + return dtypes.Float32() + if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover + # Not available for Polars pre 1.8.0 + return dtypes.Int128() + if dtype == pl.Int64: + return dtypes.Int64() + if dtype == pl.Int32: + return dtypes.Int32() + if dtype == pl.Int16: + return dtypes.Int16() + if dtype == pl.Int8: + return dtypes.Int8() + if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover + # Not available for Polars pre 1.8.0 + return dtypes.UInt128() + if dtype == pl.UInt64: + return dtypes.UInt64() + if dtype == pl.UInt32: + return dtypes.UInt32() + if dtype == pl.UInt16: + return dtypes.UInt16() + if dtype == pl.UInt8: + return dtypes.UInt8() + if dtype == pl.String: + return dtypes.String() + if dtype == pl.Boolean: + return dtypes.Boolean() + if dtype == pl.Object: + return dtypes.Object() + if dtype == pl.Categorical: + return dtypes.Categorical() + if isinstance_or_issubclass(dtype, pl.Enum): + if version is Version.V1: + return dtypes.Enum() # type: ignore[call-arg] + categories = _DeferredIterable( + dtype.categories.to_list + if backend_version >= (0, 20, 4) + else lambda: cast("list[str]", dtype.categories) + ) + return dtypes.Enum(categories) + if dtype == pl.Date: + return dtypes.Date() + if isinstance_or_issubclass(dtype, pl.Datetime): + return ( + dtypes.Datetime() + if dtype is pl.Datetime + else dtypes.Datetime(dtype.time_unit, dtype.time_zone) + ) + if isinstance_or_issubclass(dtype, pl.Duration): + return ( + dtypes.Duration() + if dtype is pl.Duration + else dtypes.Duration(dtype.time_unit) + ) + if isinstance_or_issubclass(dtype, pl.Struct): + fields = [ + dtypes.Field(name, native_to_narwhals_dtype(tp, version, backend_version)) + for name, tp in dtype + ] + return dtypes.Struct(fields) + if isinstance_or_issubclass(dtype, pl.List): + return dtypes.List( + native_to_narwhals_dtype(dtype.inner, version, backend_version) + ) + if isinstance_or_issubclass(dtype, pl.Array): + outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size + return dtypes.Array( + native_to_narwhals_dtype(dtype.inner, version, backend_version), outer_shape + ) + if dtype == pl.Decimal: + return dtypes.Decimal() + if dtype == pl.Time: + return dtypes.Time() + if dtype == pl.Binary: + return dtypes.Binary() + return dtypes.Unknown() + + +def narwhals_to_native_dtype( # noqa: C901, PLR0912 + dtype: IntoDType, version: Version, backend_version: tuple[int, ...] +) -> pl.DataType: + dtypes = version.dtypes + if dtype == dtypes.Float64: + return pl.Float64() + if dtype == dtypes.Float32: + return pl.Float32() + if dtype == dtypes.Int128 and hasattr(pl, "Int128"): + # Not available for Polars pre 1.8.0 + return pl.Int128() + if dtype == dtypes.Int64: + return pl.Int64() + if dtype == dtypes.Int32: + return pl.Int32() + if dtype == dtypes.Int16: + return pl.Int16() + if dtype == dtypes.Int8: + return pl.Int8() + if dtype == dtypes.UInt64: + return pl.UInt64() + if dtype == dtypes.UInt32: + return pl.UInt32() + if dtype == dtypes.UInt16: + return pl.UInt16() + if dtype == dtypes.UInt8: + return pl.UInt8() + if dtype == dtypes.String: + return pl.String() + if dtype == dtypes.Boolean: + return pl.Boolean() + if dtype == dtypes.Object: # pragma: no cover + return pl.Object() + if dtype == dtypes.Categorical: + return pl.Categorical() + if isinstance_or_issubclass(dtype, dtypes.Enum): + if version is Version.V1: + msg = "Converting to Enum is not supported in narwhals.stable.v1" + raise NotImplementedError(msg) + if isinstance(dtype, dtypes.Enum): + return pl.Enum(dtype.categories) + msg = "Can not cast / initialize Enum without categories present" + raise ValueError(msg) + if dtype == dtypes.Date: + return pl.Date() + if dtype == dtypes.Time: + return pl.Time() + if dtype == dtypes.Binary: + return pl.Binary() + if dtype == dtypes.Decimal: + msg = "Casting to Decimal is not supported yet." + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Datetime): + return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type] + if isinstance_or_issubclass(dtype, dtypes.Duration): + return pl.Duration(dtype.time_unit) # type: ignore[arg-type] + if isinstance_or_issubclass(dtype, dtypes.List): + return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version)) + if isinstance_or_issubclass(dtype, dtypes.Struct): + fields = [ + pl.Field( + field.name, + narwhals_to_native_dtype(field.dtype, version, backend_version), + ) + for field in dtype.fields + ] + return pl.Struct(fields) + if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover + size = dtype.size + kwargs = {"width": size} if backend_version < (0, 20, 30) else {"shape": size} + return pl.Array( + narwhals_to_native_dtype(dtype.inner, version, backend_version), **kwargs + ) + return pl.Unknown() # pragma: no cover + + +def catch_polars_exception( + exception: Exception, backend_version: tuple[int, ...] +) -> NarwhalsError | Exception: + if isinstance(exception, pl.exceptions.ColumnNotFoundError): + return ColumnNotFoundError(str(exception)) + elif isinstance(exception, pl.exceptions.ShapeError): + return ShapeError(str(exception)) + elif isinstance(exception, pl.exceptions.InvalidOperationError): + return InvalidOperationError(str(exception)) + elif isinstance(exception, pl.exceptions.DuplicateError): + return DuplicateError(str(exception)) + elif isinstance(exception, pl.exceptions.ComputeError): + return ComputeError(str(exception)) + if backend_version >= (1,) and isinstance(exception, pl.exceptions.PolarsError): + # Old versions of Polars didn't have PolarsError. + return NarwhalsError(str(exception)) # pragma: no cover + elif backend_version < (1,) and "polars.exceptions" in str( + type(exception) + ): # pragma: no cover + # Last attempt, for old Polars versions. + return NarwhalsError(str(exception)) + # Just return exception as-is. + return exception |