aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py')
-rw-r--r--venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py283
1 files changed, 283 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py
new file mode 100644
index 0000000..02d4c69
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py
@@ -0,0 +1,283 @@
+from __future__ import annotations
+
+import operator
+from functools import reduce
+from itertools import chain
+from typing import TYPE_CHECKING, Literal, Sequence
+
+import pyarrow as pa
+import pyarrow.compute as pc
+
+from narwhals._arrow.dataframe import ArrowDataFrame
+from narwhals._arrow.expr import ArrowExpr
+from narwhals._arrow.selectors import ArrowSelectorNamespace
+from narwhals._arrow.series import ArrowSeries
+from narwhals._arrow.utils import (
+ align_series_full_broadcast,
+ cast_to_comparable_string_types,
+)
+from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen
+from narwhals._expression_parsing import (
+ combine_alias_output_names,
+ combine_evaluate_output_names,
+)
+from narwhals._utils import Implementation
+
+if TYPE_CHECKING:
+ from narwhals._arrow.typing import Incomplete
+ from narwhals._utils import Version
+ from narwhals.typing import IntoDType, NonNestedLiteral
+
+
+class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, pa.Table]):
+ @property
+ def _dataframe(self) -> type[ArrowDataFrame]:
+ return ArrowDataFrame
+
+ @property
+ def _expr(self) -> type[ArrowExpr]:
+ return ArrowExpr
+
+ @property
+ def _series(self) -> type[ArrowSeries]:
+ return ArrowSeries
+
+ # --- not in spec ---
+ def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> None:
+ self._backend_version = backend_version
+ self._implementation = Implementation.PYARROW
+ self._version = version
+
+ def len(self) -> ArrowExpr:
+ # coverage bug? this is definitely hit
+ return self._expr( # pragma: no cover
+ lambda df: [
+ ArrowSeries.from_iterable([len(df.native)], name="len", context=self)
+ ],
+ depth=0,
+ function_name="len",
+ evaluate_output_names=lambda _df: ["len"],
+ alias_output_names=None,
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+
+ def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> ArrowExpr:
+ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
+ arrow_series = ArrowSeries.from_iterable(
+ data=[value], name="literal", context=self
+ )
+ if dtype:
+ return arrow_series.cast(dtype)
+ return arrow_series
+
+ return self._expr(
+ lambda df: [_lit_arrow_series(df)],
+ depth=0,
+ function_name="lit",
+ evaluate_output_names=lambda _df: ["literal"],
+ alias_output_names=None,
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+
+ def all_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
+ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
+ series = chain.from_iterable(expr(df) for expr in exprs)
+ return [reduce(operator.and_, align_series_full_broadcast(*series))]
+
+ return self._expr._from_callable(
+ func=func,
+ depth=max(x._depth for x in exprs) + 1,
+ function_name="all_horizontal",
+ evaluate_output_names=combine_evaluate_output_names(*exprs),
+ alias_output_names=combine_alias_output_names(*exprs),
+ context=self,
+ )
+
+ def any_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
+ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
+ series = chain.from_iterable(expr(df) for expr in exprs)
+ return [reduce(operator.or_, align_series_full_broadcast(*series))]
+
+ return self._expr._from_callable(
+ func=func,
+ depth=max(x._depth for x in exprs) + 1,
+ function_name="any_horizontal",
+ evaluate_output_names=combine_evaluate_output_names(*exprs),
+ alias_output_names=combine_alias_output_names(*exprs),
+ context=self,
+ )
+
+ def sum_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
+ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
+ it = chain.from_iterable(expr(df) for expr in exprs)
+ series = (s.fill_null(0, strategy=None, limit=None) for s in it)
+ return [reduce(operator.add, align_series_full_broadcast(*series))]
+
+ return self._expr._from_callable(
+ func=func,
+ depth=max(x._depth for x in exprs) + 1,
+ function_name="sum_horizontal",
+ evaluate_output_names=combine_evaluate_output_names(*exprs),
+ alias_output_names=combine_alias_output_names(*exprs),
+ context=self,
+ )
+
+ def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
+ int_64 = self._version.dtypes.Int64()
+
+ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
+ expr_results = list(chain.from_iterable(expr(df) for expr in exprs))
+ series = align_series_full_broadcast(
+ *(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
+ )
+ non_na = align_series_full_broadcast(
+ *(1 - s.is_null().cast(int_64) for s in expr_results)
+ )
+ return [reduce(operator.add, series) / reduce(operator.add, non_na)]
+
+ return self._expr._from_callable(
+ func=func,
+ depth=max(x._depth for x in exprs) + 1,
+ function_name="mean_horizontal",
+ evaluate_output_names=combine_evaluate_output_names(*exprs),
+ alias_output_names=combine_alias_output_names(*exprs),
+ context=self,
+ )
+
+ def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
+ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
+ init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
+ init_series, *series = align_series_full_broadcast(init_series, *series)
+ native_series = reduce(
+ pc.min_element_wise, [s.native for s in series], init_series.native
+ )
+ return [
+ ArrowSeries(
+ native_series,
+ name=init_series.name,
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+ ]
+
+ return self._expr._from_callable(
+ func=func,
+ depth=max(x._depth for x in exprs) + 1,
+ function_name="min_horizontal",
+ evaluate_output_names=combine_evaluate_output_names(*exprs),
+ alias_output_names=combine_alias_output_names(*exprs),
+ context=self,
+ )
+
+ def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
+ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
+ init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
+ init_series, *series = align_series_full_broadcast(init_series, *series)
+ native_series = reduce(
+ pc.max_element_wise, [s.native for s in series], init_series.native
+ )
+ return [
+ ArrowSeries(
+ native_series,
+ name=init_series.name,
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+ ]
+
+ return self._expr._from_callable(
+ func=func,
+ depth=max(x._depth for x in exprs) + 1,
+ function_name="max_horizontal",
+ evaluate_output_names=combine_evaluate_output_names(*exprs),
+ alias_output_names=combine_alias_output_names(*exprs),
+ context=self,
+ )
+
+ def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
+ if self._backend_version >= (14,):
+ return pa.concat_tables(dfs, promote_options="default")
+ return pa.concat_tables(dfs, promote=True) # pragma: no cover
+
+ def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
+ names = list(chain.from_iterable(df.column_names for df in dfs))
+ arrays = list(chain.from_iterable(df.itercolumns() for df in dfs))
+ return pa.Table.from_arrays(arrays, names=names)
+
+ def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table:
+ cols_0 = dfs[0].column_names
+ for i, df in enumerate(dfs[1:], start=1):
+ cols_current = df.column_names
+ if cols_current != cols_0:
+ msg = (
+ "unable to vstack, column names don't match:\n"
+ f" - dataframe 0: {cols_0}\n"
+ f" - dataframe {i}: {cols_current}\n"
+ )
+ raise TypeError(msg)
+ return pa.concat_tables(dfs)
+
+ @property
+ def selectors(self) -> ArrowSelectorNamespace:
+ return ArrowSelectorNamespace.from_namespace(self)
+
+ def when(self, predicate: ArrowExpr) -> ArrowWhen:
+ return ArrowWhen.from_expr(predicate, context=self)
+
+ def concat_str(
+ self, *exprs: ArrowExpr, separator: str, ignore_nulls: bool
+ ) -> ArrowExpr:
+ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
+ compliant_series_list = align_series_full_broadcast(
+ *(chain.from_iterable(expr(df) for expr in exprs))
+ )
+ name = compliant_series_list[0].name
+ null_handling: Literal["skip", "emit_null"] = (
+ "skip" if ignore_nulls else "emit_null"
+ )
+ it, separator_scalar = cast_to_comparable_string_types(
+ *(s.native for s in compliant_series_list), separator=separator
+ )
+ # NOTE: stubs indicate `separator` must also be a `ChunkedArray`
+ # Reality: `str` is fine
+ concat_str: Incomplete = pc.binary_join_element_wise
+ compliant = self._series(
+ concat_str(*it, separator_scalar, null_handling=null_handling),
+ name=name,
+ backend_version=self._backend_version,
+ version=self._version,
+ )
+ return [compliant]
+
+ return self._expr._from_callable(
+ func=func,
+ depth=max(x._depth for x in exprs) + 1,
+ function_name="concat_str",
+ evaluate_output_names=combine_evaluate_output_names(*exprs),
+ alias_output_names=combine_alias_output_names(*exprs),
+ context=self,
+ )
+
+
+class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr]):
+ @property
+ def _then(self) -> type[ArrowThen]:
+ return ArrowThen
+
+ def _if_then_else(
+ self, when: ArrowSeries, then: ArrowSeries, otherwise: ArrowSeries | None, /
+ ) -> ArrowSeries:
+ if otherwise is None:
+ when, then = align_series_full_broadcast(when, then)
+ res_native = pc.if_else(
+ when.native, then.native, pa.nulls(len(when.native), then.native.type)
+ )
+ else:
+ when, then, otherwise = align_series_full_broadcast(when, then, otherwise)
+ res_native = pc.if_else(when.native, then.native, otherwise.native)
+ return then._with_native(res_native)
+
+
+class ArrowThen(CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr], ArrowExpr): ...