diff options
Diffstat (limited to 'venv/lib/python3.8/site-packages/narwhals/_duckdb/utils.py')
-rw-r--r-- | venv/lib/python3.8/site-packages/narwhals/_duckdb/utils.py | 287 |
1 files changed, 287 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/narwhals/_duckdb/utils.py b/venv/lib/python3.8/site-packages/narwhals/_duckdb/utils.py new file mode 100644 index 0000000..c5d4872 --- /dev/null +++ b/venv/lib/python3.8/site-packages/narwhals/_duckdb/utils.py @@ -0,0 +1,287 @@ +from __future__ import annotations + +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +import duckdb + +from narwhals._utils import Version, isinstance_or_issubclass + +if TYPE_CHECKING: + from duckdb import DuckDBPyRelation, Expression + from duckdb.typing import DuckDBPyType + + from narwhals._duckdb.dataframe import DuckDBLazyFrame + from narwhals._duckdb.expr import DuckDBExpr + from narwhals.dtypes import DType + from narwhals.typing import IntoDType + +UNITS_DICT = { + "y": "year", + "q": "quarter", + "mo": "month", + "d": "day", + "h": "hour", + "m": "minute", + "s": "second", + "ms": "millisecond", + "us": "microsecond", + "ns": "nanosecond", +} + +col = duckdb.ColumnExpression +"""Alias for `duckdb.ColumnExpression`.""" + +lit = duckdb.ConstantExpression +"""Alias for `duckdb.ConstantExpression`.""" + +when = duckdb.CaseExpression +"""Alias for `duckdb.CaseExpression`.""" + + +def concat_str(*exprs: Expression, separator: str = "") -> Expression: + """Concatenate many strings, NULL inputs are skipped. + + Wraps [concat] and [concat_ws] `FunctionExpression`(s). + + Arguments: + exprs: Native columns. + separator: String that will be used to separate the values of each column. + + Returns: + A new native expression. + + [concat]: https://duckdb.org/docs/stable/sql/functions/char.html#concatstring- + [concat_ws]: https://duckdb.org/docs/stable/sql/functions/char.html#concat_wsseparator-string- + """ + return ( + duckdb.FunctionExpression("concat_ws", lit(separator), *exprs) + if separator + else duckdb.FunctionExpression("concat", *exprs) + ) + + +def evaluate_exprs( + df: DuckDBLazyFrame, /, *exprs: DuckDBExpr +) -> list[tuple[str, Expression]]: + native_results: list[tuple[str, Expression]] = [] + for expr in exprs: + native_series_list = expr._call(df) + output_names = expr._evaluate_output_names(df) + if expr._alias_output_names is not None: + output_names = expr._alias_output_names(output_names) + if len(output_names) != len(native_series_list): # pragma: no cover + msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" + raise AssertionError(msg) + native_results.extend(zip(output_names, native_series_list)) + return native_results + + +class DeferredTimeZone: + """Object which gets passed between `native_to_narwhals_dtype` calls. + + DuckDB stores the time zone in the connection, rather than in the dtypes, so + this ensures that when calculating the schema of a dataframe with multiple + timezone-aware columns, that the connection's time zone is only fetched once. + + Note: we cannot make the time zone a cached `DuckDBLazyFrame` property because + the time zone can be modified after `DuckDBLazyFrame` creation: + + ```python + df = nw.from_native(rel) + print(df.collect_schema()) + rel.query("set timezone = 'Asia/Kolkata'") + print(df.collect_schema()) # should change to reflect new time zone + ``` + """ + + _cached_time_zone: str | None = None + + def __init__(self, rel: DuckDBPyRelation) -> None: + self._rel = rel + + @property + def time_zone(self) -> str: + """Fetch relation time zone (if it wasn't calculated already).""" + if self._cached_time_zone is None: + self._cached_time_zone = fetch_rel_time_zone(self._rel) + return self._cached_time_zone + + +def native_to_narwhals_dtype( + duckdb_dtype: DuckDBPyType, version: Version, deferred_time_zone: DeferredTimeZone +) -> DType: + duckdb_dtype_id = duckdb_dtype.id + dtypes = version.dtypes + + # Handle nested data types first + if duckdb_dtype_id == "list": + return dtypes.List( + native_to_narwhals_dtype(duckdb_dtype.child, version, deferred_time_zone) + ) + + if duckdb_dtype_id == "struct": + children = duckdb_dtype.children + return dtypes.Struct( + [ + dtypes.Field( + name=child[0], + dtype=native_to_narwhals_dtype(child[1], version, deferred_time_zone), + ) + for child in children + ] + ) + + if duckdb_dtype_id == "array": + child, size = duckdb_dtype.children + shape: list[int] = [size[1]] + + while child[1].id == "array": + child, size = child[1].children + shape.insert(0, size[1]) + + inner = native_to_narwhals_dtype(child[1], version, deferred_time_zone) + return dtypes.Array(inner=inner, shape=tuple(shape)) + + if duckdb_dtype_id == "enum": + if version is Version.V1: + return dtypes.Enum() # type: ignore[call-arg] + categories = duckdb_dtype.children[0][1] + return dtypes.Enum(categories=categories) + + if duckdb_dtype_id == "timestamp with time zone": + return dtypes.Datetime(time_zone=deferred_time_zone.time_zone) + + return _non_nested_native_to_narwhals_dtype(duckdb_dtype_id, version) + + +def fetch_rel_time_zone(rel: duckdb.DuckDBPyRelation) -> str: + result = rel.query( + "duckdb_settings()", "select value from duckdb_settings() where name = 'TimeZone'" + ).fetchone() + assert result is not None # noqa: S101 + return result[0] # type: ignore[no-any-return] + + +@lru_cache(maxsize=16) +def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) -> DType: + dtypes = version.dtypes + return { + "hugeint": dtypes.Int128(), + "bigint": dtypes.Int64(), + "integer": dtypes.Int32(), + "smallint": dtypes.Int16(), + "tinyint": dtypes.Int8(), + "uhugeint": dtypes.UInt128(), + "ubigint": dtypes.UInt64(), + "uinteger": dtypes.UInt32(), + "usmallint": dtypes.UInt16(), + "utinyint": dtypes.UInt8(), + "double": dtypes.Float64(), + "float": dtypes.Float32(), + "varchar": dtypes.String(), + "date": dtypes.Date(), + "timestamp": dtypes.Datetime(), + "boolean": dtypes.Boolean(), + "interval": dtypes.Duration(), + "decimal": dtypes.Decimal(), + "time": dtypes.Time(), + "blob": dtypes.Binary(), + }.get(duckdb_dtype_id, dtypes.Unknown()) + + +def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> str: # noqa: C901, PLR0912, PLR0915 + dtypes = version.dtypes + if isinstance_or_issubclass(dtype, dtypes.Decimal): + msg = "Casting to Decimal is not supported yet." + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Float64): + return "DOUBLE" + if isinstance_or_issubclass(dtype, dtypes.Float32): + return "FLOAT" + if isinstance_or_issubclass(dtype, dtypes.Int128): + return "INT128" + if isinstance_or_issubclass(dtype, dtypes.Int64): + return "BIGINT" + if isinstance_or_issubclass(dtype, dtypes.Int32): + return "INTEGER" + if isinstance_or_issubclass(dtype, dtypes.Int16): + return "SMALLINT" + if isinstance_or_issubclass(dtype, dtypes.Int8): + return "TINYINT" + if isinstance_or_issubclass(dtype, dtypes.UInt128): + return "UINT128" + if isinstance_or_issubclass(dtype, dtypes.UInt64): + return "UBIGINT" + if isinstance_or_issubclass(dtype, dtypes.UInt32): + return "UINTEGER" + if isinstance_or_issubclass(dtype, dtypes.UInt16): # pragma: no cover + return "USMALLINT" + if isinstance_or_issubclass(dtype, dtypes.UInt8): # pragma: no cover + return "UTINYINT" + if isinstance_or_issubclass(dtype, dtypes.String): + return "VARCHAR" + if isinstance_or_issubclass(dtype, dtypes.Boolean): # pragma: no cover + return "BOOLEAN" + if isinstance_or_issubclass(dtype, dtypes.Time): + return "TIME" + if isinstance_or_issubclass(dtype, dtypes.Binary): + return "BLOB" + if isinstance_or_issubclass(dtype, dtypes.Categorical): + msg = "Categorical not supported by DuckDB" + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Enum): + if version is Version.V1: + msg = "Converting to Enum is not supported in narwhals.stable.v1" + raise NotImplementedError(msg) + if isinstance(dtype, dtypes.Enum): + categories = "'" + "', '".join(dtype.categories) + "'" + return f"ENUM ({categories})" + msg = "Can not cast / initialize Enum without categories present" + raise ValueError(msg) + + if isinstance_or_issubclass(dtype, dtypes.Datetime): + _time_unit = dtype.time_unit + _time_zone = dtype.time_zone + msg = "todo" + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Duration): # pragma: no cover + _time_unit = dtype.time_unit + msg = "todo" + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Date): # pragma: no cover + return "DATE" + if isinstance_or_issubclass(dtype, dtypes.List): + inner = narwhals_to_native_dtype(dtype.inner, version) + return f"{inner}[]" + if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover + inner = ", ".join( + f'"{field.name}" {narwhals_to_native_dtype(field.dtype, version)}' + for field in dtype.fields + ) + return f"STRUCT({inner})" + if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover + shape = dtype.shape + duckdb_shape_fmt = "".join(f"[{item}]" for item in shape) + inner_dtype: Any = dtype + for _ in shape: + inner_dtype = inner_dtype.inner + duckdb_inner = narwhals_to_native_dtype(inner_dtype, version) + return f"{duckdb_inner}{duckdb_shape_fmt}" + msg = f"Unknown dtype: {dtype}" # pragma: no cover + raise AssertionError(msg) + + +def generate_partition_by_sql(*partition_by: str | Expression) -> str: + if not partition_by: + return "" + by_sql = ", ".join([f"{col(x) if isinstance(x, str) else x}" for x in partition_by]) + return f"partition by {by_sql}" + + +def generate_order_by_sql(*order_by: str, ascending: bool) -> str: + if ascending: + by_sql = ", ".join([f"{col(x)} asc nulls first" for x in order_by]) + else: + by_sql = ", ".join([f"{col(x)} desc nulls last" for x in order_by]) + return f"order by {by_sql}" |