aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_polars/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_polars/utils.py')
-rw-r--r--venv/lib/python3.8/site-packages/narwhals/_polars/utils.py249
1 files changed, 249 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_polars/utils.py b/venv/lib/python3.8/site-packages/narwhals/_polars/utils.py
new file mode 100644
index 0000000..bb15dfb
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/narwhals/_polars/utils.py
@@ -0,0 +1,249 @@
+from __future__ import annotations
+
+from functools import lru_cache
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Iterable,
+ Iterator,
+ Mapping,
+ TypeVar,
+ cast,
+ overload,
+)
+
+import polars as pl
+
+from narwhals._utils import Version, _DeferredIterable, isinstance_or_issubclass
+from narwhals.exceptions import (
+ ColumnNotFoundError,
+ ComputeError,
+ DuplicateError,
+ InvalidOperationError,
+ NarwhalsError,
+ ShapeError,
+)
+
+if TYPE_CHECKING:
+ from typing_extensions import TypeIs
+
+ from narwhals._utils import _StoresNative
+ from narwhals.dtypes import DType
+ from narwhals.typing import IntoDType
+
+ T = TypeVar("T")
+ NativeT = TypeVar(
+ "NativeT", bound="pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr"
+ )
+
+
+@overload
+def extract_native(obj: _StoresNative[NativeT]) -> NativeT: ...
+@overload
+def extract_native(obj: T) -> T: ...
+def extract_native(obj: _StoresNative[NativeT] | T) -> NativeT | T:
+ return obj.native if _is_compliant_polars(obj) else obj
+
+
+def _is_compliant_polars(
+ obj: _StoresNative[NativeT] | Any,
+) -> TypeIs[_StoresNative[NativeT]]:
+ from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
+ from narwhals._polars.expr import PolarsExpr
+ from narwhals._polars.series import PolarsSeries
+
+ return isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr))
+
+
+def extract_args_kwargs(
+ args: Iterable[Any], kwds: Mapping[str, Any], /
+) -> tuple[Iterator[Any], dict[str, Any]]:
+ it_args = (extract_native(arg) for arg in args)
+ return it_args, {k: extract_native(v) for k, v in kwds.items()}
+
+
+@lru_cache(maxsize=16)
+def native_to_narwhals_dtype( # noqa: C901, PLR0912
+ dtype: pl.DataType, version: Version, backend_version: tuple[int, ...]
+) -> DType:
+ dtypes = version.dtypes
+ if dtype == pl.Float64:
+ return dtypes.Float64()
+ if dtype == pl.Float32:
+ return dtypes.Float32()
+ if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover
+ # Not available for Polars pre 1.8.0
+ return dtypes.Int128()
+ if dtype == pl.Int64:
+ return dtypes.Int64()
+ if dtype == pl.Int32:
+ return dtypes.Int32()
+ if dtype == pl.Int16:
+ return dtypes.Int16()
+ if dtype == pl.Int8:
+ return dtypes.Int8()
+ if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover
+ # Not available for Polars pre 1.8.0
+ return dtypes.UInt128()
+ if dtype == pl.UInt64:
+ return dtypes.UInt64()
+ if dtype == pl.UInt32:
+ return dtypes.UInt32()
+ if dtype == pl.UInt16:
+ return dtypes.UInt16()
+ if dtype == pl.UInt8:
+ return dtypes.UInt8()
+ if dtype == pl.String:
+ return dtypes.String()
+ if dtype == pl.Boolean:
+ return dtypes.Boolean()
+ if dtype == pl.Object:
+ return dtypes.Object()
+ if dtype == pl.Categorical:
+ return dtypes.Categorical()
+ if isinstance_or_issubclass(dtype, pl.Enum):
+ if version is Version.V1:
+ return dtypes.Enum() # type: ignore[call-arg]
+ categories = _DeferredIterable(
+ dtype.categories.to_list
+ if backend_version >= (0, 20, 4)
+ else lambda: cast("list[str]", dtype.categories)
+ )
+ return dtypes.Enum(categories)
+ if dtype == pl.Date:
+ return dtypes.Date()
+ if isinstance_or_issubclass(dtype, pl.Datetime):
+ return (
+ dtypes.Datetime()
+ if dtype is pl.Datetime
+ else dtypes.Datetime(dtype.time_unit, dtype.time_zone)
+ )
+ if isinstance_or_issubclass(dtype, pl.Duration):
+ return (
+ dtypes.Duration()
+ if dtype is pl.Duration
+ else dtypes.Duration(dtype.time_unit)
+ )
+ if isinstance_or_issubclass(dtype, pl.Struct):
+ fields = [
+ dtypes.Field(name, native_to_narwhals_dtype(tp, version, backend_version))
+ for name, tp in dtype
+ ]
+ return dtypes.Struct(fields)
+ if isinstance_or_issubclass(dtype, pl.List):
+ return dtypes.List(
+ native_to_narwhals_dtype(dtype.inner, version, backend_version)
+ )
+ if isinstance_or_issubclass(dtype, pl.Array):
+ outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size
+ return dtypes.Array(
+ native_to_narwhals_dtype(dtype.inner, version, backend_version), outer_shape
+ )
+ if dtype == pl.Decimal:
+ return dtypes.Decimal()
+ if dtype == pl.Time:
+ return dtypes.Time()
+ if dtype == pl.Binary:
+ return dtypes.Binary()
+ return dtypes.Unknown()
+
+
+def narwhals_to_native_dtype( # noqa: C901, PLR0912
+ dtype: IntoDType, version: Version, backend_version: tuple[int, ...]
+) -> pl.DataType:
+ dtypes = version.dtypes
+ if dtype == dtypes.Float64:
+ return pl.Float64()
+ if dtype == dtypes.Float32:
+ return pl.Float32()
+ if dtype == dtypes.Int128 and hasattr(pl, "Int128"):
+ # Not available for Polars pre 1.8.0
+ return pl.Int128()
+ if dtype == dtypes.Int64:
+ return pl.Int64()
+ if dtype == dtypes.Int32:
+ return pl.Int32()
+ if dtype == dtypes.Int16:
+ return pl.Int16()
+ if dtype == dtypes.Int8:
+ return pl.Int8()
+ if dtype == dtypes.UInt64:
+ return pl.UInt64()
+ if dtype == dtypes.UInt32:
+ return pl.UInt32()
+ if dtype == dtypes.UInt16:
+ return pl.UInt16()
+ if dtype == dtypes.UInt8:
+ return pl.UInt8()
+ if dtype == dtypes.String:
+ return pl.String()
+ if dtype == dtypes.Boolean:
+ return pl.Boolean()
+ if dtype == dtypes.Object: # pragma: no cover
+ return pl.Object()
+ if dtype == dtypes.Categorical:
+ return pl.Categorical()
+ if isinstance_or_issubclass(dtype, dtypes.Enum):
+ if version is Version.V1:
+ msg = "Converting to Enum is not supported in narwhals.stable.v1"
+ raise NotImplementedError(msg)
+ if isinstance(dtype, dtypes.Enum):
+ return pl.Enum(dtype.categories)
+ msg = "Can not cast / initialize Enum without categories present"
+ raise ValueError(msg)
+ if dtype == dtypes.Date:
+ return pl.Date()
+ if dtype == dtypes.Time:
+ return pl.Time()
+ if dtype == dtypes.Binary:
+ return pl.Binary()
+ if dtype == dtypes.Decimal:
+ msg = "Casting to Decimal is not supported yet."
+ raise NotImplementedError(msg)
+ if isinstance_or_issubclass(dtype, dtypes.Datetime):
+ return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type]
+ if isinstance_or_issubclass(dtype, dtypes.Duration):
+ return pl.Duration(dtype.time_unit) # type: ignore[arg-type]
+ if isinstance_or_issubclass(dtype, dtypes.List):
+ return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version))
+ if isinstance_or_issubclass(dtype, dtypes.Struct):
+ fields = [
+ pl.Field(
+ field.name,
+ narwhals_to_native_dtype(field.dtype, version, backend_version),
+ )
+ for field in dtype.fields
+ ]
+ return pl.Struct(fields)
+ if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
+ size = dtype.size
+ kwargs = {"width": size} if backend_version < (0, 20, 30) else {"shape": size}
+ return pl.Array(
+ narwhals_to_native_dtype(dtype.inner, version, backend_version), **kwargs
+ )
+ return pl.Unknown() # pragma: no cover
+
+
+def catch_polars_exception(
+ exception: Exception, backend_version: tuple[int, ...]
+) -> NarwhalsError | Exception:
+ if isinstance(exception, pl.exceptions.ColumnNotFoundError):
+ return ColumnNotFoundError(str(exception))
+ elif isinstance(exception, pl.exceptions.ShapeError):
+ return ShapeError(str(exception))
+ elif isinstance(exception, pl.exceptions.InvalidOperationError):
+ return InvalidOperationError(str(exception))
+ elif isinstance(exception, pl.exceptions.DuplicateError):
+ return DuplicateError(str(exception))
+ elif isinstance(exception, pl.exceptions.ComputeError):
+ return ComputeError(str(exception))
+ if backend_version >= (1,) and isinstance(exception, pl.exceptions.PolarsError):
+ # Old versions of Polars didn't have PolarsError.
+ return NarwhalsError(str(exception)) # pragma: no cover
+ elif backend_version < (1,) and "polars.exceptions" in str(
+ type(exception)
+ ): # pragma: no cover
+ # Last attempt, for old Polars versions.
+ return NarwhalsError(str(exception))
+ # Just return exception as-is.
+ return exception