from __future__ import annotations import warnings from functools import reduce from operator import and_ from typing import TYPE_CHECKING, Any, Iterator, Mapping, Sequence from narwhals._namespace import is_native_spark_like from narwhals._spark_like.utils import ( evaluate_exprs, import_functions, import_native_dtypes, import_window, native_to_narwhals_dtype, ) from narwhals._utils import ( Implementation, find_stacklevel, generate_temporary_column_name, not_implemented, parse_columns_to_drop, parse_version, validate_backend_version, ) from narwhals.exceptions import InvalidOperationError from narwhals.typing import CompliantLazyFrame if TYPE_CHECKING: from types import ModuleType import pyarrow as pa from sqlframe.base.column import Column from sqlframe.base.dataframe import BaseDataFrame from sqlframe.base.window import Window from typing_extensions import Self, TypeAlias, TypeIs from narwhals._compliant.typing import CompliantDataFrameAny from narwhals._spark_like.expr import SparkLikeExpr from narwhals._spark_like.group_by import SparkLikeLazyGroupBy from narwhals._spark_like.namespace import SparkLikeNamespace from narwhals._utils import Version, _FullContext from narwhals.dataframe import LazyFrame from narwhals.dtypes import DType from narwhals.typing import JoinStrategy, LazyUniqueKeepStrategy SQLFrameDataFrame = BaseDataFrame[Any, Any, Any, Any, Any] Incomplete: TypeAlias = Any # pragma: no cover """Marker for working code that fails type checking.""" class SparkLikeLazyFrame( CompliantLazyFrame[ "SparkLikeExpr", "SQLFrameDataFrame", "LazyFrame[SQLFrameDataFrame]" ] ): def __init__( self, native_dataframe: SQLFrameDataFrame, *, backend_version: tuple[int, ...], version: Version, implementation: Implementation, ) -> None: self._native_frame: SQLFrameDataFrame = native_dataframe self._backend_version = backend_version self._implementation = implementation self._version = version self._cached_schema: dict[str, DType] | None = None self._cached_columns: list[str] | None = None validate_backend_version(self._implementation, self._backend_version) @property def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202, N802 if TYPE_CHECKING: from sqlframe.base import functions return functions else: return import_functions(self._implementation) @property def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202 if TYPE_CHECKING: from sqlframe.base import types return types else: return import_native_dtypes(self._implementation) @property def _Window(self) -> type[Window]: # noqa: N802 if TYPE_CHECKING: from sqlframe.base.window import Window return Window else: return import_window(self._implementation) @staticmethod def _is_native(obj: SQLFrameDataFrame | Any) -> TypeIs[SQLFrameDataFrame]: return is_native_spark_like(obj) @classmethod def from_native(cls, data: SQLFrameDataFrame, /, *, context: _FullContext) -> Self: return cls( data, backend_version=context._backend_version, version=context._version, implementation=context._implementation, ) def to_narwhals(self) -> LazyFrame[SQLFrameDataFrame]: return self._version.lazyframe(self, level="lazy") def __native_namespace__(self) -> ModuleType: # pragma: no cover return self._implementation.to_native_namespace() def __narwhals_namespace__(self) -> SparkLikeNamespace: from narwhals._spark_like.namespace import SparkLikeNamespace return SparkLikeNamespace( backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) def __narwhals_lazyframe__(self) -> Self: return self def _with_version(self, version: Version) -> Self: return self.__class__( self.native, backend_version=self._backend_version, version=version, implementation=self._implementation, ) def _with_native(self, df: SQLFrameDataFrame) -> Self: return self.__class__( df, backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) def _to_arrow_schema(self) -> pa.Schema: # pragma: no cover import pyarrow as pa # ignore-banned-import from narwhals._arrow.utils import narwhals_to_native_dtype schema: list[tuple[str, pa.DataType]] = [] nw_schema = self.collect_schema() native_schema = self.native.schema for key, value in nw_schema.items(): try: native_dtype = narwhals_to_native_dtype(value, self._version) except Exception as exc: # noqa: BLE001,PERF203 native_spark_dtype = native_schema[key].dataType # type: ignore[index] # If we can't convert the type, just set it to `pa.null`, and warn. # Avoid the warning if we're starting from PySpark's void type. # We can avoid the check when we introduce `nw.Null` dtype. null_type = self._native_dtypes.NullType # pyright: ignore[reportAttributeAccessIssue] if not isinstance(native_spark_dtype, null_type): warnings.warn( f"Could not convert dtype {native_spark_dtype} to PyArrow dtype, {exc!r}", stacklevel=find_stacklevel(), ) schema.append((key, pa.null())) else: schema.append((key, native_dtype)) return pa.schema(schema) def _collect_to_arrow(self) -> pa.Table: if self._implementation.is_pyspark() and self._backend_version < (4,): import pyarrow as pa # ignore-banned-import try: return pa.Table.from_batches(self.native._collect_as_arrow()) except ValueError as exc: if "at least one RecordBatch" in str(exc): # Empty dataframe data: dict[str, list[Any]] = {k: [] for k in self.columns} pa_schema = self._to_arrow_schema() return pa.Table.from_pydict(data, schema=pa_schema) else: # pragma: no cover raise elif self._implementation.is_pyspark_connect() and self._backend_version < (4,): import pyarrow as pa # ignore-banned-import pa_schema = self._to_arrow_schema() return pa.Table.from_pandas(self.native.toPandas(), schema=pa_schema) else: return self.native.toArrow() def _iter_columns(self) -> Iterator[Column]: for col in self.columns: yield self._F.col(col) @property def columns(self) -> list[str]: if self._cached_columns is None: self._cached_columns = ( list(self.schema) if self._cached_schema is not None else self.native.columns ) return self._cached_columns def collect( self, backend: ModuleType | Implementation | str | None, **kwargs: Any ) -> CompliantDataFrameAny: if backend is Implementation.PANDAS: import pandas as pd # ignore-banned-import from narwhals._pandas_like.dataframe import PandasLikeDataFrame return PandasLikeDataFrame( self.native.toPandas(), implementation=Implementation.PANDAS, backend_version=parse_version(pd), version=self._version, validate_column_names=True, ) elif backend is None or backend is Implementation.PYARROW: import pyarrow as pa # ignore-banned-import from narwhals._arrow.dataframe import ArrowDataFrame return ArrowDataFrame( self._collect_to_arrow(), backend_version=parse_version(pa), version=self._version, validate_column_names=True, ) elif backend is Implementation.POLARS: import polars as pl # ignore-banned-import import pyarrow as pa # ignore-banned-import from narwhals._polars.dataframe import PolarsDataFrame return PolarsDataFrame( pl.from_arrow(self._collect_to_arrow()), # type: ignore[arg-type] backend_version=parse_version(pl), version=self._version, ) msg = f"Unsupported `backend` value: {backend}" # pragma: no cover raise ValueError(msg) # pragma: no cover def simple_select(self, *column_names: str) -> Self: return self._with_native(self.native.select(*column_names)) def aggregate(self, *exprs: SparkLikeExpr) -> Self: new_columns = evaluate_exprs(self, *exprs) new_columns_list = [col.alias(col_name) for col_name, col in new_columns] return self._with_native(self.native.agg(*new_columns_list)) def select(self, *exprs: SparkLikeExpr) -> Self: new_columns = evaluate_exprs(self, *exprs) new_columns_list = [col.alias(col_name) for (col_name, col) in new_columns] return self._with_native(self.native.select(*new_columns_list)) def with_columns(self, *exprs: SparkLikeExpr) -> Self: new_columns = evaluate_exprs(self, *exprs) return self._with_native(self.native.withColumns(dict(new_columns))) def filter(self, predicate: SparkLikeExpr) -> Self: # `[0]` is safe as the predicate's expression only returns a single column condition = predicate._call(self)[0] spark_df = self.native.where(condition) return self._with_native(spark_df) @property def schema(self) -> dict[str, DType]: if self._cached_schema is None: self._cached_schema = { field.name: native_to_narwhals_dtype( field.dataType, self._version, self._native_dtypes, self.native.sparkSession, ) for field in self.native.schema } return self._cached_schema def collect_schema(self) -> dict[str, DType]: return self.schema def drop(self, columns: Sequence[str], *, strict: bool) -> Self: columns_to_drop = parse_columns_to_drop(self, columns, strict=strict) return self._with_native(self.native.drop(*columns_to_drop)) def head(self, n: int) -> Self: return self._with_native(self.native.limit(n)) def group_by( self, keys: Sequence[str] | Sequence[SparkLikeExpr], *, drop_null_keys: bool ) -> SparkLikeLazyGroupBy: from narwhals._spark_like.group_by import SparkLikeLazyGroupBy return SparkLikeLazyGroupBy(self, keys, drop_null_keys=drop_null_keys) def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self: if isinstance(descending, bool): descending = [descending] * len(by) if nulls_last: sort_funcs = ( self._F.desc_nulls_last if d else self._F.asc_nulls_last for d in descending ) else: sort_funcs = ( self._F.desc_nulls_first if d else self._F.asc_nulls_first for d in descending ) sort_cols = [sort_f(col) for col, sort_f in zip(by, sort_funcs)] return self._with_native(self.native.sort(*sort_cols)) def drop_nulls(self, subset: Sequence[str] | None) -> Self: subset = list(subset) if subset else None return self._with_native(self.native.dropna(subset=subset)) def rename(self, mapping: Mapping[str, str]) -> Self: rename_mapping = { colname: mapping.get(colname, colname) for colname in self.columns } return self._with_native( self.native.select( [self._F.col(old).alias(new) for old, new in rename_mapping.items()] ) ) def unique( self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy ) -> Self: if subset and (error := self._check_columns_exist(subset)): raise error subset = list(subset) if subset else None if keep == "none": tmp = generate_temporary_column_name(8, self.columns) window = self._Window.partitionBy(subset or self.columns) df = ( self.native.withColumn(tmp, self._F.count("*").over(window)) .filter(self._F.col(tmp) == self._F.lit(1)) .drop(self._F.col(tmp)) ) return self._with_native(df) return self._with_native(self.native.dropDuplicates(subset=subset)) def join( self, other: Self, how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, ) -> Self: left_columns = self.columns right_columns = other.columns right_on_: list[str] = list(right_on) if right_on is not None else [] left_on_: list[str] = list(left_on) if left_on is not None else [] # create a mapping for columns on other # `right_on` columns will be renamed as `left_on` # the remaining columns will be either added the suffix or left unchanged. right_cols_to_rename = ( [c for c in right_columns if c not in right_on_] if how != "full" else right_columns ) rename_mapping = { **dict(zip(right_on_, left_on_)), **{ colname: f"{colname}{suffix}" if colname in left_columns else colname for colname in right_cols_to_rename }, } other_native = other.native.select( [self._F.col(old).alias(new) for old, new in rename_mapping.items()] ) # If how in {"semi", "anti"}, then resulting columns are same as left columns # Otherwise, we add the right columns with the new mapping, while keeping the # original order of right_columns. col_order = left_columns.copy() if how in {"inner", "left", "cross"}: col_order.extend( rename_mapping[colname] for colname in right_columns if colname not in right_on_ ) elif how == "full": col_order.extend(rename_mapping.values()) right_on_remapped = [rename_mapping[c] for c in right_on_] on_ = ( reduce( and_, ( getattr(self.native, left_key) == getattr(other_native, right_key) for left_key, right_key in zip(left_on_, right_on_remapped) ), ) if how == "full" else None if how == "cross" else left_on_ ) how_native = "full_outer" if how == "full" else how return self._with_native( self.native.join(other_native, on=on_, how=how_native).select(col_order) ) def explode(self, columns: Sequence[str]) -> Self: dtypes = self._version.dtypes schema = self.collect_schema() for col_to_explode in columns: dtype = schema[col_to_explode] if dtype != dtypes.List: msg = ( f"`explode` operation not supported for dtype `{dtype}`, " "expected List type" ) raise InvalidOperationError(msg) column_names = self.columns if len(columns) != 1: msg = ( "Exploding on multiple columns is not supported with SparkLike backend since " "we cannot guarantee that the exploded columns have matching element counts." ) raise NotImplementedError(msg) if self._implementation.is_pyspark() or self._implementation.is_pyspark_connect(): return self._with_native( self.native.select( *[ self._F.col(col_name).alias(col_name) if col_name != columns[0] else self._F.explode_outer(col_name).alias(col_name) for col_name in column_names ] ) ) elif self._implementation.is_sqlframe(): # Not every sqlframe dialect supports `explode_outer` function # (see https://github.com/eakmanrq/sqlframe/blob/3cb899c515b101ff4c197d84b34fae490d0ed257/sqlframe/base/functions.py#L2288-L2289) # therefore we simply explode the array column which will ignore nulls and # zero sized arrays, and append these specific condition with nulls (to # match polars behavior). def null_condition(col_name: str) -> Column: return self._F.isnull(col_name) | (self._F.array_size(col_name) == 0) return self._with_native( self.native.select( *[ self._F.col(col_name).alias(col_name) if col_name != columns[0] else self._F.explode(col_name).alias(col_name) for col_name in column_names ] ).union( self.native.filter(null_condition(columns[0])).select( *[ self._F.col(col_name).alias(col_name) if col_name != columns[0] else self._F.lit(None).alias(col_name) for col_name in column_names ] ) ) ) else: # pragma: no cover msg = "Unreachable code, please report an issue at https://github.com/narwhals-dev/narwhals/issues" raise AssertionError(msg) def unpivot( self, on: Sequence[str] | None, index: Sequence[str] | None, variable_name: str, value_name: str, ) -> Self: if self._implementation.is_sqlframe(): if variable_name == "": msg = "`variable_name` cannot be empty string for sqlframe backend." raise NotImplementedError(msg) if value_name == "": msg = "`value_name` cannot be empty string for sqlframe backend." raise NotImplementedError(msg) else: # pragma: no cover pass ids = tuple(index) if index else () values = ( tuple(set(self.columns).difference(set(ids))) if on is None else tuple(on) ) unpivoted_native_frame = self.native.unpivot( ids=ids, values=values, variableColumnName=variable_name, valueColumnName=value_name, ) if index is None: unpivoted_native_frame = unpivoted_native_frame.drop(*ids) return self._with_native(unpivoted_native_frame) gather_every = not_implemented.deprecated( "`LazyFrame.gather_every` is deprecated and will be removed in a future version." ) join_asof = not_implemented() tail = not_implemented.deprecated( "`LazyFrame.tail` is deprecated and will be removed in a future version." ) with_row_index = not_implemented()