aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter')
-rw-r--r--venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/__init__.py3
-rw-r--r--venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/exporter.py317
-rw-r--r--venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/__init__.py14
-rw-r--r--venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/base.py428
-rw-r--r--venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/fake_renderer.py88
-rw-r--r--venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/vega_renderer.py155
-rw-r--r--venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/vincent_renderer.py54
-rw-r--r--venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/__init__.py3
-rw-r--r--venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/test_basic.py257
-rw-r--r--venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/test_utils.py40
-rw-r--r--venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tools.py55
-rw-r--r--venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/utils.py382
12 files changed, 1796 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/__init__.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/__init__.py
new file mode 100644
index 0000000..296a47e
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/__init__.py
@@ -0,0 +1,3 @@
+# ruff: noqa: F401
+from .renderers import Renderer
+from .exporter import Exporter
diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/exporter.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/exporter.py
new file mode 100644
index 0000000..bbd1756
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/exporter.py
@@ -0,0 +1,317 @@
+"""
+Matplotlib Exporter
+===================
+This submodule contains tools for crawling a matplotlib figure and exporting
+relevant pieces to a renderer.
+"""
+
+import warnings
+import io
+from . import utils
+
+import matplotlib
+from matplotlib import transforms
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+
+
+class Exporter(object):
+ """Matplotlib Exporter
+
+ Parameters
+ ----------
+ renderer : Renderer object
+ The renderer object called by the exporter to create a figure
+ visualization. See mplexporter.Renderer for information on the
+ methods which should be defined within the renderer.
+ close_mpl : bool
+ If True (default), close the matplotlib figure as it is rendered. This
+ is useful for when the exporter is used within the notebook, or with
+ an interactive matplotlib backend.
+ """
+
+ def __init__(self, renderer, close_mpl=True):
+ self.close_mpl = close_mpl
+ self.renderer = renderer
+
+ def run(self, fig):
+ """
+ Run the exporter on the given figure
+
+ Parmeters
+ ---------
+ fig : matplotlib.Figure instance
+ The figure to export
+ """
+ # Calling savefig executes the draw() command, putting elements
+ # in the correct place.
+ if fig.canvas is None:
+ FigureCanvasAgg(fig)
+ fig.savefig(io.BytesIO(), format="png", dpi=fig.dpi)
+ if self.close_mpl:
+ import matplotlib.pyplot as plt
+
+ plt.close(fig)
+ self.crawl_fig(fig)
+
+ @staticmethod
+ def process_transform(
+ transform, ax=None, data=None, return_trans=False, force_trans=None
+ ):
+ """Process the transform and convert data to figure or data coordinates
+
+ Parameters
+ ----------
+ transform : matplotlib Transform object
+ The transform applied to the data
+ ax : matplotlib Axes object (optional)
+ The axes the data is associated with
+ data : ndarray (optional)
+ The array of data to be transformed.
+ return_trans : bool (optional)
+ If true, return the final transform of the data
+ force_trans : matplotlib.transform instance (optional)
+ If supplied, first force the data to this transform
+
+ Returns
+ -------
+ code : string
+ Code is either "data", "axes", "figure", or "display", indicating
+ the type of coordinates output.
+ transform : matplotlib transform
+ the transform used to map input data to output data.
+ Returned only if return_trans is True
+ new_data : ndarray
+ Data transformed to match the given coordinate code.
+ Returned only if data is specified
+ """
+ if isinstance(transform, transforms.BlendedGenericTransform):
+ warnings.warn(
+ "Blended transforms not yet supported. "
+ "Zoom behavior may not work as expected."
+ )
+
+ if force_trans is not None:
+ if data is not None:
+ data = (transform - force_trans).transform(data)
+ transform = force_trans
+
+ code = "display"
+ if ax is not None:
+ for c, trans in [
+ ("data", ax.transData),
+ ("axes", ax.transAxes),
+ ("figure", ax.figure.transFigure),
+ ("display", transforms.IdentityTransform()),
+ ]:
+ if transform.contains_branch(trans):
+ code, transform = (c, transform - trans)
+ break
+
+ if data is not None:
+ if return_trans:
+ return code, transform.transform(data), transform
+ else:
+ return code, transform.transform(data)
+ else:
+ if return_trans:
+ return code, transform
+ else:
+ return code
+
+ def crawl_fig(self, fig):
+ """Crawl the figure and process all axes"""
+ with self.renderer.draw_figure(fig=fig, props=utils.get_figure_properties(fig)):
+ for ax in fig.axes:
+ self.crawl_ax(ax)
+
+ def crawl_ax(self, ax):
+ """Crawl the axes and process all elements within"""
+ with self.renderer.draw_axes(ax=ax, props=utils.get_axes_properties(ax)):
+ for line in ax.lines:
+ self.draw_line(ax, line)
+ for text in ax.texts:
+ self.draw_text(ax, text)
+ for text, ttp in zip(
+ [ax.xaxis.label, ax.yaxis.label, ax.title],
+ ["xlabel", "ylabel", "title"],
+ ):
+ if hasattr(text, "get_text") and text.get_text():
+ self.draw_text(ax, text, force_trans=ax.transAxes, text_type=ttp)
+ for artist in ax.artists:
+ # TODO: process other artists
+ if isinstance(artist, matplotlib.text.Text):
+ self.draw_text(ax, artist)
+ for patch in ax.patches:
+ self.draw_patch(ax, patch)
+ for collection in ax.collections:
+ self.draw_collection(ax, collection)
+ for image in ax.images:
+ self.draw_image(ax, image)
+
+ legend = ax.get_legend()
+ if legend is not None:
+ props = utils.get_legend_properties(ax, legend)
+ with self.renderer.draw_legend(legend=legend, props=props):
+ if props["visible"]:
+ self.crawl_legend(ax, legend)
+
+ def crawl_legend(self, ax, legend):
+ """
+ Recursively look through objects in legend children
+ """
+ legendElements = list(
+ utils.iter_all_children(legend._legend_box, skipContainers=True)
+ )
+ legendElements.append(legend.legendPatch)
+ for child in legendElements:
+ # force a large zorder so it appears on top
+ child.set_zorder(1e6 + child.get_zorder())
+
+ # reorder border box to make sure marks are visible
+ if isinstance(child, matplotlib.patches.FancyBboxPatch):
+ child.set_zorder(child.get_zorder() - 1)
+
+ try:
+ # What kind of object...
+ if isinstance(child, matplotlib.patches.Patch):
+ self.draw_patch(ax, child, force_trans=ax.transAxes)
+ elif isinstance(child, matplotlib.text.Text):
+ if child.get_text() != "None":
+ self.draw_text(ax, child, force_trans=ax.transAxes)
+ elif isinstance(child, matplotlib.lines.Line2D):
+ self.draw_line(ax, child, force_trans=ax.transAxes)
+ elif isinstance(child, matplotlib.collections.Collection):
+ self.draw_collection(ax, child, force_pathtrans=ax.transAxes)
+ else:
+ warnings.warn("Legend element %s not impemented" % child)
+ except NotImplementedError:
+ warnings.warn("Legend element %s not impemented" % child)
+
+ def draw_line(self, ax, line, force_trans=None):
+ """Process a matplotlib line and call renderer.draw_line"""
+ coordinates, data = self.process_transform(
+ line.get_transform(), ax, line.get_xydata(), force_trans=force_trans
+ )
+ linestyle = utils.get_line_style(line)
+ if linestyle["dasharray"] is None and linestyle["drawstyle"] == "default":
+ linestyle = None
+ markerstyle = utils.get_marker_style(line)
+ if (
+ markerstyle["marker"] in ["None", "none", None]
+ or markerstyle["markerpath"][0].size == 0
+ ):
+ markerstyle = None
+ label = line.get_label()
+ if markerstyle or linestyle:
+ self.renderer.draw_marked_line(
+ data=data,
+ coordinates=coordinates,
+ linestyle=linestyle,
+ markerstyle=markerstyle,
+ label=label,
+ mplobj=line,
+ )
+
+ def draw_text(self, ax, text, force_trans=None, text_type=None):
+ """Process a matplotlib text object and call renderer.draw_text"""
+ content = text.get_text()
+ if content:
+ transform = text.get_transform()
+ position = text.get_position()
+ coords, position = self.process_transform(
+ transform, ax, position, force_trans=force_trans
+ )
+ style = utils.get_text_style(text)
+ self.renderer.draw_text(
+ text=content,
+ position=position,
+ coordinates=coords,
+ text_type=text_type,
+ style=style,
+ mplobj=text,
+ )
+
+ def draw_patch(self, ax, patch, force_trans=None):
+ """Process a matplotlib patch object and call renderer.draw_path"""
+ vertices, pathcodes = utils.SVG_path(patch.get_path())
+ transform = patch.get_transform()
+ coordinates, vertices = self.process_transform(
+ transform, ax, vertices, force_trans=force_trans
+ )
+ linestyle = utils.get_path_style(patch, fill=patch.get_fill())
+ self.renderer.draw_path(
+ data=vertices,
+ coordinates=coordinates,
+ pathcodes=pathcodes,
+ style=linestyle,
+ mplobj=patch,
+ )
+
+ def draw_collection(
+ self, ax, collection, force_pathtrans=None, force_offsettrans=None
+ ):
+ """Process a matplotlib collection and call renderer.draw_collection"""
+ (transform, transOffset, offsets, paths) = collection._prepare_points()
+
+ offset_coords, offsets = self.process_transform(
+ transOffset, ax, offsets, force_trans=force_offsettrans
+ )
+ path_coords = self.process_transform(transform, ax, force_trans=force_pathtrans)
+
+ processed_paths = [utils.SVG_path(path) for path in paths]
+ processed_paths = [
+ (
+ self.process_transform(
+ transform, ax, path[0], force_trans=force_pathtrans
+ )[1],
+ path[1],
+ )
+ for path in processed_paths
+ ]
+
+ path_transforms = collection.get_transforms()
+ try:
+ # matplotlib 1.3: path_transforms are transform objects.
+ # Convert them to numpy arrays.
+ path_transforms = [t.get_matrix() for t in path_transforms]
+ except AttributeError:
+ # matplotlib 1.4: path transforms are already numpy arrays.
+ pass
+
+ styles = {
+ "linewidth": collection.get_linewidths(),
+ "facecolor": collection.get_facecolors(),
+ "edgecolor": collection.get_edgecolors(),
+ "alpha": collection._alpha,
+ "zorder": collection.get_zorder(),
+ }
+
+ # TODO: When matplotlib's minimum version is bumped to 3.8, this can be
+ # simplified since collection.get_offset_position no longer exists.
+ offset_dict = {"data": "before", "screen": "after"}
+ offset_order = (
+ offset_dict[collection.get_offset_position()]
+ if hasattr(collection, "get_offset_position")
+ else "after"
+ )
+
+ self.renderer.draw_path_collection(
+ paths=processed_paths,
+ path_coordinates=path_coords,
+ path_transforms=path_transforms,
+ offsets=offsets,
+ offset_coordinates=offset_coords,
+ offset_order=offset_order,
+ styles=styles,
+ mplobj=collection,
+ )
+
+ def draw_image(self, ax, image):
+ """Process a matplotlib image object and call renderer.draw_image"""
+ self.renderer.draw_image(
+ imdata=utils.image_to_base64(image),
+ extent=image.get_extent(),
+ coordinates="data",
+ style={"alpha": image.get_alpha(), "zorder": image.get_zorder()},
+ mplobj=image,
+ )
diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/__init__.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/__init__.py
new file mode 100644
index 0000000..21113ad
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/__init__.py
@@ -0,0 +1,14 @@
+# ruff: noqa F401
+
+"""
+Matplotlib Renderers
+====================
+This submodule contains renderer objects which define renderer behavior used
+within the Exporter class. The base renderer class is :class:`Renderer`, an
+abstract base class
+"""
+
+from .base import Renderer
+from .vega_renderer import VegaRenderer, fig_to_vega
+from .vincent_renderer import VincentRenderer, fig_to_vincent
+from .fake_renderer import FakeRenderer, FullFakeRenderer
diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/base.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/base.py
new file mode 100644
index 0000000..fbb8819
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/base.py
@@ -0,0 +1,428 @@
+import warnings
+import itertools
+from contextlib import contextmanager
+from packaging.version import Version
+
+import numpy as np
+import matplotlib as mpl
+from matplotlib import transforms
+
+from .. import utils
+
+
+class Renderer(object):
+ @staticmethod
+ def ax_zoomable(ax):
+ return bool(ax and ax.get_navigate())
+
+ @staticmethod
+ def ax_has_xgrid(ax):
+ return bool(ax and ax.xaxis._gridOnMajor and ax.yaxis.get_gridlines())
+
+ @staticmethod
+ def ax_has_ygrid(ax):
+ return bool(ax and ax.yaxis._gridOnMajor and ax.yaxis.get_gridlines())
+
+ @property
+ def current_ax_zoomable(self):
+ return self.ax_zoomable(self._current_ax)
+
+ @property
+ def current_ax_has_xgrid(self):
+ return self.ax_has_xgrid(self._current_ax)
+
+ @property
+ def current_ax_has_ygrid(self):
+ return self.ax_has_ygrid(self._current_ax)
+
+ @contextmanager
+ def draw_figure(self, fig, props):
+ if hasattr(self, "_current_fig") and self._current_fig is not None:
+ warnings.warn("figure embedded in figure: something is wrong")
+ self._current_fig = fig
+ self._fig_props = props
+ self.open_figure(fig=fig, props=props)
+ yield
+ self.close_figure(fig=fig)
+ self._current_fig = None
+ self._fig_props = {}
+
+ @contextmanager
+ def draw_axes(self, ax, props):
+ if hasattr(self, "_current_ax") and self._current_ax is not None:
+ warnings.warn("axes embedded in axes: something is wrong")
+ self._current_ax = ax
+ self._ax_props = props
+ self.open_axes(ax=ax, props=props)
+ yield
+ self.close_axes(ax=ax)
+ self._current_ax = None
+ self._ax_props = {}
+
+ @contextmanager
+ def draw_legend(self, legend, props):
+ self._current_legend = legend
+ self._legend_props = props
+ self.open_legend(legend=legend, props=props)
+ yield
+ self.close_legend(legend=legend)
+ self._current_legend = None
+ self._legend_props = {}
+
+ # Following are the functions which should be overloaded in subclasses
+
+ def open_figure(self, fig, props):
+ """
+ Begin commands for a particular figure.
+
+ Parameters
+ ----------
+ fig : matplotlib.Figure
+ The Figure which will contain the ensuing axes and elements
+ props : dictionary
+ The dictionary of figure properties
+ """
+ pass
+
+ def close_figure(self, fig):
+ """
+ Finish commands for a particular figure.
+
+ Parameters
+ ----------
+ fig : matplotlib.Figure
+ The figure which is finished being drawn.
+ """
+ pass
+
+ def open_axes(self, ax, props):
+ """
+ Begin commands for a particular axes.
+
+ Parameters
+ ----------
+ ax : matplotlib.Axes
+ The Axes which will contain the ensuing axes and elements
+ props : dictionary
+ The dictionary of axes properties
+ """
+ pass
+
+ def close_axes(self, ax):
+ """
+ Finish commands for a particular axes.
+
+ Parameters
+ ----------
+ ax : matplotlib.Axes
+ The Axes which is finished being drawn.
+ """
+ pass
+
+ def open_legend(self, legend, props):
+ """
+ Beging commands for a particular legend.
+
+ Parameters
+ ----------
+ legend : matplotlib.legend.Legend
+ The Legend that will contain the ensuing elements
+ props : dictionary
+ The dictionary of legend properties
+ """
+ pass
+
+ def close_legend(self, legend):
+ """
+ Finish commands for a particular legend.
+
+ Parameters
+ ----------
+ legend : matplotlib.legend.Legend
+ The Legend which is finished being drawn
+ """
+ pass
+
+ def draw_marked_line(
+ self, data, coordinates, linestyle, markerstyle, label, mplobj=None
+ ):
+ """Draw a line that also has markers.
+
+ If this isn't reimplemented by a renderer object, by default, it will
+ make a call to BOTH draw_line and draw_markers when both markerstyle
+ and linestyle are not None in the same Line2D object.
+
+ """
+ if linestyle is not None:
+ self.draw_line(data, coordinates, linestyle, label, mplobj)
+ if markerstyle is not None:
+ self.draw_markers(data, coordinates, markerstyle, label, mplobj)
+
+ def draw_line(self, data, coordinates, style, label, mplobj=None):
+ """
+ Draw a line. By default, draw the line via the draw_path() command.
+ Some renderers might wish to override this and provide more
+ fine-grained behavior.
+
+ In matplotlib, lines are generally created via the plt.plot() command,
+ though this command also can create marker collections.
+
+ Parameters
+ ----------
+ data : array_like
+ A shape (N, 2) array of datapoints.
+ coordinates : string
+ A string code, which should be either 'data' for data coordinates,
+ or 'figure' for figure (pixel) coordinates.
+ style : dictionary
+ a dictionary specifying the appearance of the line.
+ mplobj : matplotlib object
+ the matplotlib plot element which generated this line
+ """
+ pathcodes = ["M"] + (data.shape[0] - 1) * ["L"]
+ pathstyle = dict(facecolor="none", **style)
+ pathstyle["edgecolor"] = pathstyle.pop("color")
+ pathstyle["edgewidth"] = pathstyle.pop("linewidth")
+ self.draw_path(
+ data=data,
+ coordinates=coordinates,
+ pathcodes=pathcodes,
+ style=pathstyle,
+ mplobj=mplobj,
+ )
+
+ @staticmethod
+ def _iter_path_collection(paths, path_transforms, offsets, styles):
+ """Build an iterator over the elements of the path collection"""
+ N = max(len(paths), len(offsets))
+
+ # Before mpl 1.4.0, path_transform can be a false-y value, not a valid
+ # transformation matrix.
+ if Version(mpl.__version__) < Version("1.4.0"):
+ if path_transforms is None:
+ path_transforms = [np.eye(3)]
+
+ edgecolor = styles["edgecolor"]
+ if np.size(edgecolor) == 0:
+ edgecolor = ["none"]
+ facecolor = styles["facecolor"]
+ if np.size(facecolor) == 0:
+ facecolor = ["none"]
+
+ elements = [
+ paths,
+ path_transforms,
+ offsets,
+ edgecolor,
+ styles["linewidth"],
+ facecolor,
+ ]
+
+ it = itertools
+ return it.islice(zip(*map(it.cycle, elements)), N)
+
+ def draw_path_collection(
+ self,
+ paths,
+ path_coordinates,
+ path_transforms,
+ offsets,
+ offset_coordinates,
+ offset_order,
+ styles,
+ mplobj=None,
+ ):
+ """
+ Draw a collection of paths. The paths, offsets, and styles are all
+ iterables, and the number of paths is max(len(paths), len(offsets)).
+
+ By default, this is implemented via multiple calls to the draw_path()
+ function. For efficiency, Renderers may choose to customize this
+ implementation.
+
+ Examples of path collections created by matplotlib are scatter plots,
+ histograms, contour plots, and many others.
+
+ Parameters
+ ----------
+ paths : list
+ list of tuples, where each tuple has two elements:
+ (data, pathcodes). See draw_path() for a description of these.
+ path_coordinates: string
+ the coordinates code for the paths, which should be either
+ 'data' for data coordinates, or 'figure' for figure (pixel)
+ coordinates.
+ path_transforms: array_like
+ an array of shape (*, 3, 3), giving a series of 2D Affine
+ transforms for the paths. These encode translations, rotations,
+ and scalings in the standard way.
+ offsets: array_like
+ An array of offsets of shape (N, 2)
+ offset_coordinates : string
+ the coordinates code for the offsets, which should be either
+ 'data' for data coordinates, or 'figure' for figure (pixel)
+ coordinates.
+ offset_order : string
+ either "before" or "after". This specifies whether the offset
+ is applied before the path transform, or after. The matplotlib
+ backend equivalent is "before"->"data", "after"->"screen".
+ styles: dictionary
+ A dictionary in which each value is a list of length N, containing
+ the style(s) for the paths.
+ mplobj : matplotlib object
+ the matplotlib plot element which generated this collection
+ """
+ if offset_order == "before":
+ raise NotImplementedError("offset before transform")
+
+ for tup in self._iter_path_collection(paths, path_transforms, offsets, styles):
+ (path, path_transform, offset, ec, lw, fc) = tup
+ vertices, pathcodes = path
+ path_transform = transforms.Affine2D(path_transform)
+ vertices = path_transform.transform(vertices)
+ # This is a hack:
+ if path_coordinates == "figure":
+ path_coordinates = "points"
+ style = {
+ "edgecolor": utils.export_color(ec),
+ "facecolor": utils.export_color(fc),
+ "edgewidth": lw,
+ "dasharray": "10,0",
+ "alpha": styles["alpha"],
+ "zorder": styles["zorder"],
+ }
+ self.draw_path(
+ data=vertices,
+ coordinates=path_coordinates,
+ pathcodes=pathcodes,
+ style=style,
+ offset=offset,
+ offset_coordinates=offset_coordinates,
+ mplobj=mplobj,
+ )
+
+ def draw_markers(self, data, coordinates, style, label, mplobj=None):
+ """
+ Draw a set of markers. By default, this is done by repeatedly
+ calling draw_path(), but renderers should generally overload
+ this method to provide a more efficient implementation.
+
+ In matplotlib, markers are created using the plt.plot() command.
+
+ Parameters
+ ----------
+ data : array_like
+ A shape (N, 2) array of datapoints.
+ coordinates : string
+ A string code, which should be either 'data' for data coordinates,
+ or 'figure' for figure (pixel) coordinates.
+ style : dictionary
+ a dictionary specifying the appearance of the markers.
+ mplobj : matplotlib object
+ the matplotlib plot element which generated this marker collection
+ """
+ vertices, pathcodes = style["markerpath"]
+ pathstyle = dict(
+ (key, style[key])
+ for key in ["alpha", "edgecolor", "facecolor", "zorder", "edgewidth"]
+ )
+ pathstyle["dasharray"] = "10,0"
+ for vertex in data:
+ self.draw_path(
+ data=vertices,
+ coordinates="points",
+ pathcodes=pathcodes,
+ style=pathstyle,
+ offset=vertex,
+ offset_coordinates=coordinates,
+ mplobj=mplobj,
+ )
+
+ def draw_text(
+ self, text, position, coordinates, style, text_type=None, mplobj=None
+ ):
+ """
+ Draw text on the image.
+
+ Parameters
+ ----------
+ text : string
+ The text to draw
+ position : tuple
+ The (x, y) position of the text
+ coordinates : string
+ A string code, which should be either 'data' for data coordinates,
+ or 'figure' for figure (pixel) coordinates.
+ style : dictionary
+ a dictionary specifying the appearance of the text.
+ text_type : string or None
+ if specified, a type of text such as "xlabel", "ylabel", "title"
+ mplobj : matplotlib object
+ the matplotlib plot element which generated this text
+ """
+ raise NotImplementedError()
+
+ def draw_path(
+ self,
+ data,
+ coordinates,
+ pathcodes,
+ style,
+ offset=None,
+ offset_coordinates="data",
+ mplobj=None,
+ ):
+ """
+ Draw a path.
+
+ In matplotlib, paths are created by filled regions, histograms,
+ contour plots, patches, etc.
+
+ Parameters
+ ----------
+ data : array_like
+ A shape (N, 2) array of datapoints.
+ coordinates : string
+ A string code, which should be either 'data' for data coordinates,
+ 'figure' for figure (pixel) coordinates, or "points" for raw
+ point coordinates (useful in conjunction with offsets, below).
+ pathcodes : list
+ A list of single-character SVG pathcodes associated with the data.
+ Path codes are one of ['M', 'm', 'L', 'l', 'Q', 'q', 'T', 't',
+ 'S', 's', 'C', 'c', 'Z', 'z']
+ See the SVG specification for details. Note that some path codes
+ consume more than one datapoint (while 'Z' consumes none), so
+ in general, the length of the pathcodes list will not be the same
+ as that of the data array.
+ style : dictionary
+ a dictionary specifying the appearance of the line.
+ offset : list (optional)
+ the (x, y) offset of the path. If not given, no offset will
+ be used.
+ offset_coordinates : string (optional)
+ A string code, which should be either 'data' for data coordinates,
+ or 'figure' for figure (pixel) coordinates.
+ mplobj : matplotlib object
+ the matplotlib plot element which generated this path
+ """
+ raise NotImplementedError()
+
+ def draw_image(self, imdata, extent, coordinates, style, mplobj=None):
+ """
+ Draw an image.
+
+ Parameters
+ ----------
+ imdata : string
+ base64 encoded png representation of the image
+ extent : list
+ the axes extent of the image: [xmin, xmax, ymin, ymax]
+ coordinates: string
+ A string code, which should be either 'data' for data coordinates,
+ or 'figure' for figure (pixel) coordinates.
+ style : dictionary
+ a dictionary specifying the appearance of the image
+ mplobj : matplotlib object
+ the matplotlib plot object which generated this image
+ """
+ raise NotImplementedError()
diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/fake_renderer.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/fake_renderer.py
new file mode 100644
index 0000000..de2ae40
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/fake_renderer.py
@@ -0,0 +1,88 @@
+from .base import Renderer
+
+
+class FakeRenderer(Renderer):
+ """
+ Fake Renderer
+
+ This is a fake renderer which simply outputs a text tree representing the
+ elements found in the plot(s). This is used in the unit tests for the
+ package.
+
+ Below are the methods your renderer must implement. You are free to do
+ anything you wish within the renderer (i.e. build an XML or JSON
+ representation, call an external API, etc.) Here the renderer just
+ builds a simple string representation for testing purposes.
+ """
+
+ def __init__(self):
+ self.output = ""
+
+ def open_figure(self, fig, props):
+ self.output += "opening figure\n"
+
+ def close_figure(self, fig):
+ self.output += "closing figure\n"
+
+ def open_axes(self, ax, props):
+ self.output += " opening axes\n"
+
+ def close_axes(self, ax):
+ self.output += " closing axes\n"
+
+ def open_legend(self, legend, props):
+ self.output += " opening legend\n"
+
+ def close_legend(self, legend):
+ self.output += " closing legend\n"
+
+ def draw_text(
+ self, text, position, coordinates, style, text_type=None, mplobj=None
+ ):
+ self.output += " draw text '{0}' {1}\n".format(text, text_type)
+
+ def draw_path(
+ self,
+ data,
+ coordinates,
+ pathcodes,
+ style,
+ offset=None,
+ offset_coordinates="data",
+ mplobj=None,
+ ):
+ self.output += " draw path with {0} vertices\n".format(data.shape[0])
+
+ def draw_image(self, imdata, extent, coordinates, style, mplobj=None):
+ self.output += " draw image of size {0}\n".format(len(imdata))
+
+
+class FullFakeRenderer(FakeRenderer):
+ """
+ Renderer with the full complement of methods.
+
+ When the following are left undefined, they will be implemented via
+ other methods in the class. They can be defined explicitly for
+ more efficient or specialized use within the renderer implementation.
+ """
+
+ def draw_line(self, data, coordinates, style, label, mplobj=None):
+ self.output += " draw line with {0} points\n".format(data.shape[0])
+
+ def draw_markers(self, data, coordinates, style, label, mplobj=None):
+ self.output += " draw {0} markers\n".format(data.shape[0])
+
+ def draw_path_collection(
+ self,
+ paths,
+ path_coordinates,
+ path_transforms,
+ offsets,
+ offset_coordinates,
+ offset_order,
+ styles,
+ mplobj=None,
+ ):
+ self.output += " draw path collection with {0} offsets\n".format(
+ offsets.shape[0]
+ )
diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/vega_renderer.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/vega_renderer.py
new file mode 100644
index 0000000..eab02e1
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/vega_renderer.py
@@ -0,0 +1,155 @@
+import warnings
+import json
+import random
+from .base import Renderer
+from ..exporter import Exporter
+
+
+class VegaRenderer(Renderer):
+ def open_figure(self, fig, props):
+ self.props = props
+ self.figwidth = int(props["figwidth"] * props["dpi"])
+ self.figheight = int(props["figheight"] * props["dpi"])
+ self.data = []
+ self.scales = []
+ self.axes = []
+ self.marks = []
+
+ def open_axes(self, ax, props):
+ if len(self.axes) > 0:
+ warnings.warn("multiple axes not yet supported")
+ self.axes = [
+ dict(type="x", scale="x", ticks=10),
+ dict(type="y", scale="y", ticks=10),
+ ]
+ self.scales = [
+ dict(
+ name="x",
+ domain=props["xlim"],
+ type="linear",
+ range="width",
+ ),
+ dict(
+ name="y",
+ domain=props["ylim"],
+ type="linear",
+ range="height",
+ ),
+ ]
+
+ def draw_line(self, data, coordinates, style, label, mplobj=None):
+ if coordinates != "data":
+ warnings.warn("Only data coordinates supported. Skipping this")
+ dataname = "table{0:03d}".format(len(self.data) + 1)
+
+ # TODO: respect the other style settings
+ self.data.append(
+ {"name": dataname, "values": [dict(x=d[0], y=d[1]) for d in data]}
+ )
+ self.marks.append(
+ {
+ "type": "line",
+ "from": {"data": dataname},
+ "properties": {
+ "enter": {
+ "interpolate": {"value": "monotone"},
+ "x": {"scale": "x", "field": "data.x"},
+ "y": {"scale": "y", "field": "data.y"},
+ "stroke": {"value": style["color"]},
+ "strokeOpacity": {"value": style["alpha"]},
+ "strokeWidth": {"value": style["linewidth"]},
+ }
+ },
+ }
+ )
+
+ def draw_markers(self, data, coordinates, style, label, mplobj=None):
+ if coordinates != "data":
+ warnings.warn("Only data coordinates supported. Skipping this")
+ dataname = "table{0:03d}".format(len(self.data) + 1)
+
+ # TODO: respect the other style settings
+ self.data.append(
+ {"name": dataname, "values": [dict(x=d[0], y=d[1]) for d in data]}
+ )
+ self.marks.append(
+ {
+ "type": "symbol",
+ "from": {"data": dataname},
+ "properties": {
+ "enter": {
+ "interpolate": {"value": "monotone"},
+ "x": {"scale": "x", "field": "data.x"},
+ "y": {"scale": "y", "field": "data.y"},
+ "fill": {"value": style["facecolor"]},
+ "fillOpacity": {"value": style["alpha"]},
+ "stroke": {"value": style["edgecolor"]},
+ "strokeOpacity": {"value": style["alpha"]},
+ "strokeWidth": {"value": style["edgewidth"]},
+ }
+ },
+ }
+ )
+
+ def draw_text(
+ self, text, position, coordinates, style, text_type=None, mplobj=None
+ ):
+ if text_type == "xlabel":
+ self.axes[0]["title"] = text
+ elif text_type == "ylabel":
+ self.axes[1]["title"] = text
+
+
+class VegaHTML(object):
+ def __init__(self, renderer):
+ self.specification = dict(
+ width=renderer.figwidth,
+ height=renderer.figheight,
+ data=renderer.data,
+ scales=renderer.scales,
+ axes=renderer.axes,
+ marks=renderer.marks,
+ )
+
+ def html(self):
+ """Build the HTML representation for IPython."""
+ id = random.randint(0, 2**16)
+ html = '<div id="vis%d"></div>' % id
+ html += "<script>\n"
+ html += VEGA_TEMPLATE % (json.dumps(self.specification), id)
+ html += "</script>\n"
+ return html
+
+ def _repr_html_(self):
+ return self.html()
+
+
+def fig_to_vega(fig, notebook=False):
+ """Convert a matplotlib figure to vega dictionary
+
+ if notebook=True, then return an object which will display in a notebook
+ otherwise, return an HTML string.
+ """
+ renderer = VegaRenderer()
+ Exporter(renderer).run(fig)
+ vega_html = VegaHTML(renderer)
+ if notebook:
+ return vega_html
+ else:
+ return vega_html.html()
+
+
+VEGA_TEMPLATE = """
+( function() {
+ var _do_plot = function() {
+ if ( (typeof vg == 'undefined') && (typeof IPython != 'undefined')) {
+ $([IPython.events]).on("vega_loaded.vincent", _do_plot);
+ return;
+ }
+ vg.parse.spec(%s, function(chart) {
+ chart({el: "#vis%d"}).update();
+ });
+ };
+ _do_plot();
+})();
+"""
diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/vincent_renderer.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/vincent_renderer.py
new file mode 100644
index 0000000..36074f6
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/renderers/vincent_renderer.py
@@ -0,0 +1,54 @@
+import warnings
+from .base import Renderer
+from ..exporter import Exporter
+
+
+class VincentRenderer(Renderer):
+ def open_figure(self, fig, props):
+ self.chart = None
+ self.figwidth = int(props["figwidth"] * props["dpi"])
+ self.figheight = int(props["figheight"] * props["dpi"])
+
+ def draw_line(self, data, coordinates, style, label, mplobj=None):
+ import vincent # only import if VincentRenderer is used
+
+ if coordinates != "data":
+ warnings.warn("Only data coordinates supported. Skipping this")
+ linedata = {"x": data[:, 0], "y": data[:, 1]}
+ line = vincent.Line(
+ linedata, iter_idx="x", width=self.figwidth, height=self.figheight
+ )
+
+ # TODO: respect the other style settings
+ line.scales["color"].range = [style["color"]]
+
+ if self.chart is None:
+ self.chart = line
+ else:
+ warnings.warn("Multiple plot elements not yet supported")
+
+ def draw_markers(self, data, coordinates, style, label, mplobj=None):
+ import vincent # only import if VincentRenderer is used
+
+ if coordinates != "data":
+ warnings.warn("Only data coordinates supported. Skipping this")
+ markerdata = {"x": data[:, 0], "y": data[:, 1]}
+ markers = vincent.Scatter(
+ markerdata, iter_idx="x", width=self.figwidth, height=self.figheight
+ )
+
+ # TODO: respect the other style settings
+ markers.scales["color"].range = [style["facecolor"]]
+
+ if self.chart is None:
+ self.chart = markers
+ else:
+ warnings.warn("Multiple plot elements not yet supported")
+
+
+def fig_to_vincent(fig):
+ """Convert a matplotlib figure to a vincent object"""
+ renderer = VincentRenderer()
+ exporter = Exporter(renderer)
+ exporter.run(fig)
+ return renderer.chart
diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/__init__.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/__init__.py
new file mode 100644
index 0000000..290cc21
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/__init__.py
@@ -0,0 +1,3 @@
+import matplotlib
+
+matplotlib.use("Agg")
diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/test_basic.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/test_basic.py
new file mode 100644
index 0000000..3739e13
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/test_basic.py
@@ -0,0 +1,257 @@
+import matplotlib
+import numpy as np
+import pytest
+from packaging.version import Version
+
+from ..exporter import Exporter
+from ..renderers import FakeRenderer, FullFakeRenderer
+import matplotlib.pyplot as plt
+
+
+def fake_renderer_output(fig, Renderer):
+ renderer = Renderer()
+ exporter = Exporter(renderer)
+ exporter.run(fig)
+ return renderer.output
+
+
+def _assert_output_equal(text1, text2):
+ for line1, line2 in zip(text1.strip().split(), text2.strip().split()):
+ assert line1 == line2
+
+
+def test_lines():
+ fig, ax = plt.subplots()
+ ax.plot(range(20), "-k")
+
+ _assert_output_equal(
+ fake_renderer_output(fig, FakeRenderer),
+ """
+ opening figure
+ opening axes
+ draw path with 20 vertices
+ closing axes
+ closing figure
+ """,
+ )
+
+ _assert_output_equal(
+ fake_renderer_output(fig, FullFakeRenderer),
+ """
+ opening figure
+ opening axes
+ draw line with 20 points
+ closing axes
+ closing figure
+ """,
+ )
+
+
+def test_markers():
+ fig, ax = plt.subplots()
+ ax.plot(range(2), "ok")
+
+ _assert_output_equal(
+ fake_renderer_output(fig, FakeRenderer),
+ """
+ opening figure
+ opening axes
+ draw path with 25 vertices
+ draw path with 25 vertices
+ closing axes
+ closing figure
+ """,
+ )
+
+ _assert_output_equal(
+ fake_renderer_output(fig, FullFakeRenderer),
+ """
+ opening figure
+ opening axes
+ draw 2 markers
+ closing axes
+ closing figure
+ """,
+ )
+
+
+def test_path_collection():
+ fig, ax = plt.subplots()
+ ax.scatter(range(3), range(3))
+
+ _assert_output_equal(
+ fake_renderer_output(fig, FakeRenderer),
+ """
+ opening figure
+ opening axes
+ draw path with 25 vertices
+ draw path with 25 vertices
+ draw path with 25 vertices
+ closing axes
+ closing figure
+ """,
+ )
+
+ _assert_output_equal(
+ fake_renderer_output(fig, FullFakeRenderer),
+ """
+ opening figure
+ opening axes
+ draw path collection with 3 offsets
+ closing axes
+ closing figure
+ """,
+ )
+
+
+def test_text():
+ fig, ax = plt.subplots()
+ ax.set_xlabel("my x label")
+ ax.set_ylabel("my y label")
+ ax.set_title("my title")
+ ax.text(0.5, 0.5, "my text")
+
+ _assert_output_equal(
+ fake_renderer_output(fig, FakeRenderer),
+ """
+ opening figure
+ opening axes
+ draw text 'my text' None
+ draw text 'my x label' xlabel
+ draw text 'my y label' ylabel
+ draw text 'my title' title
+ closing axes
+ closing figure
+ """,
+ )
+
+
+def test_path():
+ fig, ax = plt.subplots()
+ ax.add_patch(plt.Circle((0, 0), 1))
+ ax.add_patch(plt.Rectangle((0, 0), 1, 2))
+
+ _assert_output_equal(
+ fake_renderer_output(fig, FakeRenderer),
+ """
+ opening figure
+ opening axes
+ draw path with 25 vertices
+ draw path with 4 vertices
+ closing axes
+ closing figure
+ """,
+ )
+
+
+def test_Figure():
+ """if the fig is not associated with a canvas, FakeRenderer shall
+ not fail."""
+ fig = plt.Figure()
+ ax = fig.add_subplot(111)
+ ax.add_patch(plt.Circle((0, 0), 1))
+ ax.add_patch(plt.Rectangle((0, 0), 1, 2))
+
+ _assert_output_equal(
+ fake_renderer_output(fig, FakeRenderer),
+ """
+ opening figure
+ opening axes
+ draw path with 25 vertices
+ draw path with 4 vertices
+ closing axes
+ closing figure
+ """,
+ )
+
+
+def test_multiaxes():
+ fig, ax = plt.subplots(2)
+ ax[0].plot(range(4))
+ ax[1].plot(range(10))
+
+ _assert_output_equal(
+ fake_renderer_output(fig, FakeRenderer),
+ """
+ opening figure
+ opening axes
+ draw path with 4 vertices
+ closing axes
+ opening axes
+ draw path with 10 vertices
+ closing axes
+ closing figure
+ """,
+ )
+
+
+def test_image():
+ # Test fails for matplotlib 1.5+ because the size of the image
+ # generated by matplotlib has changed.
+ if Version(matplotlib.__version__) == Version("3.4.1"):
+ image_size = 432
+ else:
+ pytest.skip("Test fails for older matplotlib")
+ np.random.seed(0) # image size depends on the seed
+ fig, ax = plt.subplots(figsize=(2, 2))
+ ax.imshow(np.random.random((10, 10)), cmap=plt.cm.jet, interpolation="nearest")
+ _assert_output_equal(
+ fake_renderer_output(fig, FakeRenderer),
+ f"""
+ opening figure
+ opening axes
+ draw image of size {image_size}
+ closing axes
+ closing figure
+ """,
+ )
+
+
+def test_legend():
+ fig, ax = plt.subplots()
+ ax.plot([1, 2, 3], label="label")
+ ax.legend().set_visible(False)
+ _assert_output_equal(
+ fake_renderer_output(fig, FakeRenderer),
+ """
+ opening figure
+ opening axes
+ draw path with 3 vertices
+ opening legend
+ closing legend
+ closing axes
+ closing figure
+ """,
+ )
+
+
+def test_legend_dots():
+ fig, ax = plt.subplots()
+ ax.plot([1, 2, 3], label="label")
+ ax.plot([2, 2, 2], "o", label="dots")
+ ax.legend().set_visible(True)
+ # legend draws 1 line and 1 marker
+ # path around legend now has 13 vertices??
+ _assert_output_equal(
+ fake_renderer_output(fig, FullFakeRenderer),
+ """
+ opening figure
+ opening axes
+ draw line with 3 points
+ draw 3 markers
+ opening legend
+ draw line with 2 points
+ draw text 'label' None
+ draw 1 markers
+ draw text 'dots' None
+ draw path with 13 vertices
+ closing legend
+ closing axes
+ closing figure
+ """,
+ )
+
+
+def test_blended():
+ fig, ax = plt.subplots()
+ ax.axvline(0)
diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/test_utils.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/test_utils.py
new file mode 100644
index 0000000..5659163
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tests/test_utils.py
@@ -0,0 +1,40 @@
+from numpy.testing import assert_allclose, assert_equal
+from . import plt
+from .. import utils
+
+
+def test_path_data():
+ circle = plt.Circle((0, 0), 1)
+ vertices, codes = utils.SVG_path(circle.get_path())
+
+ assert_allclose(vertices.shape, (25, 2))
+ assert_equal(codes, ["M", "C", "C", "C", "C", "C", "C", "C", "C", "Z"])
+
+
+def test_linestyle():
+ linestyles = {
+ "solid": "none",
+ "-": "none",
+ "dashed": "5.550000000000001,2.4000000000000004",
+ "--": "5.550000000000001,2.4000000000000004",
+ "dotted": "1.5,2.4749999999999996",
+ ":": "1.5,2.4749999999999996",
+ "dashdot": "9.600000000000001,2.4000000000000004,1.5,2.4000000000000004",
+ "-.": "9.600000000000001,2.4000000000000004,1.5,2.4000000000000004",
+ "": None,
+ "None": None,
+ }
+
+ for ls, result in linestyles.items():
+ (line,) = plt.plot([1, 2, 3], linestyle=ls)
+ assert_equal(utils.get_dasharray(line), result)
+
+
+def test_axis_w_fixed_formatter():
+ positions, labels = [0, 1, 10], ["A", "B", "C"]
+
+ plt.xticks(positions, labels)
+ props = utils.get_axis_properties(plt.gca().xaxis)
+
+ assert_equal(props["tickvalues"], positions)
+ assert_equal(props["tickformat"], labels)
diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tools.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tools.py
new file mode 100644
index 0000000..f66fdfb
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/tools.py
@@ -0,0 +1,55 @@
+"""
+Tools for matplotlib plot exporting
+"""
+
+
+def ipynb_vega_init():
+ """Initialize the IPython notebook display elements
+
+ This function borrows heavily from the excellent vincent package:
+ http://github.com/wrobstory/vincent
+ """
+ try:
+ from IPython.core.display import display, HTML
+ except ImportError:
+ print("IPython Notebook could not be loaded.")
+
+ require_js = """
+ if (window['d3'] === undefined) {{
+ require.config({{ paths: {{d3: "http://d3js.org/d3.v3.min"}} }});
+ require(["d3"], function(d3) {{
+ window.d3 = d3;
+ {0}
+ }});
+ }};
+ if (window['topojson'] === undefined) {{
+ require.config(
+ {{ paths: {{topojson: "http://d3js.org/topojson.v1.min"}} }}
+ );
+ require(["topojson"], function(topojson) {{
+ window.topojson = topojson;
+ }});
+ }};
+ """
+ d3_geo_projection_js_url = "http://d3js.org/d3.geo.projection.v0.min.js"
+ d3_layout_cloud_js_url = "http://wrobstory.github.io/d3-cloud/d3.layout.cloud.js"
+ topojson_js_url = "http://d3js.org/topojson.v1.min.js"
+ vega_js_url = "http://trifacta.github.com/vega/vega.js"
+
+ dep_libs = """$.getScript("%s", function() {
+ $.getScript("%s", function() {
+ $.getScript("%s", function() {
+ $.getScript("%s", function() {
+ $([IPython.events]).trigger("vega_loaded.vincent");
+ })
+ })
+ })
+ });""" % (
+ d3_geo_projection_js_url,
+ d3_layout_cloud_js_url,
+ topojson_js_url,
+ vega_js_url,
+ )
+ load_js = require_js.format(dep_libs)
+ html = "<script>" + load_js + "</script>"
+ display(HTML(html))
diff --git a/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/utils.py b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/utils.py
new file mode 100644
index 0000000..646e11e
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/utils.py
@@ -0,0 +1,382 @@
+"""
+Utility Routines for Working with Matplotlib Objects
+====================================================
+"""
+
+import itertools
+import io
+import base64
+
+import numpy as np
+
+import warnings
+
+import matplotlib
+from matplotlib.colors import colorConverter
+from matplotlib.path import Path
+from matplotlib.markers import MarkerStyle
+from matplotlib.transforms import Affine2D
+from matplotlib import ticker
+
+
+def export_color(color):
+ """Convert matplotlib color code to hex color or RGBA color"""
+ if color is None or colorConverter.to_rgba(color)[3] == 0:
+ return "none"
+ elif colorConverter.to_rgba(color)[3] == 1:
+ rgb = colorConverter.to_rgb(color)
+ return "#{0:02X}{1:02X}{2:02X}".format(*(int(255 * c) for c in rgb))
+ else:
+ c = colorConverter.to_rgba(color)
+ return (
+ "rgba("
+ + ", ".join(str(int(np.round(val * 255))) for val in c[:3])
+ + ", "
+ + str(c[3])
+ + ")"
+ )
+
+
+def _many_to_one(input_dict):
+ """Convert a many-to-one mapping to a one-to-one mapping"""
+ return dict((key, val) for keys, val in input_dict.items() for key in keys)
+
+
+LINESTYLES = _many_to_one(
+ {
+ ("solid", "-", (None, None)): "none",
+ ("dashed", "--"): "6,6",
+ ("dotted", ":"): "2,2",
+ ("dashdot", "-."): "4,4,2,4",
+ ("", " ", "None", "none"): None,
+ }
+)
+
+
+def get_dasharray(obj):
+ """Get an SVG dash array for the given matplotlib linestyle
+
+ Parameters
+ ----------
+ obj : matplotlib object
+ The matplotlib line or path object, which must have a get_linestyle()
+ method which returns a valid matplotlib line code
+
+ Returns
+ -------
+ dasharray : string
+ The HTML/SVG dasharray code associated with the object.
+ """
+ if obj.__dict__.get("_dashSeq", None) is not None:
+ return ",".join(map(str, obj._dashSeq))
+ else:
+ ls = obj.get_linestyle()
+ dasharray = LINESTYLES.get(ls, "not found")
+ if dasharray == "not found":
+ warnings.warn(
+ "line style '{0}' not understood: defaulting to solid line.".format(ls)
+ )
+ dasharray = LINESTYLES["solid"]
+ return dasharray
+
+
+PATH_DICT = {
+ Path.LINETO: "L",
+ Path.MOVETO: "M",
+ Path.CURVE3: "S",
+ Path.CURVE4: "C",
+ Path.CLOSEPOLY: "Z",
+}
+
+
+def SVG_path(path, transform=None, simplify=False):
+ """Construct the vertices and SVG codes for the path
+
+ Parameters
+ ----------
+ path : matplotlib.Path object
+
+ transform : matplotlib transform (optional)
+ if specified, the path will be transformed before computing the output.
+
+ Returns
+ -------
+ vertices : array
+ The shape (M, 2) array of vertices of the Path. Note that some Path
+ codes require multiple vertices, so the length of these vertices may
+ be longer than the list of path codes.
+ path_codes : list
+ A length N list of single-character path codes, N <= M. Each code is
+ a single character, in ['L','M','S','C','Z']. See the standard SVG
+ path specification for a description of these.
+ """
+ if transform is not None:
+ path = path.transformed(transform)
+
+ vc_tuples = [
+ (vertices if path_code != Path.CLOSEPOLY else [], PATH_DICT[path_code])
+ for (vertices, path_code) in path.iter_segments(simplify=simplify)
+ ]
+
+ if not vc_tuples:
+ # empty path is a special case
+ return np.zeros((0, 2)), []
+ else:
+ vertices, codes = zip(*vc_tuples)
+ vertices = np.array(list(itertools.chain(*vertices))).reshape(-1, 2)
+ return vertices, list(codes)
+
+
+def get_path_style(path, fill=True):
+ """Get the style dictionary for matplotlib path objects"""
+ style = {}
+ style["alpha"] = path.get_alpha()
+ if style["alpha"] is None:
+ style["alpha"] = 1
+ style["edgecolor"] = export_color(path.get_edgecolor())
+ if fill:
+ style["facecolor"] = export_color(path.get_facecolor())
+ else:
+ style["facecolor"] = "none"
+ style["edgewidth"] = path.get_linewidth()
+ style["dasharray"] = get_dasharray(path)
+ style["zorder"] = path.get_zorder()
+ return style
+
+
+def get_line_style(line):
+ """Get the style dictionary for matplotlib line objects"""
+ style = {}
+ style["alpha"] = line.get_alpha()
+ if style["alpha"] is None:
+ style["alpha"] = 1
+ style["color"] = export_color(line.get_color())
+ style["linewidth"] = line.get_linewidth()
+ style["dasharray"] = get_dasharray(line)
+ style["zorder"] = line.get_zorder()
+ style["drawstyle"] = line.get_drawstyle()
+ return style
+
+
+def get_marker_style(line):
+ """Get the style dictionary for matplotlib marker objects"""
+ style = {}
+ style["alpha"] = line.get_alpha()
+ if style["alpha"] is None:
+ style["alpha"] = 1
+
+ style["facecolor"] = export_color(line.get_markerfacecolor())
+ style["edgecolor"] = export_color(line.get_markeredgecolor())
+ style["edgewidth"] = line.get_markeredgewidth()
+
+ style["marker"] = line.get_marker()
+ markerstyle = MarkerStyle(line.get_marker())
+ markersize = line.get_markersize()
+ markertransform = markerstyle.get_transform() + Affine2D().scale(
+ markersize, -markersize
+ )
+ style["markerpath"] = SVG_path(markerstyle.get_path(), markertransform)
+ style["markersize"] = markersize
+ style["zorder"] = line.get_zorder()
+ return style
+
+
+def get_text_style(text):
+ """Return the text style dict for a text instance"""
+ style = {}
+ style["alpha"] = text.get_alpha()
+ if style["alpha"] is None:
+ style["alpha"] = 1
+ style["fontsize"] = text.get_size()
+ style["color"] = export_color(text.get_color())
+ style["halign"] = text.get_horizontalalignment() # left, center, right
+ style["valign"] = text.get_verticalalignment() # baseline, center, top
+ style["malign"] = text._multialignment # text alignment when '\n' in text
+ style["rotation"] = text.get_rotation()
+ style["zorder"] = text.get_zorder()
+ return style
+
+
+def get_axis_properties(axis):
+ """Return the property dictionary for a matplotlib.Axis instance"""
+ props = {}
+ label1On = axis._major_tick_kw.get("label1On", True)
+
+ if isinstance(axis, matplotlib.axis.XAxis):
+ if label1On:
+ props["position"] = "bottom"
+ else:
+ props["position"] = "top"
+ elif isinstance(axis, matplotlib.axis.YAxis):
+ if label1On:
+ props["position"] = "left"
+ else:
+ props["position"] = "right"
+ else:
+ raise ValueError("{0} should be an Axis instance".format(axis))
+
+ # Use tick values if appropriate
+ locator = axis.get_major_locator()
+ props["nticks"] = len(locator())
+ if isinstance(locator, ticker.FixedLocator):
+ props["tickvalues"] = list(locator())
+ else:
+ props["tickvalues"] = None
+
+ # Find tick formats
+ formatter = axis.get_major_formatter()
+ if isinstance(formatter, ticker.NullFormatter):
+ props["tickformat"] = ""
+ elif isinstance(formatter, ticker.FixedFormatter):
+ props["tickformat"] = list(formatter.seq)
+ elif isinstance(formatter, ticker.FuncFormatter):
+ props["tickformat"] = list(formatter.func.args[0].values())
+ elif not any(label.get_visible() for label in axis.get_ticklabels()):
+ props["tickformat"] = ""
+ else:
+ props["tickformat"] = None
+
+ # Get axis scale
+ props["scale"] = axis.get_scale()
+
+ # Get major tick label size (assumes that's all we really care about!)
+ labels = axis.get_ticklabels()
+ if labels:
+ props["fontsize"] = labels[0].get_fontsize()
+ else:
+ props["fontsize"] = None
+
+ # Get associated grid
+ props["grid"] = get_grid_style(axis)
+
+ # get axis visibility
+ props["visible"] = axis.get_visible()
+
+ return props
+
+
+def get_grid_style(axis):
+ gridlines = axis.get_gridlines()
+ if axis._major_tick_kw["gridOn"] and len(gridlines) > 0:
+ color = export_color(gridlines[0].get_color())
+ alpha = gridlines[0].get_alpha()
+ dasharray = get_dasharray(gridlines[0])
+ return dict(gridOn=True, color=color, dasharray=dasharray, alpha=alpha)
+ else:
+ return {"gridOn": False}
+
+
+def get_figure_properties(fig):
+ return {
+ "figwidth": fig.get_figwidth(),
+ "figheight": fig.get_figheight(),
+ "dpi": fig.dpi,
+ }
+
+
+def get_axes_properties(ax):
+ props = {
+ "axesbg": export_color(ax.patch.get_facecolor()),
+ "axesbgalpha": ax.patch.get_alpha(),
+ "bounds": ax.get_position().bounds,
+ "dynamic": ax.get_navigate(),
+ "axison": ax.axison,
+ "frame_on": ax.get_frame_on(),
+ "patch_visible": ax.patch.get_visible(),
+ "axes": [get_axis_properties(ax.xaxis), get_axis_properties(ax.yaxis)],
+ }
+
+ for axname in ["x", "y"]:
+ axis = getattr(ax, axname + "axis")
+ domain = getattr(ax, "get_{0}lim".format(axname))()
+ lim = domain
+ if isinstance(axis.converter, matplotlib.dates.DateConverter):
+ scale = "date"
+ try:
+ import pandas as pd
+ from pandas.tseries.converter import PeriodConverter
+ except ImportError:
+ pd = None
+
+ if pd is not None and isinstance(axis.converter, PeriodConverter):
+ _dates = [pd.Period(ordinal=int(d), freq=axis.freq) for d in domain]
+ domain = [
+ (d.year, d.month - 1, d.day, d.hour, d.minute, d.second, 0)
+ for d in _dates
+ ]
+ else:
+ domain = [
+ (
+ d.year,
+ d.month - 1,
+ d.day,
+ d.hour,
+ d.minute,
+ d.second,
+ d.microsecond * 1e-3,
+ )
+ for d in matplotlib.dates.num2date(domain)
+ ]
+ else:
+ scale = axis.get_scale()
+
+ if scale not in ["date", "linear", "log"]:
+ raise ValueError("Unknown axis scale: {0}".format(axis.get_scale()))
+
+ props[axname + "scale"] = scale
+ props[axname + "lim"] = lim
+ props[axname + "domain"] = domain
+
+ return props
+
+
+def iter_all_children(obj, skipContainers=False):
+ """
+ Returns an iterator over all childen and nested children using
+ obj's get_children() method
+
+ if skipContainers is true, only childless objects are returned.
+ """
+ if hasattr(obj, "get_children") and len(obj.get_children()) > 0:
+ for child in obj.get_children():
+ if not skipContainers:
+ yield child
+ # could use `yield from` in python 3...
+ for grandchild in iter_all_children(child, skipContainers):
+ yield grandchild
+ else:
+ yield obj
+
+
+def get_legend_properties(ax, legend):
+ handles, labels = ax.get_legend_handles_labels()
+ visible = legend.get_visible()
+ return {"handles": handles, "labels": labels, "visible": visible}
+
+
+def image_to_base64(image):
+ """
+ Convert a matplotlib image to a base64 png representation
+
+ Parameters
+ ----------
+ image : matplotlib image object
+ The image to be converted.
+
+ Returns
+ -------
+ image_base64 : string
+ The UTF8-encoded base64 string representation of the png image.
+ """
+ ax = image.axes
+ binary_buffer = io.BytesIO()
+
+ # image is saved in axes coordinates: we need to temporarily
+ # set the correct limits to get the correct image
+ lim = ax.axis()
+ ax.axis(image.get_extent())
+ image.write_png(binary_buffer)
+ ax.axis(lim)
+
+ binary_buffer.seek(0)
+ return base64.b64encode(binary_buffer.read()).decode("utf-8")