from __future__ import annotations import contextlib from functools import reduce from operator import and_ from typing import TYPE_CHECKING, Any, Iterator, Mapping, Sequence import duckdb from duckdb import FunctionExpression, StarExpression from narwhals._duckdb.utils import ( DeferredTimeZone, col, evaluate_exprs, generate_partition_by_sql, lit, native_to_narwhals_dtype, ) from narwhals._utils import ( Implementation, Version, generate_temporary_column_name, not_implemented, parse_columns_to_drop, parse_version, validate_backend_version, ) from narwhals.dependencies import get_duckdb from narwhals.exceptions import InvalidOperationError from narwhals.typing import CompliantLazyFrame if TYPE_CHECKING: from types import ModuleType import pandas as pd import pyarrow as pa from duckdb import Expression from duckdb.typing import DuckDBPyType from typing_extensions import Self, TypeIs from narwhals._compliant.typing import CompliantDataFrameAny from narwhals._duckdb.expr import DuckDBExpr from narwhals._duckdb.group_by import DuckDBGroupBy from narwhals._duckdb.namespace import DuckDBNamespace from narwhals._duckdb.series import DuckDBInterchangeSeries from narwhals._utils import _FullContext from narwhals.dataframe import LazyFrame from narwhals.dtypes import DType from narwhals.stable.v1 import DataFrame as DataFrameV1 from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy with contextlib.suppress(ImportError): # requires duckdb>=1.3.0 from duckdb import SQLExpression class DuckDBLazyFrame( CompliantLazyFrame[ "DuckDBExpr", "duckdb.DuckDBPyRelation", "LazyFrame[duckdb.DuckDBPyRelation] | DataFrameV1[duckdb.DuckDBPyRelation]", ] ): _implementation = Implementation.DUCKDB def __init__( self, df: duckdb.DuckDBPyRelation, *, backend_version: tuple[int, ...], version: Version, ) -> None: self._native_frame: duckdb.DuckDBPyRelation = df self._version = version self._backend_version = backend_version self._cached_native_schema: dict[str, DuckDBPyType] | None = None self._cached_columns: list[str] | None = None validate_backend_version(self._implementation, self._backend_version) @staticmethod def _is_native(obj: duckdb.DuckDBPyRelation | Any) -> TypeIs[duckdb.DuckDBPyRelation]: return isinstance(obj, duckdb.DuckDBPyRelation) @classmethod def from_native( cls, data: duckdb.DuckDBPyRelation, /, *, context: _FullContext ) -> Self: return cls( data, backend_version=context._backend_version, version=context._version ) def to_narwhals( self, *args: Any, **kwds: Any ) -> LazyFrame[duckdb.DuckDBPyRelation] | DataFrameV1[duckdb.DuckDBPyRelation]: if self._version is Version.MAIN: return self._version.lazyframe(self, level="lazy") from narwhals.stable.v1 import DataFrame as DataFrameV1 return DataFrameV1(self, level="interchange") # type: ignore[no-any-return] def __narwhals_dataframe__(self) -> Self: # pragma: no cover # Keep around for backcompat. if self._version is not Version.V1: msg = "__narwhals_dataframe__ is not implemented for DuckDBLazyFrame" raise AttributeError(msg) return self def __narwhals_lazyframe__(self) -> Self: return self def __native_namespace__(self) -> ModuleType: return get_duckdb() # type: ignore[no-any-return] def __narwhals_namespace__(self) -> DuckDBNamespace: from narwhals._duckdb.namespace import DuckDBNamespace return DuckDBNamespace( backend_version=self._backend_version, version=self._version ) def get_column(self, name: str) -> DuckDBInterchangeSeries: from narwhals._duckdb.series import DuckDBInterchangeSeries return DuckDBInterchangeSeries(self.native.select(name), version=self._version) def _iter_columns(self) -> Iterator[Expression]: for name in self.columns: yield col(name) def collect( self, backend: ModuleType | Implementation | str | None, **kwargs: Any ) -> CompliantDataFrameAny: if backend is None or backend is Implementation.PYARROW: import pyarrow as pa # ignore-banned-import from narwhals._arrow.dataframe import ArrowDataFrame return ArrowDataFrame( self.native.arrow(), backend_version=parse_version(pa), version=self._version, validate_column_names=True, ) if backend is Implementation.PANDAS: import pandas as pd # ignore-banned-import from narwhals._pandas_like.dataframe import PandasLikeDataFrame return PandasLikeDataFrame( self.native.df(), implementation=Implementation.PANDAS, backend_version=parse_version(pd), version=self._version, validate_column_names=True, ) if backend is Implementation.POLARS: import polars as pl # ignore-banned-import from narwhals._polars.dataframe import PolarsDataFrame return PolarsDataFrame( self.native.pl(), backend_version=parse_version(pl), version=self._version ) msg = f"Unsupported `backend` value: {backend}" # pragma: no cover raise ValueError(msg) # pragma: no cover def head(self, n: int) -> Self: return self._with_native(self.native.limit(n)) def simple_select(self, *column_names: str) -> Self: return self._with_native(self.native.select(*column_names)) def aggregate(self, *exprs: DuckDBExpr) -> Self: selection = [val.alias(name) for name, val in evaluate_exprs(self, *exprs)] return self._with_native(self.native.aggregate(selection)) # type: ignore[arg-type] def select(self, *exprs: DuckDBExpr) -> Self: selection = (val.alias(name) for name, val in evaluate_exprs(self, *exprs)) return self._with_native(self.native.select(*selection)) def drop(self, columns: Sequence[str], *, strict: bool) -> Self: columns_to_drop = parse_columns_to_drop(self, columns, strict=strict) selection = (name for name in self.columns if name not in columns_to_drop) return self._with_native(self.native.select(*selection)) def lazy(self, *, backend: Implementation | None = None) -> Self: # The `backend`` argument has no effect but we keep it here for # backwards compatibility because in `narwhals.stable.v1` # function `.from_native()` will return a DataFrame for DuckDB. if backend is not None: # pragma: no cover msg = "`backend` argument is not supported for DuckDB" raise ValueError(msg) return self def with_columns(self, *exprs: DuckDBExpr) -> Self: new_columns_map = dict(evaluate_exprs(self, *exprs)) result = [ new_columns_map.pop(name).alias(name) if name in new_columns_map else col(name) for name in self.columns ] result.extend(value.alias(name) for name, value in new_columns_map.items()) return self._with_native(self.native.select(*result)) def filter(self, predicate: DuckDBExpr) -> Self: # `[0]` is safe as the predicate's expression only returns a single column mask = predicate(self)[0] return self._with_native(self.native.filter(mask)) @property def schema(self) -> dict[str, DType]: if self._cached_native_schema is None: # Note: prefer `self._cached_native_schema` over `functools.cached_property` # due to Python3.13 failures. self._cached_native_schema = dict(zip(self.columns, self.native.types)) deferred_time_zone = DeferredTimeZone(self.native) return { column_name: native_to_narwhals_dtype( duckdb_dtype, self._version, deferred_time_zone ) for column_name, duckdb_dtype in zip(self.native.columns, self.native.types) } @property def columns(self) -> list[str]: if self._cached_columns is None: self._cached_columns = ( list(self.schema) if self._cached_native_schema is not None else self.native.columns ) return self._cached_columns def to_pandas(self) -> pd.DataFrame: # only if version is v1, keep around for backcompat import pandas as pd # ignore-banned-import() if parse_version(pd) >= (1, 0, 0): return self.native.df() else: # pragma: no cover msg = f"Conversion to pandas requires 'pandas>=1.0.0', found {pd.__version__}" raise NotImplementedError(msg) def to_arrow(self) -> pa.Table: # only if version is v1, keep around for backcompat return self.native.arrow() def _with_version(self, version: Version) -> Self: return self.__class__( self.native, version=version, backend_version=self._backend_version ) def _with_native(self, df: duckdb.DuckDBPyRelation) -> Self: return self.__class__( df, backend_version=self._backend_version, version=self._version ) def group_by( self, keys: Sequence[str] | Sequence[DuckDBExpr], *, drop_null_keys: bool ) -> DuckDBGroupBy: from narwhals._duckdb.group_by import DuckDBGroupBy return DuckDBGroupBy(self, keys, drop_null_keys=drop_null_keys) def rename(self, mapping: Mapping[str, str]) -> Self: df = self.native selection = ( col(name).alias(mapping[name]) if name in mapping else col(name) for name in df.columns ) return self._with_native(self.native.select(*selection)) def join( self, other: Self, *, how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, ) -> Self: native_how = "outer" if how == "full" else how if native_how == "cross": if self._backend_version < (1, 1, 4): msg = f"'duckdb>=1.1.4' is required for cross-join, found version: {self._backend_version}" raise NotImplementedError(msg) rel = self.native.set_alias("lhs").cross(other.native.set_alias("rhs")) else: # help mypy assert left_on is not None # noqa: S101 assert right_on is not None # noqa: S101 it = ( col(f'lhs."{left}"') == col(f'rhs."{right}"') for left, right in zip(left_on, right_on) ) condition: Expression = reduce(and_, it) rel = self.native.set_alias("lhs").join( other.native.set_alias("rhs"), # NOTE: Fixed in `--pre` https://github.com/duckdb/duckdb/pull/16933 condition=condition, # type: ignore[arg-type, unused-ignore] how=native_how, ) if native_how in {"inner", "left", "cross", "outer"}: select = [col(f'lhs."{x}"') for x in self.columns] for name in other.columns: col_in_lhs: bool = name in self.columns if native_how == "outer" and not col_in_lhs: select.append(col(f'rhs."{name}"')) elif (native_how == "outer") or ( col_in_lhs and (right_on is None or name not in right_on) ): select.append(col(f'rhs."{name}"').alias(f"{name}{suffix}")) elif right_on is None or name not in right_on: select.append(col(name)) res = rel.select(*select).set_alias(self.native.alias) else: # semi, anti res = rel.select("lhs.*").set_alias(self.native.alias) return self._with_native(res) def join_asof( self, other: Self, *, left_on: str, right_on: str, by_left: Sequence[str] | None, by_right: Sequence[str] | None, strategy: AsofJoinStrategy, suffix: str, ) -> Self: lhs = self.native rhs = other.native conditions: list[Expression] = [] if by_left is not None and by_right is not None: conditions.extend( col(f'lhs."{left}"') == col(f'rhs."{right}"') for left, right in zip(by_left, by_right) ) else: by_left = by_right = [] if strategy == "backward": conditions.append(col(f'lhs."{left_on}"') >= col(f'rhs."{right_on}"')) elif strategy == "forward": conditions.append(col(f'lhs."{left_on}"') <= col(f'rhs."{right_on}"')) else: msg = "Only 'backward' and 'forward' strategies are currently supported for DuckDB" raise NotImplementedError(msg) condition: Expression = reduce(and_, conditions) select = ["lhs.*"] for name in rhs.columns: if name in lhs.columns and ( right_on is None or name not in {right_on, *by_right} ): select.append(f'rhs."{name}" as "{name}{suffix}"') elif right_on is None or name not in {right_on, *by_right}: select.append(str(col(name))) # Replace with Python API call once # https://github.com/duckdb/duckdb/discussions/16947 is addressed. query = f""" SELECT {",".join(select)} FROM lhs ASOF LEFT JOIN rhs ON {condition} """ # noqa: S608 return self._with_native(duckdb.sql(query)) def collect_schema(self) -> dict[str, DType]: return self.schema def unique( self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy ) -> Self: if subset_ := subset if keep == "any" else (subset or self.columns): if self._backend_version < (1, 3): msg = ( "At least version 1.3 of DuckDB is required for `unique` operation\n" "with `subset` specified." ) raise NotImplementedError(msg) # Sanitise input if error := self._check_columns_exist(subset_): raise error idx_name = generate_temporary_column_name(8, self.columns) count_name = generate_temporary_column_name(8, [*self.columns, idx_name]) partition_by_sql = generate_partition_by_sql(*(subset_)) name = count_name if keep == "none" else idx_name idx_expr = SQLExpression( f"{FunctionExpression('row_number')} over ({partition_by_sql})" ).alias(idx_name) count_expr = SQLExpression( f"{FunctionExpression('count', StarExpression())} over ({partition_by_sql})" ).alias(count_name) return self._with_native( self.native.select(StarExpression(), idx_expr, count_expr) .filter(col(name) == lit(1)) .select(StarExpression(exclude=[count_name, idx_name])) ) return self._with_native(self.native.unique(", ".join(self.columns))) def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self: if isinstance(descending, bool): descending = [descending] * len(by) if nulls_last: it = ( col(name).nulls_last() if not desc else col(name).desc().nulls_last() for name, desc in zip(by, descending) ) else: it = ( col(name).nulls_first() if not desc else col(name).desc().nulls_first() for name, desc in zip(by, descending) ) return self._with_native(self.native.sort(*it)) def drop_nulls(self, subset: Sequence[str] | None) -> Self: subset_ = subset if subset is not None else self.columns keep_condition = reduce(and_, (col(name).isnotnull() for name in subset_)) return self._with_native(self.native.filter(keep_condition)) def explode(self, columns: Sequence[str]) -> Self: dtypes = self._version.dtypes schema = self.collect_schema() for name in columns: dtype = schema[name] if dtype != dtypes.List: msg = ( f"`explode` operation not supported for dtype `{dtype}`, " "expected List type" ) raise InvalidOperationError(msg) if len(columns) != 1: msg = ( "Exploding on multiple columns is not supported with DuckDB backend since " "we cannot guarantee that the exploded columns have matching element counts." ) raise NotImplementedError(msg) col_to_explode = col(columns[0]) rel = self.native original_columns = self.columns not_null_condition = col_to_explode.isnotnull() & FunctionExpression( "len", col_to_explode ) > lit(0) non_null_rel = rel.filter(not_null_condition).select( *( FunctionExpression("unnest", col_to_explode).alias(name) if name in columns else name for name in original_columns ) ) null_rel = rel.filter(~not_null_condition).select( *( lit(None).alias(name) if name in columns else name for name in original_columns ) ) return self._with_native(non_null_rel.union(null_rel)) def unpivot( self, on: Sequence[str] | None, index: Sequence[str] | None, variable_name: str, value_name: str, ) -> Self: index_ = [] if index is None else index on_ = [c for c in self.columns if c not in index_] if on is None else on if variable_name == "": msg = "`variable_name` cannot be empty string for duckdb backend." raise NotImplementedError(msg) if value_name == "": msg = "`value_name` cannot be empty string for duckdb backend." raise NotImplementedError(msg) unpivot_on = ", ".join(str(col(name)) for name in on_) rel = self.native # noqa: F841 # Replace with Python API once # https://github.com/duckdb/duckdb/discussions/16980 is addressed. query = f""" unpivot rel on {unpivot_on} into name "{variable_name}" value "{value_name}" """ return self._with_native( duckdb.sql(query).select(*[*index_, variable_name, value_name]) ) gather_every = not_implemented.deprecated( "`LazyFrame.gather_every` is deprecated and will be removed in a future version." ) tail = not_implemented.deprecated( "`LazyFrame.tail` is deprecated and will be removed in a future version." ) with_row_index = not_implemented()