aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_duckdb/dataframe.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_duckdb/dataframe.py')
-rw-r--r--venv/lib/python3.8/site-packages/narwhals/_duckdb/dataframe.py512
1 files changed, 512 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_duckdb/dataframe.py b/venv/lib/python3.8/site-packages/narwhals/_duckdb/dataframe.py
new file mode 100644
index 0000000..6b4b197
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/narwhals/_duckdb/dataframe.py
@@ -0,0 +1,512 @@
+from __future__ import annotations
+
+import contextlib
+from functools import reduce
+from operator import and_
+from typing import TYPE_CHECKING, Any, Iterator, Mapping, Sequence
+
+import duckdb
+from duckdb import FunctionExpression, StarExpression
+
+from narwhals._duckdb.utils import (
+ DeferredTimeZone,
+ col,
+ evaluate_exprs,
+ generate_partition_by_sql,
+ lit,
+ native_to_narwhals_dtype,
+)
+from narwhals._utils import (
+ Implementation,
+ Version,
+ generate_temporary_column_name,
+ not_implemented,
+ parse_columns_to_drop,
+ parse_version,
+ validate_backend_version,
+)
+from narwhals.dependencies import get_duckdb
+from narwhals.exceptions import InvalidOperationError
+from narwhals.typing import CompliantLazyFrame
+
+if TYPE_CHECKING:
+ from types import ModuleType
+
+ import pandas as pd
+ import pyarrow as pa
+ from duckdb import Expression
+ from duckdb.typing import DuckDBPyType
+ from typing_extensions import Self, TypeIs
+
+ from narwhals._compliant.typing import CompliantDataFrameAny
+ from narwhals._duckdb.expr import DuckDBExpr
+ from narwhals._duckdb.group_by import DuckDBGroupBy
+ from narwhals._duckdb.namespace import DuckDBNamespace
+ from narwhals._duckdb.series import DuckDBInterchangeSeries
+ from narwhals._utils import _FullContext
+ from narwhals.dataframe import LazyFrame
+ from narwhals.dtypes import DType
+ from narwhals.stable.v1 import DataFrame as DataFrameV1
+ from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
+
+with contextlib.suppress(ImportError): # requires duckdb>=1.3.0
+ from duckdb import SQLExpression
+
+
+class DuckDBLazyFrame(
+ CompliantLazyFrame[
+ "DuckDBExpr",
+ "duckdb.DuckDBPyRelation",
+ "LazyFrame[duckdb.DuckDBPyRelation] | DataFrameV1[duckdb.DuckDBPyRelation]",
+ ]
+):
+ _implementation = Implementation.DUCKDB
+
+ def __init__(
+ self,
+ df: duckdb.DuckDBPyRelation,
+ *,
+ backend_version: tuple[int, ...],
+ version: Version,
+ ) -> None:
+ self._native_frame: duckdb.DuckDBPyRelation = df
+ self._version = version
+ self._backend_version = backend_version
+ self._cached_native_schema: dict[str, DuckDBPyType] | None = None
+ self._cached_columns: list[str] | None = None
+ validate_backend_version(self._implementation, self._backend_version)
+
+ @staticmethod
+ def _is_native(obj: duckdb.DuckDBPyRelation | Any) -> TypeIs[duckdb.DuckDBPyRelation]:
+ return isinstance(obj, duckdb.DuckDBPyRelation)
+
+ @classmethod
+ def from_native(
+ cls, data: duckdb.DuckDBPyRelation, /, *, context: _FullContext
+ ) -> Self:
+ return cls(
+ data, backend_version=context._backend_version, version=context._version
+ )
+
+ def to_narwhals(
+ self, *args: Any, **kwds: Any
+ ) -> LazyFrame[duckdb.DuckDBPyRelation] | DataFrameV1[duckdb.DuckDBPyRelation]:
+ if self._version is Version.MAIN:
+ return self._version.lazyframe(self, level="lazy")
+
+ from narwhals.stable.v1 import DataFrame as DataFrameV1
+
+ return DataFrameV1(self, level="interchange") # type: ignore[no-any-return]
+
+ def __narwhals_dataframe__(self) -> Self: # pragma: no cover
+ # Keep around for backcompat.
+ if self._version is not Version.V1:
+ msg = "__narwhals_dataframe__ is not implemented for DuckDBLazyFrame"
+ raise AttributeError(msg)
+ return self
+
+ def __narwhals_lazyframe__(self) -> Self:
+ return self
+
+ def __native_namespace__(self) -> ModuleType:
+ return get_duckdb() # type: ignore[no-any-return]
+
+ def __narwhals_namespace__(self) -> DuckDBNamespace:
+ from narwhals._duckdb.namespace import DuckDBNamespace
+
+ return DuckDBNamespace(
+ backend_version=self._backend_version, version=self._version
+ )
+
+ def get_column(self, name: str) -> DuckDBInterchangeSeries:
+ from narwhals._duckdb.series import DuckDBInterchangeSeries
+
+ return DuckDBInterchangeSeries(self.native.select(name), version=self._version)
+
+ def _iter_columns(self) -> Iterator[Expression]:
+ for name in self.columns:
+ yield col(name)
+
+ def collect(
+ self, backend: ModuleType | Implementation | str | None, **kwargs: Any
+ ) -> CompliantDataFrameAny:
+ if backend is None or backend is Implementation.PYARROW:
+ import pyarrow as pa # ignore-banned-import
+
+ from narwhals._arrow.dataframe import ArrowDataFrame
+
+ return ArrowDataFrame(
+ self.native.arrow(),
+ backend_version=parse_version(pa),
+ version=self._version,
+ validate_column_names=True,
+ )
+
+ if backend is Implementation.PANDAS:
+ import pandas as pd # ignore-banned-import
+
+ from narwhals._pandas_like.dataframe import PandasLikeDataFrame
+
+ return PandasLikeDataFrame(
+ self.native.df(),
+ implementation=Implementation.PANDAS,
+ backend_version=parse_version(pd),
+ version=self._version,
+ validate_column_names=True,
+ )
+
+ if backend is Implementation.POLARS:
+ import polars as pl # ignore-banned-import
+
+ from narwhals._polars.dataframe import PolarsDataFrame
+
+ return PolarsDataFrame(
+ self.native.pl(), backend_version=parse_version(pl), version=self._version
+ )
+
+ msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
+ raise ValueError(msg) # pragma: no cover
+
+ def head(self, n: int) -> Self:
+ return self._with_native(self.native.limit(n))
+
+ def simple_select(self, *column_names: str) -> Self:
+ return self._with_native(self.native.select(*column_names))
+
+ def aggregate(self, *exprs: DuckDBExpr) -> Self:
+ selection = [val.alias(name) for name, val in evaluate_exprs(self, *exprs)]
+ return self._with_native(self.native.aggregate(selection)) # type: ignore[arg-type]
+
+ def select(self, *exprs: DuckDBExpr) -> Self:
+ selection = (val.alias(name) for name, val in evaluate_exprs(self, *exprs))
+ return self._with_native(self.native.select(*selection))
+
+ def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
+ columns_to_drop = parse_columns_to_drop(self, columns, strict=strict)
+ selection = (name for name in self.columns if name not in columns_to_drop)
+ return self._with_native(self.native.select(*selection))
+
+ def lazy(self, *, backend: Implementation | None = None) -> Self:
+ # The `backend`` argument has no effect but we keep it here for
+ # backwards compatibility because in `narwhals.stable.v1`
+ # function `.from_native()` will return a DataFrame for DuckDB.
+
+ if backend is not None: # pragma: no cover
+ msg = "`backend` argument is not supported for DuckDB"
+ raise ValueError(msg)
+ return self
+
+ def with_columns(self, *exprs: DuckDBExpr) -> Self:
+ new_columns_map = dict(evaluate_exprs(self, *exprs))
+ result = [
+ new_columns_map.pop(name).alias(name)
+ if name in new_columns_map
+ else col(name)
+ for name in self.columns
+ ]
+ result.extend(value.alias(name) for name, value in new_columns_map.items())
+ return self._with_native(self.native.select(*result))
+
+ def filter(self, predicate: DuckDBExpr) -> Self:
+ # `[0]` is safe as the predicate's expression only returns a single column
+ mask = predicate(self)[0]
+ return self._with_native(self.native.filter(mask))
+
+ @property
+ def schema(self) -> dict[str, DType]:
+ if self._cached_native_schema is None:
+ # Note: prefer `self._cached_native_schema` over `functools.cached_property`
+ # due to Python3.13 failures.
+ self._cached_native_schema = dict(zip(self.columns, self.native.types))
+
+ deferred_time_zone = DeferredTimeZone(self.native)
+ return {
+ column_name: native_to_narwhals_dtype(
+ duckdb_dtype, self._version, deferred_time_zone
+ )
+ for column_name, duckdb_dtype in zip(self.native.columns, self.native.types)
+ }
+
+ @property
+ def columns(self) -> list[str]:
+ if self._cached_columns is None:
+ self._cached_columns = (
+ list(self.schema)
+ if self._cached_native_schema is not None
+ else self.native.columns
+ )
+ return self._cached_columns
+
+ def to_pandas(self) -> pd.DataFrame:
+ # only if version is v1, keep around for backcompat
+ import pandas as pd # ignore-banned-import()
+
+ if parse_version(pd) >= (1, 0, 0):
+ return self.native.df()
+ else: # pragma: no cover
+ msg = f"Conversion to pandas requires 'pandas>=1.0.0', found {pd.__version__}"
+ raise NotImplementedError(msg)
+
+ def to_arrow(self) -> pa.Table:
+ # only if version is v1, keep around for backcompat
+ return self.native.arrow()
+
+ def _with_version(self, version: Version) -> Self:
+ return self.__class__(
+ self.native, version=version, backend_version=self._backend_version
+ )
+
+ def _with_native(self, df: duckdb.DuckDBPyRelation) -> Self:
+ return self.__class__(
+ df, backend_version=self._backend_version, version=self._version
+ )
+
+ def group_by(
+ self, keys: Sequence[str] | Sequence[DuckDBExpr], *, drop_null_keys: bool
+ ) -> DuckDBGroupBy:
+ from narwhals._duckdb.group_by import DuckDBGroupBy
+
+ return DuckDBGroupBy(self, keys, drop_null_keys=drop_null_keys)
+
+ def rename(self, mapping: Mapping[str, str]) -> Self:
+ df = self.native
+ selection = (
+ col(name).alias(mapping[name]) if name in mapping else col(name)
+ for name in df.columns
+ )
+ return self._with_native(self.native.select(*selection))
+
+ def join(
+ self,
+ other: Self,
+ *,
+ how: JoinStrategy,
+ left_on: Sequence[str] | None,
+ right_on: Sequence[str] | None,
+ suffix: str,
+ ) -> Self:
+ native_how = "outer" if how == "full" else how
+
+ if native_how == "cross":
+ if self._backend_version < (1, 1, 4):
+ msg = f"'duckdb>=1.1.4' is required for cross-join, found version: {self._backend_version}"
+ raise NotImplementedError(msg)
+ rel = self.native.set_alias("lhs").cross(other.native.set_alias("rhs"))
+ else:
+ # help mypy
+ assert left_on is not None # noqa: S101
+ assert right_on is not None # noqa: S101
+ it = (
+ col(f'lhs."{left}"') == col(f'rhs."{right}"')
+ for left, right in zip(left_on, right_on)
+ )
+ condition: Expression = reduce(and_, it)
+ rel = self.native.set_alias("lhs").join(
+ other.native.set_alias("rhs"),
+ # NOTE: Fixed in `--pre` https://github.com/duckdb/duckdb/pull/16933
+ condition=condition, # type: ignore[arg-type, unused-ignore]
+ how=native_how,
+ )
+
+ if native_how in {"inner", "left", "cross", "outer"}:
+ select = [col(f'lhs."{x}"') for x in self.columns]
+ for name in other.columns:
+ col_in_lhs: bool = name in self.columns
+ if native_how == "outer" and not col_in_lhs:
+ select.append(col(f'rhs."{name}"'))
+ elif (native_how == "outer") or (
+ col_in_lhs and (right_on is None or name not in right_on)
+ ):
+ select.append(col(f'rhs."{name}"').alias(f"{name}{suffix}"))
+ elif right_on is None or name not in right_on:
+ select.append(col(name))
+ res = rel.select(*select).set_alias(self.native.alias)
+ else: # semi, anti
+ res = rel.select("lhs.*").set_alias(self.native.alias)
+
+ return self._with_native(res)
+
+ def join_asof(
+ self,
+ other: Self,
+ *,
+ left_on: str,
+ right_on: str,
+ by_left: Sequence[str] | None,
+ by_right: Sequence[str] | None,
+ strategy: AsofJoinStrategy,
+ suffix: str,
+ ) -> Self:
+ lhs = self.native
+ rhs = other.native
+ conditions: list[Expression] = []
+ if by_left is not None and by_right is not None:
+ conditions.extend(
+ col(f'lhs."{left}"') == col(f'rhs."{right}"')
+ for left, right in zip(by_left, by_right)
+ )
+ else:
+ by_left = by_right = []
+ if strategy == "backward":
+ conditions.append(col(f'lhs."{left_on}"') >= col(f'rhs."{right_on}"'))
+ elif strategy == "forward":
+ conditions.append(col(f'lhs."{left_on}"') <= col(f'rhs."{right_on}"'))
+ else:
+ msg = "Only 'backward' and 'forward' strategies are currently supported for DuckDB"
+ raise NotImplementedError(msg)
+ condition: Expression = reduce(and_, conditions)
+ select = ["lhs.*"]
+ for name in rhs.columns:
+ if name in lhs.columns and (
+ right_on is None or name not in {right_on, *by_right}
+ ):
+ select.append(f'rhs."{name}" as "{name}{suffix}"')
+ elif right_on is None or name not in {right_on, *by_right}:
+ select.append(str(col(name)))
+ # Replace with Python API call once
+ # https://github.com/duckdb/duckdb/discussions/16947 is addressed.
+ query = f"""
+ SELECT {",".join(select)}
+ FROM lhs
+ ASOF LEFT JOIN rhs
+ ON {condition}
+ """ # noqa: S608
+ return self._with_native(duckdb.sql(query))
+
+ def collect_schema(self) -> dict[str, DType]:
+ return self.schema
+
+ def unique(
+ self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
+ ) -> Self:
+ if subset_ := subset if keep == "any" else (subset or self.columns):
+ if self._backend_version < (1, 3):
+ msg = (
+ "At least version 1.3 of DuckDB is required for `unique` operation\n"
+ "with `subset` specified."
+ )
+ raise NotImplementedError(msg)
+ # Sanitise input
+ if error := self._check_columns_exist(subset_):
+ raise error
+ idx_name = generate_temporary_column_name(8, self.columns)
+ count_name = generate_temporary_column_name(8, [*self.columns, idx_name])
+ partition_by_sql = generate_partition_by_sql(*(subset_))
+ name = count_name if keep == "none" else idx_name
+ idx_expr = SQLExpression(
+ f"{FunctionExpression('row_number')} over ({partition_by_sql})"
+ ).alias(idx_name)
+ count_expr = SQLExpression(
+ f"{FunctionExpression('count', StarExpression())} over ({partition_by_sql})"
+ ).alias(count_name)
+ return self._with_native(
+ self.native.select(StarExpression(), idx_expr, count_expr)
+ .filter(col(name) == lit(1))
+ .select(StarExpression(exclude=[count_name, idx_name]))
+ )
+ return self._with_native(self.native.unique(", ".join(self.columns)))
+
+ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
+ if isinstance(descending, bool):
+ descending = [descending] * len(by)
+ if nulls_last:
+ it = (
+ col(name).nulls_last() if not desc else col(name).desc().nulls_last()
+ for name, desc in zip(by, descending)
+ )
+ else:
+ it = (
+ col(name).nulls_first() if not desc else col(name).desc().nulls_first()
+ for name, desc in zip(by, descending)
+ )
+ return self._with_native(self.native.sort(*it))
+
+ def drop_nulls(self, subset: Sequence[str] | None) -> Self:
+ subset_ = subset if subset is not None else self.columns
+ keep_condition = reduce(and_, (col(name).isnotnull() for name in subset_))
+ return self._with_native(self.native.filter(keep_condition))
+
+ def explode(self, columns: Sequence[str]) -> Self:
+ dtypes = self._version.dtypes
+ schema = self.collect_schema()
+ for name in columns:
+ dtype = schema[name]
+ if dtype != dtypes.List:
+ msg = (
+ f"`explode` operation not supported for dtype `{dtype}`, "
+ "expected List type"
+ )
+ raise InvalidOperationError(msg)
+
+ if len(columns) != 1:
+ msg = (
+ "Exploding on multiple columns is not supported with DuckDB backend since "
+ "we cannot guarantee that the exploded columns have matching element counts."
+ )
+ raise NotImplementedError(msg)
+
+ col_to_explode = col(columns[0])
+ rel = self.native
+ original_columns = self.columns
+
+ not_null_condition = col_to_explode.isnotnull() & FunctionExpression(
+ "len", col_to_explode
+ ) > lit(0)
+ non_null_rel = rel.filter(not_null_condition).select(
+ *(
+ FunctionExpression("unnest", col_to_explode).alias(name)
+ if name in columns
+ else name
+ for name in original_columns
+ )
+ )
+
+ null_rel = rel.filter(~not_null_condition).select(
+ *(
+ lit(None).alias(name) if name in columns else name
+ for name in original_columns
+ )
+ )
+
+ return self._with_native(non_null_rel.union(null_rel))
+
+ def unpivot(
+ self,
+ on: Sequence[str] | None,
+ index: Sequence[str] | None,
+ variable_name: str,
+ value_name: str,
+ ) -> Self:
+ index_ = [] if index is None else index
+ on_ = [c for c in self.columns if c not in index_] if on is None else on
+
+ if variable_name == "":
+ msg = "`variable_name` cannot be empty string for duckdb backend."
+ raise NotImplementedError(msg)
+
+ if value_name == "":
+ msg = "`value_name` cannot be empty string for duckdb backend."
+ raise NotImplementedError(msg)
+
+ unpivot_on = ", ".join(str(col(name)) for name in on_)
+ rel = self.native # noqa: F841
+ # Replace with Python API once
+ # https://github.com/duckdb/duckdb/discussions/16980 is addressed.
+ query = f"""
+ unpivot rel
+ on {unpivot_on}
+ into
+ name "{variable_name}"
+ value "{value_name}"
+ """
+ return self._with_native(
+ duckdb.sql(query).select(*[*index_, variable_name, value_name])
+ )
+
+ gather_every = not_implemented.deprecated(
+ "`LazyFrame.gather_every` is deprecated and will be removed in a future version."
+ )
+ tail = not_implemented.deprecated(
+ "`LazyFrame.tail` is deprecated and will be removed in a future version."
+ )
+ with_row_index = not_implemented()