diff options
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py | 470 |
1 files changed, 470 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py new file mode 100644 index 0000000..d100448 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/utils.py @@ -0,0 +1,470 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, Sequence, cast + +import pyarrow as pa +import pyarrow.compute as pc + +from narwhals._compliant.series import _SeriesNamespace +from narwhals._utils import isinstance_or_issubclass +from narwhals.exceptions import ShapeError + +if TYPE_CHECKING: + from typing_extensions import TypeAlias, TypeIs + + from narwhals._arrow.series import ArrowSeries + from narwhals._arrow.typing import ( + ArrayAny, + ArrayOrScalar, + ArrayOrScalarT1, + ArrayOrScalarT2, + ChunkedArrayAny, + NativeIntervalUnit, + ScalarAny, + ) + from narwhals._duration import IntervalUnit + from narwhals._utils import Version + from narwhals.dtypes import DType + from narwhals.typing import IntoDType, PythonLiteral + + # NOTE: stubs don't allow for `ChunkedArray[StructArray]` + # Intended to represent the `.chunks` property storing `list[pa.StructArray]` + ChunkedArrayStructArray: TypeAlias = ChunkedArrayAny + + def is_timestamp(t: Any) -> TypeIs[pa.TimestampType[Any, Any]]: ... + def is_duration(t: Any) -> TypeIs[pa.DurationType[Any]]: ... + def is_list(t: Any) -> TypeIs[pa.ListType[Any]]: ... + def is_large_list(t: Any) -> TypeIs[pa.LargeListType[Any]]: ... + def is_fixed_size_list(t: Any) -> TypeIs[pa.FixedSizeListType[Any, Any]]: ... + def is_dictionary(t: Any) -> TypeIs[pa.DictionaryType[Any, Any, Any]]: ... + def extract_regex( + strings: ChunkedArrayAny, + /, + pattern: str, + *, + options: Any = None, + memory_pool: Any = None, + ) -> ChunkedArrayStructArray: ... +else: + from pyarrow.compute import extract_regex + from pyarrow.types import ( + is_dictionary, # noqa: F401 + is_duration, + is_fixed_size_list, + is_large_list, + is_list, + is_timestamp, + ) + +UNITS_DICT: Mapping[IntervalUnit, NativeIntervalUnit] = { + "y": "year", + "q": "quarter", + "mo": "month", + "d": "day", + "h": "hour", + "m": "minute", + "s": "second", + "ms": "millisecond", + "us": "microsecond", + "ns": "nanosecond", +} + +lit = pa.scalar +"""Alias for `pyarrow.scalar`.""" + + +def extract_py_scalar(value: Any, /) -> Any: + from narwhals._arrow.series import maybe_extract_py_scalar + + return maybe_extract_py_scalar(value, return_py_scalar=True) + + +def chunked_array( + arr: ArrayOrScalar | list[Iterable[Any]], dtype: pa.DataType | None = None, / +) -> ChunkedArrayAny: + if isinstance(arr, pa.ChunkedArray): + return arr + if isinstance(arr, list): + return pa.chunked_array(arr, dtype) + else: + return pa.chunked_array([arr], arr.type) + + +def nulls_like(n: int, series: ArrowSeries) -> ArrayAny: + """Create a strongly-typed Array instance with all elements null. + + Uses the type of `series`, without upseting `mypy`. + """ + return pa.nulls(n, series.native.type) + + +@lru_cache(maxsize=16) +def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: # noqa: C901, PLR0912 + dtypes = version.dtypes + if pa.types.is_int64(dtype): + return dtypes.Int64() + if pa.types.is_int32(dtype): + return dtypes.Int32() + if pa.types.is_int16(dtype): + return dtypes.Int16() + if pa.types.is_int8(dtype): + return dtypes.Int8() + if pa.types.is_uint64(dtype): + return dtypes.UInt64() + if pa.types.is_uint32(dtype): + return dtypes.UInt32() + if pa.types.is_uint16(dtype): + return dtypes.UInt16() + if pa.types.is_uint8(dtype): + return dtypes.UInt8() + if pa.types.is_boolean(dtype): + return dtypes.Boolean() + if pa.types.is_float64(dtype): + return dtypes.Float64() + if pa.types.is_float32(dtype): + return dtypes.Float32() + # bug in coverage? it shows `31->exit` (where `31` is currently the line number of + # the next line), even though both when the if condition is true and false are covered + if ( # pragma: no cover + pa.types.is_string(dtype) + or pa.types.is_large_string(dtype) + or getattr(pa.types, "is_string_view", lambda _: False)(dtype) + ): + return dtypes.String() + if pa.types.is_date32(dtype): + return dtypes.Date() + if is_timestamp(dtype): + return dtypes.Datetime(time_unit=dtype.unit, time_zone=dtype.tz) + if is_duration(dtype): + return dtypes.Duration(time_unit=dtype.unit) + if pa.types.is_dictionary(dtype): + return dtypes.Categorical() + if pa.types.is_struct(dtype): + return dtypes.Struct( + [ + dtypes.Field( + dtype.field(i).name, + native_to_narwhals_dtype(dtype.field(i).type, version), + ) + for i in range(dtype.num_fields) + ] + ) + if is_list(dtype) or is_large_list(dtype): + return dtypes.List(native_to_narwhals_dtype(dtype.value_type, version)) + if is_fixed_size_list(dtype): + return dtypes.Array( + native_to_narwhals_dtype(dtype.value_type, version), dtype.list_size + ) + if pa.types.is_decimal(dtype): + return dtypes.Decimal() + if pa.types.is_time32(dtype) or pa.types.is_time64(dtype): + return dtypes.Time() + if pa.types.is_binary(dtype): + return dtypes.Binary() + return dtypes.Unknown() # pragma: no cover + + +def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType: # noqa: C901, PLR0912 + dtypes = version.dtypes + if isinstance_or_issubclass(dtype, dtypes.Decimal): + msg = "Casting to Decimal is not supported yet." + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Float64): + return pa.float64() + if isinstance_or_issubclass(dtype, dtypes.Float32): + return pa.float32() + if isinstance_or_issubclass(dtype, dtypes.Int64): + return pa.int64() + if isinstance_or_issubclass(dtype, dtypes.Int32): + return pa.int32() + if isinstance_or_issubclass(dtype, dtypes.Int16): + return pa.int16() + if isinstance_or_issubclass(dtype, dtypes.Int8): + return pa.int8() + if isinstance_or_issubclass(dtype, dtypes.UInt64): + return pa.uint64() + if isinstance_or_issubclass(dtype, dtypes.UInt32): + return pa.uint32() + if isinstance_or_issubclass(dtype, dtypes.UInt16): + return pa.uint16() + if isinstance_or_issubclass(dtype, dtypes.UInt8): + return pa.uint8() + if isinstance_or_issubclass(dtype, dtypes.String): + return pa.string() + if isinstance_or_issubclass(dtype, dtypes.Boolean): + return pa.bool_() + if isinstance_or_issubclass(dtype, dtypes.Categorical): + return pa.dictionary(pa.uint32(), pa.string()) + if isinstance_or_issubclass(dtype, dtypes.Datetime): + unit = dtype.time_unit + return pa.timestamp(unit, tz) if (tz := dtype.time_zone) else pa.timestamp(unit) + if isinstance_or_issubclass(dtype, dtypes.Duration): + return pa.duration(dtype.time_unit) + if isinstance_or_issubclass(dtype, dtypes.Date): + return pa.date32() + if isinstance_or_issubclass(dtype, dtypes.List): + return pa.list_(value_type=narwhals_to_native_dtype(dtype.inner, version=version)) + if isinstance_or_issubclass(dtype, dtypes.Struct): + return pa.struct( + [ + (field.name, narwhals_to_native_dtype(field.dtype, version=version)) + for field in dtype.fields + ] + ) + if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover + inner = narwhals_to_native_dtype(dtype.inner, version=version) + list_size = dtype.size + return pa.list_(inner, list_size=list_size) + if isinstance_or_issubclass(dtype, dtypes.Time): + return pa.time64("ns") + if isinstance_or_issubclass(dtype, dtypes.Binary): + return pa.binary() + + msg = f"Unknown dtype: {dtype}" # pragma: no cover + raise AssertionError(msg) + + +def extract_native( + lhs: ArrowSeries, rhs: ArrowSeries | PythonLiteral | ScalarAny +) -> tuple[ChunkedArrayAny | ScalarAny, ChunkedArrayAny | ScalarAny]: + """Extract native objects in binary operation. + + If the comparison isn't supported, return `NotImplemented` so that the + "right-hand-side" operation (e.g. `__radd__`) can be tried. + + If one of the two sides has a `_broadcast` flag, then extract the scalar + underneath it so that PyArrow can do its own broadcasting. + """ + from narwhals._arrow.dataframe import ArrowDataFrame + from narwhals._arrow.series import ArrowSeries + + if rhs is None: # pragma: no cover + return lhs.native, lit(None, type=lhs._type) + + if isinstance(rhs, ArrowDataFrame): + return NotImplemented + + if isinstance(rhs, ArrowSeries): + if lhs._broadcast and not rhs._broadcast: + return lhs.native[0], rhs.native + if rhs._broadcast: + return lhs.native, rhs.native[0] + return lhs.native, rhs.native + + if isinstance(rhs, list): + msg = "Expected Series or scalar, got list." + raise TypeError(msg) + + return lhs.native, rhs if isinstance(rhs, pa.Scalar) else lit(rhs) + + +def align_series_full_broadcast(*series: ArrowSeries) -> Sequence[ArrowSeries]: + # Ensure all of `series` are of the same length. + lengths = [len(s) for s in series] + max_length = max(lengths) + fast_path = all(_len == max_length for _len in lengths) + + if fast_path: + return series + + reshaped = [] + for s in series: + if s._broadcast: + value = s.native[0] + if s._backend_version < (13,) and hasattr(value, "as_py"): + value = value.as_py() + reshaped.append(s._with_native(pa.array([value] * max_length, type=s._type))) + else: + if (actual_len := len(s)) != max_length: + msg = f"Expected object of length {max_length}, got {actual_len}." + raise ShapeError(msg) + reshaped.append(s) + + return reshaped + + +def floordiv_compat(left: ArrayOrScalar, right: ArrayOrScalar, /) -> Any: + # The following lines are adapted from pandas' pyarrow implementation. + # Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154 + + if pa.types.is_integer(left.type) and pa.types.is_integer(right.type): + divided = pc.divide_checked(left, right) + # TODO @dangotbanned: Use a `TypeVar` in guards + # Narrowing to a `Union` isn't interacting well with the rest of the stubs + # https://github.com/zen-xu/pyarrow-stubs/pull/215 + if pa.types.is_signed_integer(divided.type): + div_type = cast("pa._lib.Int64Type", divided.type) + has_remainder = pc.not_equal(pc.multiply(divided, right), left) + has_one_negative_operand = pc.less( + pc.bit_wise_xor(left, right), lit(0, div_type) + ) + result = pc.if_else( + pc.and_(has_remainder, has_one_negative_operand), + pc.subtract(divided, lit(1, div_type)), + divided, + ) + else: + result = divided # pragma: no cover + result = result.cast(left.type) + else: + divided = pc.divide(left, right) + result = pc.floor(divided) + return result + + +def cast_for_truediv( + arrow_array: ArrayOrScalarT1, pa_object: ArrayOrScalarT2 +) -> tuple[ArrayOrScalarT1, ArrayOrScalarT2]: + # Lifted from: + # https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122 + # Ensure int / int -> float mirroring Python/Numpy behavior + # as pc.divide_checked(int, int) -> int + if pa.types.is_integer(arrow_array.type) and pa.types.is_integer(pa_object.type): + # GH: 56645. # noqa: ERA001 + # https://github.com/apache/arrow/issues/35563 + # NOTE: `pyarrow==11.*` doesn't allow keywords in `Array.cast` + return pc.cast(arrow_array, pa.float64(), safe=False), pc.cast( + pa_object, pa.float64(), safe=False + ) + + return arrow_array, pa_object + + +# Regex for date, time, separator and timezone components +DATE_RE = r"(?P<date>\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4}|\d{8})" +SEP_RE = r"(?P<sep>\s|T)" +TIME_RE = r"(?P<time>\d{2}:\d{2}(?::\d{2})?|\d{6}?)" # \s*(?P<period>[AP]M)?)? +HMS_RE = r"^(?P<hms>\d{2}:\d{2}:\d{2})$" +HM_RE = r"^(?P<hm>\d{2}:\d{2})$" +HMS_RE_NO_SEP = r"^(?P<hms_no_sep>\d{6})$" +TZ_RE = r"(?P<tz>Z|[+-]\d{2}:?\d{2})" # Matches 'Z', '+02:00', '+0200', '+02', etc. +FULL_RE = rf"{DATE_RE}{SEP_RE}?{TIME_RE}?{TZ_RE}?$" + +# Separate regexes for different date formats +YMD_RE = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])$" +DMY_RE = r"^(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$" +MDY_RE = r"^(?P<month>0[1-9]|1[0-2])(?P<sep1>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$" +YMD_RE_NO_SEP = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<month>0[1-9]|1[0-2])(?P<day>0[1-9]|[12][0-9]|3[01])$" + +DATE_FORMATS = ( + (YMD_RE_NO_SEP, "%Y%m%d"), + (YMD_RE, "%Y-%m-%d"), + (DMY_RE, "%d-%m-%Y"), + (MDY_RE, "%m-%d-%Y"), +) +TIME_FORMATS = ((HMS_RE, "%H:%M:%S"), (HM_RE, "%H:%M"), (HMS_RE_NO_SEP, "%H%M%S")) + + +def _extract_regex_concat_arrays( + strings: ChunkedArrayAny, + /, + pattern: str, + *, + options: Any = None, + memory_pool: Any = None, +) -> pa.StructArray: + r = pa.concat_arrays( + extract_regex(strings, pattern, options=options, memory_pool=memory_pool).chunks + ) + return cast("pa.StructArray", r) + + +def parse_datetime_format(arr: ChunkedArrayAny) -> str: + """Try to infer datetime format from StringArray.""" + matches = _extract_regex_concat_arrays(arr.drop_null().slice(0, 10), pattern=FULL_RE) + if not pc.all(matches.is_valid()).as_py(): + msg = ( + "Unable to infer datetime format, provided format is not supported. " + "Please report a bug to https://github.com/narwhals-dev/narwhals/issues" + ) + raise NotImplementedError(msg) + + separators = matches.field("sep") + tz = matches.field("tz") + + # separators and time zones must be unique + if pc.count(pc.unique(separators)).as_py() > 1: + msg = "Found multiple separator values while inferring datetime format." + raise ValueError(msg) + + if pc.count(pc.unique(tz)).as_py() > 1: + msg = "Found multiple timezone values while inferring datetime format." + raise ValueError(msg) + + date_value = _parse_date_format(cast("pc.StringArray", matches.field("date"))) + time_value = _parse_time_format(cast("pc.StringArray", matches.field("time"))) + + sep_value = separators[0].as_py() + tz_value = "%z" if tz[0].as_py() else "" + + return f"{date_value}{sep_value}{time_value}{tz_value}" + + +def _parse_date_format(arr: pc.StringArray) -> str: + for date_rgx, date_fmt in DATE_FORMATS: + matches = pc.extract_regex(arr, pattern=date_rgx) + if date_fmt == "%Y%m%d" and pc.all(matches.is_valid()).as_py(): + return date_fmt + elif ( + pc.all(matches.is_valid()).as_py() + and pc.count(pc.unique(sep1 := matches.field("sep1"))).as_py() == 1 + and pc.count(pc.unique(sep2 := matches.field("sep2"))).as_py() == 1 + and (date_sep_value := sep1[0].as_py()) == sep2[0].as_py() + ): + return date_fmt.replace("-", date_sep_value) + + msg = ( + "Unable to infer datetime format. " + "Please report a bug to https://github.com/narwhals-dev/narwhals/issues" + ) + raise ValueError(msg) + + +def _parse_time_format(arr: pc.StringArray) -> str: + for time_rgx, time_fmt in TIME_FORMATS: + matches = pc.extract_regex(arr, pattern=time_rgx) + if pc.all(matches.is_valid()).as_py(): + return time_fmt + return "" + + +def pad_series( + series: ArrowSeries, *, window_size: int, center: bool +) -> tuple[ArrowSeries, int]: + """Pad series with None values on the left and/or right side, depending on the specified parameters. + + Arguments: + series: The input ArrowSeries to be padded. + window_size: The desired size of the window. + center: Specifies whether to center the padding or not. + + Returns: + A tuple containing the padded ArrowSeries and the offset value. + """ + if not center: + return series, 0 + offset_left = window_size // 2 + # subtract one if window_size is even + offset_right = offset_left - (window_size % 2 == 0) + pad_left = pa.array([None] * offset_left, type=series._type) + pad_right = pa.array([None] * offset_right, type=series._type) + concat = pa.concat_arrays([pad_left, *series.native.chunks, pad_right]) + return series._with_native(concat), offset_left + offset_right + + +def cast_to_comparable_string_types( + *chunked_arrays: ChunkedArrayAny, separator: str +) -> tuple[Iterator[ChunkedArrayAny], ScalarAny]: + # Ensure `chunked_arrays` are either all `string` or all `large_string`. + dtype = ( + pa.string() # (PyArrow default) + if not any(pa.types.is_large_string(ca.type) for ca in chunked_arrays) + else pa.large_string() + ) + return (ca.cast(dtype) for ca in chunked_arrays), lit(separator, dtype) + + +class ArrowSeriesNamespace(_SeriesNamespace["ArrowSeries", "ChunkedArrayAny"]): + def __init__(self, series: ArrowSeries, /) -> None: + self._compliant_series = series |