from __future__ import annotations import operator from functools import reduce from typing import TYPE_CHECKING, Callable, Iterable, Sequence from narwhals._compliant import LazyNamespace, LazyThen, LazyWhen from narwhals._expression_parsing import ( combine_alias_output_names, combine_evaluate_output_names, ) from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.expr import SparkLikeExpr from narwhals._spark_like.selectors import SparkLikeSelectorNamespace from narwhals._spark_like.utils import ( import_functions, import_native_dtypes, narwhals_to_native_dtype, ) if TYPE_CHECKING: from sqlframe.base.column import Column from narwhals._spark_like.dataframe import SQLFrameDataFrame # noqa: F401 from narwhals._spark_like.expr import SparkWindowInputs from narwhals._utils import Implementation, Version from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral class SparkLikeNamespace( LazyNamespace[SparkLikeLazyFrame, SparkLikeExpr, "SQLFrameDataFrame"] ): def __init__( self, *, backend_version: tuple[int, ...], version: Version, implementation: Implementation, ) -> None: self._backend_version = backend_version self._version = version self._implementation = implementation @property def selectors(self) -> SparkLikeSelectorNamespace: return SparkLikeSelectorNamespace.from_namespace(self) @property def _expr(self) -> type[SparkLikeExpr]: return SparkLikeExpr @property def _lazyframe(self) -> type[SparkLikeLazyFrame]: return SparkLikeLazyFrame @property def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202, N802 if TYPE_CHECKING: from sqlframe.base import functions return functions else: return import_functions(self._implementation) @property def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202 if TYPE_CHECKING: from sqlframe.base import types return types else: return import_native_dtypes(self._implementation) def _with_elementwise( self, func: Callable[[Iterable[Column]], Column], *exprs: SparkLikeExpr ) -> SparkLikeExpr: def call(df: SparkLikeLazyFrame) -> list[Column]: cols = (col for _expr in exprs for col in _expr(df)) return [func(cols)] def window_function( df: SparkLikeLazyFrame, window_inputs: SparkWindowInputs ) -> list[Column]: cols = ( col for _expr in exprs for col in _expr.window_function(df, window_inputs) ) return [func(cols)] return self._expr( call=call, window_function=window_function, evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> SparkLikeExpr: def _lit(df: SparkLikeLazyFrame) -> list[Column]: column = df._F.lit(value) if dtype: native_dtype = narwhals_to_native_dtype( dtype, version=self._version, spark_types=df._native_dtypes ) column = column.cast(native_dtype) return [column] return self._expr( call=_lit, evaluate_output_names=lambda _df: ["literal"], alias_output_names=None, backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) def len(self) -> SparkLikeExpr: def func(df: SparkLikeLazyFrame) -> list[Column]: return [df._F.count("*")] return self._expr( func, evaluate_output_names=lambda _df: ["len"], alias_output_names=None, backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) def all_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr: def func(cols: Iterable[Column]) -> Column: return reduce(operator.and_, cols) return self._with_elementwise(func, *exprs) def any_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr: def func(cols: Iterable[Column]) -> Column: return reduce(operator.or_, cols) return self._with_elementwise(func, *exprs) def max_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr: def func(cols: Iterable[Column]) -> Column: return self._F.greatest(*cols) return self._with_elementwise(func, *exprs) def min_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr: def func(cols: Iterable[Column]) -> Column: return self._F.least(*cols) return self._with_elementwise(func, *exprs) def sum_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr: def func(cols: Iterable[Column]) -> Column: return reduce( operator.add, (self._F.coalesce(col, self._F.lit(0)) for col in cols) ) return self._with_elementwise(func, *exprs) def mean_horizontal(self, *exprs: SparkLikeExpr) -> SparkLikeExpr: def func(cols: Iterable[Column]) -> Column: cols = list(cols) F = exprs[0]._F # noqa: N806 # PySpark before 3.5 doesn't have `try_divide`, SQLFrame doesn't have it. divide = getattr(F, "try_divide", operator.truediv) return divide( reduce( operator.add, (self._F.coalesce(col, self._F.lit(0)) for col in cols) ), reduce( operator.add, ( col.isNotNull().cast(self._native_dtypes.IntegerType()) for col in cols ), ), ) return self._with_elementwise(func, *exprs) def concat( self, items: Iterable[SparkLikeLazyFrame], *, how: ConcatMethod ) -> SparkLikeLazyFrame: dfs = [item._native_frame for item in items] if how == "vertical": cols_0 = dfs[0].columns for i, df in enumerate(dfs[1:], start=1): cols_current = df.columns if not ((len(cols_current) == len(cols_0)) and (cols_current == cols_0)): msg = ( "unable to vstack, column names don't match:\n" f" - dataframe 0: {cols_0}\n" f" - dataframe {i}: {cols_current}\n" ) raise TypeError(msg) return SparkLikeLazyFrame( native_dataframe=reduce(lambda x, y: x.union(y), dfs), backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) if how == "diagonal": return SparkLikeLazyFrame( native_dataframe=reduce( lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs ), backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) raise NotImplementedError def concat_str( self, *exprs: SparkLikeExpr, separator: str, ignore_nulls: bool ) -> SparkLikeExpr: def func(df: SparkLikeLazyFrame) -> list[Column]: cols = [s for _expr in exprs for s in _expr(df)] cols_casted = [s.cast(df._native_dtypes.StringType()) for s in cols] null_mask = [df._F.isnull(s) for s in cols] if not ignore_nulls: null_mask_result = reduce(operator.or_, null_mask) result = df._F.when( ~null_mask_result, reduce( lambda x, y: df._F.format_string(f"%s{separator}%s", x, y), cols_casted, ), ).otherwise(df._F.lit(None)) else: init_value, *values = [ df._F.when(~nm, col).otherwise(df._F.lit("")) for col, nm in zip(cols_casted, null_mask) ] separators = ( df._F.when(nm, df._F.lit("")).otherwise(df._F.lit(separator)) for nm in null_mask[:-1] ) result = reduce( lambda x, y: df._F.format_string("%s%s", x, y), ( df._F.format_string("%s%s", s, v) for s, v in zip(separators, values) ), init_value, ) return [result] return self._expr( call=func, evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) def when(self, predicate: SparkLikeExpr) -> SparkLikeWhen: return SparkLikeWhen.from_expr(predicate, context=self) class SparkLikeWhen(LazyWhen[SparkLikeLazyFrame, "Column", SparkLikeExpr]): @property def _then(self) -> type[SparkLikeThen]: return SparkLikeThen def __call__(self, df: SparkLikeLazyFrame) -> Sequence[Column]: self.when = df._F.when self.lit = df._F.lit return super().__call__(df) def _window_function( self, df: SparkLikeLazyFrame, window_inputs: SparkWindowInputs ) -> Sequence[Column]: self.when = df._F.when self.lit = df._F.lit return super()._window_function(df, window_inputs) class SparkLikeThen( LazyThen[SparkLikeLazyFrame, "Column", SparkLikeExpr], SparkLikeExpr ): ...