aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/plotly/express/_imshow.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.8/site-packages/plotly/express/_imshow.py')
-rw-r--r--venv/lib/python3.8/site-packages/plotly/express/_imshow.py605
1 files changed, 605 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/plotly/express/_imshow.py b/venv/lib/python3.8/site-packages/plotly/express/_imshow.py
new file mode 100644
index 0000000..ce6ddb8
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/express/_imshow.py
@@ -0,0 +1,605 @@
+import plotly.graph_objs as go
+from _plotly_utils.basevalidators import ColorscaleValidator
+from ._core import apply_default_cascade, init_figure, configure_animation_controls
+from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types
+import narwhals.stable.v1 as nw
+import numpy as np
+import itertools
+from plotly.utils import image_array_to_data_uri
+
+try:
+ import xarray
+
+ xarray_imported = True
+except ImportError:
+ xarray_imported = False
+
+_float_types = []
+
+
+def _vectorize_zvalue(z, mode="max"):
+ alpha = 255 if mode == "max" else 0
+ if z is None:
+ return z
+ elif np.isscalar(z):
+ return [z] * 3 + [alpha]
+ elif len(z) == 1:
+ return list(z) * 3 + [alpha]
+ elif len(z) == 3:
+ return list(z) + [alpha]
+ elif len(z) == 4:
+ return z
+ else:
+ raise ValueError(
+ "zmax can be a scalar, or an iterable of length 1, 3 or 4. "
+ "A value of %s was passed for zmax." % str(z)
+ )
+
+
+def _infer_zmax_from_type(img):
+ dt = img.dtype.type
+ rtol = 1.05
+ if dt in _integer_types:
+ return _integer_ranges[dt][1]
+ else:
+ im_max = img[np.isfinite(img)].max()
+ if im_max <= 1 * rtol:
+ return 1
+ elif im_max <= 255 * rtol:
+ return 255
+ elif im_max <= 65535 * rtol:
+ return 65535
+ else:
+ return 2**32
+
+
+def imshow(
+ img,
+ zmin=None,
+ zmax=None,
+ origin=None,
+ labels={},
+ x=None,
+ y=None,
+ animation_frame=None,
+ facet_col=None,
+ facet_col_wrap=None,
+ facet_col_spacing=None,
+ facet_row_spacing=None,
+ color_continuous_scale=None,
+ color_continuous_midpoint=None,
+ range_color=None,
+ title=None,
+ template=None,
+ width=None,
+ height=None,
+ aspect=None,
+ contrast_rescaling=None,
+ binary_string=None,
+ binary_backend="auto",
+ binary_compression_level=4,
+ binary_format="png",
+ text_auto=False,
+) -> go.Figure:
+ """
+ Display an image, i.e. data on a 2D regular raster.
+
+ Parameters
+ ----------
+
+ img: array-like image, or xarray
+ The image data. Supported array shapes are
+
+ - (M, N): an image with scalar data. The data is visualized
+ using a colormap.
+ - (M, N, 3): an image with RGB values.
+ - (M, N, 4): an image with RGBA values, i.e. including transparency.
+
+ zmin, zmax : scalar or iterable, optional
+ zmin and zmax define the scalar range that the colormap covers. By default,
+ zmin and zmax correspond to the min and max values of the datatype for integer
+ datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For
+ a multichannel image of floats, the max of the image is computed and zmax is the
+ smallest power of 256 (1, 255, 65535) greater than this max value,
+ with a 5% tolerance. For a single-channel image, the max of the image is used.
+ Overridden by range_color.
+
+ origin : str, 'upper' or 'lower' (default 'upper')
+ position of the [0, 0] pixel of the image array, in the upper left or lower left
+ corner. The convention 'upper' is typically used for matrices and images.
+
+ labels : dict with str keys and str values (default `{}`)
+ Sets names used in the figure for axis titles (keys ``x`` and ``y``),
+ colorbar title and hoverlabel (key ``color``). The values should correspond
+ to the desired label to be displayed. If ``img`` is an xarray, dimension
+ names are used for axis titles, and long name for the colorbar title
+ (unless overridden in ``labels``). Possible keys are: x, y, and color.
+
+ x, y: list-like, optional
+ x and y are used to label the axes of single-channel heatmap visualizations and
+ their lengths must match the lengths of the second and first dimensions of the
+ img argument. They are auto-populated if the input is an xarray.
+
+ animation_frame: int or str, optional (default None)
+ axis number along which the image array is sliced to create an animation plot.
+ If `img` is an xarray, `animation_frame` can be the name of one the dimensions.
+
+ facet_col: int or str, optional (default None)
+ axis number along which the image array is sliced to create a facetted plot.
+ If `img` is an xarray, `facet_col` can be the name of one the dimensions.
+
+ facet_col_wrap: int
+ Maximum number of facet columns. Wraps the column variable at this width,
+ so that the column facets span multiple rows.
+ Ignored if `facet_col` is None.
+
+ facet_col_spacing: float between 0 and 1
+ Spacing between facet columns, in paper units. Default is 0.02.
+
+ facet_row_spacing: float between 0 and 1
+ Spacing between facet rows created when ``facet_col_wrap`` is used, in
+ paper units. Default is 0.0.7.
+
+ color_continuous_scale : str or list of str
+ colormap used to map scalar data to colors (for a 2D image). This parameter is
+ not used for RGB or RGBA images. If a string is provided, it should be the name
+ of a known color scale, and if a list is provided, it should be a list of CSS-
+ compatible colors.
+
+ color_continuous_midpoint : number
+ If set, computes the bounds of the continuous color scale to have the desired
+ midpoint. Overridden by range_color or zmin and zmax.
+
+ range_color : list of two numbers
+ If provided, overrides auto-scaling on the continuous color scale, including
+ overriding `color_continuous_midpoint`. Also overrides zmin and zmax. Used only
+ for single-channel images.
+
+ title : str
+ The figure title.
+
+ template : str or dict or plotly.graph_objects.layout.Template instance
+ The figure template name or definition.
+
+ width : number
+ The figure width in pixels.
+
+ height: number
+ The figure height in pixels.
+
+ aspect: 'equal', 'auto', or None
+ - 'equal': Ensures an aspect ratio of 1 or pixels (square pixels)
+ - 'auto': The axes is kept fixed and the aspect ratio of pixels is
+ adjusted so that the data fit in the axes. In general, this will
+ result in non-square pixels.
+ - if None, 'equal' is used for numpy arrays and 'auto' for xarrays
+ (which have typically heterogeneous coordinates)
+
+ contrast_rescaling: 'minmax', 'infer', or None
+ how to determine data values corresponding to the bounds of the color
+ range, when zmin or zmax are not passed. If `minmax`, the min and max
+ values of the image are used. If `infer`, a heuristic based on the image
+ data type is used.
+
+ binary_string: bool, default None
+ if True, the image data are first rescaled and encoded as uint8 and
+ then passed to plotly.js as a b64 PNG string. If False, data are passed
+ unchanged as a numerical array. Setting to True may lead to performance
+ gains, at the cost of a loss of precision depending on the original data
+ type. If None, use_binary_string is set to True for multichannel (eg) RGB
+ arrays, and to False for single-channel (2D) arrays. 2D arrays are
+ represented as grayscale and with no colorbar if use_binary_string is
+ True.
+
+ binary_backend: str, 'auto' (default), 'pil' or 'pypng'
+ Third-party package for the transformation of numpy arrays to
+ png b64 strings. If 'auto', Pillow is used if installed, otherwise
+ pypng.
+
+ binary_compression_level: int, between 0 and 9 (default 4)
+ png compression level to be passed to the backend when transforming an
+ array to a png b64 string. Increasing `binary_compression` decreases the
+ size of the png string, but the compression step takes more time. For most
+ images it is not worth using levels greater than 5, but it's possible to
+ test `len(fig.data[0].source)` and to time the execution of `imshow` to
+ tune the level of compression. 0 means no compression (not recommended).
+
+ binary_format: str, 'png' (default) or 'jpg'
+ compression format used to generate b64 string. 'png' is recommended
+ since it uses lossless compression, but 'jpg' (lossy) compression can
+ result if smaller binary strings for natural images.
+
+ text_auto: bool or str (default `False`)
+ If `True` or a string, single-channel `img` values will be displayed as text.
+ A string like `'.2f'` will be interpreted as a `texttemplate` numeric formatting directive.
+
+ Returns
+ -------
+ fig : graph_objects.Figure containing the displayed image
+
+ See also
+ --------
+
+ plotly.graph_objects.Image : image trace
+ plotly.graph_objects.Heatmap : heatmap trace
+
+ Notes
+ -----
+
+ In order to update and customize the returned figure, use
+ `go.Figure.update_traces` or `go.Figure.update_layout`.
+
+ If an xarray is passed, dimensions names and coordinates are used for
+ axes labels and ticks.
+ """
+ args = locals()
+ apply_default_cascade(args)
+ labels = labels.copy()
+ nslices_facet = 1
+ if facet_col is not None:
+ if isinstance(facet_col, str):
+ facet_col = img.dims.index(facet_col)
+ nslices_facet = img.shape[facet_col]
+ facet_slices = range(nslices_facet)
+ ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet
+ nrows = (
+ nslices_facet // ncols + 1
+ if nslices_facet % ncols
+ else nslices_facet // ncols
+ )
+ else:
+ nrows = 1
+ ncols = 1
+ if animation_frame is not None:
+ if isinstance(animation_frame, str):
+ animation_frame = img.dims.index(animation_frame)
+ nslices_animation = img.shape[animation_frame]
+ animation_slices = range(nslices_animation)
+ slice_dimensions = (facet_col is not None) + (
+ animation_frame is not None
+ ) # 0, 1, or 2
+ facet_label = None
+ animation_label = None
+ img_is_xarray = False
+ # ----- Define x and y, set labels if img is an xarray -------------------
+ if xarray_imported and isinstance(img, xarray.DataArray):
+ dims = list(img.dims)
+ img_is_xarray = True
+ pop_indexes = []
+ if facet_col is not None:
+ facet_slices = img.coords[img.dims[facet_col]].values
+ pop_indexes.append(facet_col)
+ facet_label = img.dims[facet_col]
+ if animation_frame is not None:
+ animation_slices = img.coords[img.dims[animation_frame]].values
+ pop_indexes.append(animation_frame)
+ animation_label = img.dims[animation_frame]
+ # Remove indices in sorted order.
+ for index in sorted(pop_indexes, reverse=True):
+ _ = dims.pop(index)
+ y_label, x_label = dims[0], dims[1]
+ # np.datetime64 is not handled correctly by go.Heatmap
+ for ax in [x_label, y_label]:
+ if np.issubdtype(img.coords[ax].dtype, np.datetime64):
+ img.coords[ax] = img.coords[ax].astype(str)
+ if x is None:
+ x = img.coords[x_label].values
+ if y is None:
+ y = img.coords[y_label].values
+ if aspect is None:
+ aspect = "auto"
+ if labels.get("x", None) is None:
+ labels["x"] = x_label
+ if labels.get("y", None) is None:
+ labels["y"] = y_label
+ if labels.get("animation_frame", None) is None:
+ labels["animation_frame"] = animation_label
+ if labels.get("facet_col", None) is None:
+ labels["facet_col"] = facet_label
+ if labels.get("color", None) is None:
+ labels["color"] = xarray.plot.utils.label_from_attrs(img)
+ labels["color"] = labels["color"].replace("\n", "<br>")
+ else:
+ if hasattr(img, "columns") and hasattr(img.columns, "__len__"):
+ if x is None:
+ x = img.columns
+ if labels.get("x", None) is None and hasattr(img.columns, "name"):
+ labels["x"] = img.columns.name or ""
+ if hasattr(img, "index") and hasattr(img.index, "__len__"):
+ if y is None:
+ y = img.index
+ if labels.get("y", None) is None and hasattr(img.index, "name"):
+ labels["y"] = img.index.name or ""
+
+ if labels.get("x", None) is None:
+ labels["x"] = ""
+ if labels.get("y", None) is None:
+ labels["y"] = ""
+ if labels.get("color", None) is None:
+ labels["color"] = ""
+ if aspect is None:
+ aspect = "equal"
+
+ # --- Set the value of binary_string (forbidden for pandas)
+ img = nw.from_native(img, pass_through=True)
+ if isinstance(img, nw.DataFrame):
+ if binary_string:
+ raise ValueError("Binary strings cannot be used with pandas arrays")
+ is_dataframe = True
+ else:
+ is_dataframe = False
+
+ # --------------- Starting from here img is always a numpy array --------
+ img = np.asanyarray(img)
+ # Reshape array so that animation dimension comes first, then facets, then images
+ if facet_col is not None:
+ img = np.moveaxis(img, facet_col, 0)
+ if animation_frame is not None and animation_frame < facet_col:
+ animation_frame += 1
+ facet_col = True
+ if animation_frame is not None:
+ img = np.moveaxis(img, animation_frame, 0)
+ animation_frame = True
+ args["animation_frame"] = (
+ "animation_frame"
+ if labels.get("animation_frame") is None
+ else labels["animation_frame"]
+ )
+ iterables = ()
+ if animation_frame is not None:
+ iterables += (range(nslices_animation),)
+ if facet_col is not None:
+ iterables += (range(nslices_facet),)
+
+ # Default behaviour of binary_string: True for RGB images, False for 2D
+ if binary_string is None:
+ binary_string = img.ndim >= (3 + slice_dimensions) and not is_dataframe
+
+ # Cast bools to uint8 (also one byte)
+ if img.dtype == bool:
+ img = 255 * img.astype(np.uint8)
+
+ if range_color is not None:
+ zmin = range_color[0]
+ zmax = range_color[1]
+
+ # -------- Contrast rescaling: either minmax or infer ------------------
+ if contrast_rescaling is None:
+ contrast_rescaling = "minmax" if img.ndim == (2 + slice_dimensions) else "infer"
+
+ # We try to set zmin and zmax only if necessary, because traces have good defaults
+ if contrast_rescaling == "minmax":
+ # When using binary_string and minmax we need to set zmin and zmax to rescale the image
+ if (zmin is not None or binary_string) and zmax is None:
+ zmax = img.max()
+ if (zmax is not None or binary_string) and zmin is None:
+ zmin = img.min()
+ else:
+ # For uint8 data and infer we let zmin and zmax to be None if passed as None
+ if zmax is None and img.dtype != np.uint8:
+ zmax = _infer_zmax_from_type(img)
+ if zmin is None and zmax is not None:
+ zmin = 0
+
+ # For 2d data, use Heatmap trace, unless binary_string is True
+ if img.ndim == 2 + slice_dimensions and not binary_string:
+ y_index = slice_dimensions
+ if y is not None and img.shape[y_index] != len(y):
+ raise ValueError(
+ "The length of the y vector must match the length of the first "
+ + "dimension of the img matrix."
+ )
+ x_index = slice_dimensions + 1
+ if x is not None and img.shape[x_index] != len(x):
+ raise ValueError(
+ "The length of the x vector must match the length of the second "
+ + "dimension of the img matrix."
+ )
+
+ texttemplate = None
+ if text_auto is True:
+ texttemplate = "%{z}"
+ elif text_auto is not False:
+ texttemplate = "%{z:" + text_auto + "}"
+
+ traces = [
+ go.Heatmap(
+ x=x,
+ y=y,
+ z=img[index_tup],
+ coloraxis="coloraxis1",
+ name=str(i),
+ texttemplate=texttemplate,
+ )
+ for i, index_tup in enumerate(itertools.product(*iterables))
+ ]
+ autorange = True if origin == "lower" else "reversed"
+ layout = dict(yaxis=dict(autorange=autorange))
+ if aspect == "equal":
+ layout["xaxis"] = dict(scaleanchor="y", constrain="domain")
+ layout["yaxis"]["constrain"] = "domain"
+ colorscale_validator = ColorscaleValidator("colorscale", "imshow")
+ layout["coloraxis1"] = dict(
+ colorscale=colorscale_validator.validate_coerce(
+ args["color_continuous_scale"]
+ ),
+ cmid=color_continuous_midpoint,
+ cmin=zmin,
+ cmax=zmax,
+ )
+ if labels["color"]:
+ layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"])
+
+ # For 2D+RGB data, use Image trace
+ elif (
+ img.ndim >= 3
+ and (img.shape[-1] in [3, 4] or slice_dimensions and binary_string)
+ ) or (img.ndim == 2 and binary_string):
+ rescale_image = True # to check whether image has been modified
+ if zmin is not None and zmax is not None:
+ zmin, zmax = (
+ _vectorize_zvalue(zmin, mode="min"),
+ _vectorize_zvalue(zmax, mode="max"),
+ )
+ x0, y0, dx, dy = (None,) * 4
+ error_msg_xarray = (
+ "Non-numerical coordinates were passed with xarray `img`, but "
+ "the Image trace cannot handle it. Please use `binary_string=False` "
+ "for 2D data or pass instead the numpy array `img.values` to `px.imshow`."
+ )
+ if x is not None:
+ x = np.asanyarray(x)
+ if np.issubdtype(x.dtype, np.number):
+ x0 = x[0]
+ dx = x[1] - x[0]
+ else:
+ error_msg = (
+ error_msg_xarray
+ if img_is_xarray
+ else (
+ "Only numerical values are accepted for the `x` parameter "
+ "when an Image trace is used."
+ )
+ )
+ raise ValueError(error_msg)
+ if y is not None:
+ y = np.asanyarray(y)
+ if np.issubdtype(y.dtype, np.number):
+ y0 = y[0]
+ dy = y[1] - y[0]
+ else:
+ error_msg = (
+ error_msg_xarray
+ if img_is_xarray
+ else (
+ "Only numerical values are accepted for the `y` parameter "
+ "when an Image trace is used."
+ )
+ )
+ raise ValueError(error_msg)
+ if binary_string:
+ if zmin is None and zmax is None: # no rescaling, faster
+ img_rescaled = img
+ rescale_image = False
+ elif img.ndim == 2 + slice_dimensions: # single-channel image
+ img_rescaled = rescale_intensity(
+ img, in_range=(zmin[0], zmax[0]), out_range=np.uint8
+ )
+ else:
+ img_rescaled = np.stack(
+ [
+ rescale_intensity(
+ img[..., ch],
+ in_range=(zmin[ch], zmax[ch]),
+ out_range=np.uint8,
+ )
+ for ch in range(img.shape[-1])
+ ],
+ axis=-1,
+ )
+ img_str = [
+ image_array_to_data_uri(
+ img_rescaled[index_tup],
+ backend=binary_backend,
+ compression=binary_compression_level,
+ ext=binary_format,
+ )
+ for index_tup in itertools.product(*iterables)
+ ]
+
+ traces = [
+ go.Image(source=img_str_slice, name=str(i), x0=x0, y0=y0, dx=dx, dy=dy)
+ for i, img_str_slice in enumerate(img_str)
+ ]
+ else:
+ colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
+ traces = [
+ go.Image(
+ z=img[index_tup],
+ zmin=zmin,
+ zmax=zmax,
+ colormodel=colormodel,
+ x0=x0,
+ y0=y0,
+ dx=dx,
+ dy=dy,
+ )
+ for index_tup in itertools.product(*iterables)
+ ]
+ layout = {}
+ if origin == "lower" or (dy is not None and dy < 0):
+ layout["yaxis"] = dict(autorange=True)
+ if dx is not None and dx < 0:
+ layout["xaxis"] = dict(autorange="reversed")
+ else:
+ raise ValueError(
+ "px.imshow only accepts 2D single-channel, RGB or RGBA images. "
+ "An image of shape %s was provided. "
+ "Alternatively, 3- or 4-D single or multichannel datasets can be "
+ "visualized using the `facet_col` or/and `animation_frame` arguments."
+ % str(img.shape)
+ )
+
+ # Now build figure
+ col_labels = []
+ if facet_col is not None:
+ slice_label = (
+ "facet_col" if labels.get("facet_col") is None else labels["facet_col"]
+ )
+ col_labels = [f"{slice_label}={i}" for i in facet_slices]
+ fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
+ for attr_name in ["height", "width"]:
+ if args[attr_name]:
+ layout[attr_name] = args[attr_name]
+ if args["title"]:
+ layout["title_text"] = args["title"]
+ elif args["template"].layout.margin.t is None:
+ layout["margin"] = {"t": 60}
+
+ frame_list = []
+ for index, trace in enumerate(traces):
+ if (facet_col and index < nrows * ncols) or index == 0:
+ fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
+ if animation_frame is not None:
+ for i, index in zip(range(nslices_animation), animation_slices):
+ frame_list.append(
+ dict(
+ data=traces[nslices_facet * i : nslices_facet * (i + 1)],
+ layout=layout,
+ name=str(index),
+ )
+ )
+ if animation_frame:
+ fig.frames = frame_list
+ fig.update_layout(layout)
+ # Hover name, z or color
+ if binary_string and rescale_image and not np.all(img == img_rescaled):
+ # we rescaled the image, hence z is not displayed in hover since it does
+ # not correspond to img values
+ hovertemplate = "%s: %%{x}<br>%s: %%{y}<extra></extra>" % (
+ labels["x"] or "x",
+ labels["y"] or "y",
+ )
+ else:
+ if trace["type"] == "heatmap":
+ hover_name = "%{z}"
+ elif img.ndim == 2:
+ hover_name = "%{z[0]}"
+ elif img.ndim == 3 and img.shape[-1] == 3:
+ hover_name = "[%{z[0]}, %{z[1]}, %{z[2]}]"
+ else:
+ hover_name = "%{z}"
+ hovertemplate = "%s: %%{x}<br>%s: %%{y}<br>%s: %s<extra></extra>" % (
+ labels["x"] or "x",
+ labels["y"] or "y",
+ labels["color"] or "color",
+ hover_name,
+ )
+ fig.update_traces(hovertemplate=hovertemplate)
+ if labels["x"]:
+ fig.update_xaxes(title_text=labels["x"], row=1)
+ if labels["y"]:
+ fig.update_yaxes(title_text=labels["y"], col=1)
+ configure_animation_controls(args, go.Image, fig)
+ fig.update_layout(template=args["template"], overwrite=True)
+ return fig