diff options
author | sotech117 <michael_foiani@brown.edu> | 2025-07-31 17:27:24 -0400 |
---|---|---|
committer | sotech117 <michael_foiani@brown.edu> | 2025-07-31 17:27:24 -0400 |
commit | 5bf22fc7e3c392c8bd44315ca2d06d7dca7d084e (patch) | |
tree | 8dacb0f195df1c0788d36dd0064f6bbaa3143ede /venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py | |
parent | b832d364da8c2efe09e3f75828caf73c50d01ce3 (diff) |
add code for analysis of data
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py | 283 |
1 files changed, 283 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py b/venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py new file mode 100644 index 0000000..02d4c69 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_arrow/namespace.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import operator +from functools import reduce +from itertools import chain +from typing import TYPE_CHECKING, Literal, Sequence + +import pyarrow as pa +import pyarrow.compute as pc + +from narwhals._arrow.dataframe import ArrowDataFrame +from narwhals._arrow.expr import ArrowExpr +from narwhals._arrow.selectors import ArrowSelectorNamespace +from narwhals._arrow.series import ArrowSeries +from narwhals._arrow.utils import ( + align_series_full_broadcast, + cast_to_comparable_string_types, +) +from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen +from narwhals._expression_parsing import ( + combine_alias_output_names, + combine_evaluate_output_names, +) +from narwhals._utils import Implementation + +if TYPE_CHECKING: + from narwhals._arrow.typing import Incomplete + from narwhals._utils import Version + from narwhals.typing import IntoDType, NonNestedLiteral + + +class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, pa.Table]): + @property + def _dataframe(self) -> type[ArrowDataFrame]: + return ArrowDataFrame + + @property + def _expr(self) -> type[ArrowExpr]: + return ArrowExpr + + @property + def _series(self) -> type[ArrowSeries]: + return ArrowSeries + + # --- not in spec --- + def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> None: + self._backend_version = backend_version + self._implementation = Implementation.PYARROW + self._version = version + + def len(self) -> ArrowExpr: + # coverage bug? this is definitely hit + return self._expr( # pragma: no cover + lambda df: [ + ArrowSeries.from_iterable([len(df.native)], name="len", context=self) + ], + depth=0, + function_name="len", + evaluate_output_names=lambda _df: ["len"], + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + ) + + def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> ArrowExpr: + def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: + arrow_series = ArrowSeries.from_iterable( + data=[value], name="literal", context=self + ) + if dtype: + return arrow_series.cast(dtype) + return arrow_series + + return self._expr( + lambda df: [_lit_arrow_series(df)], + depth=0, + function_name="lit", + evaluate_output_names=lambda _df: ["literal"], + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + ) + + def all_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + series = chain.from_iterable(expr(df) for expr in exprs) + return [reduce(operator.and_, align_series_full_broadcast(*series))] + + return self._expr._from_callable( + func=func, + depth=max(x._depth for x in exprs) + 1, + function_name="all_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + context=self, + ) + + def any_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + series = chain.from_iterable(expr(df) for expr in exprs) + return [reduce(operator.or_, align_series_full_broadcast(*series))] + + return self._expr._from_callable( + func=func, + depth=max(x._depth for x in exprs) + 1, + function_name="any_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + context=self, + ) + + def sum_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + it = chain.from_iterable(expr(df) for expr in exprs) + series = (s.fill_null(0, strategy=None, limit=None) for s in it) + return [reduce(operator.add, align_series_full_broadcast(*series))] + + return self._expr._from_callable( + func=func, + depth=max(x._depth for x in exprs) + 1, + function_name="sum_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + context=self, + ) + + def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + int_64 = self._version.dtypes.Int64() + + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + expr_results = list(chain.from_iterable(expr(df) for expr in exprs)) + series = align_series_full_broadcast( + *(s.fill_null(0, strategy=None, limit=None) for s in expr_results) + ) + non_na = align_series_full_broadcast( + *(1 - s.is_null().cast(int_64) for s in expr_results) + ) + return [reduce(operator.add, series) / reduce(operator.add, non_na)] + + return self._expr._from_callable( + func=func, + depth=max(x._depth for x in exprs) + 1, + function_name="mean_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + context=self, + ) + + def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs)) + init_series, *series = align_series_full_broadcast(init_series, *series) + native_series = reduce( + pc.min_element_wise, [s.native for s in series], init_series.native + ) + return [ + ArrowSeries( + native_series, + name=init_series.name, + backend_version=self._backend_version, + version=self._version, + ) + ] + + return self._expr._from_callable( + func=func, + depth=max(x._depth for x in exprs) + 1, + function_name="min_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + context=self, + ) + + def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs)) + init_series, *series = align_series_full_broadcast(init_series, *series) + native_series = reduce( + pc.max_element_wise, [s.native for s in series], init_series.native + ) + return [ + ArrowSeries( + native_series, + name=init_series.name, + backend_version=self._backend_version, + version=self._version, + ) + ] + + return self._expr._from_callable( + func=func, + depth=max(x._depth for x in exprs) + 1, + function_name="max_horizontal", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + context=self, + ) + + def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table: + if self._backend_version >= (14,): + return pa.concat_tables(dfs, promote_options="default") + return pa.concat_tables(dfs, promote=True) # pragma: no cover + + def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table: + names = list(chain.from_iterable(df.column_names for df in dfs)) + arrays = list(chain.from_iterable(df.itercolumns() for df in dfs)) + return pa.Table.from_arrays(arrays, names=names) + + def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table: + cols_0 = dfs[0].column_names + for i, df in enumerate(dfs[1:], start=1): + cols_current = df.column_names + if cols_current != cols_0: + msg = ( + "unable to vstack, column names don't match:\n" + f" - dataframe 0: {cols_0}\n" + f" - dataframe {i}: {cols_current}\n" + ) + raise TypeError(msg) + return pa.concat_tables(dfs) + + @property + def selectors(self) -> ArrowSelectorNamespace: + return ArrowSelectorNamespace.from_namespace(self) + + def when(self, predicate: ArrowExpr) -> ArrowWhen: + return ArrowWhen.from_expr(predicate, context=self) + + def concat_str( + self, *exprs: ArrowExpr, separator: str, ignore_nulls: bool + ) -> ArrowExpr: + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + compliant_series_list = align_series_full_broadcast( + *(chain.from_iterable(expr(df) for expr in exprs)) + ) + name = compliant_series_list[0].name + null_handling: Literal["skip", "emit_null"] = ( + "skip" if ignore_nulls else "emit_null" + ) + it, separator_scalar = cast_to_comparable_string_types( + *(s.native for s in compliant_series_list), separator=separator + ) + # NOTE: stubs indicate `separator` must also be a `ChunkedArray` + # Reality: `str` is fine + concat_str: Incomplete = pc.binary_join_element_wise + compliant = self._series( + concat_str(*it, separator_scalar, null_handling=null_handling), + name=name, + backend_version=self._backend_version, + version=self._version, + ) + return [compliant] + + return self._expr._from_callable( + func=func, + depth=max(x._depth for x in exprs) + 1, + function_name="concat_str", + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), + context=self, + ) + + +class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr]): + @property + def _then(self) -> type[ArrowThen]: + return ArrowThen + + def _if_then_else( + self, when: ArrowSeries, then: ArrowSeries, otherwise: ArrowSeries | None, / + ) -> ArrowSeries: + if otherwise is None: + when, then = align_series_full_broadcast(when, then) + res_native = pc.if_else( + when.native, then.native, pa.nulls(len(when.native), then.native.type) + ) + else: + when, then, otherwise = align_series_full_broadcast(when, then, otherwise) + res_native = pc.if_else(when.native, then.native, otherwise.native) + return then._with_native(res_native) + + +class ArrowThen(CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr], ArrowExpr): ... |