diff options
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_duckdb/dataframe.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/narwhals/_duckdb/dataframe.py | 512 |
1 files changed, 512 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_duckdb/dataframe.py b/venv/lib/python3.8/site-packages/narwhals/_duckdb/dataframe.py new file mode 100644 index 0000000..6b4b197 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_duckdb/dataframe.py @@ -0,0 +1,512 @@ +from __future__ import annotations + +import contextlib +from functools import reduce +from operator import and_ +from typing import TYPE_CHECKING, Any, Iterator, Mapping, Sequence + +import duckdb +from duckdb import FunctionExpression, StarExpression + +from narwhals._duckdb.utils import ( + DeferredTimeZone, + col, + evaluate_exprs, + generate_partition_by_sql, + lit, + native_to_narwhals_dtype, +) +from narwhals._utils import ( + Implementation, + Version, + generate_temporary_column_name, + not_implemented, + parse_columns_to_drop, + parse_version, + validate_backend_version, +) +from narwhals.dependencies import get_duckdb +from narwhals.exceptions import InvalidOperationError +from narwhals.typing import CompliantLazyFrame + +if TYPE_CHECKING: + from types import ModuleType + + import pandas as pd + import pyarrow as pa + from duckdb import Expression + from duckdb.typing import DuckDBPyType + from typing_extensions import Self, TypeIs + + from narwhals._compliant.typing import CompliantDataFrameAny + from narwhals._duckdb.expr import DuckDBExpr + from narwhals._duckdb.group_by import DuckDBGroupBy + from narwhals._duckdb.namespace import DuckDBNamespace + from narwhals._duckdb.series import DuckDBInterchangeSeries + from narwhals._utils import _FullContext + from narwhals.dataframe import LazyFrame + from narwhals.dtypes import DType + from narwhals.stable.v1 import DataFrame as DataFrameV1 + from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy + +with contextlib.suppress(ImportError): # requires duckdb>=1.3.0 + from duckdb import SQLExpression + + +class DuckDBLazyFrame( + CompliantLazyFrame[ + "DuckDBExpr", + "duckdb.DuckDBPyRelation", + "LazyFrame[duckdb.DuckDBPyRelation] | DataFrameV1[duckdb.DuckDBPyRelation]", + ] +): + _implementation = Implementation.DUCKDB + + def __init__( + self, + df: duckdb.DuckDBPyRelation, + *, + backend_version: tuple[int, ...], + version: Version, + ) -> None: + self._native_frame: duckdb.DuckDBPyRelation = df + self._version = version + self._backend_version = backend_version + self._cached_native_schema: dict[str, DuckDBPyType] | None = None + self._cached_columns: list[str] | None = None + validate_backend_version(self._implementation, self._backend_version) + + @staticmethod + def _is_native(obj: duckdb.DuckDBPyRelation | Any) -> TypeIs[duckdb.DuckDBPyRelation]: + return isinstance(obj, duckdb.DuckDBPyRelation) + + @classmethod + def from_native( + cls, data: duckdb.DuckDBPyRelation, /, *, context: _FullContext + ) -> Self: + return cls( + data, backend_version=context._backend_version, version=context._version + ) + + def to_narwhals( + self, *args: Any, **kwds: Any + ) -> LazyFrame[duckdb.DuckDBPyRelation] | DataFrameV1[duckdb.DuckDBPyRelation]: + if self._version is Version.MAIN: + return self._version.lazyframe(self, level="lazy") + + from narwhals.stable.v1 import DataFrame as DataFrameV1 + + return DataFrameV1(self, level="interchange") # type: ignore[no-any-return] + + def __narwhals_dataframe__(self) -> Self: # pragma: no cover + # Keep around for backcompat. + if self._version is not Version.V1: + msg = "__narwhals_dataframe__ is not implemented for DuckDBLazyFrame" + raise AttributeError(msg) + return self + + def __narwhals_lazyframe__(self) -> Self: + return self + + def __native_namespace__(self) -> ModuleType: + return get_duckdb() # type: ignore[no-any-return] + + def __narwhals_namespace__(self) -> DuckDBNamespace: + from narwhals._duckdb.namespace import DuckDBNamespace + + return DuckDBNamespace( + backend_version=self._backend_version, version=self._version + ) + + def get_column(self, name: str) -> DuckDBInterchangeSeries: + from narwhals._duckdb.series import DuckDBInterchangeSeries + + return DuckDBInterchangeSeries(self.native.select(name), version=self._version) + + def _iter_columns(self) -> Iterator[Expression]: + for name in self.columns: + yield col(name) + + def collect( + self, backend: ModuleType | Implementation | str | None, **kwargs: Any + ) -> CompliantDataFrameAny: + if backend is None or backend is Implementation.PYARROW: + import pyarrow as pa # ignore-banned-import + + from narwhals._arrow.dataframe import ArrowDataFrame + + return ArrowDataFrame( + self.native.arrow(), + backend_version=parse_version(pa), + version=self._version, + validate_column_names=True, + ) + + if backend is Implementation.PANDAS: + import pandas as pd # ignore-banned-import + + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + + return PandasLikeDataFrame( + self.native.df(), + implementation=Implementation.PANDAS, + backend_version=parse_version(pd), + version=self._version, + validate_column_names=True, + ) + + if backend is Implementation.POLARS: + import polars as pl # ignore-banned-import + + from narwhals._polars.dataframe import PolarsDataFrame + + return PolarsDataFrame( + self.native.pl(), backend_version=parse_version(pl), version=self._version + ) + + msg = f"Unsupported `backend` value: {backend}" # pragma: no cover + raise ValueError(msg) # pragma: no cover + + def head(self, n: int) -> Self: + return self._with_native(self.native.limit(n)) + + def simple_select(self, *column_names: str) -> Self: + return self._with_native(self.native.select(*column_names)) + + def aggregate(self, *exprs: DuckDBExpr) -> Self: + selection = [val.alias(name) for name, val in evaluate_exprs(self, *exprs)] + return self._with_native(self.native.aggregate(selection)) # type: ignore[arg-type] + + def select(self, *exprs: DuckDBExpr) -> Self: + selection = (val.alias(name) for name, val in evaluate_exprs(self, *exprs)) + return self._with_native(self.native.select(*selection)) + + def drop(self, columns: Sequence[str], *, strict: bool) -> Self: + columns_to_drop = parse_columns_to_drop(self, columns, strict=strict) + selection = (name for name in self.columns if name not in columns_to_drop) + return self._with_native(self.native.select(*selection)) + + def lazy(self, *, backend: Implementation | None = None) -> Self: + # The `backend`` argument has no effect but we keep it here for + # backwards compatibility because in `narwhals.stable.v1` + # function `.from_native()` will return a DataFrame for DuckDB. + + if backend is not None: # pragma: no cover + msg = "`backend` argument is not supported for DuckDB" + raise ValueError(msg) + return self + + def with_columns(self, *exprs: DuckDBExpr) -> Self: + new_columns_map = dict(evaluate_exprs(self, *exprs)) + result = [ + new_columns_map.pop(name).alias(name) + if name in new_columns_map + else col(name) + for name in self.columns + ] + result.extend(value.alias(name) for name, value in new_columns_map.items()) + return self._with_native(self.native.select(*result)) + + def filter(self, predicate: DuckDBExpr) -> Self: + # `[0]` is safe as the predicate's expression only returns a single column + mask = predicate(self)[0] + return self._with_native(self.native.filter(mask)) + + @property + def schema(self) -> dict[str, DType]: + if self._cached_native_schema is None: + # Note: prefer `self._cached_native_schema` over `functools.cached_property` + # due to Python3.13 failures. + self._cached_native_schema = dict(zip(self.columns, self.native.types)) + + deferred_time_zone = DeferredTimeZone(self.native) + return { + column_name: native_to_narwhals_dtype( + duckdb_dtype, self._version, deferred_time_zone + ) + for column_name, duckdb_dtype in zip(self.native.columns, self.native.types) + } + + @property + def columns(self) -> list[str]: + if self._cached_columns is None: + self._cached_columns = ( + list(self.schema) + if self._cached_native_schema is not None + else self.native.columns + ) + return self._cached_columns + + def to_pandas(self) -> pd.DataFrame: + # only if version is v1, keep around for backcompat + import pandas as pd # ignore-banned-import() + + if parse_version(pd) >= (1, 0, 0): + return self.native.df() + else: # pragma: no cover + msg = f"Conversion to pandas requires 'pandas>=1.0.0', found {pd.__version__}" + raise NotImplementedError(msg) + + def to_arrow(self) -> pa.Table: + # only if version is v1, keep around for backcompat + return self.native.arrow() + + def _with_version(self, version: Version) -> Self: + return self.__class__( + self.native, version=version, backend_version=self._backend_version + ) + + def _with_native(self, df: duckdb.DuckDBPyRelation) -> Self: + return self.__class__( + df, backend_version=self._backend_version, version=self._version + ) + + def group_by( + self, keys: Sequence[str] | Sequence[DuckDBExpr], *, drop_null_keys: bool + ) -> DuckDBGroupBy: + from narwhals._duckdb.group_by import DuckDBGroupBy + + return DuckDBGroupBy(self, keys, drop_null_keys=drop_null_keys) + + def rename(self, mapping: Mapping[str, str]) -> Self: + df = self.native + selection = ( + col(name).alias(mapping[name]) if name in mapping else col(name) + for name in df.columns + ) + return self._with_native(self.native.select(*selection)) + + def join( + self, + other: Self, + *, + how: JoinStrategy, + left_on: Sequence[str] | None, + right_on: Sequence[str] | None, + suffix: str, + ) -> Self: + native_how = "outer" if how == "full" else how + + if native_how == "cross": + if self._backend_version < (1, 1, 4): + msg = f"'duckdb>=1.1.4' is required for cross-join, found version: {self._backend_version}" + raise NotImplementedError(msg) + rel = self.native.set_alias("lhs").cross(other.native.set_alias("rhs")) + else: + # help mypy + assert left_on is not None # noqa: S101 + assert right_on is not None # noqa: S101 + it = ( + col(f'lhs."{left}"') == col(f'rhs."{right}"') + for left, right in zip(left_on, right_on) + ) + condition: Expression = reduce(and_, it) + rel = self.native.set_alias("lhs").join( + other.native.set_alias("rhs"), + # NOTE: Fixed in `--pre` https://github.com/duckdb/duckdb/pull/16933 + condition=condition, # type: ignore[arg-type, unused-ignore] + how=native_how, + ) + + if native_how in {"inner", "left", "cross", "outer"}: + select = [col(f'lhs."{x}"') for x in self.columns] + for name in other.columns: + col_in_lhs: bool = name in self.columns + if native_how == "outer" and not col_in_lhs: + select.append(col(f'rhs."{name}"')) + elif (native_how == "outer") or ( + col_in_lhs and (right_on is None or name not in right_on) + ): + select.append(col(f'rhs."{name}"').alias(f"{name}{suffix}")) + elif right_on is None or name not in right_on: + select.append(col(name)) + res = rel.select(*select).set_alias(self.native.alias) + else: # semi, anti + res = rel.select("lhs.*").set_alias(self.native.alias) + + return self._with_native(res) + + def join_asof( + self, + other: Self, + *, + left_on: str, + right_on: str, + by_left: Sequence[str] | None, + by_right: Sequence[str] | None, + strategy: AsofJoinStrategy, + suffix: str, + ) -> Self: + lhs = self.native + rhs = other.native + conditions: list[Expression] = [] + if by_left is not None and by_right is not None: + conditions.extend( + col(f'lhs."{left}"') == col(f'rhs."{right}"') + for left, right in zip(by_left, by_right) + ) + else: + by_left = by_right = [] + if strategy == "backward": + conditions.append(col(f'lhs."{left_on}"') >= col(f'rhs."{right_on}"')) + elif strategy == "forward": + conditions.append(col(f'lhs."{left_on}"') <= col(f'rhs."{right_on}"')) + else: + msg = "Only 'backward' and 'forward' strategies are currently supported for DuckDB" + raise NotImplementedError(msg) + condition: Expression = reduce(and_, conditions) + select = ["lhs.*"] + for name in rhs.columns: + if name in lhs.columns and ( + right_on is None or name not in {right_on, *by_right} + ): + select.append(f'rhs."{name}" as "{name}{suffix}"') + elif right_on is None or name not in {right_on, *by_right}: + select.append(str(col(name))) + # Replace with Python API call once + # https://github.com/duckdb/duckdb/discussions/16947 is addressed. + query = f""" + SELECT {",".join(select)} + FROM lhs + ASOF LEFT JOIN rhs + ON {condition} + """ # noqa: S608 + return self._with_native(duckdb.sql(query)) + + def collect_schema(self) -> dict[str, DType]: + return self.schema + + def unique( + self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy + ) -> Self: + if subset_ := subset if keep == "any" else (subset or self.columns): + if self._backend_version < (1, 3): + msg = ( + "At least version 1.3 of DuckDB is required for `unique` operation\n" + "with `subset` specified." + ) + raise NotImplementedError(msg) + # Sanitise input + if error := self._check_columns_exist(subset_): + raise error + idx_name = generate_temporary_column_name(8, self.columns) + count_name = generate_temporary_column_name(8, [*self.columns, idx_name]) + partition_by_sql = generate_partition_by_sql(*(subset_)) + name = count_name if keep == "none" else idx_name + idx_expr = SQLExpression( + f"{FunctionExpression('row_number')} over ({partition_by_sql})" + ).alias(idx_name) + count_expr = SQLExpression( + f"{FunctionExpression('count', StarExpression())} over ({partition_by_sql})" + ).alias(count_name) + return self._with_native( + self.native.select(StarExpression(), idx_expr, count_expr) + .filter(col(name) == lit(1)) + .select(StarExpression(exclude=[count_name, idx_name])) + ) + return self._with_native(self.native.unique(", ".join(self.columns))) + + def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self: + if isinstance(descending, bool): + descending = [descending] * len(by) + if nulls_last: + it = ( + col(name).nulls_last() if not desc else col(name).desc().nulls_last() + for name, desc in zip(by, descending) + ) + else: + it = ( + col(name).nulls_first() if not desc else col(name).desc().nulls_first() + for name, desc in zip(by, descending) + ) + return self._with_native(self.native.sort(*it)) + + def drop_nulls(self, subset: Sequence[str] | None) -> Self: + subset_ = subset if subset is not None else self.columns + keep_condition = reduce(and_, (col(name).isnotnull() for name in subset_)) + return self._with_native(self.native.filter(keep_condition)) + + def explode(self, columns: Sequence[str]) -> Self: + dtypes = self._version.dtypes + schema = self.collect_schema() + for name in columns: + dtype = schema[name] + if dtype != dtypes.List: + msg = ( + f"`explode` operation not supported for dtype `{dtype}`, " + "expected List type" + ) + raise InvalidOperationError(msg) + + if len(columns) != 1: + msg = ( + "Exploding on multiple columns is not supported with DuckDB backend since " + "we cannot guarantee that the exploded columns have matching element counts." + ) + raise NotImplementedError(msg) + + col_to_explode = col(columns[0]) + rel = self.native + original_columns = self.columns + + not_null_condition = col_to_explode.isnotnull() & FunctionExpression( + "len", col_to_explode + ) > lit(0) + non_null_rel = rel.filter(not_null_condition).select( + *( + FunctionExpression("unnest", col_to_explode).alias(name) + if name in columns + else name + for name in original_columns + ) + ) + + null_rel = rel.filter(~not_null_condition).select( + *( + lit(None).alias(name) if name in columns else name + for name in original_columns + ) + ) + + return self._with_native(non_null_rel.union(null_rel)) + + def unpivot( + self, + on: Sequence[str] | None, + index: Sequence[str] | None, + variable_name: str, + value_name: str, + ) -> Self: + index_ = [] if index is None else index + on_ = [c for c in self.columns if c not in index_] if on is None else on + + if variable_name == "": + msg = "`variable_name` cannot be empty string for duckdb backend." + raise NotImplementedError(msg) + + if value_name == "": + msg = "`value_name` cannot be empty string for duckdb backend." + raise NotImplementedError(msg) + + unpivot_on = ", ".join(str(col(name)) for name in on_) + rel = self.native # noqa: F841 + # Replace with Python API once + # https://github.com/duckdb/duckdb/discussions/16980 is addressed. + query = f""" + unpivot rel + on {unpivot_on} + into + name "{variable_name}" + value "{value_name}" + """ + return self._with_native( + duckdb.sql(query).select(*[*index_, variable_name, value_name]) + ) + + gather_every = not_implemented.deprecated( + "`LazyFrame.gather_every` is deprecated and will be removed in a future version." + ) + tail = not_implemented.deprecated( + "`LazyFrame.tail` is deprecated and will be removed in a future version." + ) + with_row_index = not_implemented() |