aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/plotly/express/_imshow.py
blob: ce6ddb8428631327e1fa9479c2042c2189f08890 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
import plotly.graph_objs as go
from _plotly_utils.basevalidators import ColorscaleValidator
from ._core import apply_default_cascade, init_figure, configure_animation_controls
from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types
import narwhals.stable.v1 as nw
import numpy as np
import itertools
from plotly.utils import image_array_to_data_uri

try:
    import xarray

    xarray_imported = True
except ImportError:
    xarray_imported = False

_float_types = []


def _vectorize_zvalue(z, mode="max"):
    alpha = 255 if mode == "max" else 0
    if z is None:
        return z
    elif np.isscalar(z):
        return [z] * 3 + [alpha]
    elif len(z) == 1:
        return list(z) * 3 + [alpha]
    elif len(z) == 3:
        return list(z) + [alpha]
    elif len(z) == 4:
        return z
    else:
        raise ValueError(
            "zmax can be a scalar, or an iterable of length 1, 3 or 4. "
            "A value of %s was passed for zmax." % str(z)
        )


def _infer_zmax_from_type(img):
    dt = img.dtype.type
    rtol = 1.05
    if dt in _integer_types:
        return _integer_ranges[dt][1]
    else:
        im_max = img[np.isfinite(img)].max()
        if im_max <= 1 * rtol:
            return 1
        elif im_max <= 255 * rtol:
            return 255
        elif im_max <= 65535 * rtol:
            return 65535
        else:
            return 2**32


def imshow(
    img,
    zmin=None,
    zmax=None,
    origin=None,
    labels={},
    x=None,
    y=None,
    animation_frame=None,
    facet_col=None,
    facet_col_wrap=None,
    facet_col_spacing=None,
    facet_row_spacing=None,
    color_continuous_scale=None,
    color_continuous_midpoint=None,
    range_color=None,
    title=None,
    template=None,
    width=None,
    height=None,
    aspect=None,
    contrast_rescaling=None,
    binary_string=None,
    binary_backend="auto",
    binary_compression_level=4,
    binary_format="png",
    text_auto=False,
) -> go.Figure:
    """
    Display an image, i.e. data on a 2D regular raster.

    Parameters
    ----------

    img: array-like image, or xarray
        The image data. Supported array shapes are

        - (M, N): an image with scalar data. The data is visualized
          using a colormap.
        - (M, N, 3): an image with RGB values.
        - (M, N, 4): an image with RGBA values, i.e. including transparency.

    zmin, zmax : scalar or iterable, optional
        zmin and zmax define the scalar range that the colormap covers. By default,
        zmin and zmax correspond to the min and max values of the datatype for integer
        datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For
        a multichannel image of floats, the max of the image is computed and zmax is the
        smallest power of 256 (1, 255, 65535) greater than this max value,
        with a 5% tolerance. For a single-channel image, the max of the image is used.
        Overridden by range_color.

    origin : str, 'upper' or 'lower' (default 'upper')
        position of the [0, 0] pixel of the image array, in the upper left or lower left
        corner. The convention 'upper' is typically used for matrices and images.

    labels : dict with str keys and str values (default `{}`)
        Sets names used in the figure for axis titles (keys ``x`` and ``y``),
        colorbar title and hoverlabel (key ``color``). The values should correspond
        to the desired label to be displayed. If ``img`` is an xarray, dimension
        names are used for axis titles, and long name for the colorbar title
        (unless overridden in ``labels``). Possible keys are: x, y, and color.

    x, y: list-like, optional
        x and y are used to label the axes of single-channel heatmap visualizations and
        their lengths must match the lengths of the second and first dimensions of the
        img argument. They are auto-populated if the input is an xarray.

    animation_frame: int or str, optional (default None)
        axis number along which the image array is sliced to create an animation plot.
        If `img` is an xarray, `animation_frame` can be the name of one the dimensions.

    facet_col: int or str, optional (default None)
        axis number along which the image array is sliced to create a facetted plot.
        If `img` is an xarray, `facet_col` can be the name of one the dimensions.

    facet_col_wrap: int
        Maximum number of facet columns. Wraps the column variable at this width,
        so that the column facets span multiple rows.
        Ignored if `facet_col` is None.

    facet_col_spacing: float between 0 and 1
        Spacing between facet columns, in paper units. Default is 0.02.

    facet_row_spacing: float between 0 and 1
        Spacing between facet rows created when ``facet_col_wrap`` is used, in
        paper units. Default is 0.0.7.

    color_continuous_scale : str or list of str
        colormap used to map scalar data to colors (for a 2D image). This parameter is
        not used for RGB or RGBA images. If a string is provided, it should be the name
        of a known color scale, and if a list is provided, it should be a list of CSS-
        compatible colors.

    color_continuous_midpoint : number
        If set, computes the bounds of the continuous color scale to have the desired
        midpoint. Overridden by range_color or zmin and zmax.

    range_color : list of two numbers
        If provided, overrides auto-scaling on the continuous color scale, including
        overriding `color_continuous_midpoint`. Also overrides zmin and zmax. Used only
        for single-channel images.

    title : str
        The figure title.

    template : str or dict or plotly.graph_objects.layout.Template instance
        The figure template name or definition.

    width : number
        The figure width in pixels.

    height: number
        The figure height in pixels.

    aspect: 'equal', 'auto', or None
      - 'equal': Ensures an aspect ratio of 1 or pixels (square pixels)
      - 'auto': The axes is kept fixed and the aspect ratio of pixels is
        adjusted so that the data fit in the axes. In general, this will
        result in non-square pixels.
      - if None, 'equal' is used for numpy arrays and 'auto' for xarrays
        (which have typically heterogeneous coordinates)

    contrast_rescaling: 'minmax', 'infer', or None
        how to determine data values corresponding to the bounds of the color
        range, when zmin or zmax are not passed. If `minmax`, the min and max
        values of the image are used. If `infer`, a heuristic based on the image
        data type is used.

    binary_string: bool, default None
        if True, the image data are first rescaled and encoded as uint8 and
        then passed to plotly.js as a b64 PNG string. If False, data are passed
        unchanged as a numerical array. Setting to True may lead to performance
        gains, at the cost of a loss of precision depending on the original data
        type. If None, use_binary_string is set to True for multichannel (eg) RGB
        arrays, and to False for single-channel (2D) arrays. 2D arrays are
        represented as grayscale and with no colorbar if use_binary_string is
        True.

    binary_backend: str, 'auto' (default), 'pil' or 'pypng'
        Third-party package for the transformation of numpy arrays to
        png b64 strings. If 'auto', Pillow is used if installed,  otherwise
        pypng.

    binary_compression_level: int, between 0 and 9 (default 4)
        png compression level to be passed to the backend when transforming an
        array to a png b64 string. Increasing `binary_compression` decreases the
        size of the png string, but the compression step takes more time. For most
        images it is not worth using levels greater than 5, but it's possible to
        test `len(fig.data[0].source)` and to time the execution of `imshow` to
        tune the level of compression. 0 means no compression (not recommended).

    binary_format: str, 'png' (default) or 'jpg'
        compression format used to generate b64 string. 'png' is recommended
        since it uses lossless compression, but 'jpg' (lossy) compression can
        result if smaller binary strings for natural images.

    text_auto: bool or str (default `False`)
        If `True` or a string, single-channel `img` values will be displayed as text.
        A string like `'.2f'` will be interpreted as a `texttemplate` numeric formatting directive.

    Returns
    -------
    fig : graph_objects.Figure containing the displayed image

    See also
    --------

    plotly.graph_objects.Image : image trace
    plotly.graph_objects.Heatmap : heatmap trace

    Notes
    -----

    In order to update and customize the returned figure, use
    `go.Figure.update_traces` or `go.Figure.update_layout`.

    If an xarray is passed, dimensions names and coordinates are used for
    axes labels and ticks.
    """
    args = locals()
    apply_default_cascade(args)
    labels = labels.copy()
    nslices_facet = 1
    if facet_col is not None:
        if isinstance(facet_col, str):
            facet_col = img.dims.index(facet_col)
        nslices_facet = img.shape[facet_col]
        facet_slices = range(nslices_facet)
        ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet
        nrows = (
            nslices_facet // ncols + 1
            if nslices_facet % ncols
            else nslices_facet // ncols
        )
    else:
        nrows = 1
        ncols = 1
    if animation_frame is not None:
        if isinstance(animation_frame, str):
            animation_frame = img.dims.index(animation_frame)
        nslices_animation = img.shape[animation_frame]
        animation_slices = range(nslices_animation)
    slice_dimensions = (facet_col is not None) + (
        animation_frame is not None
    )  # 0, 1, or 2
    facet_label = None
    animation_label = None
    img_is_xarray = False
    # ----- Define x and y, set labels if img is an xarray -------------------
    if xarray_imported and isinstance(img, xarray.DataArray):
        dims = list(img.dims)
        img_is_xarray = True
        pop_indexes = []
        if facet_col is not None:
            facet_slices = img.coords[img.dims[facet_col]].values
            pop_indexes.append(facet_col)
            facet_label = img.dims[facet_col]
        if animation_frame is not None:
            animation_slices = img.coords[img.dims[animation_frame]].values
            pop_indexes.append(animation_frame)
            animation_label = img.dims[animation_frame]
        # Remove indices in sorted order.
        for index in sorted(pop_indexes, reverse=True):
            _ = dims.pop(index)
        y_label, x_label = dims[0], dims[1]
        # np.datetime64 is not handled correctly by go.Heatmap
        for ax in [x_label, y_label]:
            if np.issubdtype(img.coords[ax].dtype, np.datetime64):
                img.coords[ax] = img.coords[ax].astype(str)
        if x is None:
            x = img.coords[x_label].values
        if y is None:
            y = img.coords[y_label].values
        if aspect is None:
            aspect = "auto"
        if labels.get("x", None) is None:
            labels["x"] = x_label
        if labels.get("y", None) is None:
            labels["y"] = y_label
        if labels.get("animation_frame", None) is None:
            labels["animation_frame"] = animation_label
        if labels.get("facet_col", None) is None:
            labels["facet_col"] = facet_label
        if labels.get("color", None) is None:
            labels["color"] = xarray.plot.utils.label_from_attrs(img)
            labels["color"] = labels["color"].replace("\n", "<br>")
    else:
        if hasattr(img, "columns") and hasattr(img.columns, "__len__"):
            if x is None:
                x = img.columns
            if labels.get("x", None) is None and hasattr(img.columns, "name"):
                labels["x"] = img.columns.name or ""
        if hasattr(img, "index") and hasattr(img.index, "__len__"):
            if y is None:
                y = img.index
            if labels.get("y", None) is None and hasattr(img.index, "name"):
                labels["y"] = img.index.name or ""

        if labels.get("x", None) is None:
            labels["x"] = ""
        if labels.get("y", None) is None:
            labels["y"] = ""
        if labels.get("color", None) is None:
            labels["color"] = ""
        if aspect is None:
            aspect = "equal"

    # --- Set the value of binary_string (forbidden for pandas)
    img = nw.from_native(img, pass_through=True)
    if isinstance(img, nw.DataFrame):
        if binary_string:
            raise ValueError("Binary strings cannot be used with pandas arrays")
        is_dataframe = True
    else:
        is_dataframe = False

    # --------------- Starting from here img is always a numpy array --------
    img = np.asanyarray(img)
    # Reshape array so that animation dimension comes first, then facets, then images
    if facet_col is not None:
        img = np.moveaxis(img, facet_col, 0)
        if animation_frame is not None and animation_frame < facet_col:
            animation_frame += 1
        facet_col = True
    if animation_frame is not None:
        img = np.moveaxis(img, animation_frame, 0)
        animation_frame = True
        args["animation_frame"] = (
            "animation_frame"
            if labels.get("animation_frame") is None
            else labels["animation_frame"]
        )
    iterables = ()
    if animation_frame is not None:
        iterables += (range(nslices_animation),)
    if facet_col is not None:
        iterables += (range(nslices_facet),)

    # Default behaviour of binary_string: True for RGB images, False for 2D
    if binary_string is None:
        binary_string = img.ndim >= (3 + slice_dimensions) and not is_dataframe

    # Cast bools to uint8 (also one byte)
    if img.dtype == bool:
        img = 255 * img.astype(np.uint8)

    if range_color is not None:
        zmin = range_color[0]
        zmax = range_color[1]

    # -------- Contrast rescaling: either minmax or infer ------------------
    if contrast_rescaling is None:
        contrast_rescaling = "minmax" if img.ndim == (2 + slice_dimensions) else "infer"

    # We try to set zmin and zmax only if necessary, because traces have good defaults
    if contrast_rescaling == "minmax":
        # When using binary_string and minmax we need to set zmin and zmax to rescale the image
        if (zmin is not None or binary_string) and zmax is None:
            zmax = img.max()
        if (zmax is not None or binary_string) and zmin is None:
            zmin = img.min()
    else:
        # For uint8 data and infer we let zmin and zmax to be None if passed as None
        if zmax is None and img.dtype != np.uint8:
            zmax = _infer_zmax_from_type(img)
        if zmin is None and zmax is not None:
            zmin = 0

    # For 2d data, use Heatmap trace, unless binary_string is True
    if img.ndim == 2 + slice_dimensions and not binary_string:
        y_index = slice_dimensions
        if y is not None and img.shape[y_index] != len(y):
            raise ValueError(
                "The length of the y vector must match the length of the first "
                + "dimension of the img matrix."
            )
        x_index = slice_dimensions + 1
        if x is not None and img.shape[x_index] != len(x):
            raise ValueError(
                "The length of the x vector must match the length of the second "
                + "dimension of the img matrix."
            )

        texttemplate = None
        if text_auto is True:
            texttemplate = "%{z}"
        elif text_auto is not False:
            texttemplate = "%{z:" + text_auto + "}"

        traces = [
            go.Heatmap(
                x=x,
                y=y,
                z=img[index_tup],
                coloraxis="coloraxis1",
                name=str(i),
                texttemplate=texttemplate,
            )
            for i, index_tup in enumerate(itertools.product(*iterables))
        ]
        autorange = True if origin == "lower" else "reversed"
        layout = dict(yaxis=dict(autorange=autorange))
        if aspect == "equal":
            layout["xaxis"] = dict(scaleanchor="y", constrain="domain")
            layout["yaxis"]["constrain"] = "domain"
        colorscale_validator = ColorscaleValidator("colorscale", "imshow")
        layout["coloraxis1"] = dict(
            colorscale=colorscale_validator.validate_coerce(
                args["color_continuous_scale"]
            ),
            cmid=color_continuous_midpoint,
            cmin=zmin,
            cmax=zmax,
        )
        if labels["color"]:
            layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"])

    # For 2D+RGB data, use Image trace
    elif (
        img.ndim >= 3
        and (img.shape[-1] in [3, 4] or slice_dimensions and binary_string)
    ) or (img.ndim == 2 and binary_string):
        rescale_image = True  # to check whether image has been modified
        if zmin is not None and zmax is not None:
            zmin, zmax = (
                _vectorize_zvalue(zmin, mode="min"),
                _vectorize_zvalue(zmax, mode="max"),
            )
        x0, y0, dx, dy = (None,) * 4
        error_msg_xarray = (
            "Non-numerical coordinates were passed with xarray `img`, but "
            "the Image trace cannot handle it. Please use `binary_string=False` "
            "for 2D data or pass instead the numpy array `img.values` to `px.imshow`."
        )
        if x is not None:
            x = np.asanyarray(x)
            if np.issubdtype(x.dtype, np.number):
                x0 = x[0]
                dx = x[1] - x[0]
            else:
                error_msg = (
                    error_msg_xarray
                    if img_is_xarray
                    else (
                        "Only numerical values are accepted for the `x` parameter "
                        "when an Image trace is used."
                    )
                )
                raise ValueError(error_msg)
        if y is not None:
            y = np.asanyarray(y)
            if np.issubdtype(y.dtype, np.number):
                y0 = y[0]
                dy = y[1] - y[0]
            else:
                error_msg = (
                    error_msg_xarray
                    if img_is_xarray
                    else (
                        "Only numerical values are accepted for the `y` parameter "
                        "when an Image trace is used."
                    )
                )
                raise ValueError(error_msg)
        if binary_string:
            if zmin is None and zmax is None:  # no rescaling, faster
                img_rescaled = img
                rescale_image = False
            elif img.ndim == 2 + slice_dimensions:  # single-channel image
                img_rescaled = rescale_intensity(
                    img, in_range=(zmin[0], zmax[0]), out_range=np.uint8
                )
            else:
                img_rescaled = np.stack(
                    [
                        rescale_intensity(
                            img[..., ch],
                            in_range=(zmin[ch], zmax[ch]),
                            out_range=np.uint8,
                        )
                        for ch in range(img.shape[-1])
                    ],
                    axis=-1,
                )
            img_str = [
                image_array_to_data_uri(
                    img_rescaled[index_tup],
                    backend=binary_backend,
                    compression=binary_compression_level,
                    ext=binary_format,
                )
                for index_tup in itertools.product(*iterables)
            ]

            traces = [
                go.Image(source=img_str_slice, name=str(i), x0=x0, y0=y0, dx=dx, dy=dy)
                for i, img_str_slice in enumerate(img_str)
            ]
        else:
            colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
            traces = [
                go.Image(
                    z=img[index_tup],
                    zmin=zmin,
                    zmax=zmax,
                    colormodel=colormodel,
                    x0=x0,
                    y0=y0,
                    dx=dx,
                    dy=dy,
                )
                for index_tup in itertools.product(*iterables)
            ]
        layout = {}
        if origin == "lower" or (dy is not None and dy < 0):
            layout["yaxis"] = dict(autorange=True)
        if dx is not None and dx < 0:
            layout["xaxis"] = dict(autorange="reversed")
    else:
        raise ValueError(
            "px.imshow only accepts 2D single-channel, RGB or RGBA images. "
            "An image of shape %s was provided. "
            "Alternatively, 3- or 4-D single or multichannel datasets can be "
            "visualized using the `facet_col` or/and `animation_frame` arguments."
            % str(img.shape)
        )

    # Now build figure
    col_labels = []
    if facet_col is not None:
        slice_label = (
            "facet_col" if labels.get("facet_col") is None else labels["facet_col"]
        )
        col_labels = [f"{slice_label}={i}" for i in facet_slices]
    fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
    for attr_name in ["height", "width"]:
        if args[attr_name]:
            layout[attr_name] = args[attr_name]
    if args["title"]:
        layout["title_text"] = args["title"]
    elif args["template"].layout.margin.t is None:
        layout["margin"] = {"t": 60}

    frame_list = []
    for index, trace in enumerate(traces):
        if (facet_col and index < nrows * ncols) or index == 0:
            fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
    if animation_frame is not None:
        for i, index in zip(range(nslices_animation), animation_slices):
            frame_list.append(
                dict(
                    data=traces[nslices_facet * i : nslices_facet * (i + 1)],
                    layout=layout,
                    name=str(index),
                )
            )
    if animation_frame:
        fig.frames = frame_list
    fig.update_layout(layout)
    # Hover name, z or color
    if binary_string and rescale_image and not np.all(img == img_rescaled):
        # we rescaled the image, hence z is not displayed in hover since it does
        # not correspond to img values
        hovertemplate = "%s: %%{x}<br>%s: %%{y}<extra></extra>" % (
            labels["x"] or "x",
            labels["y"] or "y",
        )
    else:
        if trace["type"] == "heatmap":
            hover_name = "%{z}"
        elif img.ndim == 2:
            hover_name = "%{z[0]}"
        elif img.ndim == 3 and img.shape[-1] == 3:
            hover_name = "[%{z[0]}, %{z[1]}, %{z[2]}]"
        else:
            hover_name = "%{z}"
        hovertemplate = "%s: %%{x}<br>%s: %%{y}<br>%s: %s<extra></extra>" % (
            labels["x"] or "x",
            labels["y"] or "y",
            labels["color"] or "color",
            hover_name,
        )
    fig.update_traces(hovertemplate=hovertemplate)
    if labels["x"]:
        fig.update_xaxes(title_text=labels["x"], row=1)
    if labels["y"]:
        fig.update_yaxes(title_text=labels["y"], col=1)
    configure_animation_controls(args, go.Image, fig)
    fig.update_layout(template=args["template"], overwrite=True)
    return fig