aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/narwhals/_dask/expr_dt.py
blob: 14481a3b2b6c5301dfdd80d9b6aa2087f38d972f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from __future__ import annotations

from typing import TYPE_CHECKING

from narwhals._duration import parse_interval_string
from narwhals._pandas_like.utils import (
    UNIT_DICT,
    calculate_timestamp_date,
    calculate_timestamp_datetime,
    native_to_narwhals_dtype,
)
from narwhals._utils import Implementation

if TYPE_CHECKING:
    import dask.dataframe.dask_expr as dx

    from narwhals._dask.expr import DaskExpr
    from narwhals.typing import TimeUnit


class DaskExprDateTimeNamespace:
    def __init__(self, expr: DaskExpr) -> None:
        self._compliant_expr = expr

    def date(self) -> DaskExpr:
        return self._compliant_expr._with_callable(lambda expr: expr.dt.date, "date")

    def year(self) -> DaskExpr:
        return self._compliant_expr._with_callable(lambda expr: expr.dt.year, "year")

    def month(self) -> DaskExpr:
        return self._compliant_expr._with_callable(lambda expr: expr.dt.month, "month")

    def day(self) -> DaskExpr:
        return self._compliant_expr._with_callable(lambda expr: expr.dt.day, "day")

    def hour(self) -> DaskExpr:
        return self._compliant_expr._with_callable(lambda expr: expr.dt.hour, "hour")

    def minute(self) -> DaskExpr:
        return self._compliant_expr._with_callable(lambda expr: expr.dt.minute, "minute")

    def second(self) -> DaskExpr:
        return self._compliant_expr._with_callable(lambda expr: expr.dt.second, "second")

    def millisecond(self) -> DaskExpr:
        return self._compliant_expr._with_callable(
            lambda expr: expr.dt.microsecond // 1000, "millisecond"
        )

    def microsecond(self) -> DaskExpr:
        return self._compliant_expr._with_callable(
            lambda expr: expr.dt.microsecond, "microsecond"
        )

    def nanosecond(self) -> DaskExpr:
        return self._compliant_expr._with_callable(
            lambda expr: expr.dt.microsecond * 1000 + expr.dt.nanosecond, "nanosecond"
        )

    def ordinal_day(self) -> DaskExpr:
        return self._compliant_expr._with_callable(
            lambda expr: expr.dt.dayofyear, "ordinal_day"
        )

    def weekday(self) -> DaskExpr:
        return self._compliant_expr._with_callable(
            lambda expr: expr.dt.weekday + 1,  # Dask is 0-6
            "weekday",
        )

    def to_string(self, format: str) -> DaskExpr:
        return self._compliant_expr._with_callable(
            lambda expr, format: expr.dt.strftime(format.replace("%.f", ".%f")),
            "strftime",
            format=format,
        )

    def replace_time_zone(self, time_zone: str | None) -> DaskExpr:
        return self._compliant_expr._with_callable(
            lambda expr, time_zone: expr.dt.tz_localize(None).dt.tz_localize(time_zone)
            if time_zone is not None
            else expr.dt.tz_localize(None),
            "tz_localize",
            time_zone=time_zone,
        )

    def convert_time_zone(self, time_zone: str) -> DaskExpr:
        def func(s: dx.Series, time_zone: str) -> dx.Series:
            dtype = native_to_narwhals_dtype(
                s.dtype, self._compliant_expr._version, Implementation.DASK
            )
            if dtype.time_zone is None:  # type: ignore[attr-defined]
                return s.dt.tz_localize("UTC").dt.tz_convert(time_zone)  # pyright: ignore[reportAttributeAccessIssue]
            else:
                return s.dt.tz_convert(time_zone)  # pyright: ignore[reportAttributeAccessIssue]

        return self._compliant_expr._with_callable(
            func, "tz_convert", time_zone=time_zone
        )

    def timestamp(self, time_unit: TimeUnit) -> DaskExpr:
        def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series:
            dtype = native_to_narwhals_dtype(
                s.dtype, self._compliant_expr._version, Implementation.DASK
            )
            is_pyarrow_dtype = "pyarrow" in str(dtype)
            mask_na = s.isna()
            dtypes = self._compliant_expr._version.dtypes
            if dtype == dtypes.Date:
                # Date is only supported in pandas dtypes if pyarrow-backed
                s_cast = s.astype("Int32[pyarrow]")
                result = calculate_timestamp_date(s_cast, time_unit)
            elif isinstance(dtype, dtypes.Datetime):
                original_time_unit = dtype.time_unit
                s_cast = (
                    s.astype("Int64[pyarrow]") if is_pyarrow_dtype else s.astype("int64")
                )
                result = calculate_timestamp_datetime(
                    s_cast, original_time_unit, time_unit
                )
            else:
                msg = "Input should be either of Date or Datetime type"
                raise TypeError(msg)
            return result.where(~mask_na)  # pyright: ignore[reportReturnType]

        return self._compliant_expr._with_callable(func, "datetime", time_unit=time_unit)

    def total_minutes(self) -> DaskExpr:
        return self._compliant_expr._with_callable(
            lambda expr: expr.dt.total_seconds() // 60, "total_minutes"
        )

    def total_seconds(self) -> DaskExpr:
        return self._compliant_expr._with_callable(
            lambda expr: expr.dt.total_seconds() // 1, "total_seconds"
        )

    def total_milliseconds(self) -> DaskExpr:
        return self._compliant_expr._with_callable(
            lambda expr: expr.dt.total_seconds() * 1000 // 1, "total_milliseconds"
        )

    def total_microseconds(self) -> DaskExpr:
        return self._compliant_expr._with_callable(
            lambda expr: expr.dt.total_seconds() * 1_000_000 // 1, "total_microseconds"
        )

    def total_nanoseconds(self) -> DaskExpr:
        return self._compliant_expr._with_callable(
            lambda expr: expr.dt.total_seconds() * 1_000_000_000 // 1, "total_nanoseconds"
        )

    def truncate(self, every: str) -> DaskExpr:
        multiple, unit = parse_interval_string(every)
        if unit in {"mo", "q", "y"}:
            msg = f"Truncating to {unit} is not supported yet for dask."
            raise NotImplementedError(msg)
        freq = f"{multiple}{UNIT_DICT.get(unit, unit)}"
        return self._compliant_expr._with_callable(
            lambda expr: expr.dt.floor(freq), "truncate"
        )