aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_duckdb/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_duckdb/utils.py')
-rw-r--r--venv/lib/python3.8/site-packages/narwhals/_duckdb/utils.py287
1 files changed, 287 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_duckdb/utils.py b/venv/lib/python3.8/site-packages/narwhals/_duckdb/utils.py
new file mode 100644
index 0000000..c5d4872
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/narwhals/_duckdb/utils.py
@@ -0,0 +1,287 @@
+from __future__ import annotations
+
+from functools import lru_cache
+from typing import TYPE_CHECKING, Any
+
+import duckdb
+
+from narwhals._utils import Version, isinstance_or_issubclass
+
+if TYPE_CHECKING:
+ from duckdb import DuckDBPyRelation, Expression
+ from duckdb.typing import DuckDBPyType
+
+ from narwhals._duckdb.dataframe import DuckDBLazyFrame
+ from narwhals._duckdb.expr import DuckDBExpr
+ from narwhals.dtypes import DType
+ from narwhals.typing import IntoDType
+
+UNITS_DICT = {
+ "y": "year",
+ "q": "quarter",
+ "mo": "month",
+ "d": "day",
+ "h": "hour",
+ "m": "minute",
+ "s": "second",
+ "ms": "millisecond",
+ "us": "microsecond",
+ "ns": "nanosecond",
+}
+
+col = duckdb.ColumnExpression
+"""Alias for `duckdb.ColumnExpression`."""
+
+lit = duckdb.ConstantExpression
+"""Alias for `duckdb.ConstantExpression`."""
+
+when = duckdb.CaseExpression
+"""Alias for `duckdb.CaseExpression`."""
+
+
+def concat_str(*exprs: Expression, separator: str = "") -> Expression:
+ """Concatenate many strings, NULL inputs are skipped.
+
+ Wraps [concat] and [concat_ws] `FunctionExpression`(s).
+
+ Arguments:
+ exprs: Native columns.
+ separator: String that will be used to separate the values of each column.
+
+ Returns:
+ A new native expression.
+
+ [concat]: https://duckdb.org/docs/stable/sql/functions/char.html#concatstring-
+ [concat_ws]: https://duckdb.org/docs/stable/sql/functions/char.html#concat_wsseparator-string-
+ """
+ return (
+ duckdb.FunctionExpression("concat_ws", lit(separator), *exprs)
+ if separator
+ else duckdb.FunctionExpression("concat", *exprs)
+ )
+
+
+def evaluate_exprs(
+ df: DuckDBLazyFrame, /, *exprs: DuckDBExpr
+) -> list[tuple[str, Expression]]:
+ native_results: list[tuple[str, Expression]] = []
+ for expr in exprs:
+ native_series_list = expr._call(df)
+ output_names = expr._evaluate_output_names(df)
+ if expr._alias_output_names is not None:
+ output_names = expr._alias_output_names(output_names)
+ if len(output_names) != len(native_series_list): # pragma: no cover
+ msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
+ raise AssertionError(msg)
+ native_results.extend(zip(output_names, native_series_list))
+ return native_results
+
+
+class DeferredTimeZone:
+ """Object which gets passed between `native_to_narwhals_dtype` calls.
+
+ DuckDB stores the time zone in the connection, rather than in the dtypes, so
+ this ensures that when calculating the schema of a dataframe with multiple
+ timezone-aware columns, that the connection's time zone is only fetched once.
+
+ Note: we cannot make the time zone a cached `DuckDBLazyFrame` property because
+ the time zone can be modified after `DuckDBLazyFrame` creation:
+
+ ```python
+ df = nw.from_native(rel)
+ print(df.collect_schema())
+ rel.query("set timezone = 'Asia/Kolkata'")
+ print(df.collect_schema()) # should change to reflect new time zone
+ ```
+ """
+
+ _cached_time_zone: str | None = None
+
+ def __init__(self, rel: DuckDBPyRelation) -> None:
+ self._rel = rel
+
+ @property
+ def time_zone(self) -> str:
+ """Fetch relation time zone (if it wasn't calculated already)."""
+ if self._cached_time_zone is None:
+ self._cached_time_zone = fetch_rel_time_zone(self._rel)
+ return self._cached_time_zone
+
+
+def native_to_narwhals_dtype(
+ duckdb_dtype: DuckDBPyType, version: Version, deferred_time_zone: DeferredTimeZone
+) -> DType:
+ duckdb_dtype_id = duckdb_dtype.id
+ dtypes = version.dtypes
+
+ # Handle nested data types first
+ if duckdb_dtype_id == "list":
+ return dtypes.List(
+ native_to_narwhals_dtype(duckdb_dtype.child, version, deferred_time_zone)
+ )
+
+ if duckdb_dtype_id == "struct":
+ children = duckdb_dtype.children
+ return dtypes.Struct(
+ [
+ dtypes.Field(
+ name=child[0],
+ dtype=native_to_narwhals_dtype(child[1], version, deferred_time_zone),
+ )
+ for child in children
+ ]
+ )
+
+ if duckdb_dtype_id == "array":
+ child, size = duckdb_dtype.children
+ shape: list[int] = [size[1]]
+
+ while child[1].id == "array":
+ child, size = child[1].children
+ shape.insert(0, size[1])
+
+ inner = native_to_narwhals_dtype(child[1], version, deferred_time_zone)
+ return dtypes.Array(inner=inner, shape=tuple(shape))
+
+ if duckdb_dtype_id == "enum":
+ if version is Version.V1:
+ return dtypes.Enum() # type: ignore[call-arg]
+ categories = duckdb_dtype.children[0][1]
+ return dtypes.Enum(categories=categories)
+
+ if duckdb_dtype_id == "timestamp with time zone":
+ return dtypes.Datetime(time_zone=deferred_time_zone.time_zone)
+
+ return _non_nested_native_to_narwhals_dtype(duckdb_dtype_id, version)
+
+
+def fetch_rel_time_zone(rel: duckdb.DuckDBPyRelation) -> str:
+ result = rel.query(
+ "duckdb_settings()", "select value from duckdb_settings() where name = 'TimeZone'"
+ ).fetchone()
+ assert result is not None # noqa: S101
+ return result[0] # type: ignore[no-any-return]
+
+
+@lru_cache(maxsize=16)
+def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) -> DType:
+ dtypes = version.dtypes
+ return {
+ "hugeint": dtypes.Int128(),
+ "bigint": dtypes.Int64(),
+ "integer": dtypes.Int32(),
+ "smallint": dtypes.Int16(),
+ "tinyint": dtypes.Int8(),
+ "uhugeint": dtypes.UInt128(),
+ "ubigint": dtypes.UInt64(),
+ "uinteger": dtypes.UInt32(),
+ "usmallint": dtypes.UInt16(),
+ "utinyint": dtypes.UInt8(),
+ "double": dtypes.Float64(),
+ "float": dtypes.Float32(),
+ "varchar": dtypes.String(),
+ "date": dtypes.Date(),
+ "timestamp": dtypes.Datetime(),
+ "boolean": dtypes.Boolean(),
+ "interval": dtypes.Duration(),
+ "decimal": dtypes.Decimal(),
+ "time": dtypes.Time(),
+ "blob": dtypes.Binary(),
+ }.get(duckdb_dtype_id, dtypes.Unknown())
+
+
+def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> str: # noqa: C901, PLR0912, PLR0915
+ dtypes = version.dtypes
+ if isinstance_or_issubclass(dtype, dtypes.Decimal):
+ msg = "Casting to Decimal is not supported yet."
+ raise NotImplementedError(msg)
+ if isinstance_or_issubclass(dtype, dtypes.Float64):
+ return "DOUBLE"
+ if isinstance_or_issubclass(dtype, dtypes.Float32):
+ return "FLOAT"
+ if isinstance_or_issubclass(dtype, dtypes.Int128):
+ return "INT128"
+ if isinstance_or_issubclass(dtype, dtypes.Int64):
+ return "BIGINT"
+ if isinstance_or_issubclass(dtype, dtypes.Int32):
+ return "INTEGER"
+ if isinstance_or_issubclass(dtype, dtypes.Int16):
+ return "SMALLINT"
+ if isinstance_or_issubclass(dtype, dtypes.Int8):
+ return "TINYINT"
+ if isinstance_or_issubclass(dtype, dtypes.UInt128):
+ return "UINT128"
+ if isinstance_or_issubclass(dtype, dtypes.UInt64):
+ return "UBIGINT"
+ if isinstance_or_issubclass(dtype, dtypes.UInt32):
+ return "UINTEGER"
+ if isinstance_or_issubclass(dtype, dtypes.UInt16): # pragma: no cover
+ return "USMALLINT"
+ if isinstance_or_issubclass(dtype, dtypes.UInt8): # pragma: no cover
+ return "UTINYINT"
+ if isinstance_or_issubclass(dtype, dtypes.String):
+ return "VARCHAR"
+ if isinstance_or_issubclass(dtype, dtypes.Boolean): # pragma: no cover
+ return "BOOLEAN"
+ if isinstance_or_issubclass(dtype, dtypes.Time):
+ return "TIME"
+ if isinstance_or_issubclass(dtype, dtypes.Binary):
+ return "BLOB"
+ if isinstance_or_issubclass(dtype, dtypes.Categorical):
+ msg = "Categorical not supported by DuckDB"
+ raise NotImplementedError(msg)
+ if isinstance_or_issubclass(dtype, dtypes.Enum):
+ if version is Version.V1:
+ msg = "Converting to Enum is not supported in narwhals.stable.v1"
+ raise NotImplementedError(msg)
+ if isinstance(dtype, dtypes.Enum):
+ categories = "'" + "', '".join(dtype.categories) + "'"
+ return f"ENUM ({categories})"
+ msg = "Can not cast / initialize Enum without categories present"
+ raise ValueError(msg)
+
+ if isinstance_or_issubclass(dtype, dtypes.Datetime):
+ _time_unit = dtype.time_unit
+ _time_zone = dtype.time_zone
+ msg = "todo"
+ raise NotImplementedError(msg)
+ if isinstance_or_issubclass(dtype, dtypes.Duration): # pragma: no cover
+ _time_unit = dtype.time_unit
+ msg = "todo"
+ raise NotImplementedError(msg)
+ if isinstance_or_issubclass(dtype, dtypes.Date): # pragma: no cover
+ return "DATE"
+ if isinstance_or_issubclass(dtype, dtypes.List):
+ inner = narwhals_to_native_dtype(dtype.inner, version)
+ return f"{inner}[]"
+ if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
+ inner = ", ".join(
+ f'"{field.name}" {narwhals_to_native_dtype(field.dtype, version)}'
+ for field in dtype.fields
+ )
+ return f"STRUCT({inner})"
+ if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
+ shape = dtype.shape
+ duckdb_shape_fmt = "".join(f"[{item}]" for item in shape)
+ inner_dtype: Any = dtype
+ for _ in shape:
+ inner_dtype = inner_dtype.inner
+ duckdb_inner = narwhals_to_native_dtype(inner_dtype, version)
+ return f"{duckdb_inner}{duckdb_shape_fmt}"
+ msg = f"Unknown dtype: {dtype}" # pragma: no cover
+ raise AssertionError(msg)
+
+
+def generate_partition_by_sql(*partition_by: str | Expression) -> str:
+ if not partition_by:
+ return ""
+ by_sql = ", ".join([f"{col(x) if isinstance(x, str) else x}" for x in partition_by])
+ return f"partition by {by_sql}"
+
+
+def generate_order_by_sql(*order_by: str, ascending: bool) -> str:
+ if ascending:
+ by_sql = ", ".join([f"{col(x)} asc nulls first" for x in order_by])
+ else:
+ by_sql = ", ".join([f"{col(x)} desc nulls last" for x in order_by])
+ return f"order by {by_sql}"