aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py')
-rw-r--r--venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py470
1 files changed, 470 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py
new file mode 100644
index 0000000..d100448
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py
@@ -0,0 +1,470 @@
+from __future__ import annotations
+
+from functools import lru_cache
+from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, Sequence, cast
+
+import pyarrow as pa
+import pyarrow.compute as pc
+
+from narwhals._compliant.series import _SeriesNamespace
+from narwhals._utils import isinstance_or_issubclass
+from narwhals.exceptions import ShapeError
+
+if TYPE_CHECKING:
+ from typing_extensions import TypeAlias, TypeIs
+
+ from narwhals._arrow.series import ArrowSeries
+ from narwhals._arrow.typing import (
+ ArrayAny,
+ ArrayOrScalar,
+ ArrayOrScalarT1,
+ ArrayOrScalarT2,
+ ChunkedArrayAny,
+ NativeIntervalUnit,
+ ScalarAny,
+ )
+ from narwhals._duration import IntervalUnit
+ from narwhals._utils import Version
+ from narwhals.dtypes import DType
+ from narwhals.typing import IntoDType, PythonLiteral
+
+ # NOTE: stubs don't allow for `ChunkedArray[StructArray]`
+ # Intended to represent the `.chunks` property storing `list[pa.StructArray]`
+ ChunkedArrayStructArray: TypeAlias = ChunkedArrayAny
+
+ def is_timestamp(t: Any) -> TypeIs[pa.TimestampType[Any, Any]]: ...
+ def is_duration(t: Any) -> TypeIs[pa.DurationType[Any]]: ...
+ def is_list(t: Any) -> TypeIs[pa.ListType[Any]]: ...
+ def is_large_list(t: Any) -> TypeIs[pa.LargeListType[Any]]: ...
+ def is_fixed_size_list(t: Any) -> TypeIs[pa.FixedSizeListType[Any, Any]]: ...
+ def is_dictionary(t: Any) -> TypeIs[pa.DictionaryType[Any, Any, Any]]: ...
+ def extract_regex(
+ strings: ChunkedArrayAny,
+ /,
+ pattern: str,
+ *,
+ options: Any = None,
+ memory_pool: Any = None,
+ ) -> ChunkedArrayStructArray: ...
+else:
+ from pyarrow.compute import extract_regex
+ from pyarrow.types import (
+ is_dictionary, # noqa: F401
+ is_duration,
+ is_fixed_size_list,
+ is_large_list,
+ is_list,
+ is_timestamp,
+ )
+
+UNITS_DICT: Mapping[IntervalUnit, NativeIntervalUnit] = {
+ "y": "year",
+ "q": "quarter",
+ "mo": "month",
+ "d": "day",
+ "h": "hour",
+ "m": "minute",
+ "s": "second",
+ "ms": "millisecond",
+ "us": "microsecond",
+ "ns": "nanosecond",
+}
+
+lit = pa.scalar
+"""Alias for `pyarrow.scalar`."""
+
+
+def extract_py_scalar(value: Any, /) -> Any:
+ from narwhals._arrow.series import maybe_extract_py_scalar
+
+ return maybe_extract_py_scalar(value, return_py_scalar=True)
+
+
+def chunked_array(
+ arr: ArrayOrScalar | list[Iterable[Any]], dtype: pa.DataType | None = None, /
+) -> ChunkedArrayAny:
+ if isinstance(arr, pa.ChunkedArray):
+ return arr
+ if isinstance(arr, list):
+ return pa.chunked_array(arr, dtype)
+ else:
+ return pa.chunked_array([arr], arr.type)
+
+
+def nulls_like(n: int, series: ArrowSeries) -> ArrayAny:
+ """Create a strongly-typed Array instance with all elements null.
+
+ Uses the type of `series`, without upseting `mypy`.
+ """
+ return pa.nulls(n, series.native.type)
+
+
+@lru_cache(maxsize=16)
+def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: # noqa: C901, PLR0912
+ dtypes = version.dtypes
+ if pa.types.is_int64(dtype):
+ return dtypes.Int64()
+ if pa.types.is_int32(dtype):
+ return dtypes.Int32()
+ if pa.types.is_int16(dtype):
+ return dtypes.Int16()
+ if pa.types.is_int8(dtype):
+ return dtypes.Int8()
+ if pa.types.is_uint64(dtype):
+ return dtypes.UInt64()
+ if pa.types.is_uint32(dtype):
+ return dtypes.UInt32()
+ if pa.types.is_uint16(dtype):
+ return dtypes.UInt16()
+ if pa.types.is_uint8(dtype):
+ return dtypes.UInt8()
+ if pa.types.is_boolean(dtype):
+ return dtypes.Boolean()
+ if pa.types.is_float64(dtype):
+ return dtypes.Float64()
+ if pa.types.is_float32(dtype):
+ return dtypes.Float32()
+ # bug in coverage? it shows `31->exit` (where `31` is currently the line number of
+ # the next line), even though both when the if condition is true and false are covered
+ if ( # pragma: no cover
+ pa.types.is_string(dtype)
+ or pa.types.is_large_string(dtype)
+ or getattr(pa.types, "is_string_view", lambda _: False)(dtype)
+ ):
+ return dtypes.String()
+ if pa.types.is_date32(dtype):
+ return dtypes.Date()
+ if is_timestamp(dtype):
+ return dtypes.Datetime(time_unit=dtype.unit, time_zone=dtype.tz)
+ if is_duration(dtype):
+ return dtypes.Duration(time_unit=dtype.unit)
+ if pa.types.is_dictionary(dtype):
+ return dtypes.Categorical()
+ if pa.types.is_struct(dtype):
+ return dtypes.Struct(
+ [
+ dtypes.Field(
+ dtype.field(i).name,
+ native_to_narwhals_dtype(dtype.field(i).type, version),
+ )
+ for i in range(dtype.num_fields)
+ ]
+ )
+ if is_list(dtype) or is_large_list(dtype):
+ return dtypes.List(native_to_narwhals_dtype(dtype.value_type, version))
+ if is_fixed_size_list(dtype):
+ return dtypes.Array(
+ native_to_narwhals_dtype(dtype.value_type, version), dtype.list_size
+ )
+ if pa.types.is_decimal(dtype):
+ return dtypes.Decimal()
+ if pa.types.is_time32(dtype) or pa.types.is_time64(dtype):
+ return dtypes.Time()
+ if pa.types.is_binary(dtype):
+ return dtypes.Binary()
+ return dtypes.Unknown() # pragma: no cover
+
+
+def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType: # noqa: C901, PLR0912
+ dtypes = version.dtypes
+ if isinstance_or_issubclass(dtype, dtypes.Decimal):
+ msg = "Casting to Decimal is not supported yet."
+ raise NotImplementedError(msg)
+ if isinstance_or_issubclass(dtype, dtypes.Float64):
+ return pa.float64()
+ if isinstance_or_issubclass(dtype, dtypes.Float32):
+ return pa.float32()
+ if isinstance_or_issubclass(dtype, dtypes.Int64):
+ return pa.int64()
+ if isinstance_or_issubclass(dtype, dtypes.Int32):
+ return pa.int32()
+ if isinstance_or_issubclass(dtype, dtypes.Int16):
+ return pa.int16()
+ if isinstance_or_issubclass(dtype, dtypes.Int8):
+ return pa.int8()
+ if isinstance_or_issubclass(dtype, dtypes.UInt64):
+ return pa.uint64()
+ if isinstance_or_issubclass(dtype, dtypes.UInt32):
+ return pa.uint32()
+ if isinstance_or_issubclass(dtype, dtypes.UInt16):
+ return pa.uint16()
+ if isinstance_or_issubclass(dtype, dtypes.UInt8):
+ return pa.uint8()
+ if isinstance_or_issubclass(dtype, dtypes.String):
+ return pa.string()
+ if isinstance_or_issubclass(dtype, dtypes.Boolean):
+ return pa.bool_()
+ if isinstance_or_issubclass(dtype, dtypes.Categorical):
+ return pa.dictionary(pa.uint32(), pa.string())
+ if isinstance_or_issubclass(dtype, dtypes.Datetime):
+ unit = dtype.time_unit
+ return pa.timestamp(unit, tz) if (tz := dtype.time_zone) else pa.timestamp(unit)
+ if isinstance_or_issubclass(dtype, dtypes.Duration):
+ return pa.duration(dtype.time_unit)
+ if isinstance_or_issubclass(dtype, dtypes.Date):
+ return pa.date32()
+ if isinstance_or_issubclass(dtype, dtypes.List):
+ return pa.list_(value_type=narwhals_to_native_dtype(dtype.inner, version=version))
+ if isinstance_or_issubclass(dtype, dtypes.Struct):
+ return pa.struct(
+ [
+ (field.name, narwhals_to_native_dtype(field.dtype, version=version))
+ for field in dtype.fields
+ ]
+ )
+ if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
+ inner = narwhals_to_native_dtype(dtype.inner, version=version)
+ list_size = dtype.size
+ return pa.list_(inner, list_size=list_size)
+ if isinstance_or_issubclass(dtype, dtypes.Time):
+ return pa.time64("ns")
+ if isinstance_or_issubclass(dtype, dtypes.Binary):
+ return pa.binary()
+
+ msg = f"Unknown dtype: {dtype}" # pragma: no cover
+ raise AssertionError(msg)
+
+
+def extract_native(
+ lhs: ArrowSeries, rhs: ArrowSeries | PythonLiteral | ScalarAny
+) -> tuple[ChunkedArrayAny | ScalarAny, ChunkedArrayAny | ScalarAny]:
+ """Extract native objects in binary operation.
+
+ If the comparison isn't supported, return `NotImplemented` so that the
+ "right-hand-side" operation (e.g. `__radd__`) can be tried.
+
+ If one of the two sides has a `_broadcast` flag, then extract the scalar
+ underneath it so that PyArrow can do its own broadcasting.
+ """
+ from narwhals._arrow.dataframe import ArrowDataFrame
+ from narwhals._arrow.series import ArrowSeries
+
+ if rhs is None: # pragma: no cover
+ return lhs.native, lit(None, type=lhs._type)
+
+ if isinstance(rhs, ArrowDataFrame):
+ return NotImplemented
+
+ if isinstance(rhs, ArrowSeries):
+ if lhs._broadcast and not rhs._broadcast:
+ return lhs.native[0], rhs.native
+ if rhs._broadcast:
+ return lhs.native, rhs.native[0]
+ return lhs.native, rhs.native
+
+ if isinstance(rhs, list):
+ msg = "Expected Series or scalar, got list."
+ raise TypeError(msg)
+
+ return lhs.native, rhs if isinstance(rhs, pa.Scalar) else lit(rhs)
+
+
+def align_series_full_broadcast(*series: ArrowSeries) -> Sequence[ArrowSeries]:
+ # Ensure all of `series` are of the same length.
+ lengths = [len(s) for s in series]
+ max_length = max(lengths)
+ fast_path = all(_len == max_length for _len in lengths)
+
+ if fast_path:
+ return series
+
+ reshaped = []
+ for s in series:
+ if s._broadcast:
+ value = s.native[0]
+ if s._backend_version < (13,) and hasattr(value, "as_py"):
+ value = value.as_py()
+ reshaped.append(s._with_native(pa.array([value] * max_length, type=s._type)))
+ else:
+ if (actual_len := len(s)) != max_length:
+ msg = f"Expected object of length {max_length}, got {actual_len}."
+ raise ShapeError(msg)
+ reshaped.append(s)
+
+ return reshaped
+
+
+def floordiv_compat(left: ArrayOrScalar, right: ArrayOrScalar, /) -> Any:
+ # The following lines are adapted from pandas' pyarrow implementation.
+ # Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154
+
+ if pa.types.is_integer(left.type) and pa.types.is_integer(right.type):
+ divided = pc.divide_checked(left, right)
+ # TODO @dangotbanned: Use a `TypeVar` in guards
+ # Narrowing to a `Union` isn't interacting well with the rest of the stubs
+ # https://github.com/zen-xu/pyarrow-stubs/pull/215
+ if pa.types.is_signed_integer(divided.type):
+ div_type = cast("pa._lib.Int64Type", divided.type)
+ has_remainder = pc.not_equal(pc.multiply(divided, right), left)
+ has_one_negative_operand = pc.less(
+ pc.bit_wise_xor(left, right), lit(0, div_type)
+ )
+ result = pc.if_else(
+ pc.and_(has_remainder, has_one_negative_operand),
+ pc.subtract(divided, lit(1, div_type)),
+ divided,
+ )
+ else:
+ result = divided # pragma: no cover
+ result = result.cast(left.type)
+ else:
+ divided = pc.divide(left, right)
+ result = pc.floor(divided)
+ return result
+
+
+def cast_for_truediv(
+ arrow_array: ArrayOrScalarT1, pa_object: ArrayOrScalarT2
+) -> tuple[ArrayOrScalarT1, ArrayOrScalarT2]:
+ # Lifted from:
+ # https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122
+ # Ensure int / int -> float mirroring Python/Numpy behavior
+ # as pc.divide_checked(int, int) -> int
+ if pa.types.is_integer(arrow_array.type) and pa.types.is_integer(pa_object.type):
+ # GH: 56645. # noqa: ERA001
+ # https://github.com/apache/arrow/issues/35563
+ # NOTE: `pyarrow==11.*` doesn't allow keywords in `Array.cast`
+ return pc.cast(arrow_array, pa.float64(), safe=False), pc.cast(
+ pa_object, pa.float64(), safe=False
+ )
+
+ return arrow_array, pa_object
+
+
+# Regex for date, time, separator and timezone components
+DATE_RE = r"(?P<date>\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4}|\d{8})"
+SEP_RE = r"(?P<sep>\s|T)"
+TIME_RE = r"(?P<time>\d{2}:\d{2}(?::\d{2})?|\d{6}?)" # \s*(?P<period>[AP]M)?)?
+HMS_RE = r"^(?P<hms>\d{2}:\d{2}:\d{2})$"
+HM_RE = r"^(?P<hm>\d{2}:\d{2})$"
+HMS_RE_NO_SEP = r"^(?P<hms_no_sep>\d{6})$"
+TZ_RE = r"(?P<tz>Z|[+-]\d{2}:?\d{2})" # Matches 'Z', '+02:00', '+0200', '+02', etc.
+FULL_RE = rf"{DATE_RE}{SEP_RE}?{TIME_RE}?{TZ_RE}?$"
+
+# Separate regexes for different date formats
+YMD_RE = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])$"
+DMY_RE = r"^(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
+MDY_RE = r"^(?P<month>0[1-9]|1[0-2])(?P<sep1>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
+YMD_RE_NO_SEP = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<month>0[1-9]|1[0-2])(?P<day>0[1-9]|[12][0-9]|3[01])$"
+
+DATE_FORMATS = (
+ (YMD_RE_NO_SEP, "%Y%m%d"),
+ (YMD_RE, "%Y-%m-%d"),
+ (DMY_RE, "%d-%m-%Y"),
+ (MDY_RE, "%m-%d-%Y"),
+)
+TIME_FORMATS = ((HMS_RE, "%H:%M:%S"), (HM_RE, "%H:%M"), (HMS_RE_NO_SEP, "%H%M%S"))
+
+
+def _extract_regex_concat_arrays(
+ strings: ChunkedArrayAny,
+ /,
+ pattern: str,
+ *,
+ options: Any = None,
+ memory_pool: Any = None,
+) -> pa.StructArray:
+ r = pa.concat_arrays(
+ extract_regex(strings, pattern, options=options, memory_pool=memory_pool).chunks
+ )
+ return cast("pa.StructArray", r)
+
+
+def parse_datetime_format(arr: ChunkedArrayAny) -> str:
+ """Try to infer datetime format from StringArray."""
+ matches = _extract_regex_concat_arrays(arr.drop_null().slice(0, 10), pattern=FULL_RE)
+ if not pc.all(matches.is_valid()).as_py():
+ msg = (
+ "Unable to infer datetime format, provided format is not supported. "
+ "Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
+ )
+ raise NotImplementedError(msg)
+
+ separators = matches.field("sep")
+ tz = matches.field("tz")
+
+ # separators and time zones must be unique
+ if pc.count(pc.unique(separators)).as_py() > 1:
+ msg = "Found multiple separator values while inferring datetime format."
+ raise ValueError(msg)
+
+ if pc.count(pc.unique(tz)).as_py() > 1:
+ msg = "Found multiple timezone values while inferring datetime format."
+ raise ValueError(msg)
+
+ date_value = _parse_date_format(cast("pc.StringArray", matches.field("date")))
+ time_value = _parse_time_format(cast("pc.StringArray", matches.field("time")))
+
+ sep_value = separators[0].as_py()
+ tz_value = "%z" if tz[0].as_py() else ""
+
+ return f"{date_value}{sep_value}{time_value}{tz_value}"
+
+
+def _parse_date_format(arr: pc.StringArray) -> str:
+ for date_rgx, date_fmt in DATE_FORMATS:
+ matches = pc.extract_regex(arr, pattern=date_rgx)
+ if date_fmt == "%Y%m%d" and pc.all(matches.is_valid()).as_py():
+ return date_fmt
+ elif (
+ pc.all(matches.is_valid()).as_py()
+ and pc.count(pc.unique(sep1 := matches.field("sep1"))).as_py() == 1
+ and pc.count(pc.unique(sep2 := matches.field("sep2"))).as_py() == 1
+ and (date_sep_value := sep1[0].as_py()) == sep2[0].as_py()
+ ):
+ return date_fmt.replace("-", date_sep_value)
+
+ msg = (
+ "Unable to infer datetime format. "
+ "Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
+ )
+ raise ValueError(msg)
+
+
+def _parse_time_format(arr: pc.StringArray) -> str:
+ for time_rgx, time_fmt in TIME_FORMATS:
+ matches = pc.extract_regex(arr, pattern=time_rgx)
+ if pc.all(matches.is_valid()).as_py():
+ return time_fmt
+ return ""
+
+
+def pad_series(
+ series: ArrowSeries, *, window_size: int, center: bool
+) -> tuple[ArrowSeries, int]:
+ """Pad series with None values on the left and/or right side, depending on the specified parameters.
+
+ Arguments:
+ series: The input ArrowSeries to be padded.
+ window_size: The desired size of the window.
+ center: Specifies whether to center the padding or not.
+
+ Returns:
+ A tuple containing the padded ArrowSeries and the offset value.
+ """
+ if not center:
+ return series, 0
+ offset_left = window_size // 2
+ # subtract one if window_size is even
+ offset_right = offset_left - (window_size % 2 == 0)
+ pad_left = pa.array([None] * offset_left, type=series._type)
+ pad_right = pa.array([None] * offset_right, type=series._type)
+ concat = pa.concat_arrays([pad_left, *series.native.chunks, pad_right])
+ return series._with_native(concat), offset_left + offset_right
+
+
+def cast_to_comparable_string_types(
+ *chunked_arrays: ChunkedArrayAny, separator: str
+) -> tuple[Iterator[ChunkedArrayAny], ScalarAny]:
+ # Ensure `chunked_arrays` are either all `string` or all `large_string`.
+ dtype = (
+ pa.string() # (PyArrow default)
+ if not any(pa.types.is_large_string(ca.type) for ca in chunked_arrays)
+ else pa.large_string()
+ )
+ return (ca.cast(dtype) for ca in chunked_arrays), lit(separator, dtype)
+
+
+class ArrowSeriesNamespace(_SeriesNamespace["ArrowSeries", "ChunkedArrayAny"]):
+ def __init__(self, series: ArrowSeries, /) -> None:
+ self._compliant_series = series