diff options
Diffstat (limited to 'venv/lib/python3.8/site-packages/plotly/express')
10 files changed, 6759 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/plotly/express/__init__.py b/venv/lib/python3.8/site-packages/plotly/express/__init__.py new file mode 100644 index 0000000..62a9bca --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/express/__init__.py @@ -0,0 +1,132 @@ +# ruff: noqa: E402 + +""" +`plotly.express` is a terse, consistent, high-level wrapper around `plotly.graph_objects` +for rapid data exploration and figure generation. Learn more at https://plotly.com/python/plotly-express/ +""" + +from plotly import optional_imports + +np = optional_imports.get_module("numpy") +if np is None: + raise ImportError( + """\ +Plotly Express requires numpy to be installed. You can install numpy using pip with: + +$ pip install numpy + +Or install Plotly Express and its dependencies directly with: + +$ pip install "plotly[express]" + +You can also use Plotly Graph Objects to create a large number of charts without installing +numpy. See examples here: https://plotly.com/python/graph-objects/ +""" + ) + +from ._imshow import imshow +from ._chart_types import ( # noqa: F401 + scatter, + scatter_3d, + scatter_polar, + scatter_ternary, + scatter_map, + scatter_mapbox, + scatter_geo, + line, + line_3d, + line_polar, + line_ternary, + line_map, + line_mapbox, + line_geo, + area, + bar, + timeline, + bar_polar, + violin, + box, + strip, + histogram, + ecdf, + scatter_matrix, + parallel_coordinates, + parallel_categories, + choropleth, + density_contour, + density_heatmap, + pie, + sunburst, + treemap, + icicle, + funnel, + funnel_area, + choropleth_map, + choropleth_mapbox, + density_map, + density_mapbox, +) + + +from ._core import ( # noqa: F401 + set_mapbox_access_token, + defaults, + get_trendline_results, + NO_COLOR, +) + +from ._special_inputs import IdentityMap, Constant, Range # noqa: F401 + +from . import data, colors, trendline_functions # noqa: F401 + +__all__ = [ + "scatter", + "scatter_3d", + "scatter_polar", + "scatter_ternary", + "scatter_map", + "scatter_mapbox", + "scatter_geo", + "scatter_matrix", + "density_contour", + "density_heatmap", + "density_map", + "density_mapbox", + "line", + "line_3d", + "line_polar", + "line_ternary", + "line_map", + "line_mapbox", + "line_geo", + "parallel_coordinates", + "parallel_categories", + "area", + "bar", + "timeline", + "bar_polar", + "violin", + "box", + "strip", + "histogram", + "ecdf", + "choropleth", + "choropleth_map", + "choropleth_mapbox", + "pie", + "sunburst", + "treemap", + "icicle", + "funnel", + "funnel_area", + "imshow", + "data", + "colors", + "trendline_functions", + "set_mapbox_access_token", + "get_trendline_results", + "IdentityMap", + "Constant", + "Range", + "NO_COLOR", +] diff --git a/venv/lib/python3.8/site-packages/plotly/express/_chart_types.py b/venv/lib/python3.8/site-packages/plotly/express/_chart_types.py new file mode 100644 index 0000000..9ec2b4a --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/express/_chart_types.py @@ -0,0 +1,1950 @@ +from warnings import warn + +from ._core import make_figure +from ._doc import make_docstring +import plotly.graph_objs as go + +_wide_mode_xy_append = [ + "Either `x` or `y` can optionally be a list of column references or array_likes, ", + "in which case the data will be treated as if it were 'wide' rather than 'long'.", +] +_cartesian_append_dict = dict(x=_wide_mode_xy_append, y=_wide_mode_xy_append) + + +def scatter( + data_frame=None, + x=None, + y=None, + color=None, + symbol=None, + size=None, + hover_name=None, + hover_data=None, + custom_data=None, + text=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + error_x=None, + error_x_minus=None, + error_y=None, + error_y_minus=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + orientation=None, + color_discrete_sequence=None, + color_discrete_map=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + symbol_sequence=None, + symbol_map=None, + opacity=None, + size_max=None, + marginal_x=None, + marginal_y=None, + trendline=None, + trendline_options=None, + trendline_color_override=None, + trendline_scope="trace", + log_x=False, + log_y=False, + range_x=None, + range_y=None, + render_mode="auto", + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a scatter plot, each row of `data_frame` is represented by a symbol + mark in 2D space. + """ + return make_figure(args=locals(), constructor=go.Scatter) + + +scatter.__doc__ = make_docstring(scatter, append_dict=_cartesian_append_dict) + + +def density_contour( + data_frame=None, + x=None, + y=None, + z=None, + color=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + hover_name=None, + hover_data=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + orientation=None, + color_discrete_sequence=None, + color_discrete_map=None, + marginal_x=None, + marginal_y=None, + trendline=None, + trendline_options=None, + trendline_color_override=None, + trendline_scope="trace", + log_x=False, + log_y=False, + range_x=None, + range_y=None, + histfunc=None, + histnorm=None, + nbinsx=None, + nbinsy=None, + text_auto=False, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a density contour plot, rows of `data_frame` are grouped together + into contour marks to visualize the 2D distribution of an aggregate + function `histfunc` (e.g. the count or sum) of the value `z`. + """ + return make_figure( + args=locals(), + constructor=go.Histogram2dContour, + trace_patch=dict( + contours=dict(coloring="none"), + histfunc=histfunc, + histnorm=histnorm, + nbinsx=nbinsx, + nbinsy=nbinsy, + xbingroup="x", + ybingroup="y", + ), + ) + + +density_contour.__doc__ = make_docstring( + density_contour, + append_dict=dict( + x=_wide_mode_xy_append, + y=_wide_mode_xy_append, + z=[ + "For `density_heatmap` and `density_contour` these values are used as the inputs to `histfunc`.", + ], + histfunc=["The arguments to this function are the values of `z`."], + ), +) + + +def density_heatmap( + data_frame=None, + x=None, + y=None, + z=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + hover_name=None, + hover_data=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + orientation=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + marginal_x=None, + marginal_y=None, + opacity=None, + log_x=False, + log_y=False, + range_x=None, + range_y=None, + histfunc=None, + histnorm=None, + nbinsx=None, + nbinsy=None, + text_auto=False, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a density heatmap, rows of `data_frame` are grouped together into + colored rectangular tiles to visualize the 2D distribution of an + aggregate function `histfunc` (e.g. the count or sum) of the value `z`. + """ + return make_figure( + args=locals(), + constructor=go.Histogram2d, + trace_patch=dict( + histfunc=histfunc, + histnorm=histnorm, + nbinsx=nbinsx, + nbinsy=nbinsy, + xbingroup="x", + ybingroup="y", + ), + ) + + +density_heatmap.__doc__ = make_docstring( + density_heatmap, + append_dict=dict( + x=_wide_mode_xy_append, + y=_wide_mode_xy_append, + z=[ + "For `density_heatmap` and `density_contour` these values are used as the inputs to `histfunc`.", + ], + histfunc=[ + "The arguments to this function are the values of `z`.", + ], + ), +) + + +def line( + data_frame=None, + x=None, + y=None, + line_group=None, + color=None, + line_dash=None, + symbol=None, + hover_name=None, + hover_data=None, + custom_data=None, + text=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + error_x=None, + error_x_minus=None, + error_y=None, + error_y_minus=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + orientation=None, + color_discrete_sequence=None, + color_discrete_map=None, + line_dash_sequence=None, + line_dash_map=None, + symbol_sequence=None, + symbol_map=None, + markers=False, + log_x=False, + log_y=False, + range_x=None, + range_y=None, + line_shape=None, + render_mode="auto", + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a 2D line plot, each row of `data_frame` is represented as a vertex of + a polyline mark in 2D space. + """ + return make_figure(args=locals(), constructor=go.Scatter) + + +line.__doc__ = make_docstring(line, append_dict=_cartesian_append_dict) + + +def area( + data_frame=None, + x=None, + y=None, + line_group=None, + color=None, + pattern_shape=None, + symbol=None, + hover_name=None, + hover_data=None, + custom_data=None, + text=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + pattern_shape_sequence=None, + pattern_shape_map=None, + symbol_sequence=None, + symbol_map=None, + markers=False, + orientation=None, + groupnorm=None, + log_x=False, + log_y=False, + range_x=None, + range_y=None, + line_shape=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a stacked area plot, each row of `data_frame` is represented as + a vertex of a polyline mark in 2D space. The area between + successive polylines is filled. + """ + return make_figure( + args=locals(), + constructor=go.Scatter, + trace_patch=dict(stackgroup=1, mode="lines", groupnorm=groupnorm), + ) + + +area.__doc__ = make_docstring(area, append_dict=_cartesian_append_dict) + + +def bar( + data_frame=None, + x=None, + y=None, + color=None, + pattern_shape=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + hover_name=None, + hover_data=None, + custom_data=None, + text=None, + base=None, + error_x=None, + error_x_minus=None, + error_y=None, + error_y_minus=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + color_continuous_scale=None, + pattern_shape_sequence=None, + pattern_shape_map=None, + range_color=None, + color_continuous_midpoint=None, + opacity=None, + orientation=None, + barmode="relative", + log_x=False, + log_y=False, + range_x=None, + range_y=None, + text_auto=False, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a bar plot, each row of `data_frame` is represented as a rectangular + mark. + """ + return make_figure( + args=locals(), + constructor=go.Bar, + trace_patch=dict(textposition="auto"), + layout_patch=dict(barmode=barmode), + ) + + +bar.__doc__ = make_docstring(bar, append_dict=_cartesian_append_dict) + + +def timeline( + data_frame=None, + x_start=None, + x_end=None, + y=None, + color=None, + pattern_shape=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + hover_name=None, + hover_data=None, + custom_data=None, + text=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + pattern_shape_sequence=None, + pattern_shape_map=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + opacity=None, + range_x=None, + range_y=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a timeline plot, each row of `data_frame` is represented as a rectangular + mark on an x axis of type `date`, spanning from `x_start` to `x_end`. + """ + return make_figure( + args=locals(), + constructor="timeline", + trace_patch=dict(textposition="auto", orientation="h"), + layout_patch=dict(barmode="overlay"), + ) + + +timeline.__doc__ = make_docstring(timeline) + + +def histogram( + data_frame=None, + x=None, + y=None, + color=None, + pattern_shape=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + hover_name=None, + hover_data=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + pattern_shape_sequence=None, + pattern_shape_map=None, + marginal=None, + opacity=None, + orientation=None, + barmode="relative", + barnorm=None, + histnorm=None, + log_x=False, + log_y=False, + range_x=None, + range_y=None, + histfunc=None, + cumulative=None, + nbins=None, + text_auto=False, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a histogram, rows of `data_frame` are grouped together into a + rectangular mark to visualize the 1D distribution of an aggregate + function `histfunc` (e.g. the count or sum) of the value `y` (or `x` if + `orientation` is `'h'`). + """ + return make_figure( + args=locals(), + constructor=go.Histogram, + trace_patch=dict( + histnorm=histnorm, + histfunc=histfunc, + cumulative=dict(enabled=cumulative), + ), + layout_patch=dict(barmode=barmode, barnorm=barnorm), + ) + + +histogram.__doc__ = make_docstring( + histogram, + append_dict=dict( + x=["If `orientation` is `'h'`, these values are used as inputs to `histfunc`."] + + _wide_mode_xy_append, + y=["If `orientation` is `'v'`, these values are used as inputs to `histfunc`."] + + _wide_mode_xy_append, + histfunc=[ + "The arguments to this function are the values of `y` (`x`) if `orientation` is `'v'` (`'h'`).", + ], + ), +) + + +def ecdf( + data_frame=None, + x=None, + y=None, + color=None, + text=None, + line_dash=None, + symbol=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + hover_name=None, + hover_data=None, + animation_frame=None, + animation_group=None, + markers=False, + lines=True, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + line_dash_sequence=None, + line_dash_map=None, + symbol_sequence=None, + symbol_map=None, + marginal=None, + opacity=None, + orientation=None, + ecdfnorm="probability", + ecdfmode="standard", + render_mode="auto", + log_x=False, + log_y=False, + range_x=None, + range_y=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a Empirical Cumulative Distribution Function (ECDF) plot, rows of `data_frame` + are sorted by the value `x` (or `y` if `orientation` is `'h'`) and their cumulative + count (or the cumulative sum of `y` if supplied and `orientation` is `h`) is drawn + as a line. + """ + return make_figure(args=locals(), constructor=go.Scatter) + + +ecdf.__doc__ = make_docstring( + ecdf, + append_dict=dict( + x=[ + "If `orientation` is `'h'`, the cumulative sum of this argument is plotted rather than the cumulative count." + ] + + _wide_mode_xy_append, + y=[ + "If `orientation` is `'v'`, the cumulative sum of this argument is plotted rather than the cumulative count." + ] + + _wide_mode_xy_append, + ), +) + + +def violin( + data_frame=None, + x=None, + y=None, + color=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + hover_name=None, + hover_data=None, + custom_data=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + orientation=None, + violinmode=None, + log_x=False, + log_y=False, + range_x=None, + range_y=None, + points=None, + box=False, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a violin plot, rows of `data_frame` are grouped together into a + curved mark to visualize their distribution. + """ + return make_figure( + args=locals(), + constructor=go.Violin, + trace_patch=dict( + points=points, + box=dict(visible=box), + scalegroup=True, + x0=" ", + y0=" ", + ), + layout_patch=dict(violinmode=violinmode), + ) + + +violin.__doc__ = make_docstring(violin, append_dict=_cartesian_append_dict) + + +def box( + data_frame=None, + x=None, + y=None, + color=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + hover_name=None, + hover_data=None, + custom_data=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + orientation=None, + boxmode=None, + log_x=False, + log_y=False, + range_x=None, + range_y=None, + points=None, + notched=False, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a box plot, rows of `data_frame` are grouped together into a + box-and-whisker mark to visualize their distribution. + + Each box spans from quartile 1 (Q1) to quartile 3 (Q3). The second + quartile (Q2) is marked by a line inside the box. By default, the + whiskers correspond to the box' edges +/- 1.5 times the interquartile + range (IQR: Q3-Q1), see "points" for other options. + """ + return make_figure( + args=locals(), + constructor=go.Box, + trace_patch=dict(boxpoints=points, notched=notched, x0=" ", y0=" "), + layout_patch=dict(boxmode=boxmode), + ) + + +box.__doc__ = make_docstring(box, append_dict=_cartesian_append_dict) + + +def strip( + data_frame=None, + x=None, + y=None, + color=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + hover_name=None, + hover_data=None, + custom_data=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + orientation=None, + stripmode=None, + log_x=False, + log_y=False, + range_x=None, + range_y=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a strip plot each row of `data_frame` is represented as a jittered + mark within categories. + """ + return make_figure( + args=locals(), + constructor=go.Box, + trace_patch=dict( + boxpoints="all", + pointpos=0, + hoveron="points", + fillcolor="rgba(255,255,255,0)", + line={"color": "rgba(255,255,255,0)"}, + x0=" ", + y0=" ", + ), + layout_patch=dict(boxmode=stripmode), + ) + + +strip.__doc__ = make_docstring(strip, append_dict=_cartesian_append_dict) + + +def scatter_3d( + data_frame=None, + x=None, + y=None, + z=None, + color=None, + symbol=None, + size=None, + text=None, + hover_name=None, + hover_data=None, + custom_data=None, + error_x=None, + error_x_minus=None, + error_y=None, + error_y_minus=None, + error_z=None, + error_z_minus=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + size_max=None, + color_discrete_sequence=None, + color_discrete_map=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + symbol_sequence=None, + symbol_map=None, + opacity=None, + log_x=False, + log_y=False, + log_z=False, + range_x=None, + range_y=None, + range_z=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a 3D scatter plot, each row of `data_frame` is represented by a + symbol mark in 3D space. + """ + return make_figure(args=locals(), constructor=go.Scatter3d) + + +scatter_3d.__doc__ = make_docstring(scatter_3d) + + +def line_3d( + data_frame=None, + x=None, + y=None, + z=None, + color=None, + line_dash=None, + text=None, + line_group=None, + symbol=None, + hover_name=None, + hover_data=None, + custom_data=None, + error_x=None, + error_x_minus=None, + error_y=None, + error_y_minus=None, + error_z=None, + error_z_minus=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + line_dash_sequence=None, + line_dash_map=None, + symbol_sequence=None, + symbol_map=None, + markers=False, + log_x=False, + log_y=False, + log_z=False, + range_x=None, + range_y=None, + range_z=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a 3D line plot, each row of `data_frame` is represented as a vertex of + a polyline mark in 3D space. + """ + return make_figure(args=locals(), constructor=go.Scatter3d) + + +line_3d.__doc__ = make_docstring(line_3d) + + +def scatter_ternary( + data_frame=None, + a=None, + b=None, + c=None, + color=None, + symbol=None, + size=None, + text=None, + hover_name=None, + hover_data=None, + custom_data=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + symbol_sequence=None, + symbol_map=None, + opacity=None, + size_max=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a ternary scatter plot, each row of `data_frame` is represented by a + symbol mark in ternary coordinates. + """ + return make_figure(args=locals(), constructor=go.Scatterternary) + + +scatter_ternary.__doc__ = make_docstring(scatter_ternary) + + +def line_ternary( + data_frame=None, + a=None, + b=None, + c=None, + color=None, + line_dash=None, + line_group=None, + symbol=None, + hover_name=None, + hover_data=None, + custom_data=None, + text=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + line_dash_sequence=None, + line_dash_map=None, + symbol_sequence=None, + symbol_map=None, + markers=False, + line_shape=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a ternary line plot, each row of `data_frame` is represented as + a vertex of a polyline mark in ternary coordinates. + """ + return make_figure(args=locals(), constructor=go.Scatterternary) + + +line_ternary.__doc__ = make_docstring(line_ternary) + + +def scatter_polar( + data_frame=None, + r=None, + theta=None, + color=None, + symbol=None, + size=None, + hover_name=None, + hover_data=None, + custom_data=None, + text=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + symbol_sequence=None, + symbol_map=None, + opacity=None, + direction="clockwise", + start_angle=90, + size_max=None, + range_r=None, + range_theta=None, + log_r=False, + render_mode="auto", + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a polar scatter plot, each row of `data_frame` is represented by a + symbol mark in polar coordinates. + """ + return make_figure(args=locals(), constructor=go.Scatterpolar) + + +scatter_polar.__doc__ = make_docstring(scatter_polar) + + +def line_polar( + data_frame=None, + r=None, + theta=None, + color=None, + line_dash=None, + hover_name=None, + hover_data=None, + custom_data=None, + line_group=None, + text=None, + symbol=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + line_dash_sequence=None, + line_dash_map=None, + symbol_sequence=None, + symbol_map=None, + markers=False, + direction="clockwise", + start_angle=90, + line_close=False, + line_shape=None, + render_mode="auto", + range_r=None, + range_theta=None, + log_r=False, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a polar line plot, each row of `data_frame` is represented as a + vertex of a polyline mark in polar coordinates. + """ + return make_figure(args=locals(), constructor=go.Scatterpolar) + + +line_polar.__doc__ = make_docstring(line_polar) + + +def bar_polar( + data_frame=None, + r=None, + theta=None, + color=None, + pattern_shape=None, + hover_name=None, + hover_data=None, + custom_data=None, + base=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + color_continuous_scale=None, + pattern_shape_sequence=None, + pattern_shape_map=None, + range_color=None, + color_continuous_midpoint=None, + barnorm=None, + barmode="relative", + direction="clockwise", + start_angle=90, + range_r=None, + range_theta=None, + log_r=False, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a polar bar plot, each row of `data_frame` is represented as a wedge + mark in polar coordinates. + """ + return make_figure( + args=locals(), + constructor=go.Barpolar, + layout_patch=dict(barnorm=barnorm, barmode=barmode), + ) + + +bar_polar.__doc__ = make_docstring(bar_polar) + + +def choropleth( + data_frame=None, + lat=None, + lon=None, + locations=None, + locationmode=None, + geojson=None, + featureidkey=None, + color=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + hover_name=None, + hover_data=None, + custom_data=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + projection=None, + scope=None, + center=None, + fitbounds=None, + basemap_visible=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a choropleth map, each row of `data_frame` is represented by a + colored region mark on a map. + """ + return make_figure( + args=locals(), + constructor=go.Choropleth, + trace_patch=dict(locationmode=locationmode), + ) + + +choropleth.__doc__ = make_docstring(choropleth) + + +def scatter_geo( + data_frame=None, + lat=None, + lon=None, + locations=None, + locationmode=None, + geojson=None, + featureidkey=None, + color=None, + text=None, + symbol=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + hover_name=None, + hover_data=None, + custom_data=None, + size=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + symbol_sequence=None, + symbol_map=None, + opacity=None, + size_max=None, + projection=None, + scope=None, + center=None, + fitbounds=None, + basemap_visible=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a geographic scatter plot, each row of `data_frame` is represented + by a symbol mark on a map. + """ + return make_figure( + args=locals(), + constructor=go.Scattergeo, + trace_patch=dict(locationmode=locationmode), + ) + + +scatter_geo.__doc__ = make_docstring(scatter_geo) + + +def line_geo( + data_frame=None, + lat=None, + lon=None, + locations=None, + locationmode=None, + geojson=None, + featureidkey=None, + color=None, + line_dash=None, + text=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + hover_name=None, + hover_data=None, + custom_data=None, + line_group=None, + symbol=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + line_dash_sequence=None, + line_dash_map=None, + symbol_sequence=None, + symbol_map=None, + markers=False, + projection=None, + scope=None, + center=None, + fitbounds=None, + basemap_visible=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a geographic line plot, each row of `data_frame` is represented as + a vertex of a polyline mark on a map. + """ + return make_figure( + args=locals(), + constructor=go.Scattergeo, + trace_patch=dict(locationmode=locationmode), + ) + + +line_geo.__doc__ = make_docstring(line_geo) + + +def scatter_map( + data_frame=None, + lat=None, + lon=None, + color=None, + text=None, + hover_name=None, + hover_data=None, + custom_data=None, + size=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + opacity=None, + size_max=None, + zoom=8, + center=None, + map_style=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a scatter map, each row of `data_frame` is represented by a + symbol mark on the map. + """ + return make_figure(args=locals(), constructor=go.Scattermap) + + +scatter_map.__doc__ = make_docstring(scatter_map) + + +def choropleth_map( + data_frame=None, + geojson=None, + featureidkey=None, + locations=None, + color=None, + hover_name=None, + hover_data=None, + custom_data=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + opacity=None, + zoom=8, + center=None, + map_style=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a choropleth map, each row of `data_frame` is represented by a + colored region on the map. + """ + return make_figure(args=locals(), constructor=go.Choroplethmap) + + +choropleth_map.__doc__ = make_docstring(choropleth_map) + + +def density_map( + data_frame=None, + lat=None, + lon=None, + z=None, + hover_name=None, + hover_data=None, + custom_data=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + opacity=None, + zoom=8, + center=None, + map_style=None, + radius=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a density map, each row of `data_frame` contributes to the intensity of + the color of the region around the corresponding point on the map. + """ + return make_figure( + args=locals(), constructor=go.Densitymap, trace_patch=dict(radius=radius) + ) + + +density_map.__doc__ = make_docstring(density_map) + + +def line_map( + data_frame=None, + lat=None, + lon=None, + color=None, + text=None, + hover_name=None, + hover_data=None, + custom_data=None, + line_group=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + zoom=8, + center=None, + map_style=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a line map, each row of `data_frame` is represented as + a vertex of a polyline mark on the map. + """ + return make_figure(args=locals(), constructor=go.Scattermap) + + +line_map.__doc__ = make_docstring(line_map) + + +def scatter_mapbox( + data_frame=None, + lat=None, + lon=None, + color=None, + text=None, + hover_name=None, + hover_data=None, + custom_data=None, + size=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + opacity=None, + size_max=None, + zoom=8, + center=None, + mapbox_style=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + *scatter_mapbox* is deprecated! Use *scatter_map* instead. + Learn more at: https://plotly.com/python/mapbox-to-maplibre/ + In a Mapbox scatter plot, each row of `data_frame` is represented by a + symbol mark on a Mapbox map. + """ + warn( + "*scatter_mapbox* is deprecated!" + + " Use *scatter_map* instead." + + " Learn more at: https://plotly.com/python/mapbox-to-maplibre/", + stacklevel=2, + category=DeprecationWarning, + ) + return make_figure(args=locals(), constructor=go.Scattermapbox) + + +scatter_mapbox.__doc__ = make_docstring(scatter_mapbox) + + +def choropleth_mapbox( + data_frame=None, + geojson=None, + featureidkey=None, + locations=None, + color=None, + hover_name=None, + hover_data=None, + custom_data=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + opacity=None, + zoom=8, + center=None, + mapbox_style=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + *choropleth_mapbox* is deprecated! Use *choropleth_map* instead. + Learn more at: https://plotly.com/python/mapbox-to-maplibre/ + In a Mapbox choropleth map, each row of `data_frame` is represented by a + colored region on a Mapbox map. + """ + warn( + "*choropleth_mapbox* is deprecated!" + + " Use *choropleth_map* instead." + + " Learn more at: https://plotly.com/python/mapbox-to-maplibre/", + stacklevel=2, + category=DeprecationWarning, + ) + return make_figure(args=locals(), constructor=go.Choroplethmapbox) + + +choropleth_mapbox.__doc__ = make_docstring(choropleth_mapbox) + + +def density_mapbox( + data_frame=None, + lat=None, + lon=None, + z=None, + hover_name=None, + hover_data=None, + custom_data=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + opacity=None, + zoom=8, + center=None, + mapbox_style=None, + radius=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + *density_mapbox* is deprecated! Use *density_map* instead. + Learn more at: https://plotly.com/python/mapbox-to-maplibre/ + In a Mapbox density map, each row of `data_frame` contributes to the intensity of + the color of the region around the corresponding point on the map + """ + warn( + "*density_mapbox* is deprecated!" + + " Use *density_map* instead." + + " Learn more at: https://plotly.com/python/mapbox-to-maplibre/", + stacklevel=2, + category=DeprecationWarning, + ) + return make_figure( + args=locals(), constructor=go.Densitymapbox, trace_patch=dict(radius=radius) + ) + + +density_mapbox.__doc__ = make_docstring(density_mapbox) + + +def line_mapbox( + data_frame=None, + lat=None, + lon=None, + color=None, + text=None, + hover_name=None, + hover_data=None, + custom_data=None, + line_group=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + zoom=8, + center=None, + mapbox_style=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + *line_mapbox* is deprecated! Use *line_map* instead. + Learn more at: https://plotly.com/python/mapbox-to-maplibre/ + In a Mapbox line plot, each row of `data_frame` is represented as + a vertex of a polyline mark on a Mapbox map. + """ + warn( + "*line_mapbox* is deprecated!" + + " Use *line_map* instead." + + " Learn more at: https://plotly.com/python/mapbox-to-maplibre/", + stacklevel=2, + category=DeprecationWarning, + ) + return make_figure(args=locals(), constructor=go.Scattermapbox) + + +line_mapbox.__doc__ = make_docstring(line_mapbox) + + +def scatter_matrix( + data_frame=None, + dimensions=None, + color=None, + symbol=None, + size=None, + hover_name=None, + hover_data=None, + custom_data=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + symbol_sequence=None, + symbol_map=None, + opacity=None, + size_max=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a scatter plot matrix (or SPLOM), each row of `data_frame` is + represented by a multiple symbol marks, one in each cell of a grid of + 2D scatter plots, which plot each pair of `dimensions` against each + other. + """ + return make_figure( + args=locals(), constructor=go.Splom, layout_patch=dict(dragmode="select") + ) + + +scatter_matrix.__doc__ = make_docstring(scatter_matrix) + + +def parallel_coordinates( + data_frame=None, + dimensions=None, + color=None, + labels=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a parallel coordinates plot, each row of `data_frame` is represented + by a polyline mark which traverses a set of parallel axes, one for each + of the `dimensions`. + """ + return make_figure(args=locals(), constructor=go.Parcoords) + + +parallel_coordinates.__doc__ = make_docstring(parallel_coordinates) + + +def parallel_categories( + data_frame=None, + dimensions=None, + color=None, + labels=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, + dimensions_max_cardinality=50, +) -> go.Figure: + """ + In a parallel categories (or parallel sets) plot, each row of + `data_frame` is grouped with other rows that share the same values of + `dimensions` and then plotted as a polyline mark through a set of + parallel axes, one for each of the `dimensions`. + """ + return make_figure(args=locals(), constructor=go.Parcats) + + +parallel_categories.__doc__ = make_docstring(parallel_categories) + + +def pie( + data_frame=None, + names=None, + values=None, + color=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + color_discrete_sequence=None, + color_discrete_map=None, + hover_name=None, + hover_data=None, + custom_data=None, + category_orders=None, + labels=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, + opacity=None, + hole=None, +) -> go.Figure: + """ + In a pie plot, each row of `data_frame` is represented as a sector of a + pie. + """ + if color_discrete_sequence is not None: + layout_patch = {"piecolorway": color_discrete_sequence} + else: + layout_patch = {} + return make_figure( + args=locals(), + constructor=go.Pie, + trace_patch=dict(showlegend=(names is not None), hole=hole), + layout_patch=layout_patch, + ) + + +pie.__doc__ = make_docstring( + pie, + override_dict=dict( + hole=[ + "float", + "Sets the fraction of the radius to cut out of the pie." + "Use this to make a donut chart.", + ], + ), +) + + +def sunburst( + data_frame=None, + names=None, + values=None, + parents=None, + path=None, + ids=None, + color=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + color_discrete_sequence=None, + color_discrete_map=None, + hover_name=None, + hover_data=None, + custom_data=None, + labels=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, + branchvalues=None, + maxdepth=None, +) -> go.Figure: + """ + A sunburst plot represents hierarchial data as sectors laid out over + several levels of concentric rings. + """ + if color_discrete_sequence is not None: + layout_patch = {"sunburstcolorway": color_discrete_sequence} + else: + layout_patch = {} + if path is not None and (ids is not None or parents is not None): + raise ValueError( + "Either `path` should be provided, or `ids` and `parents`." + "These parameters are mutually exclusive and cannot be passed together." + ) + if path is not None and branchvalues is None: + branchvalues = "total" + return make_figure( + args=locals(), + constructor=go.Sunburst, + trace_patch=dict(branchvalues=branchvalues, maxdepth=maxdepth), + layout_patch=layout_patch, + ) + + +sunburst.__doc__ = make_docstring(sunburst) + + +def treemap( + data_frame=None, + names=None, + values=None, + parents=None, + ids=None, + path=None, + color=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + color_discrete_sequence=None, + color_discrete_map=None, + hover_name=None, + hover_data=None, + custom_data=None, + labels=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, + branchvalues=None, + maxdepth=None, +) -> go.Figure: + """ + A treemap plot represents hierarchial data as nested rectangular + sectors. + """ + if color_discrete_sequence is not None: + layout_patch = {"treemapcolorway": color_discrete_sequence} + else: + layout_patch = {} + if path is not None and (ids is not None or parents is not None): + raise ValueError( + "Either `path` should be provided, or `ids` and `parents`." + "These parameters are mutually exclusive and cannot be passed together." + ) + if path is not None and branchvalues is None: + branchvalues = "total" + return make_figure( + args=locals(), + constructor=go.Treemap, + trace_patch=dict(branchvalues=branchvalues, maxdepth=maxdepth), + layout_patch=layout_patch, + ) + + +treemap.__doc__ = make_docstring(treemap) + + +def icicle( + data_frame=None, + names=None, + values=None, + parents=None, + path=None, + ids=None, + color=None, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + color_discrete_sequence=None, + color_discrete_map=None, + hover_name=None, + hover_data=None, + custom_data=None, + labels=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, + branchvalues=None, + maxdepth=None, +) -> go.Figure: + """ + An icicle plot represents hierarchial data with adjoined rectangular + sectors that all cascade from root down to leaf in one direction. + """ + if color_discrete_sequence is not None: + layout_patch = {"iciclecolorway": color_discrete_sequence} + else: + layout_patch = {} + if path is not None and (ids is not None or parents is not None): + raise ValueError( + "Either `path` should be provided, or `ids` and `parents`." + "These parameters are mutually exclusive and cannot be passed together." + ) + if path is not None and branchvalues is None: + branchvalues = "total" + return make_figure( + args=locals(), + constructor=go.Icicle, + trace_patch=dict(branchvalues=branchvalues, maxdepth=maxdepth), + layout_patch=layout_patch, + ) + + +icicle.__doc__ = make_docstring(icicle) + + +def funnel( + data_frame=None, + x=None, + y=None, + color=None, + facet_row=None, + facet_col=None, + facet_col_wrap=0, + facet_row_spacing=None, + facet_col_spacing=None, + hover_name=None, + hover_data=None, + custom_data=None, + text=None, + animation_frame=None, + animation_group=None, + category_orders=None, + labels=None, + color_discrete_sequence=None, + color_discrete_map=None, + opacity=None, + orientation=None, + log_x=False, + log_y=False, + range_x=None, + range_y=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, +) -> go.Figure: + """ + In a funnel plot, each row of `data_frame` is represented as a + rectangular sector of a funnel. + """ + return make_figure(args=locals(), constructor=go.Funnel) + + +funnel.__doc__ = make_docstring(funnel, append_dict=_cartesian_append_dict) + + +def funnel_area( + data_frame=None, + names=None, + values=None, + color=None, + color_discrete_sequence=None, + color_discrete_map=None, + hover_name=None, + hover_data=None, + custom_data=None, + labels=None, + title=None, + subtitle=None, + template=None, + width=None, + height=None, + opacity=None, +) -> go.Figure: + """ + In a funnel area plot, each row of `data_frame` is represented as a + trapezoidal sector of a funnel. + """ + if color_discrete_sequence is not None: + layout_patch = {"funnelareacolorway": color_discrete_sequence} + else: + layout_patch = {} + return make_figure( + args=locals(), + constructor=go.Funnelarea, + trace_patch=dict(showlegend=(names is not None)), + layout_patch=layout_patch, + ) + + +funnel_area.__doc__ = make_docstring(funnel_area) diff --git a/venv/lib/python3.8/site-packages/plotly/express/_core.py b/venv/lib/python3.8/site-packages/plotly/express/_core.py new file mode 100644 index 0000000..d2dbc84 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/express/_core.py @@ -0,0 +1,2905 @@ +import plotly.graph_objs as go +import plotly.io as pio +from collections import namedtuple, OrderedDict +from ._special_inputs import IdentityMap, Constant, Range +from .trendline_functions import ols, lowess, rolling, expanding, ewm + +from _plotly_utils.basevalidators import ColorscaleValidator +from plotly.colors import qualitative, sequential +import math + +from plotly._subplots import ( + make_subplots, + _set_trace_grid_reference, + _subplot_type_for_trace_type, +) + +import narwhals.stable.v1 as nw + +# The reason to use narwhals.stable.v1 is to have a stable and perfectly +# backwards-compatible API, hence the confidence to not pin the Narwhals version exactly, +# allowing for multiple major libraries to have Narwhals as a dependency without +# forbidding users to install them all together due to dependency conflicts. + +NO_COLOR = "px_no_color_constant" + + +trendline_functions = dict( + lowess=lowess, rolling=rolling, ewm=ewm, expanding=expanding, ols=ols +) + +# Declare all supported attributes, across all plot types +direct_attrables = ( + ["base", "x", "y", "z", "a", "b", "c", "r", "theta", "size", "x_start", "x_end"] + + ["hover_name", "text", "names", "values", "parents", "wide_cross"] + + ["ids", "error_x", "error_x_minus", "error_y", "error_y_minus", "error_z"] + + ["error_z_minus", "lat", "lon", "locations", "animation_group"] +) +array_attrables = ["dimensions", "custom_data", "hover_data", "path", "wide_variable"] +group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] +renameable_group_attrables = [ + "color", # renamed to marker.color or line.color in infer_config + "symbol", # renamed to marker.symbol in infer_config + "line_dash", # renamed to line.dash in infer_config + "pattern_shape", # renamed to marker.pattern.shape in infer_config +] +all_attrables = ( + direct_attrables + array_attrables + group_attrables + renameable_group_attrables +) + +cartesians = [go.Scatter, go.Scattergl, go.Bar, go.Funnel, go.Box, go.Violin] +cartesians += [go.Histogram, go.Histogram2d, go.Histogram2dContour] + + +class PxDefaults(object): + __slots__ = [ + "template", + "width", + "height", + "color_discrete_sequence", + "color_discrete_map", + "color_continuous_scale", + "symbol_sequence", + "symbol_map", + "line_dash_sequence", + "line_dash_map", + "pattern_shape_sequence", + "pattern_shape_map", + "size_max", + "category_orders", + "labels", + ] + + def __init__(self): + self.reset() + + def reset(self): + self.template = None + self.width = None + self.height = None + self.color_discrete_sequence = None + self.color_discrete_map = {} + self.color_continuous_scale = None + self.symbol_sequence = None + self.symbol_map = {} + self.line_dash_sequence = None + self.line_dash_map = {} + self.pattern_shape_sequence = None + self.pattern_shape_map = {} + self.size_max = 20 + self.category_orders = {} + self.labels = {} + + +defaults = PxDefaults() +del PxDefaults + + +MAPBOX_TOKEN = None + + +def set_mapbox_access_token(token): + """ + Arguments: + token: A Mapbox token to be used in `plotly.express.scatter_mapbox` and \ + `plotly.express.line_mapbox` figures. See \ + https://docs.mapbox.com/help/how-mapbox-works/access-tokens/ for more details + """ + global MAPBOX_TOKEN + MAPBOX_TOKEN = token + + +def get_trendline_results(fig): + """ + Extracts fit statistics for trendlines (when applied to figures generated with + the `trendline` argument set to `"ols"`). + + Arguments: + fig: the output of a `plotly.express` charting call + Returns: + A `pandas.DataFrame` with a column "px_fit_results" containing the `statsmodels` + results objects, along with columns identifying the subset of the data the + trendline was fit on. + """ + return fig._px_trendlines + + +Mapping = namedtuple( + "Mapping", + [ + "show_in_trace_name", + "grouper", + "val_map", + "sequence", + "updater", + "variable", + "facet", + ], +) +TraceSpec = namedtuple("TraceSpec", ["constructor", "attrs", "trace_patch", "marginal"]) + + +def get_label(args, column): + try: + return args["labels"][column] + except Exception: + return column + + +def invert_label(args, column): + """Invert mapping. + Find key corresponding to value column in dict args["labels"]. + Returns `column` if the value does not exist. + """ + reversed_labels = {value: key for (key, value) in args["labels"].items()} + try: + return reversed_labels[column] + except Exception: + return column + + +def _is_continuous(df: nw.DataFrame, col_name: str) -> bool: + if nw.dependencies.is_pandas_like_dataframe(df_native := df.to_native()): + # fastpath for pandas: Narwhals' Series.dtype has a bit of overhead, as it + # tries to distinguish between true "object" columns, and "string" columns + # disguised as "object". But here, we deal with neither. + return df_native[col_name].dtype.kind in "ifc" + return df.get_column(col_name).dtype.is_numeric() + + +def _to_unix_epoch_seconds(s: nw.Series) -> nw.Series: + dtype = s.dtype + if dtype == nw.Date: + return s.dt.timestamp("ms") / 1_000 + if dtype == nw.Datetime: + if dtype.time_unit in ("s", "ms"): + return s.dt.timestamp("ms") / 1_000 + elif dtype.time_unit == "us": + return s.dt.timestamp("us") / 1_000_000 + elif dtype.time_unit == "ns": + return s.dt.timestamp("ns") / 1_000_000_000 + else: + msg = "Unexpected dtype, please report a bug" + raise ValueError(msg) + else: + msg = f"Expected Date or Datetime, got {dtype}" + raise TypeError(msg) + + +def _generate_temporary_column_name(n_bytes, columns) -> str: + """Wraps of Narwhals generate_temporary_column_name to generate a token + which is guaranteed to not be in columns, nor in [col + token for col in columns] + """ + counter = 0 + while True: + # This is guaranteed to not be in columns by Narwhals + token = nw.generate_temporary_column_name(n_bytes, columns=columns) + + # Now check that it is not in the [col + token for col in columns] list + if token not in {f"{c}{token}" for c in columns}: + return token + + counter += 1 + if counter > 100: + msg = ( + "Internal Error: Plotly was not able to generate a column name with " + f"{n_bytes=} and not in {columns}.\n" + "Please report this to " + "https://github.com/plotly/plotly.py/issues/new and we will try to " + "replicate and fix it." + ) + raise AssertionError(msg) + + +def get_decorated_label(args, column, role): + original_label = label = get_label(args, column) + if "histfunc" in args and ( + (role == "z") + or (role == "x" and "orientation" in args and args["orientation"] == "h") + or (role == "y" and "orientation" in args and args["orientation"] == "v") + ): + histfunc = args["histfunc"] or "count" + if histfunc != "count": + label = "%s of %s" % (histfunc, label) + else: + label = "count" + + if "histnorm" in args and args["histnorm"] is not None: + if label == "count": + label = args["histnorm"] + else: + histnorm = args["histnorm"] + if histfunc == "sum": + if histnorm == "probability": + label = "%s of %s" % ("fraction", label) + elif histnorm == "percent": + label = "%s of %s" % (histnorm, label) + else: + label = "%s weighted by %s" % (histnorm, original_label) + elif histnorm == "probability": + label = "%s of sum of %s" % ("fraction", label) + elif histnorm == "percent": + label = "%s of sum of %s" % ("percent", label) + else: + label = "%s of %s" % (histnorm, label) + + if "barnorm" in args and args["barnorm"] is not None: + label = "%s (normalized as %s)" % (label, args["barnorm"]) + + return label + + +def make_mapping(args, variable): + if variable == "line_group" or variable == "animation_frame": + return Mapping( + show_in_trace_name=False, + grouper=args[variable], + val_map={}, + sequence=[""], + variable=variable, + updater=(lambda trace, v: v), + facet=None, + ) + if variable == "facet_row" or variable == "facet_col": + letter = "x" if variable == "facet_col" else "y" + return Mapping( + show_in_trace_name=False, + variable=letter, + grouper=args[variable], + val_map={}, + sequence=[i for i in range(1, 1000)], + updater=(lambda trace, v: v), + facet="row" if variable == "facet_row" else "col", + ) + (parent, variable, *other_variables) = variable.split(".") + vprefix = variable + arg_name = variable + if variable == "color": + vprefix = "color_discrete" + if variable == "dash": + arg_name = "line_dash" + vprefix = "line_dash" + if variable in ["pattern", "shape"]: + arg_name = "pattern_shape" + vprefix = "pattern_shape" + if args[vprefix + "_map"] == "identity": + val_map = IdentityMap() + else: + val_map = args[vprefix + "_map"].copy() + return Mapping( + show_in_trace_name=True, + variable=variable, + grouper=args[arg_name], + val_map=val_map, + sequence=args[vprefix + "_sequence"], + updater=lambda trace, v: trace.update( + {parent: {".".join([variable] + other_variables): v}} + ), + facet=None, + ) + + +def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): + """Populates a dict with arguments to update trace + + Parameters + ---------- + args : dict + args to be used for the trace + trace_spec : NamedTuple + which kind of trace to be used (has constructor, marginal etc. + attributes) + trace_data : pandas DataFrame + data + mapping_labels : dict + to be used for hovertemplate + sizeref : float + marker sizeref + + Returns + ------- + trace_patch : dict + dict to be used to update trace + fit_results : dict + fit information to be used for trendlines + """ + trace_data: nw.DataFrame + df: nw.DataFrame = args["data_frame"] + + if "line_close" in args and args["line_close"]: + trace_data = nw.concat([trace_data, trace_data.head(1)], how="vertical") + + trace_patch = trace_spec.trace_patch.copy() or {} + fit_results = None + hover_header = "" + for attr_name in trace_spec.attrs: + attr_value = args[attr_name] + attr_label = get_decorated_label(args, attr_value, attr_name) + if attr_name == "dimensions": + dims = [ + (name, trace_data.get_column(name)) + for name in trace_data.columns + if ((not attr_value) or (name in attr_value)) + and (trace_spec.constructor != go.Parcoords or _is_continuous(df, name)) + and ( + trace_spec.constructor != go.Parcats + or (attr_value is not None and name in attr_value) + or nw.to_py_scalar(df.get_column(name).n_unique()) + <= args["dimensions_max_cardinality"] + ) + ] + trace_patch["dimensions"] = [ + dict(label=get_label(args, name), values=column) + for (name, column) in dims + ] + if trace_spec.constructor == go.Splom: + for d in trace_patch["dimensions"]: + d["axis"] = dict(matches=True) + mapping_labels["%{xaxis.title.text}"] = "%{x}" + mapping_labels["%{yaxis.title.text}"] = "%{y}" + + elif attr_value is not None: + if attr_name == "size": + if "marker" not in trace_patch: + trace_patch["marker"] = dict() + trace_patch["marker"]["size"] = trace_data.get_column(attr_value) + trace_patch["marker"]["sizemode"] = "area" + trace_patch["marker"]["sizeref"] = sizeref + mapping_labels[attr_label] = "%{marker.size}" + elif attr_name == "marginal_x": + if trace_spec.constructor == go.Histogram: + mapping_labels["count"] = "%{y}" + elif attr_name == "marginal_y": + if trace_spec.constructor == go.Histogram: + mapping_labels["count"] = "%{x}" + elif attr_name == "trendline": + if ( + args["x"] + and args["y"] + and len( + trace_data.select(nw.col(args["x"], args["y"])).drop_nulls() + ) + > 1 + ): + # sorting is bad but trace_specs with "trendline" have no other attrs + sorted_trace_data = trace_data.sort(by=args["x"], nulls_last=True) + y = sorted_trace_data.get_column(args["y"]) + x = sorted_trace_data.get_column(args["x"]) + + if x.dtype == nw.Datetime or x.dtype == nw.Date: + # convert to unix epoch seconds + x = _to_unix_epoch_seconds(x) + elif not x.dtype.is_numeric(): + try: + x = x.cast(nw.Float64()) + except ValueError: + raise ValueError( + "Could not convert value of 'x' ('%s') into a numeric type. " + "If 'x' contains stringified dates, please convert to a datetime column." + % args["x"] + ) + + if not y.dtype.is_numeric(): + try: + y = y.cast(nw.Float64()) + except ValueError: + raise ValueError( + "Could not convert value of 'y' into a numeric type." + ) + + # preserve original values of "x" in case they're dates + # otherwise numpy/pandas can mess with the timezones + # NB this means trendline functions must output one-to-one with the input series + # i.e. we can't do resampling, because then the X values might not line up! + non_missing = ~(x.is_null() | y.is_null()) + trace_patch["x"] = sorted_trace_data.filter(non_missing).get_column( + args["x"] + ) + if ( + trace_patch["x"].dtype == nw.Datetime + and trace_patch["x"].dtype.time_zone is not None + ): + # Remove time zone so that local time is displayed + trace_patch["x"] = ( + trace_patch["x"].dt.replace_time_zone(None).to_numpy() + ) + else: + trace_patch["x"] = trace_patch["x"].to_numpy() + + trendline_function = trendline_functions[attr_value] + y_out, hover_header, fit_results = trendline_function( + args["trendline_options"], + sorted_trace_data.get_column(args["x"]), # narwhals series + x.to_numpy(), # numpy array + y.to_numpy(), # numpy array + args["x"], + args["y"], + non_missing.to_numpy(), # numpy array + ) + assert len(y_out) == len(trace_patch["x"]), ( + "missing-data-handling failure in trendline code" + ) + trace_patch["y"] = y_out + mapping_labels[get_label(args, args["x"])] = "%{x}" + mapping_labels[get_label(args, args["y"])] = "%{y} <b>(trend)</b>" + elif attr_name.startswith("error"): + error_xy = attr_name[:7] + arr = "arrayminus" if attr_name.endswith("minus") else "array" + if error_xy not in trace_patch: + trace_patch[error_xy] = {} + trace_patch[error_xy][arr] = trace_data.get_column(attr_value) + elif attr_name == "custom_data": + if len(attr_value) > 0: + # here we store a data frame in customdata, and it's serialized + # as a list of row lists, which is what we want + trace_patch["customdata"] = trace_data.select(nw.col(attr_value)) + elif attr_name == "hover_name": + if trace_spec.constructor not in [ + go.Histogram, + go.Histogram2d, + go.Histogram2dContour, + ]: + trace_patch["hovertext"] = trace_data.get_column(attr_value) + if hover_header == "": + hover_header = "<b>%{hovertext}</b><br><br>" + elif attr_name == "hover_data": + if trace_spec.constructor not in [ + go.Histogram, + go.Histogram2d, + go.Histogram2dContour, + ]: + hover_is_dict = isinstance(attr_value, dict) + customdata_cols = args.get("custom_data") or [] + for col in attr_value: + if hover_is_dict and not attr_value[col]: + continue + if col in [ + args.get("x"), + args.get("y"), + args.get("z"), + args.get("base"), + ]: + continue + try: + position = args["custom_data"].index(col) + except (ValueError, AttributeError, KeyError): + position = len(customdata_cols) + customdata_cols.append(col) + attr_label_col = get_decorated_label(args, col, None) + mapping_labels[attr_label_col] = "%%{customdata[%d]}" % ( + position + ) + + if len(customdata_cols) > 0: + # here we store a data frame in customdata, and it's serialized + # as a list of row lists, which is what we want + + # dict.fromkeys(customdata_cols) allows to deduplicate column + # names, yet maintaining the original order. + trace_patch["customdata"] = trace_data.select( + *[nw.col(c) for c in dict.fromkeys(customdata_cols)] + ) + elif attr_name == "color": + if trace_spec.constructor in [ + go.Choropleth, + go.Choroplethmap, + go.Choroplethmapbox, + ]: + trace_patch["z"] = trace_data.get_column(attr_value) + trace_patch["coloraxis"] = "coloraxis1" + mapping_labels[attr_label] = "%{z}" + elif trace_spec.constructor in [ + go.Sunburst, + go.Treemap, + go.Icicle, + go.Pie, + go.Funnelarea, + ]: + if "marker" not in trace_patch: + trace_patch["marker"] = dict() + + if args.get("color_is_continuous"): + trace_patch["marker"]["colors"] = trace_data.get_column( + attr_value + ) + trace_patch["marker"]["coloraxis"] = "coloraxis1" + mapping_labels[attr_label] = "%{color}" + else: + trace_patch["marker"]["colors"] = [] + if args["color_discrete_map"] is not None: + mapping = args["color_discrete_map"].copy() + else: + mapping = {} + for cat in trace_data.get_column(attr_value).to_list(): + # although trace_data.get_column(attr_value) is a Narwhals + # Series, which is an iterable, explicitly calling a to_list() + # makes sure that the elements we loop over are python objects + # in all cases, since depending on the backend this may not be + # the case (e.g. PyArrow) + if mapping.get(cat) is None: + mapping[cat] = args["color_discrete_sequence"][ + len(mapping) % len(args["color_discrete_sequence"]) + ] + trace_patch["marker"]["colors"].append(mapping[cat]) + else: + colorable = "marker" + if trace_spec.constructor in [go.Parcats, go.Parcoords]: + colorable = "line" + if colorable not in trace_patch: + trace_patch[colorable] = dict() + trace_patch[colorable]["color"] = trace_data.get_column(attr_value) + trace_patch[colorable]["coloraxis"] = "coloraxis1" + mapping_labels[attr_label] = "%%{%s.color}" % colorable + elif attr_name == "animation_group": + trace_patch["ids"] = trace_data.get_column(attr_value) + elif attr_name == "locations": + trace_patch[attr_name] = trace_data.get_column(attr_value) + mapping_labels[attr_label] = "%{location}" + elif attr_name == "values": + trace_patch[attr_name] = trace_data.get_column(attr_value) + _label = "value" if attr_label == "values" else attr_label + mapping_labels[_label] = "%{value}" + elif attr_name == "parents": + trace_patch[attr_name] = trace_data.get_column(attr_value) + _label = "parent" if attr_label == "parents" else attr_label + mapping_labels[_label] = "%{parent}" + elif attr_name == "ids": + trace_patch[attr_name] = trace_data.get_column(attr_value) + _label = "id" if attr_label == "ids" else attr_label + mapping_labels[_label] = "%{id}" + elif attr_name == "names": + if trace_spec.constructor in [ + go.Sunburst, + go.Treemap, + go.Icicle, + go.Pie, + go.Funnelarea, + ]: + trace_patch["labels"] = trace_data.get_column(attr_value) + _label = "label" if attr_label == "names" else attr_label + mapping_labels[_label] = "%{label}" + else: + trace_patch[attr_name] = trace_data.get_column(attr_value) + else: + trace_patch[attr_name] = trace_data.get_column(attr_value) + mapping_labels[attr_label] = "%%{%s}" % attr_name + elif (trace_spec.constructor == go.Histogram and attr_name in ["x", "y"]) or ( + trace_spec.constructor in [go.Histogram2d, go.Histogram2dContour] + and attr_name == "z" + ): + # ensure that stuff like "count" gets into the hoverlabel + mapping_labels[attr_label] = "%%{%s}" % attr_name + if trace_spec.constructor not in [go.Parcoords, go.Parcats]: + # Modify mapping_labels according to hover_data keys + # if hover_data is a dict + mapping_labels_copy = OrderedDict(mapping_labels) + if args["hover_data"] and isinstance(args["hover_data"], dict): + for k, v in mapping_labels.items(): + # We need to invert the mapping here + k_args = invert_label(args, k) + if k_args in args["hover_data"]: + formatter = args["hover_data"][k_args][0] + if formatter: + if isinstance(formatter, str): + mapping_labels_copy[k] = v.replace("}", "%s}" % formatter) + else: + _ = mapping_labels_copy.pop(k) + hover_lines = [k + "=" + v for k, v in mapping_labels_copy.items()] + trace_patch["hovertemplate"] = hover_header + "<br>".join(hover_lines) + trace_patch["hovertemplate"] += "<extra></extra>" + return trace_patch, fit_results + + +def configure_axes(args, constructor, fig, orders): + configurators = { + go.Scatter3d: configure_3d_axes, + go.Scatterternary: configure_ternary_axes, + go.Scatterpolar: configure_polar_axes, + go.Scatterpolargl: configure_polar_axes, + go.Barpolar: configure_polar_axes, + go.Scattermap: configure_map, + go.Choroplethmap: configure_map, + go.Densitymap: configure_map, + go.Scattermapbox: configure_mapbox, + go.Choroplethmapbox: configure_mapbox, + go.Densitymapbox: configure_mapbox, + go.Scattergeo: configure_geo, + go.Choropleth: configure_geo, + } + for c in cartesians: + configurators[c] = configure_cartesian_axes + if constructor in configurators: + configurators[constructor](args, fig, orders) + + +def set_cartesian_axis_opts(args, axis, letter, orders): + log_key = "log_" + letter + range_key = "range_" + letter + if log_key in args and args[log_key]: + axis["type"] = "log" + if range_key in args and args[range_key]: + axis["range"] = [math.log(r, 10) for r in args[range_key]] + elif range_key in args and args[range_key]: + axis["range"] = args[range_key] + + if args[letter] in orders: + axis["categoryorder"] = "array" + axis["categoryarray"] = ( + orders[args[letter]] + if isinstance(axis, go.layout.XAxis) + else list(reversed(orders[args[letter]])) # top down for Y axis + ) + + +def configure_cartesian_marginal_axes(args, fig, orders): + nrows = len(fig._grid_ref) + ncols = len(fig._grid_ref[0]) + + # Set y-axis titles and axis options in the left-most column + for yaxis in fig.select_yaxes(col=1): + set_cartesian_axis_opts(args, yaxis, "y", orders) + + # Set x-axis titles and axis options in the bottom-most row + for xaxis in fig.select_xaxes(row=1): + set_cartesian_axis_opts(args, xaxis, "x", orders) + + # Configure axis ticks on marginal subplots + if args["marginal_x"]: + fig.update_yaxes( + showticklabels=False, showline=False, ticks="", range=None, row=nrows + ) + if args["template"].layout.yaxis.showgrid is None: + fig.update_yaxes(showgrid=args["marginal_x"] == "histogram", row=nrows) + if args["template"].layout.xaxis.showgrid is None: + fig.update_xaxes(showgrid=True, row=nrows) + + if args["marginal_y"]: + fig.update_xaxes( + showticklabels=False, showline=False, ticks="", range=None, col=ncols + ) + if args["template"].layout.xaxis.showgrid is None: + fig.update_xaxes(showgrid=args["marginal_y"] == "histogram", col=ncols) + if args["template"].layout.yaxis.showgrid is None: + fig.update_yaxes(showgrid=True, col=ncols) + + # Add axis titles to non-marginal subplots + y_title = get_decorated_label(args, args["y"], "y") + if args["marginal_x"]: + fig.update_yaxes(title_text=y_title, row=1, col=1) + else: + for row in range(1, nrows + 1): + fig.update_yaxes(title_text=y_title, row=row, col=1) + + x_title = get_decorated_label(args, args["x"], "x") + if args["marginal_y"]: + fig.update_xaxes(title_text=x_title, row=1, col=1) + else: + for col in range(1, ncols + 1): + fig.update_xaxes(title_text=x_title, row=1, col=col) + + # Configure axis type across all x-axes + if "log_x" in args and args["log_x"]: + fig.update_xaxes(type="log") + + # Configure axis type across all y-axes + if "log_y" in args and args["log_y"]: + fig.update_yaxes(type="log") + + # Configure matching and axis type for marginal y-axes + matches_y = "y" + str(ncols + 1) + if args["marginal_x"]: + for row in range(2, nrows + 1, 2): + fig.update_yaxes(matches=matches_y, type=None, row=row) + + if args["marginal_y"]: + for col in range(2, ncols + 1, 2): + fig.update_xaxes(matches="x2", type=None, col=col) + + +def configure_cartesian_axes(args, fig, orders): + if ("marginal_x" in args and args["marginal_x"]) or ( + "marginal_y" in args and args["marginal_y"] + ): + configure_cartesian_marginal_axes(args, fig, orders) + return + + # Set y-axis titles and axis options in the left-most column + y_title = get_decorated_label(args, args["y"], "y") + for yaxis in fig.select_yaxes(col=1): + yaxis.update(title_text=y_title) + set_cartesian_axis_opts(args, yaxis, "y", orders) + + # Set x-axis titles and axis options in the bottom-most row + x_title = get_decorated_label(args, args["x"], "x") + for xaxis in fig.select_xaxes(row=1): + if "is_timeline" not in args: + xaxis.update(title_text=x_title) + set_cartesian_axis_opts(args, xaxis, "x", orders) + + # Configure axis type across all x-axes + if "log_x" in args and args["log_x"]: + fig.update_xaxes(type="log") + + # Configure axis type across all y-axes + if "log_y" in args and args["log_y"]: + fig.update_yaxes(type="log") + + if "is_timeline" in args: + fig.update_xaxes(type="date") + + if "ecdfmode" in args: + if args["orientation"] == "v": + fig.update_yaxes(rangemode="tozero") + else: + fig.update_xaxes(rangemode="tozero") + + +def configure_ternary_axes(args, fig, orders): + fig.update_ternaries( + aaxis=dict(title_text=get_label(args, args["a"])), + baxis=dict(title_text=get_label(args, args["b"])), + caxis=dict(title_text=get_label(args, args["c"])), + ) + + +def configure_polar_axes(args, fig, orders): + patch = dict( + angularaxis=dict(direction=args["direction"], rotation=args["start_angle"]), + radialaxis=dict(), + ) + + for var, axis in [("r", "radialaxis"), ("theta", "angularaxis")]: + if args[var] in orders: + patch[axis]["categoryorder"] = "array" + patch[axis]["categoryarray"] = orders[args[var]] + + radialaxis = patch["radialaxis"] + if args["log_r"]: + radialaxis["type"] = "log" + if args["range_r"]: + radialaxis["range"] = [math.log(x, 10) for x in args["range_r"]] + else: + if args["range_r"]: + radialaxis["range"] = args["range_r"] + + if args["range_theta"]: + patch["sector"] = args["range_theta"] + fig.update_polars(patch) + + +def configure_3d_axes(args, fig, orders): + patch = dict( + xaxis=dict(title_text=get_label(args, args["x"])), + yaxis=dict(title_text=get_label(args, args["y"])), + zaxis=dict(title_text=get_label(args, args["z"])), + ) + + for letter in ["x", "y", "z"]: + axis = patch[letter + "axis"] + if args["log_" + letter]: + axis["type"] = "log" + if args["range_" + letter]: + axis["range"] = [math.log(x, 10) for x in args["range_" + letter]] + else: + if args["range_" + letter]: + axis["range"] = args["range_" + letter] + if args[letter] in orders: + axis["categoryorder"] = "array" + axis["categoryarray"] = orders[args[letter]] + fig.update_scenes(patch) + + +def configure_mapbox(args, fig, orders): + center = args["center"] + if not center and "lat" in args and "lon" in args: + center = dict( + lat=args["data_frame"][args["lat"]].mean(), + lon=args["data_frame"][args["lon"]].mean(), + ) + fig.update_mapboxes( + accesstoken=MAPBOX_TOKEN, + center=center, + zoom=args["zoom"], + style=args["mapbox_style"], + ) + + +def configure_map(args, fig, orders): + center = args["center"] + if not center and "lat" in args and "lon" in args: + center = dict( + lat=args["data_frame"][args["lat"]].mean(), + lon=args["data_frame"][args["lon"]].mean(), + ) + fig.update_maps( + center=center, + zoom=args["zoom"], + style=args["map_style"], + ) + + +def configure_geo(args, fig, orders): + fig.update_geos( + center=args["center"], + scope=args["scope"], + fitbounds=args["fitbounds"], + visible=args["basemap_visible"], + projection=dict(type=args["projection"]), + ) + + +def configure_animation_controls(args, constructor, fig): + def frame_args(duration): + return { + "frame": {"duration": duration, "redraw": constructor != go.Scatter}, + "mode": "immediate", + "fromcurrent": True, + "transition": {"duration": duration, "easing": "linear"}, + } + + if "animation_frame" in args and args["animation_frame"] and len(fig.frames) > 1: + fig.layout.updatemenus = [ + { + "buttons": [ + { + "args": [None, frame_args(500)], + "label": "▶", + "method": "animate", + }, + { + "args": [[None], frame_args(0)], + "label": "◼", + "method": "animate", + }, + ], + "direction": "left", + "pad": {"r": 10, "t": 70}, + "showactive": False, + "type": "buttons", + "x": 0.1, + "xanchor": "right", + "y": 0, + "yanchor": "top", + } + ] + fig.layout.sliders = [ + { + "active": 0, + "yanchor": "top", + "xanchor": "left", + "currentvalue": { + "prefix": get_label(args, args["animation_frame"]) + "=" + }, + "pad": {"b": 10, "t": 60}, + "len": 0.9, + "x": 0.1, + "y": 0, + "steps": [ + { + "args": [[f.name], frame_args(0)], + "label": f.name, + "method": "animate", + } + for f in fig.frames + ], + } + ] + + +def make_trace_spec(args, constructor, attrs, trace_patch): + if constructor in [go.Scatter, go.Scatterpolar]: + if "render_mode" in args and ( + args["render_mode"] == "webgl" + or ( + args["render_mode"] == "auto" + and len(args["data_frame"]) > 1000 + and args.get("line_shape") != "spline" + and args["animation_frame"] is None + ) + ): + if constructor == go.Scatter: + constructor = go.Scattergl + if "orientation" in trace_patch: + del trace_patch["orientation"] + else: + constructor = go.Scatterpolargl + # Create base trace specification + result = [TraceSpec(constructor, attrs, trace_patch, None)] + + # Add marginal trace specifications + for letter in ["x", "y"]: + if "marginal_" + letter in args and args["marginal_" + letter]: + trace_spec = None + axis_map = dict( + xaxis="x1" if letter == "x" else "x2", + yaxis="y1" if letter == "y" else "y2", + ) + if args["marginal_" + letter] == "histogram": + trace_spec = TraceSpec( + constructor=go.Histogram, + attrs=[letter, "marginal_" + letter], + trace_patch=dict(opacity=0.5, bingroup=letter, **axis_map), + marginal=letter, + ) + elif args["marginal_" + letter] == "violin": + trace_spec = TraceSpec( + constructor=go.Violin, + attrs=[letter, "hover_name", "hover_data"], + trace_patch=dict(scalegroup=letter), + marginal=letter, + ) + elif args["marginal_" + letter] == "box": + trace_spec = TraceSpec( + constructor=go.Box, + attrs=[letter, "hover_name", "hover_data"], + trace_patch=dict(notched=True), + marginal=letter, + ) + elif args["marginal_" + letter] == "rug": + symbols = {"x": "line-ns-open", "y": "line-ew-open"} + trace_spec = TraceSpec( + constructor=go.Box, + attrs=[letter, "hover_name", "hover_data"], + trace_patch=dict( + fillcolor="rgba(255,255,255,0)", + line={"color": "rgba(255,255,255,0)"}, + boxpoints="all", + jitter=0, + hoveron="points", + marker={"symbol": symbols[letter]}, + ), + marginal=letter, + ) + if "color" in attrs or "color" not in args: + if "marker" not in trace_spec.trace_patch: + trace_spec.trace_patch["marker"] = dict() + first_default_color = args["color_continuous_scale"][0] + trace_spec.trace_patch["marker"]["color"] = first_default_color + result.append(trace_spec) + + # Add trendline trace specifications + if args.get("trendline") and args.get("trendline_scope", "trace") == "trace": + result.append(make_trendline_spec(args, constructor)) + return result + + +def make_trendline_spec(args, constructor): + trace_spec = TraceSpec( + constructor=( + go.Scattergl + if constructor == go.Scattergl # could be contour + else go.Scatter + ), + attrs=["trendline"], + trace_patch=dict(mode="lines"), + marginal=None, + ) + if args["trendline_color_override"]: + trace_spec.trace_patch["line"] = dict(color=args["trendline_color_override"]) + return trace_spec + + +def one_group(x): + return "" + + +def apply_default_cascade(args): + # first we apply px.defaults to unspecified args + + for param in defaults.__slots__: + if param in args and args[param] is None: + args[param] = getattr(defaults, param) + + # load the default template if set, otherwise "plotly" + if args["template"] is None: + if pio.templates.default is not None: + args["template"] = pio.templates.default + else: + args["template"] = "plotly" + + try: + # retrieve the actual template if we were given a name + args["template"] = pio.templates[args["template"]] + except Exception: + # otherwise try to build a real template + args["template"] = go.layout.Template(args["template"]) + + # if colors not set explicitly or in px.defaults, defer to a template + # if the template doesn't have one, we set some final fallback defaults + if "color_continuous_scale" in args: + if ( + args["color_continuous_scale"] is None + and args["template"].layout.colorscale.sequential + ): + args["color_continuous_scale"] = [ + x[1] for x in args["template"].layout.colorscale.sequential + ] + if args["color_continuous_scale"] is None: + args["color_continuous_scale"] = sequential.Viridis + + if "color_discrete_sequence" in args: + if args["color_discrete_sequence"] is None and args["template"].layout.colorway: + args["color_discrete_sequence"] = args["template"].layout.colorway + if args["color_discrete_sequence"] is None: + args["color_discrete_sequence"] = qualitative.D3 + + # if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults, + # see if we can defer to template. If not, set reasonable defaults + if "symbol_sequence" in args: + if args["symbol_sequence"] is None and args["template"].data.scatter: + args["symbol_sequence"] = [ + scatter.marker.symbol for scatter in args["template"].data.scatter + ] + if not args["symbol_sequence"] or not any(args["symbol_sequence"]): + args["symbol_sequence"] = ["circle", "diamond", "square", "x", "cross"] + + if "line_dash_sequence" in args: + if args["line_dash_sequence"] is None and args["template"].data.scatter: + args["line_dash_sequence"] = [ + scatter.line.dash for scatter in args["template"].data.scatter + ] + if not args["line_dash_sequence"] or not any(args["line_dash_sequence"]): + args["line_dash_sequence"] = [ + "solid", + "dot", + "dash", + "longdash", + "dashdot", + "longdashdot", + ] + + if "pattern_shape_sequence" in args: + if args["pattern_shape_sequence"] is None and args["template"].data.bar: + args["pattern_shape_sequence"] = [ + bar.marker.pattern.shape for bar in args["template"].data.bar + ] + if not args["pattern_shape_sequence"] or not any( + args["pattern_shape_sequence"] + ): + args["pattern_shape_sequence"] = ["", "/", "\\", "x", "+", "."] + + +def _check_name_not_reserved(field_name, reserved_names): + if field_name not in reserved_names: + return field_name + else: + raise NameError( + "A name conflict was encountered for argument '%s'. " + "A column or index with name '%s' is ambiguous." % (field_name, field_name) + ) + + +def _get_reserved_col_names(args): + """ + This function builds a list of columns of the data_frame argument used + as arguments, either as str/int arguments or given as columns + (pandas series type). + """ + df: nw.DataFrame = args["data_frame"] + reserved_names = set() + for field in args: + if field not in all_attrables: + continue + names = args[field] if field in array_attrables else [args[field]] + if names is None: + continue + for arg in names: + if arg is None: + continue + elif isinstance(arg, str): # no need to add ints since kw arg are not ints + reserved_names.add(arg) + elif nw.dependencies.is_into_series(arg): + arg_series = nw.from_native(arg, series_only=True) + arg_name = arg_series.name + if arg_name and arg_name in df.columns: + in_df = (arg_series == df.get_column(arg_name)).all() + if in_df: + reserved_names.add(arg_name) + elif arg is nw.maybe_get_index(df) and arg.name is not None: + reserved_names.add(arg.name) + + return reserved_names + + +def _is_col_list(columns, arg, is_pd_like, native_namespace): + """Returns True if arg looks like it's a list of columns or references to columns + in df_input, and False otherwise (in which case it's assumed to be a single column + or reference to a column). + """ + if arg is None or isinstance(arg, str) or isinstance(arg, int): + return False + if is_pd_like and isinstance(arg, native_namespace.MultiIndex): + return False # just to keep existing behaviour for now + try: + iter(arg) + except TypeError: + return False # not iterable + for c in arg: + if isinstance(c, str) or isinstance(c, int): + if columns is None or c not in columns: + return False + else: + try: + iter(c) + except TypeError: + return False # not iterable + return True + + +def _isinstance_listlike(x): + """Returns True if x is an iterable which can be transformed into a pandas Series, + False for the other types of possible values of a `hover_data` dict. + A tuple of length 2 is a special case corresponding to a (format, data) tuple. + """ + if ( + isinstance(x, str) + or (isinstance(x, tuple) and len(x) == 2) + or isinstance(x, bool) + or x is None + ): + return False + else: + return True + + +def _escape_col_name(columns, col_name, extra): + if columns is None: + return col_name + while col_name in columns or col_name in extra: + col_name = "_" + col_name + return col_name + + +def to_named_series(x, name=None, native_namespace=None): + """Assuming x is list-like or even an existing Series, returns a new Series named `name`.""" + # With `pass_through=True`, the original object will be returned if unable to convert + # to a Narwhals Series. + x = nw.from_native(x, series_only=True, pass_through=True) + if isinstance(x, nw.Series): + return x.rename(name) + elif native_namespace is not None: + return nw.new_series(name=name, values=x, native_namespace=native_namespace) + else: + try: + import pandas as pd + + return nw.new_series(name=name, values=x, native_namespace=pd) + except ImportError: + msg = "Pandas installation is required if no dataframe is provided." + raise NotImplementedError(msg) + + +def process_args_into_dataframe( + args, wide_mode, var_name, value_name, is_pd_like, native_namespace +): + """ + After this function runs, the `all_attrables` keys of `args` all contain only + references to columns of `df_output`. This function handles the extraction of data + from `args["attrable"]` and column-name-generation as appropriate, and adds the + data to `df_output` and then replaces `args["attrable"]` with the appropriate + reference. + """ + + df_input: nw.DataFrame | None = args["data_frame"] + df_provided = df_input is not None + + # we use a dict instead of a dataframe directly so that it doesn't cause + # PerformanceWarning by pandas by repeatedly setting the columns. + # a dict is used instead of a list as the columns needs to be overwritten. + df_output = {} + constants = {} + ranges = [] + wide_id_vars = set() + reserved_names = _get_reserved_col_names(args) if df_provided else set() + + # Case of functions with a "dimensions" kw: scatter_matrix, parcats, parcoords + if "dimensions" in args and args["dimensions"] is None: + if not df_provided: + raise ValueError( + "No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument." + ) + else: + df_output = {col: df_input.get_column(col) for col in df_input.columns} + + # hover_data is a dict + hover_data_is_dict = ( + "hover_data" in args + and args["hover_data"] + and isinstance(args["hover_data"], dict) + ) + # If dict, convert all values of hover_data to tuples to simplify processing + if hover_data_is_dict: + for k in args["hover_data"]: + if _isinstance_listlike(args["hover_data"][k]): + args["hover_data"][k] = (True, args["hover_data"][k]) + if not isinstance(args["hover_data"][k], tuple): + args["hover_data"][k] = (args["hover_data"][k], None) + if df_provided and args["hover_data"][k][1] is not None and k in df_input: + raise ValueError( + "Ambiguous input: values for '%s' appear both in hover_data and data_frame" + % k + ) + # Loop over possible arguments + for field_name in all_attrables: + # Massaging variables + argument_list = ( + [args.get(field_name)] + if field_name not in array_attrables + else args.get(field_name) + ) + + # argument not specified, continue + # The original also tested `or argument_list is [None]` but + # that clause is always False, so it has been removed. The + # alternative fix would have been to test that `argument_list` + # is of length 1 and its sole element is `None`, but that + # feels pedantic. All tests pass with the change below; let's + # see if the world decides we were wrong. + if argument_list is None: + continue + + # Argument name: field_name if the argument is not a list + # Else we give names like ["hover_data_0, hover_data_1"] etc. + field_list = ( + [field_name] + if field_name not in array_attrables + else [field_name + "_" + str(i) for i in range(len(argument_list))] + ) + # argument_list and field_list ready, iterate over them + # Core of the loop starts here + for i, (argument, field) in enumerate(zip(argument_list, field_list)): + length = len(df_output[next(iter(df_output))]) if len(df_output) else 0 + if argument is None: + continue + col_name = None + # Case of multiindex + if is_pd_like and isinstance(argument, native_namespace.MultiIndex): + raise TypeError( + f"Argument '{field}' is a {native_namespace.__name__} MultiIndex. " + f"{native_namespace.__name__} MultiIndex is not supported by plotly " + "express at the moment." + ) + # ----------------- argument is a special value ---------------------- + if isinstance(argument, (Constant, Range)): + col_name = _check_name_not_reserved( + str(argument.label) if argument.label is not None else field, + reserved_names, + ) + if isinstance(argument, Constant): + constants[col_name] = argument.value + else: + ranges.append(col_name) + # ----------------- argument is likely a col name ---------------------- + elif isinstance(argument, str) or not hasattr(argument, "__len__"): + if ( + field_name == "hover_data" + and hover_data_is_dict + and args["hover_data"][str(argument)][1] is not None + ): + # hover_data has onboard data + # previously-checked to have no name-conflict with data_frame + col_name = str(argument) + real_argument = args["hover_data"][col_name][1] + + if length and (real_length := len(real_argument)) != length: + raise ValueError( + "All arguments should have the same length. " + "The length of hover_data key `%s` is %d, whereas the " + "length of previously-processed arguments %s is %d" + % ( + argument, + real_length, + str(list(df_output.keys())), + length, + ) + ) + df_output[col_name] = to_named_series( + real_argument, col_name, native_namespace + ) + elif not df_provided: + raise ValueError( + "String or int arguments are only possible when a " + "DataFrame or an array is provided in the `data_frame` " + "argument. No DataFrame was provided, but argument " + "'%s' is of type str or int." % field + ) + # Check validity of column name + elif argument not in df_input.columns: + if wide_mode and argument in (value_name, var_name): + continue + else: + err_msg = ( + "Value of '%s' is not the name of a column in 'data_frame'. " + "Expected one of %s but received: %s" + % (field, str(list(df_input.columns)), argument) + ) + if argument == "index": + err_msg += "\n To use the index, pass it in directly as `df.index`." + raise ValueError(err_msg) + elif length and (actual_len := len(df_input)) != length: + raise ValueError( + "All arguments should have the same length. " + "The length of column argument `df[%s]` is %d, whereas the " + "length of previously-processed arguments %s is %d" + % ( + field, + actual_len, + str(list(df_output.keys())), + length, + ) + ) + else: + col_name = str(argument) + df_output[col_name] = to_named_series( + df_input.get_column(argument), col_name + ) + # ----------------- argument is likely a column / array / list.... ------- + else: + if df_provided and hasattr(argument, "name"): + if is_pd_like and argument is nw.maybe_get_index(df_input): + if argument.name is None or argument.name in df_input.columns: + col_name = "index" + else: + col_name = argument.name + col_name = _escape_col_name( + df_input.columns, col_name, [var_name, value_name] + ) + else: + if ( + argument.name is not None + and argument.name in df_input.columns + and ( + to_named_series( + argument, argument.name, native_namespace + ) + == df_input.get_column(argument.name) + ).all() + ): + col_name = argument.name + if col_name is None: # numpy array, list... + col_name = _check_name_not_reserved(field, reserved_names) + + if length and (len_arg := len(argument)) != length: + raise ValueError( + "All arguments should have the same length. " + "The length of argument `%s` is %d, whereas the " + "length of previously-processed arguments %s is %d" + % (field, len_arg, str(list(df_output.keys())), length) + ) + + df_output[str(col_name)] = to_named_series( + x=argument, + name=str(col_name), + native_namespace=native_namespace, + ) + + # Finally, update argument with column name now that column exists + assert col_name is not None, ( + "Data-frame processing failure, likely due to a internal bug. " + "Please report this to " + "https://github.com/plotly/plotly.py/issues/new and we will try to " + "replicate and fix it." + ) + if field_name not in array_attrables: + args[field_name] = str(col_name) + elif isinstance(args[field_name], dict): + pass + else: + args[field_name][i] = str(col_name) + if field_name != "wide_variable": + wide_id_vars.add(str(col_name)) + + length = len(df_output[next(iter(df_output))]) if len(df_output) else 0 + + if native_namespace is None: + try: + import pandas as pd + + native_namespace = pd + except ImportError: + msg = "Pandas installation is required if no dataframe is provided." + raise NotImplementedError(msg) + + if ranges: + import numpy as np + + range_series = nw.new_series( + name="__placeholder__", + values=np.arange(length), + native_namespace=native_namespace, + ) + df_output.update( + {col_name: range_series.alias(col_name) for col_name in ranges} + ) + + df_output.update( + { + # constant is single value. repeat by len to avoid creating NaN on concatenating + col_name: nw.new_series( + name=col_name, + values=[constants[col_name]] * length, + native_namespace=native_namespace, + ) + for col_name in constants + } + ) + + if df_output: + df_output = nw.from_dict(df_output) + else: + try: + import pandas as pd + except ImportError: + msg = "Pandas installation is required." + raise NotImplementedError(msg) + df_output = nw.from_native(pd.DataFrame({}), eager_only=True) + return df_output, wide_id_vars + + +def build_dataframe(args, constructor): + """ + Constructs a dataframe and modifies `args` in-place. + + The argument values in `args` can be either strings corresponding to + existing columns of a dataframe, or data arrays (lists, numpy arrays, + pandas columns, series). + + Parameters + ---------- + args : OrderedDict + arguments passed to the px function and subsequently modified + constructor : graph_object trace class + the trace type selected for this figure + """ + + # make copies of all the fields via dict() and list() + for field in args: + if field in array_attrables and args[field] is not None: + if isinstance(args[field], dict): + args[field] = dict(args[field]) + elif field in ["custom_data", "hover_data"] and isinstance( + args[field], str + ): + args[field] = [args[field]] + else: + args[field] = list(args[field]) + + # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.) + df_provided = args["data_frame"] is not None + + # Flag that indicates if the resulting data_frame after parsing is pandas-like + # (in terms of resulting Narwhals DataFrame). + # True if pandas, modin.pandas or cudf DataFrame/Series instance, or converted from + # PySpark to pandas. + is_pd_like = False + + # Flag that indicates if data_frame needs to be converted to PyArrow. + # True if Ibis, DuckDB, Vaex, or implements __dataframe__ + needs_interchanging = False + + # If data_frame is provided, we parse it into a narwhals DataFrame, while accounting + # for compatibility with pandas specific paths (e.g. Index/MultiIndex case). + if df_provided: + # data_frame is pandas-like DataFrame (pandas, modin.pandas, cudf) + if nw.dependencies.is_pandas_like_dataframe(args["data_frame"]): + columns = args["data_frame"].columns # This can be multi index + args["data_frame"] = nw.from_native(args["data_frame"], eager_only=True) + is_pd_like = True + + # data_frame is pandas-like Series (pandas, modin.pandas, cudf) + elif nw.dependencies.is_pandas_like_series(args["data_frame"]): + args["data_frame"] = nw.from_native( + args["data_frame"], series_only=True + ).to_frame() + columns = args["data_frame"].columns + is_pd_like = True + + # data_frame is any other DataFrame object natively supported via Narwhals. + # With `pass_through=True`, the original object will be returned if unable to convert + # to a Narwhals DataFrame, making this condition False. + elif isinstance( + data_frame := nw.from_native( + args["data_frame"], eager_or_interchange_only=True, pass_through=True + ), + nw.DataFrame, + ): + args["data_frame"] = data_frame + needs_interchanging = nw.get_level(data_frame) == "interchange" + columns = args["data_frame"].columns + + # data_frame is any other Series object natively supported via Narwhals. + # With `pass_through=True`, the original object will be returned if unable to convert + # to a Narwhals Series, making this condition False. + elif isinstance( + series := nw.from_native( + args["data_frame"], series_only=True, pass_through=True + ), + nw.Series, + ): + args["data_frame"] = series.to_frame() + columns = args["data_frame"].columns + + # data_frame is PySpark: it does not support interchange protocol and it is not + # integrated in Narwhals. We use its native method to convert it to pandas. + elif hasattr(args["data_frame"], "toPandas"): + args["data_frame"] = nw.from_native( + args["data_frame"].toPandas(), eager_only=True + ) + columns = args["data_frame"].columns + is_pd_like = True + + # data_frame is some other object type (e.g. dict, list, ...) + # We try to import pandas, and then try to instantiate a pandas dataframe from + # this such object + else: + try: + import pandas as pd + + try: + args["data_frame"] = nw.from_native( + pd.DataFrame(args["data_frame"]) + ) + columns = args["data_frame"].columns + is_pd_like = True + except Exception: + msg = ( + f"Unable to convert data_frame of type {type(args['data_frame'])} " + "to pandas DataFrame. Please provide a supported dataframe type " + "or a type that can be passed to pd.DataFrame." + ) + + raise NotImplementedError(msg) + except ImportError: + msg = ( + f"Attempting to convert data_frame of type {type(args['data_frame'])} " + "to pandas DataFrame, but Pandas is not installed. " + "Convert it to supported dataframe type or install pandas." + ) + raise NotImplementedError(msg) + + # data_frame is not provided + else: + columns = None + + df_input: nw.DataFrame | None = args["data_frame"] + index = ( + nw.maybe_get_index(df_input) + if df_provided and not needs_interchanging + else None + ) + native_namespace = ( + nw.get_native_namespace(df_input) + if df_provided and not needs_interchanging + else None + ) + + # now we handle special cases like wide-mode or x-xor-y specification + # by rearranging args to tee things up for process_args_into_dataframe to work + no_x = args.get("x") is None + no_y = args.get("y") is None + wide_x = ( + False + if no_x + else _is_col_list(columns, args["x"], is_pd_like, native_namespace) + ) + wide_y = ( + False + if no_y + else _is_col_list(columns, args["y"], is_pd_like, native_namespace) + ) + + wide_mode = False + var_name = None # will likely be "variable" in wide_mode + wide_cross_name = None # will likely be "index" in wide_mode + value_name = None # will likely be "value" in wide_mode + hist2d_types = [go.Histogram2d, go.Histogram2dContour] + hist1d_orientation = constructor == go.Histogram or "ecdfmode" in args + if constructor in cartesians: + if wide_x and wide_y: + raise ValueError( + "Cannot accept list of column references or list of columns for both `x` and `y`." + ) + if df_provided and no_x and no_y: + wide_mode = True + if is_pd_like and isinstance(columns, native_namespace.MultiIndex): + raise TypeError( + f"Data frame columns is a {native_namespace.__name__} MultiIndex. " + f"{native_namespace.__name__} MultiIndex is not supported by plotly " + "express at the moment." + ) + args["wide_variable"] = list(columns) + if is_pd_like and isinstance(columns, native_namespace.Index): + var_name = columns.name + else: + var_name = None + if var_name in [None, "value", "index"] or var_name in columns: + var_name = "variable" + if constructor == go.Funnel: + wide_orientation = args.get("orientation") or "h" + else: + wide_orientation = args.get("orientation") or "v" + args["orientation"] = wide_orientation + args["wide_cross"] = None + elif wide_x != wide_y: + wide_mode = True + args["wide_variable"] = args["y"] if wide_y else args["x"] + if df_provided and is_pd_like and args["wide_variable"] is columns: + var_name = columns.name + if is_pd_like and isinstance(args["wide_variable"], native_namespace.Index): + args["wide_variable"] = list(args["wide_variable"]) + if var_name in [None, "value", "index"] or ( + df_provided and var_name in columns + ): + var_name = "variable" + if hist1d_orientation: + wide_orientation = "v" if wide_x else "h" + else: + wide_orientation = "v" if wide_y else "h" + args["y" if wide_y else "x"] = None + args["wide_cross"] = None + if not no_x and not no_y: + wide_cross_name = "__x__" if wide_y else "__y__" + + if wide_mode: + value_name = _escape_col_name(columns, "value", []) + var_name = _escape_col_name(columns, var_name, []) + + # If the data_frame has interchange-only support levelin Narwhals, then we need to + # convert it to a full support level backend. + # Hence we convert requires Interchange to PyArrow. + if needs_interchanging: + if wide_mode: + args["data_frame"] = nw.from_native( + args["data_frame"].to_arrow(), eager_only=True + ) + else: + # Save precious resources by only interchanging columns that are + # actually going to be plotted. This is tricky to do in the general case, + # because Plotly allows calls like `px.line(df, x='x', y=['y1', df['y1']])`, + # but interchange-only objects (e.g. DuckDB) don't typically have a concept + # of self-standing Series. It's more important to perform project pushdown + # here seeing as we're materialising to an (eager) PyArrow table. + necessary_columns = { + i for i in args.values() if isinstance(i, str) and i in columns + } + for field in args: + if args[field] is not None and field in array_attrables: + necessary_columns.update(i for i in args[field] if i in columns) + columns = list(necessary_columns) + args["data_frame"] = nw.from_native( + args["data_frame"].select(columns).to_arrow(), eager_only=True + ) + import pyarrow as pa + + native_namespace = pa + missing_bar_dim = None + if ( + constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types + and not hist1d_orientation + ): + if not wide_mode and (no_x != no_y): + for ax in ["x", "y"]: + if args.get(ax) is None: + args[ax] = ( + index + if index is not None + else Range( + label=_escape_col_name(columns, ax, [var_name, value_name]) + ) + ) + if constructor == go.Bar: + missing_bar_dim = ax + else: + if args["orientation"] is None: + args["orientation"] = "v" if ax == "x" else "h" + if wide_mode and wide_cross_name is None: + if no_x != no_y and args["orientation"] is None: + args["orientation"] = "v" if no_x else "h" + if df_provided and is_pd_like and index is not None: + if isinstance(index, native_namespace.MultiIndex): + raise TypeError( + f"Data frame index is a {native_namespace.__name__} MultiIndex. " + f"{native_namespace.__name__} MultiIndex is not supported by " + "plotly express at the moment." + ) + args["wide_cross"] = index + else: + args["wide_cross"] = Range( + label=_escape_col_name(columns, "index", [var_name, value_name]) + ) + + no_color = False + if isinstance(args.get("color"), str) and args["color"] == NO_COLOR: + no_color = True + args["color"] = None + # now that things have been prepped, we do the systematic rewriting of `args` + + df_output, wide_id_vars = process_args_into_dataframe( + args, + wide_mode, + var_name, + value_name, + is_pd_like, + native_namespace, + ) + df_output: nw.DataFrame + # now that `df_output` exists and `args` contains only references, we complete + # the special-case and wide-mode handling by further rewriting args and/or mutating + # df_output + + count_name = _escape_col_name(df_output.columns, "count", [var_name, value_name]) + if not wide_mode and missing_bar_dim and constructor == go.Bar: + # now that we've populated df_output, we check to see if the non-missing + # dimension is categorical: if so, then setting the missing dimension to a + # constant 1 is a less-insane thing to do than setting it to the index by + # default and we let the normal auto-orientation-code do its thing later + other_dim = "x" if missing_bar_dim == "y" else "y" + if not _is_continuous(df_output, args[other_dim]): + args[missing_bar_dim] = count_name + df_output = df_output.with_columns(nw.lit(1).alias(count_name)) + else: + # on the other hand, if the non-missing dimension is continuous, then we + # can use this information to override the normal auto-orientation code + if args["orientation"] is None: + args["orientation"] = "v" if missing_bar_dim == "x" else "h" + + if constructor in hist2d_types: + del args["orientation"] + + if wide_mode: + # at this point, `df_output` is semi-long/semi-wide, but we know which columns + # are which, so we melt it and reassign `args` to refer to the newly-tidy + # columns, keeping track of various names and manglings set up above + wide_value_vars = [c for c in args["wide_variable"] if c not in wide_id_vars] + del args["wide_variable"] + if wide_cross_name == "__x__": + wide_cross_name = args["x"] + elif wide_cross_name == "__y__": + wide_cross_name = args["y"] + else: + wide_cross_name = args["wide_cross"] + del args["wide_cross"] + dtype = None + for v in wide_value_vars: + v_dtype = df_output.get_column(v).dtype + v_dtype = "number" if v_dtype.is_numeric() else str(v_dtype) + if dtype is None: + dtype = v_dtype + elif dtype != v_dtype: + raise ValueError( + "Plotly Express cannot process wide-form data with columns of different type." + ) + df_output = df_output.unpivot( + index=wide_id_vars, + on=wide_value_vars, + variable_name=var_name, + value_name=value_name, + ) + assert len(df_output.columns) == len(set(df_output.columns)), ( + "Wide-mode name-inference failure, likely due to a internal bug. " + "Please report this to " + "https://github.com/plotly/plotly.py/issues/new and we will try to " + "replicate and fix it." + ) + df_output = df_output.with_columns(nw.col(var_name).cast(nw.String)) + orient_v = wide_orientation == "v" + + if hist1d_orientation: + args["x" if orient_v else "y"] = value_name + args["y" if orient_v else "x"] = wide_cross_name + args["color"] = args["color"] or var_name + elif constructor in [go.Scatter, go.Funnel] + hist2d_types: + args["x" if orient_v else "y"] = wide_cross_name + args["y" if orient_v else "x"] = value_name + if constructor != go.Histogram2d: + args["color"] = args["color"] or var_name + if "line_group" in args: + args["line_group"] = args["line_group"] or var_name + elif constructor == go.Bar: + if _is_continuous(df_output, value_name): + args["x" if orient_v else "y"] = wide_cross_name + args["y" if orient_v else "x"] = value_name + args["color"] = args["color"] or var_name + else: + args["x" if orient_v else "y"] = value_name + args["y" if orient_v else "x"] = count_name + df_output = df_output.with_columns(nw.lit(1).alias(count_name)) + args["color"] = args["color"] or var_name + elif constructor in [go.Violin, go.Box]: + args["x" if orient_v else "y"] = wide_cross_name or var_name + args["y" if orient_v else "x"] = value_name + + if hist1d_orientation and constructor == go.Scatter: + if args["x"] is not None and args["y"] is not None: + args["histfunc"] = "sum" + elif args["x"] is None: + args["histfunc"] = None + args["orientation"] = "h" + args["x"] = count_name + df_output = df_output.with_columns(nw.lit(1).alias(count_name)) + else: + args["histfunc"] = None + args["orientation"] = "v" + args["y"] = count_name + df_output = df_output.with_columns(nw.lit(1).alias(count_name)) + + if no_color: + args["color"] = None + args["data_frame"] = df_output + return args + + +def _check_dataframe_all_leaves(df: nw.DataFrame) -> None: + cols = df.columns + df_sorted = df.sort(by=cols, descending=False, nulls_last=True) + null_mask = df_sorted.select(nw.all().is_null()) + df_sorted = df_sorted.select(nw.all().cast(nw.String())) + null_indices_mask = null_mask.select( + null_mask=nw.any_horizontal(nw.all()) + ).get_column("null_mask") + + null_mask_filtered = null_mask.filter(null_indices_mask) + if not null_mask_filtered.is_empty(): + for col_idx in range(1, null_mask_filtered.shape[1]): + # For each row, if a True value is encountered, then check that + # all values in subsequent columns are also True + null_entries_with_non_null_children = ( + ~null_mask_filtered[:, col_idx] & null_mask_filtered[:, col_idx - 1] + ) + if nw.to_py_scalar(null_entries_with_non_null_children.any()): + row_idx = null_entries_with_non_null_children.to_list().index(True) + raise ValueError( + "None entries cannot have not-None children", + df_sorted.row(row_idx), + ) + + fill_series = nw.new_series( + name="fill_value", + values=[""] * len(df_sorted), + dtype=nw.String(), + native_namespace=nw.get_native_namespace(df_sorted), + ) + df_sorted = df_sorted.with_columns( + **{ + c: df_sorted.get_column(c).zip_with(~null_mask.get_column(c), fill_series) + for c in cols + } + ) + + # Conversion to list is due to python native vs pyarrow scalars + row_strings = ( + df_sorted.select( + row_strings=nw.concat_str(cols, separator="", ignore_nulls=False) + ) + .get_column("row_strings") + .to_list() + ) + + null_indices = set(null_indices_mask.arg_true().to_list()) + for i, (current_row, next_row) in enumerate( + zip(row_strings[:-1], row_strings[1:]), start=1 + ): + if (next_row in current_row) and (i in null_indices): + raise ValueError( + "Non-leaves rows are not permitted in the dataframe \n", + df_sorted.row(i), + "is not a leaf.", + ) + + +def process_dataframe_hierarchy(args): + """ + Build dataframe for sunburst, treemap, or icicle when the path argument is provided. + """ + df: nw.DataFrame = args["data_frame"] + path = args["path"][::-1] + _check_dataframe_all_leaves(df[path[::-1]]) + discrete_color = not _is_continuous(df, args["color"]) if args["color"] else False + + df = df.lazy() + + new_path = [col_name + "_path_copy" for col_name in path] + df = df.with_columns( + nw.col(col_name).alias(new_col_name) + for new_col_name, col_name in zip(new_path, path) + ) + path = new_path + # ------------ Define aggregation functions -------------------------------- + agg_f = {} + if args["values"]: + try: + df = df.with_columns(nw.col(args["values"]).cast(nw.Float64())) + + except Exception: # pandas, Polars and pyarrow exception types are different + raise ValueError( + "Column `%s` of `df` could not be converted to a numerical data type." + % args["values"] + ) + + if args["color"] and args["color"] == args["values"]: + new_value_col_name = args["values"] + "_sum" + df = df.with_columns(nw.col(args["values"]).alias(new_value_col_name)) + args["values"] = new_value_col_name + count_colname = args["values"] + else: + # we need a count column for the first groupby and the weighted mean of color + # trick to be sure the col name is unused: take the sum of existing names + columns = df.collect_schema().names() + count_colname = ( + "count" if "count" not in columns else "".join([str(el) for el in columns]) + ) + # we can modify df because it's a copy of the px argument + df = df.with_columns(nw.lit(1).alias(count_colname)) + args["values"] = count_colname + + # Since count_colname is always in agg_f, it can be used later to normalize color + # in the continuous case after some gymnastic + agg_f[count_colname] = nw.sum(count_colname) + + discrete_aggs = [] + continuous_aggs = [] + + n_unique_token = _generate_temporary_column_name( + n_bytes=16, columns=df.collect_schema().names() + ) + + # In theory, for discrete columns aggregation, we should have a way to do + # `.agg(nw.col(x).unique())` in group_by and successively unpack/parse it as: + # ``` + # (nw.when(nw.col(x).list.len()==1) + # .then(nw.col(x).list.first()) + # .otherwise(nw.lit("(?)")) + # ) + # ``` + # which replicates the original pandas only codebase: + # ``` + # def discrete_agg(x): + # uniques = x.unique() + # return uniques[0] if len(uniques) == 1 else "(?)" + # + # df.groupby(path[i:]).agg(...) + # ``` + # However this is not possible, therefore the following workaround is provided. + # We make two aggregations for the same column: + # - take the max value + # - take the number of unique values + # Finally, after the group by statement, it is unpacked via: + # ``` + # (nw.when(nw.col(col_n_unique) == 1) + # .then(nw.col(col_max_value)) # which is the unique value + # .otherwise(nw.lit("(?)")) + # ) + # ``` + + if args["color"]: + if discrete_color: + discrete_aggs.append(args["color"]) + agg_f[args["color"]] = nw.col(args["color"]).max() + agg_f[f"{args['color']}{n_unique_token}"] = ( + nw.col(args["color"]) + .n_unique() + .alias(f"{args['color']}{n_unique_token}") + ) + else: + # This first needs to be multiplied by `count_colname` + continuous_aggs.append(args["color"]) + + agg_f[args["color"]] = nw.sum(args["color"]) + + # Other columns (for color, hover_data, custom_data etc.) + cols = list(set(df.collect_schema().names()).difference(path)) + df = df.with_columns(nw.col(c).cast(nw.String()) for c in cols if c not in agg_f) + + for col in cols: # for hover_data, custom_data etc. + if col not in agg_f: + # Similar trick as above + discrete_aggs.append(col) + agg_f[col] = nw.col(col).max() + agg_f[f"{col}{n_unique_token}"] = ( + nw.col(col).n_unique().alias(f"{col}{n_unique_token}") + ) + # Avoid collisions with reserved names - columns in the path have been copied already + cols = list(set(cols) - set(["labels", "parent", "id"])) + # ---------------------------------------------------------------------------- + all_trees = [] + + if args["color"] and not discrete_color: + df = df.with_columns( + (nw.col(args["color"]) * nw.col(count_colname)).alias(args["color"]) + ) + + def post_agg(dframe: nw.LazyFrame, continuous_aggs, discrete_aggs) -> nw.LazyFrame: + """ + - continuous_aggs is either [] or [args["color"]] + - discrete_aggs is either [args["color"], <rest_of_cols>] or [<rest_of cols>] + """ + return dframe.with_columns( + *[nw.col(col) / nw.col(count_colname) for col in continuous_aggs], + *[ + ( + nw.when(nw.col(f"{col}{n_unique_token}") == 1) + .then(nw.col(col)) + .otherwise(nw.lit("(?)")) + .alias(col) + ) + for col in discrete_aggs + ], + ).drop([f"{col}{n_unique_token}" for col in discrete_aggs]) + + for i, level in enumerate(path): + dfg = ( + df.group_by(path[i:], drop_null_keys=True) + .agg(**agg_f) + .pipe(post_agg, continuous_aggs, discrete_aggs) + ) + + # Path label massaging + df_tree = dfg.with_columns( + *cols, + labels=nw.col(level).cast(nw.String()), + parent=nw.lit(""), + id=nw.col(level).cast(nw.String()), + ) + if i < len(path) - 1: + _concat_str_token = _generate_temporary_column_name( + n_bytes=16, columns=[*cols, "labels", "parent", "id"] + ) + df_tree = ( + df_tree.with_columns( + nw.concat_str( + [ + nw.col(path[j]).cast(nw.String()) + for j in range(len(path) - 1, i, -1) + ], + separator="/", + ).alias(_concat_str_token) + ) + .with_columns( + parent=nw.concat_str( + [nw.col(_concat_str_token), nw.col("parent")], separator="/" + ), + id=nw.concat_str( + [nw.col(_concat_str_token), nw.col("id")], separator="/" + ), + ) + .drop(_concat_str_token) + ) + + # strip "/" if at the end of the string, equivalent to `.str.rstrip` + df_tree = df_tree.with_columns( + parent=nw.col("parent").str.replace("/?$", "").str.replace("^/?", "") + ) + + all_trees.append(df_tree.select(*["labels", "parent", "id", *cols])) + + df_all_trees = nw.maybe_reset_index(nw.concat(all_trees, how="vertical").collect()) + + # we want to make sure than (?) is the first color of the sequence + if args["color"] and discrete_color: + sort_col_name = "sort_color_if_discrete_color" + while sort_col_name in df_all_trees.columns: + sort_col_name += "0" + df_all_trees = df_all_trees.with_columns( + nw.col(args["color"]).cast(nw.String()).alias(sort_col_name) + ).sort(by=sort_col_name, nulls_last=True) + + # Now modify arguments + args["data_frame"] = df_all_trees + args["path"] = None + args["ids"] = "id" + args["names"] = "labels" + args["parents"] = "parent" + if args["color"]: + if not args["hover_data"]: + args["hover_data"] = [args["color"]] + elif isinstance(args["hover_data"], dict): + if not args["hover_data"].get(args["color"]): + args["hover_data"][args["color"]] = (True, None) + else: + args["hover_data"].append(args["color"]) + return args + + +def process_dataframe_timeline(args): + """ + Massage input for bar traces for px.timeline() + """ + args["is_timeline"] = True + if args["x_start"] is None or args["x_end"] is None: + raise ValueError("Both x_start and x_end are required") + + df: nw.DataFrame = args["data_frame"] + schema = df.schema + to_convert_to_datetime = [ + col + for col in [args["x_start"], args["x_end"]] + if schema[col] != nw.Datetime and schema[col] != nw.Date + ] + + if to_convert_to_datetime: + try: + df = df.with_columns(nw.col(to_convert_to_datetime).str.to_datetime()) + except Exception as exc: + raise TypeError( + "Both x_start and x_end must refer to data convertible to datetimes." + ) from exc + + # note that we are not adding any columns to the data frame here, so no risk of overwrite + args["data_frame"] = df.with_columns( + (nw.col(args["x_end"]) - nw.col(args["x_start"])) + .dt.total_milliseconds() + .alias(args["x_end"]) + ) + args["x"] = args["x_end"] + args["base"] = args["x_start"] + del args["x_start"], args["x_end"] + return args + + +def process_dataframe_pie(args, trace_patch): + import numpy as np + + names = args.get("names") + if names is None: + return args, trace_patch + order_in = args["category_orders"].get(names, {}).copy() + if not order_in: + return args, trace_patch + df: nw.DataFrame = args["data_frame"] + trace_patch["sort"] = False + trace_patch["direction"] = "clockwise" + uniques = df.get_column(names).unique(maintain_order=True).to_list() + order = [x for x in OrderedDict.fromkeys(list(order_in) + uniques) if x in uniques] + + # Sort args['data_frame'] by column `names` according to order `order`. + token = nw.generate_temporary_column_name(8, df.columns) + args["data_frame"] = ( + df.with_columns( + nw.col(names) + .replace_strict(order, np.arange(len(order)), return_dtype=nw.UInt32) + .alias(token) + ) + .sort(token) + .drop(token) + ) + return args, trace_patch + + +def infer_config(args, constructor, trace_patch, layout_patch): + attrs = [k for k in direct_attrables + array_attrables if k in args] + grouped_attrs = [] + df: nw.DataFrame = args["data_frame"] + + # Compute sizeref + sizeref = 0 + if "size" in args and args["size"]: + sizeref = ( + nw.to_py_scalar(df.get_column(args["size"]).max()) / args["size_max"] ** 2 + ) + + # Compute color attributes and grouping attributes + if "color" in args: + if "color_continuous_scale" in args: + if "color_discrete_sequence" not in args: + attrs.append("color") + else: + if args["color"] and _is_continuous(df, args["color"]): + attrs.append("color") + args["color_is_continuous"] = True + elif constructor in [go.Sunburst, go.Treemap, go.Icicle]: + attrs.append("color") + args["color_is_continuous"] = False + else: + grouped_attrs.append("marker.color") + elif "line_group" in args or constructor == go.Histogram2dContour: + grouped_attrs.append("line.color") + elif constructor in [go.Pie, go.Funnelarea]: + attrs.append("color") + if args["color"]: + if args["hover_data"] is None: + args["hover_data"] = [] + args["hover_data"].append(args["color"]) + else: + grouped_attrs.append("marker.color") + + show_colorbar = bool( + "color" in attrs + and args["color"] + and constructor not in [go.Pie, go.Funnelarea] + and ( + constructor not in [go.Treemap, go.Sunburst, go.Icicle] + or args.get("color_is_continuous") + ) + ) + else: + show_colorbar = False + + if "line_dash" in args: + grouped_attrs.append("line.dash") + + if "symbol" in args: + grouped_attrs.append("marker.symbol") + + if "pattern_shape" in args: + if constructor in [go.Scatter]: + grouped_attrs.append("fillpattern.shape") + else: + grouped_attrs.append("marker.pattern.shape") + + if "orientation" in args: + has_x = args["x"] is not None + has_y = args["y"] is not None + if args["orientation"] is None: + if constructor in [go.Histogram, go.Scatter]: + if has_y and not has_x: + args["orientation"] = "h" + elif constructor in [go.Violin, go.Box, go.Bar, go.Funnel]: + if has_x and not has_y: + args["orientation"] = "h" + + if args["orientation"] is None and has_x and has_y: + x_is_continuous = _is_continuous(df, args["x"]) + y_is_continuous = _is_continuous(df, args["y"]) + if x_is_continuous and not y_is_continuous: + args["orientation"] = "h" + if y_is_continuous and not x_is_continuous: + args["orientation"] = "v" + + if args["orientation"] is None: + args["orientation"] = "v" + + if constructor == go.Histogram: + if has_x and has_y and args["histfunc"] is None: + args["histfunc"] = trace_patch["histfunc"] = "sum" + + orientation = args["orientation"] + nbins = args["nbins"] + trace_patch["nbinsx"] = nbins if orientation == "v" else None + trace_patch["nbinsy"] = None if orientation == "v" else nbins + trace_patch["bingroup"] = "x" if orientation == "v" else "y" + trace_patch["orientation"] = args["orientation"] + + if constructor in [go.Violin, go.Box]: + mode = "boxmode" if constructor == go.Box else "violinmode" + if layout_patch[mode] is None and args["color"] is not None: + if args["y"] == args["color"] and args["orientation"] == "h": + layout_patch[mode] = "overlay" + elif args["x"] == args["color"] and args["orientation"] == "v": + layout_patch[mode] = "overlay" + if layout_patch[mode] is None: + layout_patch[mode] = "group" + + if ( + constructor == go.Histogram2d + and args["z"] is not None + and args["histfunc"] is None + ): + args["histfunc"] = trace_patch["histfunc"] = "sum" + + if args.get("text_auto", False) is not False: + if constructor in [go.Histogram2d, go.Histogram2dContour]: + letter = "z" + elif constructor == go.Bar: + letter = "y" if args["orientation"] == "v" else "x" + else: + letter = "value" + if args["text_auto"] is True: + trace_patch["texttemplate"] = "%{" + letter + "}" + else: + trace_patch["texttemplate"] = "%{" + letter + ":" + args["text_auto"] + "}" + + if constructor in [go.Histogram2d, go.Densitymap, go.Densitymapbox]: + show_colorbar = True + trace_patch["coloraxis"] = "coloraxis1" + + if "opacity" in args: + if args["opacity"] is None: + if "barmode" in args and args["barmode"] == "overlay": + trace_patch["marker"] = dict(opacity=0.5) + elif constructor in [ + go.Densitymap, + go.Densitymapbox, + go.Pie, + go.Funnel, + go.Funnelarea, + ]: + trace_patch["opacity"] = args["opacity"] + else: + trace_patch["marker"] = dict(opacity=args["opacity"]) + if ( + "line_group" in args or "line_dash" in args + ): # px.line, px.line_*, px.area, px.ecdf + modes = set() + if args.get("lines", True): + modes.add("lines") + if args.get("text") or args.get("symbol") or args.get("markers"): + modes.add("markers") + if args.get("text"): + modes.add("text") + if len(modes) == 0: + modes.add("lines") + trace_patch["mode"] = "+".join(sorted(modes)) + elif constructor != go.Splom and ( + "symbol" in args or constructor in [go.Scattermap, go.Scattermapbox] + ): + trace_patch["mode"] = "markers" + ("+text" if args["text"] else "") + + if "line_shape" in args: + trace_patch["line"] = dict(shape=args["line_shape"]) + elif "ecdfmode" in args: + trace_patch["line"] = dict( + shape="vh" if args["ecdfmode"] == "reversed" else "hv" + ) + + if "geojson" in args: + trace_patch["featureidkey"] = args["featureidkey"] + trace_patch["geojson"] = ( + args["geojson"] + if not hasattr(args["geojson"], "__geo_interface__") # for geopandas + else args["geojson"].__geo_interface__ + ) + + # Compute marginal attribute: copy to appropriate marginal_* + if "marginal" in args: + position = "marginal_x" if args["orientation"] == "v" else "marginal_y" + other_position = "marginal_x" if args["orientation"] == "h" else "marginal_y" + args[position] = args["marginal"] + args[other_position] = None + + # Ignore facet rows and columns when data frame is empty so as to prevent nrows/ncols equaling 0 + if df.is_empty(): + args["facet_row"] = args["facet_col"] = None + + # If both marginals and faceting are specified, faceting wins + if args.get("facet_col") is not None and args.get("marginal_y") is not None: + args["marginal_y"] = None + + if args.get("facet_row") is not None and args.get("marginal_x") is not None: + args["marginal_x"] = None + + # facet_col_wrap only works if no marginals or row faceting is used + if ( + args.get("marginal_x") is not None + or args.get("marginal_y") is not None + or args.get("facet_row") is not None + ): + args["facet_col_wrap"] = 0 + + if "trendline" in args and args["trendline"] is not None: + if args["trendline"] not in trendline_functions: + raise ValueError( + "Value '%s' for `trendline` must be one of %s" + % (args["trendline"], trendline_functions.keys()) + ) + + if "trendline_options" in args and args["trendline_options"] is None: + args["trendline_options"] = dict() + + if "ecdfnorm" in args: + if args.get("ecdfnorm", None) not in [None, "percent", "probability"]: + raise ValueError( + "`ecdfnorm` must be one of None, 'percent' or 'probability'. " + + "'%s' was provided." % args["ecdfnorm"] + ) + args["histnorm"] = args["ecdfnorm"] + + # Compute applicable grouping attributes + grouped_attrs.extend([k for k in group_attrables if k in args]) + + # Create grouped mappings + grouped_mappings = [make_mapping(args, a) for a in grouped_attrs] + + # Create trace specs + trace_specs = make_trace_spec(args, constructor, attrs, trace_patch) + return trace_specs, grouped_mappings, sizeref, show_colorbar + + +def get_groups_and_orders(args, grouper): + """ + `orders` is the user-supplied ordering with the remaining data-frame-supplied + ordering appended if the column is used for grouping. It includes anything the user + gave, for any variable, including values not present in the dataset. It's a dict + where the keys are e.g. "x" or "color" + + `groups` is the dicts of groups, ordered by the order above. Its keys are + tuples like [("value1", ""), ("value2", "")] where each tuple contains the name + of a single dimension-group + """ + orders = {} if "category_orders" not in args else args["category_orders"].copy() + df: nw.DataFrame = args["data_frame"] + # figure out orders and what the single group name would be if there were one + single_group_name = [] + unique_cache = dict() + + for i, col in enumerate(grouper): + if col == one_group: + single_group_name.append("") + else: + if col not in unique_cache: + unique_cache[col] = ( + df.get_column(col).unique(maintain_order=True).to_list() + ) + uniques = unique_cache[col] + if len(uniques) == 1: + single_group_name.append(uniques[0]) + if col not in orders: + orders[col] = uniques + else: + orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques)) + + if len(single_group_name) == len(grouper): + # we have a single group, so we can skip all group-by operations! + groups = {tuple(single_group_name): df} + else: + required_grouper = [group for group in orders if group in grouper] + grouped = dict(df.group_by(required_grouper, drop_null_keys=True).__iter__()) + + sorted_group_names = sorted( + grouped.keys(), + key=lambda values: [ + orders[group].index(value) if value in orders[group] else -1 + for group, value in zip(required_grouper, values) + ], + ) + + # calculate the full group_names by inserting "" in the tuple index for one_group groups + full_sorted_group_names = [ + tuple( + [ + ( + "" + if col == one_group + else sub_group_names[required_grouper.index(col)] + ) + for col in grouper + ] + ) + for sub_group_names in sorted_group_names + ] + + groups = { + sf: grouped[s] for sf, s in zip(full_sorted_group_names, sorted_group_names) + } + return groups, orders + + +def make_figure(args, constructor, trace_patch=None, layout_patch=None): + trace_patch = trace_patch or {} + layout_patch = layout_patch or {} + apply_default_cascade(args) + + args = build_dataframe(args, constructor) + if constructor in [go.Treemap, go.Sunburst, go.Icicle] and args["path"] is not None: + args = process_dataframe_hierarchy(args) + if constructor in [go.Pie]: + args, trace_patch = process_dataframe_pie(args, trace_patch) + if constructor == "timeline": + constructor = go.Bar + args = process_dataframe_timeline(args) + + # If we have marginal histograms, set barmode to "overlay" + if "histogram" in [args.get("marginal_x"), args.get("marginal_y")]: + layout_patch["barmode"] = "overlay" + + trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config( + args, constructor, trace_patch, layout_patch + ) + grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] + groups, orders = get_groups_and_orders(args, grouper) + + col_labels = [] + row_labels = [] + nrows = ncols = 1 + for m in grouped_mappings: + if m.grouper not in orders: + m.val_map[""] = m.sequence[0] + else: + sorted_values = orders[m.grouper] + if m.facet == "col": + prefix = get_label(args, args["facet_col"]) + "=" + col_labels = [prefix + str(s) for s in sorted_values] + ncols = len(col_labels) + if m.facet == "row": + prefix = get_label(args, args["facet_row"]) + "=" + row_labels = [prefix + str(s) for s in sorted_values] + nrows = len(row_labels) + for val in sorted_values: + if val not in m.val_map: # always False if it's an IdentityMap + m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)] + + subplot_type = _subplot_type_for_trace_type(constructor().type) + + trace_names_by_frame = {} + frames = OrderedDict() + trendline_rows = [] + trace_name_labels = None + facet_col_wrap = args.get("facet_col_wrap", 0) + for group_name, group in groups.items(): + mapping_labels = OrderedDict() + trace_name_labels = OrderedDict() + frame_name = "" + for col, val, m in zip(grouper, group_name, grouped_mappings): + if col != one_group: + key = get_label(args, col) + if not isinstance(m.val_map, IdentityMap): + mapping_labels[key] = str(val) + if m.show_in_trace_name: + trace_name_labels[key] = str(val) + if m.variable == "animation_frame": + frame_name = val + trace_name = ", ".join(trace_name_labels.values()) + if frame_name not in trace_names_by_frame: + trace_names_by_frame[frame_name] = set() + trace_names = trace_names_by_frame[frame_name] + + for trace_spec in trace_specs: + # Create the trace + trace = trace_spec.constructor(name=trace_name) + if trace_spec.constructor not in [ + go.Parcats, + go.Parcoords, + go.Choropleth, + go.Choroplethmap, + go.Choroplethmapbox, + go.Densitymap, + go.Densitymapbox, + go.Histogram2d, + go.Sunburst, + go.Treemap, + go.Icicle, + ]: + trace.update( + legendgroup=trace_name, + showlegend=(trace_name != "" and trace_name not in trace_names), + ) + + # Set 'offsetgroup' only in group barmode (or if no barmode is set) + barmode = layout_patch.get("barmode") + if trace_spec.constructor in [go.Bar, go.Box, go.Violin, go.Histogram] and ( + barmode == "group" or barmode is None + ): + trace.update(alignmentgroup=True, offsetgroup=trace_name) + trace_names.add(trace_name) + + # Init subplot row/col + trace._subplot_row = 1 + trace._subplot_col = 1 + + for i, m in enumerate(grouped_mappings): + val = group_name[i] + try: + m.updater(trace, m.val_map[val]) # covers most cases + except ValueError: + # this catches some odd cases like marginals + if ( + trace_spec != trace_specs[0] + and ( + trace_spec.constructor in [go.Violin, go.Box] + and m.variable in ["symbol", "pattern", "dash"] + ) + or ( + trace_spec.constructor in [go.Histogram] + and m.variable in ["symbol", "dash"] + ) + ): + pass + elif ( + trace_spec != trace_specs[0] + and trace_spec.constructor in [go.Histogram] + and m.variable == "color" + ): + trace.update(marker=dict(color=m.val_map[val])) + elif ( + trace_spec.constructor + in [go.Choropleth, go.Choroplethmap, go.Choroplethmapbox] + and m.variable == "color" + ): + trace.update( + z=[1] * len(group), + colorscale=[m.val_map[val]] * 2, + showscale=False, + showlegend=True, + ) + else: + raise + + # Find row for trace, handling facet_row and marginal_x + if m.facet == "row": + row = m.val_map[val] + else: + if ( + args.get("marginal_x") is not None # there is a marginal + and trace_spec.marginal != "x" # and we're not it + ): + row = 2 + else: + row = 1 + + # Find col for trace, handling facet_col and marginal_y + if m.facet == "col": + col = m.val_map[val] + if facet_col_wrap: # assumes no facet_row, no marginals + row = 1 + ((col - 1) // facet_col_wrap) + col = 1 + ((col - 1) % facet_col_wrap) + else: + if trace_spec.marginal == "y": + col = 2 + else: + col = 1 + + if row > 1: + trace._subplot_row = row + + if col > 1: + trace._subplot_col = col + if ( + trace_specs[0].constructor == go.Histogram2dContour + and trace_spec.constructor == go.Box + and trace.line.color + ): + trace.update(marker=dict(color=trace.line.color)) + + if "ecdfmode" in args: + base = args["x"] if args["orientation"] == "v" else args["y"] + var = args["x"] if args["orientation"] == "h" else args["y"] + ascending = args.get("ecdfmode", "standard") != "reversed" + group = group.sort(by=base, descending=not ascending, nulls_last=True) + group_sum = group.get_column( + var + ).sum() # compute here before next line mutates + group = group.with_columns(nw.col(var).cum_sum().alias(var)) + if not ascending: + group = group.sort(by=base, descending=False, nulls_last=True) + + if args.get("ecdfmode", "standard") == "complementary": + group = group.with_columns((group_sum - nw.col(var)).alias(var)) + + if args["ecdfnorm"] == "probability": + group = group.with_columns(nw.col(var) / group_sum) + elif args["ecdfnorm"] == "percent": + group = group.with_columns((nw.col(var) / group_sum) * 100.0) + + patch, fit_results = make_trace_kwargs( + args, trace_spec, group, mapping_labels.copy(), sizeref + ) + trace.update(patch) + if fit_results is not None: + trendline_rows.append(mapping_labels.copy()) + trendline_rows[-1]["px_fit_results"] = fit_results + if frame_name not in frames: + frames[frame_name] = dict(data=[], name=frame_name) + frames[frame_name]["data"].append(trace) + frame_list = [f for f in frames.values()] + if len(frame_list) > 1: + frame_list = sorted( + frame_list, key=lambda f: orders[args["animation_frame"]].index(f["name"]) + ) + + if show_colorbar: + colorvar = ( + "z" + if constructor in [go.Histogram2d, go.Densitymap, go.Densitymapbox] + else "color" + ) + range_color = args["range_color"] or [None, None] + + colorscale_validator = ColorscaleValidator("colorscale", "make_figure") + layout_patch["coloraxis1"] = dict( + colorscale=colorscale_validator.validate_coerce( + args["color_continuous_scale"] + ), + cmid=args["color_continuous_midpoint"], + cmin=range_color[0], + cmax=range_color[1], + colorbar=dict( + title_text=get_decorated_label(args, args[colorvar], colorvar) + ), + ) + for v in ["height", "width"]: + if args[v]: + layout_patch[v] = args[v] + layout_patch["legend"] = dict(tracegroupgap=0) + if trace_name_labels: + layout_patch["legend"]["title_text"] = ", ".join(trace_name_labels) + if args["title"]: + layout_patch["title_text"] = args["title"] + elif args["template"].layout.margin.t is None: + layout_patch["margin"] = {"t": 60} + if args["subtitle"]: + layout_patch["title_subtitle_text"] = args["subtitle"] + if ( + "size" in args + and args["size"] + and args["template"].layout.legend.itemsizing is None + ): + layout_patch["legend"]["itemsizing"] = "constant" + + if facet_col_wrap: + nrows = math.ceil(ncols / facet_col_wrap) + ncols = min(ncols, facet_col_wrap) + + if args.get("marginal_x") is not None: + nrows += 1 + + if args.get("marginal_y") is not None: + ncols += 1 + + fig = init_figure( + args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels + ) + + # Position traces in subplots + for frame in frame_list: + for trace in frame["data"]: + if isinstance(trace, go.Splom): + # Special case that is not compatible with make_subplots + continue + + _set_trace_grid_reference( + trace, + fig.layout, + fig._grid_ref, + nrows - trace._subplot_row + 1, + trace._subplot_col, + ) + + # Add traces, layout and frames to figure + fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else []) + fig.update_layout(layout_patch) + if "template" in args and args["template"] is not None: + fig.update_layout(template=args["template"], overwrite=True) + for f in frame_list: + f["name"] = str(f["name"]) + fig.frames = frame_list if len(frames) > 1 else [] + + if args.get("trendline") and args.get("trendline_scope", "trace") == "overall": + trendline_spec = make_trendline_spec(args, constructor) + trendline_trace = trendline_spec.constructor( + name="Overall Trendline", legendgroup="Overall Trendline", showlegend=False + ) + if "line" not in trendline_spec.trace_patch: # no color override + for m in grouped_mappings: + if m.variable == "color": + next_color = m.sequence[len(m.val_map) % len(m.sequence)] + trendline_spec.trace_patch["line"] = dict(color=next_color) + patch, fit_results = make_trace_kwargs( + args, trendline_spec, args["data_frame"], {}, sizeref + ) + trendline_trace.update(patch) + fig.add_trace( + trendline_trace, row="all", col="all", exclude_empty_subplots=True + ) + fig.update_traces(selector=-1, showlegend=True) + if fit_results is not None: + trendline_rows.append(dict(px_fit_results=fit_results)) + + if trendline_rows: + try: + import pandas as pd + + fig._px_trendlines = pd.DataFrame(trendline_rows) + except ImportError: + msg = "Trendlines require pandas to be installed." + raise NotImplementedError(msg) + else: + fig._px_trendlines = [] + + configure_axes(args, constructor, fig, orders) + configure_animation_controls(args, constructor, fig) + return fig + + +def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels): + # Build subplot specs + specs = [[dict(type=subplot_type or "domain")] * ncols for _ in range(nrows)] + + # Default row/column widths uniform + column_widths = [1.0] * ncols + row_heights = [1.0] * nrows + facet_col_wrap = args.get("facet_col_wrap", 0) + + # Build column_widths/row_heights + if subplot_type == "xy": + if args.get("marginal_x") is not None: + if args["marginal_x"] == "histogram" or ("color" in args and args["color"]): + main_size = 0.74 + else: + main_size = 0.84 + + row_heights = [main_size] * (nrows - 1) + [1 - main_size] + vertical_spacing = 0.01 + elif facet_col_wrap: + vertical_spacing = args.get("facet_row_spacing") or 0.07 + else: + vertical_spacing = args.get("facet_row_spacing") or 0.03 + + if args.get("marginal_y") is not None: + if args["marginal_y"] == "histogram" or ("color" in args and args["color"]): + main_size = 0.74 + else: + main_size = 0.84 + + column_widths = [main_size] * (ncols - 1) + [1 - main_size] + horizontal_spacing = 0.005 + else: + horizontal_spacing = args.get("facet_col_spacing") or 0.02 + else: + # Other subplot types: + # 'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None + # + # We can customize subplot spacing per type once we enable faceting + # for all plot types + if facet_col_wrap: + vertical_spacing = args.get("facet_row_spacing") or 0.07 + else: + vertical_spacing = args.get("facet_row_spacing") or 0.03 + horizontal_spacing = args.get("facet_col_spacing") or 0.02 + + if facet_col_wrap: + subplot_labels = [None] * nrows * ncols + while len(col_labels) < nrows * ncols: + col_labels.append(None) + for i in range(nrows): + for j in range(ncols): + subplot_labels[i * ncols + j] = col_labels[(nrows - 1 - i) * ncols + j] + + def _spacing_error_translator(e, direction, facet_arg): + """ + Translates the spacing errors thrown by the underlying make_subplots + routine into one that describes an argument adjustable through px. + """ + if ("%s spacing" % (direction,)) in e.args[0]: + e.args = ( + e.args[0] + + """ +Use the {facet_arg} argument to adjust this spacing.""".format(facet_arg=facet_arg), + ) + raise e + + # Create figure with subplots + try: + fig = make_subplots( + rows=nrows, + cols=ncols, + specs=specs, + shared_xaxes="all", + shared_yaxes="all", + row_titles=[] if facet_col_wrap else list(reversed(row_labels)), + column_titles=[] if facet_col_wrap else col_labels, + subplot_titles=subplot_labels if facet_col_wrap else [], + horizontal_spacing=horizontal_spacing, + vertical_spacing=vertical_spacing, + row_heights=row_heights, + column_widths=column_widths, + start_cell="bottom-left", + ) + except ValueError as e: + _spacing_error_translator(e, "Horizontal", "facet_col_spacing") + _spacing_error_translator(e, "Vertical", "facet_row_spacing") + raise + + # Remove explicit font size of row/col titles so template can take over + for annot in fig.layout.annotations: + annot.update(font=None) + + return fig diff --git a/venv/lib/python3.8/site-packages/plotly/express/_doc.py b/venv/lib/python3.8/site-packages/plotly/express/_doc.py new file mode 100644 index 0000000..59faac4 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/express/_doc.py @@ -0,0 +1,640 @@ +import inspect +from textwrap import TextWrapper + +try: + getfullargspec = inspect.getfullargspec +except AttributeError: # python 2 + getfullargspec = inspect.getargspec + + +colref_type = "str or int or Series or array-like" +colref_desc = "Either a name of a column in `data_frame`, or a pandas Series or array_like object." +colref_list_type = "list of str or int, or Series or array-like" +colref_list_desc = ( + "Either names of columns in `data_frame`, or pandas Series, or array_like objects" +) + +docs = dict( + data_frame=[ + "DataFrame or array-like or dict", + "This argument needs to be passed for column names (and not keyword names) to be used.", + "Array-like and dict are transformed internally to a pandas DataFrame.", + "Optional: if missing, a DataFrame gets constructed under the hood using the other arguments.", + ], + x=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to position marks along the x axis in cartesian coordinates.", + ], + y=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to position marks along the y axis in cartesian coordinates.", + ], + z=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to position marks along the z axis in cartesian coordinates.", + ], + x_start=[ + colref_type, + colref_desc, + "(required)", + "Values from this column or array_like are used to position marks along the x axis in cartesian coordinates.", + ], + x_end=[ + colref_type, + colref_desc, + "(required)", + "Values from this column or array_like are used to position marks along the x axis in cartesian coordinates.", + ], + a=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to position marks along the a axis in ternary coordinates.", + ], + b=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to position marks along the b axis in ternary coordinates.", + ], + c=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to position marks along the c axis in ternary coordinates.", + ], + r=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to position marks along the radial axis in polar coordinates.", + ], + theta=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to position marks along the angular axis in polar coordinates.", + ], + values=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to set values associated to sectors.", + ], + parents=[ + colref_type, + colref_desc, + "Values from this column or array_like are used as parents in sunburst and treemap charts.", + ], + ids=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to set ids of sectors", + ], + path=[ + colref_list_type, + colref_list_desc, + "List of columns names or columns of a rectangular dataframe defining the hierarchy of sectors, from root to leaves.", + "An error is raised if path AND ids or parents is passed", + ], + lat=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to position marks according to latitude on a map.", + ], + lon=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to position marks according to longitude on a map.", + ], + locations=[ + colref_type, + colref_desc, + "Values from this column or array_like are to be interpreted according to `locationmode` and mapped to longitude/latitude.", + ], + base=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to position the base of the bar.", + ], + dimensions=[ + colref_list_type, + colref_list_desc, + "Values from these columns are used for multidimensional visualization.", + ], + dimensions_max_cardinality=[ + "int (default 50)", + "When `dimensions` is `None` and `data_frame` is provided, " + "columns with more than this number of unique values are excluded from the output.", + "Not used when `dimensions` is passed.", + ], + error_x=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to size x-axis error bars.", + "If `error_x_minus` is `None`, error bars will be symmetrical, otherwise `error_x` is used for the positive direction only.", + ], + error_x_minus=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to size x-axis error bars in the negative direction.", + "Ignored if `error_x` is `None`.", + ], + error_y=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to size y-axis error bars.", + "If `error_y_minus` is `None`, error bars will be symmetrical, otherwise `error_y` is used for the positive direction only.", + ], + error_y_minus=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to size y-axis error bars in the negative direction.", + "Ignored if `error_y` is `None`.", + ], + error_z=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to size z-axis error bars.", + "If `error_z_minus` is `None`, error bars will be symmetrical, otherwise `error_z` is used for the positive direction only.", + ], + error_z_minus=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to size z-axis error bars in the negative direction.", + "Ignored if `error_z` is `None`.", + ], + color=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to assign color to marks.", + ], + opacity=["float", "Value between 0 and 1. Sets the opacity for markers."], + line_dash=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to assign dash-patterns to lines.", + ], + line_group=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to group rows of `data_frame` into lines.", + ], + symbol=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to assign symbols to marks.", + ], + pattern_shape=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to assign pattern shapes to marks.", + ], + size=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to assign mark sizes.", + ], + radius=["int (default is 30)", "Sets the radius of influence of each point."], + hover_name=[ + colref_type, + colref_desc, + "Values from this column or array_like appear in bold in the hover tooltip.", + ], + hover_data=[ + "str, or list of str or int, or Series or array-like, or dict", + "Either a name or list of names of columns in `data_frame`, or pandas Series,", + "or array_like objects", + "or a dict with column names as keys, with values True (for default formatting)", + "False (in order to remove this column from hover information),", + "or a formatting string, for example ':.3f' or '|%a'", + "or list-like data to appear in the hover tooltip", + "or tuples with a bool or formatting string as first element,", + "and list-like data to appear in hover as second element", + "Values from these columns appear as extra data in the hover tooltip.", + ], + custom_data=[ + "str, or list of str or int, or Series or array-like", + "Either name or list of names of columns in `data_frame`, or pandas Series, or array_like objects", + "Values from these columns are extra data, to be used in widgets or Dash callbacks for example. This data is not user-visible but is included in events emitted by the figure (lasso selection etc.)", + ], + text=[ + colref_type, + colref_desc, + "Values from this column or array_like appear in the figure as text labels.", + ], + names=[ + colref_type, + colref_desc, + "Values from this column or array_like are used as labels for sectors.", + ], + locationmode=[ + "str", + "One of 'ISO-3', 'USA-states', or 'country names'", + "Determines the set of locations used to match entries in `locations` to regions on the map.", + ], + facet_row=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to assign marks to facetted subplots in the vertical direction.", + ], + facet_col=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to assign marks to facetted subplots in the horizontal direction.", + ], + facet_col_wrap=[ + "int", + "Maximum number of facet columns.", + "Wraps the column variable at this width, so that the column facets span multiple rows.", + "Ignored if 0, and forced to 0 if `facet_row` or a `marginal` is set.", + ], + facet_row_spacing=[ + "float between 0 and 1", + "Spacing between facet rows, in paper units. Default is 0.03 or 0.07 when facet_col_wrap is used.", + ], + facet_col_spacing=[ + "float between 0 and 1", + "Spacing between facet columns, in paper units Default is 0.02.", + ], + animation_frame=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to assign marks to animation frames.", + ], + animation_group=[ + colref_type, + colref_desc, + "Values from this column or array_like are used to provide object-constancy across animation frames: rows with matching `animation_group`s will be treated as if they describe the same object in each frame.", + ], + symbol_sequence=[ + "list of str", + "Strings should define valid plotly.js symbols.", + "When `symbol` is set, values in that column are assigned symbols by cycling through `symbol_sequence` in the order described in `category_orders`, unless the value of `symbol` is a key in `symbol_map`.", + ], + symbol_map=[ + "dict with str keys and str values (default `{}`)", + "String values should define plotly.js symbols", + "Used to override `symbol_sequence` to assign a specific symbols to marks corresponding with specific values.", + "Keys in `symbol_map` should be values in the column denoted by `symbol`.", + "Alternatively, if the values of `symbol` are valid symbol names, the string `'identity'` may be passed to cause them to be used directly.", + ], + line_dash_map=[ + "dict with str keys and str values (default `{}`)", + "Strings values define plotly.js dash-patterns.", + "Used to override `line_dash_sequences` to assign a specific dash-patterns to lines corresponding with specific values.", + "Keys in `line_dash_map` should be values in the column denoted by `line_dash`.", + "Alternatively, if the values of `line_dash` are valid line-dash names, the string `'identity'` may be passed to cause them to be used directly.", + ], + line_dash_sequence=[ + "list of str", + "Strings should define valid plotly.js dash-patterns.", + "When `line_dash` is set, values in that column are assigned dash-patterns by cycling through `line_dash_sequence` in the order described in `category_orders`, unless the value of `line_dash` is a key in `line_dash_map`.", + ], + pattern_shape_map=[ + "dict with str keys and str values (default `{}`)", + "Strings values define plotly.js patterns-shapes.", + "Used to override `pattern_shape_sequences` to assign a specific patterns-shapes to lines corresponding with specific values.", + "Keys in `pattern_shape_map` should be values in the column denoted by `pattern_shape`.", + "Alternatively, if the values of `pattern_shape` are valid patterns-shapes names, the string `'identity'` may be passed to cause them to be used directly.", + ], + pattern_shape_sequence=[ + "list of str", + "Strings should define valid plotly.js patterns-shapes.", + "When `pattern_shape` is set, values in that column are assigned patterns-shapes by cycling through `pattern_shape_sequence` in the order described in `category_orders`, unless the value of `pattern_shape` is a key in `pattern_shape_map`.", + ], + color_discrete_sequence=[ + "list of str", + "Strings should define valid CSS-colors.", + "When `color` is set and the values in the corresponding column are not numeric, values in that column are assigned colors by cycling through `color_discrete_sequence` in the order described in `category_orders`, unless the value of `color` is a key in `color_discrete_map`.", + "Various useful color sequences are available in the `plotly.express.colors` submodules, specifically `plotly.express.colors.qualitative`.", + ], + color_discrete_map=[ + "dict with str keys and str values (default `{}`)", + "String values should define valid CSS-colors", + "Used to override `color_discrete_sequence` to assign a specific colors to marks corresponding with specific values.", + "Keys in `color_discrete_map` should be values in the column denoted by `color`.", + "Alternatively, if the values of `color` are valid colors, the string `'identity'` may be passed to cause them to be used directly.", + ], + color_continuous_scale=[ + "list of str", + "Strings should define valid CSS-colors", + "This list is used to build a continuous color scale when the column denoted by `color` contains numeric data.", + "Various useful color scales are available in the `plotly.express.colors` submodules, specifically `plotly.express.colors.sequential`, `plotly.express.colors.diverging` and `plotly.express.colors.cyclical`.", + ], + color_continuous_midpoint=[ + "number (default `None`)", + "If set, computes the bounds of the continuous color scale to have the desired midpoint.", + "Setting this value is recommended when using `plotly.express.colors.diverging` color scales as the inputs to `color_continuous_scale`.", + ], + size_max=["int (default `20`)", "Set the maximum mark size when using `size`."], + markers=["boolean (default `False`)", "If `True`, markers are shown on lines."], + lines=[ + "boolean (default `True`)", + "If `False`, lines are not drawn (forced to `True` if `markers` is `False`).", + ], + log_x=[ + "boolean (default `False`)", + "If `True`, the x-axis is log-scaled in cartesian coordinates.", + ], + log_y=[ + "boolean (default `False`)", + "If `True`, the y-axis is log-scaled in cartesian coordinates.", + ], + log_z=[ + "boolean (default `False`)", + "If `True`, the z-axis is log-scaled in cartesian coordinates.", + ], + log_r=[ + "boolean (default `False`)", + "If `True`, the radial axis is log-scaled in polar coordinates.", + ], + range_x=[ + "list of two numbers", + "If provided, overrides auto-scaling on the x-axis in cartesian coordinates.", + ], + range_y=[ + "list of two numbers", + "If provided, overrides auto-scaling on the y-axis in cartesian coordinates.", + ], + range_z=[ + "list of two numbers", + "If provided, overrides auto-scaling on the z-axis in cartesian coordinates.", + ], + range_color=[ + "list of two numbers", + "If provided, overrides auto-scaling on the continuous color scale.", + ], + range_r=[ + "list of two numbers", + "If provided, overrides auto-scaling on the radial axis in polar coordinates.", + ], + range_theta=[ + "list of two numbers", + "If provided, overrides auto-scaling on the angular axis in polar coordinates.", + ], + title=["str", "The figure title."], + subtitle=["str", "The figure subtitle."], + template=[ + "str or dict or plotly.graph_objects.layout.Template instance", + "The figure template name (must be a key in plotly.io.templates) or definition.", + ], + width=["int (default `None`)", "The figure width in pixels."], + height=["int (default `None`)", "The figure height in pixels."], + labels=[ + "dict with str keys and str values (default `{}`)", + "By default, column names are used in the figure for axis titles, legend entries and hovers.", + "This parameter allows this to be overridden.", + "The keys of this dict should correspond to column names, and the values should correspond to the desired label to be displayed.", + ], + category_orders=[ + "dict with str keys and list of str values (default `{}`)", + "By default, in Python 3.6+, the order of categorical values in axes, legends and facets depends on the order in which these values are first encountered in `data_frame` (and no order is guaranteed by default in Python below 3.6).", + "This parameter is used to force a specific ordering of values per column.", + "The keys of this dict should correspond to column names, and the values should be lists of strings corresponding to the specific display order desired.", + ], + marginal=[ + "str", + "One of `'rug'`, `'box'`, `'violin'`, or `'histogram'`.", + "If set, a subplot is drawn alongside the main plot, visualizing the distribution.", + ], + marginal_x=[ + "str", + "One of `'rug'`, `'box'`, `'violin'`, or `'histogram'`.", + "If set, a horizontal subplot is drawn above the main plot, visualizing the x-distribution.", + ], + marginal_y=[ + "str", + "One of `'rug'`, `'box'`, `'violin'`, or `'histogram'`.", + "If set, a vertical subplot is drawn to the right of the main plot, visualizing the y-distribution.", + ], + trendline=[ + "str", + "One of `'ols'`, `'lowess'`, `'rolling'`, `'expanding'` or `'ewm'`.", + "If `'ols'`, an Ordinary Least Squares regression line will be drawn for each discrete-color/symbol group.", + "If `'lowess`', a Locally Weighted Scatterplot Smoothing line will be drawn for each discrete-color/symbol group.", + "If `'rolling`', a Rolling (e.g. rolling average, rolling median) line will be drawn for each discrete-color/symbol group.", + "If `'expanding`', an Expanding (e.g. expanding average, expanding sum) line will be drawn for each discrete-color/symbol group.", + "If `'ewm`', an Exponentially Weighted Moment (e.g. exponentially-weighted moving average) line will be drawn for each discrete-color/symbol group.", + "See the docstrings for the functions in `plotly.express.trendline_functions` for more details on these functions and how", + "to configure them with the `trendline_options` argument.", + ], + trendline_options=[ + "dict", + "Options passed as the first argument to the function from `plotly.express.trendline_functions` ", + "named in the `trendline` argument.", + ], + trendline_color_override=[ + "str", + "Valid CSS color.", + "If provided, and if `trendline` is set, all trendlines will be drawn in this color rather than in the same color as the traces from which they draw their inputs.", + ], + trendline_scope=[ + "str (one of `'trace'` or `'overall'`, default `'trace'`)", + "If `'trace'`, then one trendline is drawn per trace (i.e. per color, symbol, facet, animation frame etc) and if `'overall'` then one trendline is computed for the entire dataset, and replicated across all facets.", + ], + render_mode=[ + "str", + "One of `'auto'`, `'svg'` or `'webgl'`, default `'auto'`", + "Controls the browser API used to draw marks.", + "`'svg'` is appropriate for figures of less than 1000 data points, and will allow for fully-vectorized output.", + "`'webgl'` is likely necessary for acceptable performance above 1000 points but rasterizes part of the output. ", + "`'auto'` uses heuristics to choose the mode.", + ], + direction=[ + "str", + "One of '`counterclockwise'` or `'clockwise'`. Default is `'clockwise'`", + "Sets the direction in which increasing values of the angular axis are drawn.", + ], + start_angle=[ + "int (default `90`)", + "Sets start angle for the angular axis, with 0 being due east and 90 being due north.", + ], + histfunc=[ + "str (default `'count'` if no arguments are provided, else `'sum'`)", + "One of `'count'`, `'sum'`, `'avg'`, `'min'`, or `'max'`.", + "Function used to aggregate values for summarization (note: can be normalized with `histnorm`).", + ], + histnorm=[ + "str (default `None`)", + "One of `'percent'`, `'probability'`, `'density'`, or `'probability density'`", + "If `None`, the output of `histfunc` is used as is.", + "If `'probability'`, the output of `histfunc` for a given bin is divided by the sum of the output of `histfunc` for all bins.", + "If `'percent'`, the output of `histfunc` for a given bin is divided by the sum of the output of `histfunc` for all bins and multiplied by 100.", + "If `'density'`, the output of `histfunc` for a given bin is divided by the size of the bin.", + "If `'probability density'`, the output of `histfunc` for a given bin is normalized such that it corresponds to the probability that a random event whose distribution is described by the output of `histfunc` will fall into that bin.", + ], + barnorm=[ + "str (default `None`)", + "One of `'fraction'` or `'percent'`.", + "If `'fraction'`, the value of each bar is divided by the sum of all values at that location coordinate.", + "`'percent'` is the same but multiplied by 100 to show percentages.", + "`None` will stack up all values at each location coordinate.", + ], + groupnorm=[ + "str (default `None`)", + "One of `'fraction'` or `'percent'`.", + "If `'fraction'`, the value of each point is divided by the sum of all values at that location coordinate.", + "`'percent'` is the same but multiplied by 100 to show percentages.", + "`None` will stack up all values at each location coordinate.", + ], + barmode=[ + "str (default `'relative'`)", + "One of `'group'`, `'overlay'` or `'relative'`", + "In `'relative'` mode, bars are stacked above zero for positive values and below zero for negative values.", + "In `'overlay'` mode, bars are drawn on top of one another.", + "In `'group'` mode, bars are placed beside each other.", + ], + boxmode=[ + "str (default `'group'`)", + "One of `'group'` or `'overlay'`", + "In `'overlay'` mode, boxes are on drawn top of one another.", + "In `'group'` mode, boxes are placed beside each other.", + ], + violinmode=[ + "str (default `'group'`)", + "One of `'group'` or `'overlay'`", + "In `'overlay'` mode, violins are on drawn top of one another.", + "In `'group'` mode, violins are placed beside each other.", + ], + stripmode=[ + "str (default `'group'`)", + "One of `'group'` or `'overlay'`", + "In `'overlay'` mode, strips are on drawn top of one another.", + "In `'group'` mode, strips are placed beside each other.", + ], + zoom=["int (default `8`)", "Between 0 and 20.", "Sets map zoom level."], + orientation=[ + "str, one of `'h'` for horizontal or `'v'` for vertical. ", + "(default `'v'` if `x` and `y` are provided and both continuous or both categorical, ", + "otherwise `'v'`(`'h'`) if `x`(`y`) is categorical and `y`(`x`) is continuous, ", + "otherwise `'v'`(`'h'`) if only `x`(`y`) is provided) ", + ], + line_close=[ + "boolean (default `False`)", + "If `True`, an extra line segment is drawn between the first and last point.", + ], + line_shape=[ + "str (default `'linear'`)", + "One of `'linear'`, `'spline'`, `'hv'`, `'vh'`, `'hvh'`, or `'vhv'`", + ], + fitbounds=["str (default `False`).", "One of `False`, `locations` or `geojson`."], + basemap_visible=["bool", "Force the basemap visibility."], + scope=[ + "str (default `'world'`).", + "One of `'world'`, `'usa'`, `'europe'`, `'asia'`, `'africa'`, `'north america'`, or `'south america'`" + "Default is `'world'` unless `projection` is set to `'albers usa'`, which forces `'usa'`.", + ], + projection=[ + "str ", + "One of `'equirectangular'`, `'mercator'`, `'orthographic'`, `'natural earth'`, `'kavrayskiy7'`, `'miller'`, `'robinson'`, `'eckert4'`, `'azimuthal equal area'`, `'azimuthal equidistant'`, `'conic equal area'`, `'conic conformal'`, `'conic equidistant'`, `'gnomonic'`, `'stereographic'`, `'mollweide'`, `'hammer'`, `'transverse mercator'`, `'albers usa'`, `'winkel tripel'`, `'aitoff'`, or `'sinusoidal'`" + "Default depends on `scope`.", + ], + center=[ + "dict", + "Dict keys are `'lat'` and `'lon'`", + "Sets the center point of the map.", + ], + map_style=[ + "str (default `'basic'`)", + "Identifier of base map style.", + "Allowed values are `'basic'`, `'carto-darkmatter'`, `'carto-darkmatter-nolabels'`, `'carto-positron'`, `'carto-positron-nolabels'`, `'carto-voyager'`, `'carto-voyager-nolabels'`, `'dark'`, `'light'`, `'open-street-map'`, `'outdoors'`, `'satellite'`, `'satellite-streets'`, `'streets'`, `'white-bg'`.", + ], + mapbox_style=[ + "str (default `'basic'`, needs Mapbox API token)", + "Identifier of base map style, some of which require a Mapbox or Stadia Maps API token to be set using `plotly.express.set_mapbox_access_token()`.", + "Allowed values which do not require a token are `'open-street-map'`, `'white-bg'`, `'carto-positron'`, `'carto-darkmatter'`.", + "Allowed values which require a Mapbox API token are `'basic'`, `'streets'`, `'outdoors'`, `'light'`, `'dark'`, `'satellite'`, `'satellite-streets'`.", + "Allowed values which require a Stadia Maps API token are `'stamen-terrain'`, `'stamen-toner'`, `'stamen-watercolor'`.", + ], + points=[ + "str or boolean (default `'outliers'`)", + "One of `'outliers'`, `'suspectedoutliers'`, `'all'`, or `False`.", + "If `'outliers'`, only the sample points lying outside the whiskers are shown.", + "If `'suspectedoutliers'`, all outlier points are shown and those less than 4*Q1-3*Q3 or greater than 4*Q3-3*Q1 are highlighted with the marker's `'outliercolor'`.", + "If `'outliers'`, only the sample points lying outside the whiskers are shown.", + "If `'all'`, all sample points are shown.", + "If `False`, no sample points are shown and the whiskers extend to the full range of the sample.", + ], + box=["boolean (default `False`)", "If `True`, boxes are drawn inside the violins."], + notched=["boolean (default `False`)", "If `True`, boxes are drawn with notches."], + geojson=[ + "GeoJSON-formatted dict", + "Must contain a Polygon feature collection, with IDs, which are references from `locations`.", + ], + featureidkey=[ + "str (default: `'id'`)", + "Path to field in GeoJSON feature object with which to match the values passed in to `locations`." + "The most common alternative to the default is of the form `'properties.<key>`.", + ], + cumulative=[ + "boolean (default `False`)", + "If `True`, histogram values are cumulative.", + ], + nbins=["int", "Positive integer.", "Sets the number of bins."], + nbinsx=["int", "Positive integer.", "Sets the number of bins along the x axis."], + nbinsy=["int", "Positive integer.", "Sets the number of bins along the y axis."], + branchvalues=[ + "str", + "'total' or 'remainder'", + "Determines how the items in `values` are summed. When" + "set to 'total', items in `values` are taken to be value" + "of all its descendants. When set to 'remainder', items" + "in `values` corresponding to the root and the branches" + ":sectors are taken to be the extra part not part of the" + "sum of the values at their leaves.", + ], + maxdepth=[ + "int", + "Positive integer", + "Sets the number of rendered sectors from any given `level`. Set `maxdepth` to -1 to render all the" + "levels in the hierarchy.", + ], + ecdfnorm=[ + "string or `None` (default `'probability'`)", + "One of `'probability'` or `'percent'`", + "If `None`, values will be raw counts or sums.", + "If `'probability', values will be probabilities normalized from 0 to 1.", + "If `'percent', values will be percentages normalized from 0 to 100.", + ], + ecdfmode=[ + "string (default `'standard'`)", + "One of `'standard'`, `'complementary'` or `'reversed'`", + "If `'standard'`, the ECDF is plotted such that values represent data at or below the point.", + "If `'complementary'`, the CCDF is plotted such that values represent data above the point.", + "If `'reversed'`, a variant of the CCDF is plotted such that values represent data at or above the point.", + ], + text_auto=[ + "bool or string (default `False`)", + "If `True` or a string, the x or y or z values will be displayed as text, depending on the orientation", + "A string like `'.2f'` will be interpreted as a `texttemplate` numeric formatting directive.", + ], +) + + +def make_docstring(fn, override_dict=None, append_dict=None): + override_dict = {} if override_dict is None else override_dict + append_dict = {} if append_dict is None else append_dict + tw = TextWrapper( + width=75, + initial_indent=" ", + subsequent_indent=" ", + break_on_hyphens=False, + ) + result = (fn.__doc__ or "") + "\nParameters\n----------\n" + for param in getfullargspec(fn)[0]: + if override_dict.get(param): + param_doc = list(override_dict[param]) + else: + param_doc = list(docs[param]) + if append_dict.get(param): + param_doc += append_dict[param] + param_desc_list = param_doc[1:] + param_desc = ( + tw.fill(" ".join(param_desc_list or "")) + if param in docs or param in override_dict + else "(documentation missing from map)" + ) + + param_type = param_doc[0] + result += "%s: %s\n%s\n" % (param, param_type, param_desc) + result += "\nReturns\n-------\n" + result += " plotly.graph_objects.Figure" + return result diff --git a/venv/lib/python3.8/site-packages/plotly/express/_imshow.py b/venv/lib/python3.8/site-packages/plotly/express/_imshow.py new file mode 100644 index 0000000..ce6ddb8 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/express/_imshow.py @@ -0,0 +1,605 @@ +import plotly.graph_objs as go +from _plotly_utils.basevalidators import ColorscaleValidator +from ._core import apply_default_cascade, init_figure, configure_animation_controls +from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types +import narwhals.stable.v1 as nw +import numpy as np +import itertools +from plotly.utils import image_array_to_data_uri + +try: + import xarray + + xarray_imported = True +except ImportError: + xarray_imported = False + +_float_types = [] + + +def _vectorize_zvalue(z, mode="max"): + alpha = 255 if mode == "max" else 0 + if z is None: + return z + elif np.isscalar(z): + return [z] * 3 + [alpha] + elif len(z) == 1: + return list(z) * 3 + [alpha] + elif len(z) == 3: + return list(z) + [alpha] + elif len(z) == 4: + return z + else: + raise ValueError( + "zmax can be a scalar, or an iterable of length 1, 3 or 4. " + "A value of %s was passed for zmax." % str(z) + ) + + +def _infer_zmax_from_type(img): + dt = img.dtype.type + rtol = 1.05 + if dt in _integer_types: + return _integer_ranges[dt][1] + else: + im_max = img[np.isfinite(img)].max() + if im_max <= 1 * rtol: + return 1 + elif im_max <= 255 * rtol: + return 255 + elif im_max <= 65535 * rtol: + return 65535 + else: + return 2**32 + + +def imshow( + img, + zmin=None, + zmax=None, + origin=None, + labels={}, + x=None, + y=None, + animation_frame=None, + facet_col=None, + facet_col_wrap=None, + facet_col_spacing=None, + facet_row_spacing=None, + color_continuous_scale=None, + color_continuous_midpoint=None, + range_color=None, + title=None, + template=None, + width=None, + height=None, + aspect=None, + contrast_rescaling=None, + binary_string=None, + binary_backend="auto", + binary_compression_level=4, + binary_format="png", + text_auto=False, +) -> go.Figure: + """ + Display an image, i.e. data on a 2D regular raster. + + Parameters + ---------- + + img: array-like image, or xarray + The image data. Supported array shapes are + + - (M, N): an image with scalar data. The data is visualized + using a colormap. + - (M, N, 3): an image with RGB values. + - (M, N, 4): an image with RGBA values, i.e. including transparency. + + zmin, zmax : scalar or iterable, optional + zmin and zmax define the scalar range that the colormap covers. By default, + zmin and zmax correspond to the min and max values of the datatype for integer + datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For + a multichannel image of floats, the max of the image is computed and zmax is the + smallest power of 256 (1, 255, 65535) greater than this max value, + with a 5% tolerance. For a single-channel image, the max of the image is used. + Overridden by range_color. + + origin : str, 'upper' or 'lower' (default 'upper') + position of the [0, 0] pixel of the image array, in the upper left or lower left + corner. The convention 'upper' is typically used for matrices and images. + + labels : dict with str keys and str values (default `{}`) + Sets names used in the figure for axis titles (keys ``x`` and ``y``), + colorbar title and hoverlabel (key ``color``). The values should correspond + to the desired label to be displayed. If ``img`` is an xarray, dimension + names are used for axis titles, and long name for the colorbar title + (unless overridden in ``labels``). Possible keys are: x, y, and color. + + x, y: list-like, optional + x and y are used to label the axes of single-channel heatmap visualizations and + their lengths must match the lengths of the second and first dimensions of the + img argument. They are auto-populated if the input is an xarray. + + animation_frame: int or str, optional (default None) + axis number along which the image array is sliced to create an animation plot. + If `img` is an xarray, `animation_frame` can be the name of one the dimensions. + + facet_col: int or str, optional (default None) + axis number along which the image array is sliced to create a facetted plot. + If `img` is an xarray, `facet_col` can be the name of one the dimensions. + + facet_col_wrap: int + Maximum number of facet columns. Wraps the column variable at this width, + so that the column facets span multiple rows. + Ignored if `facet_col` is None. + + facet_col_spacing: float between 0 and 1 + Spacing between facet columns, in paper units. Default is 0.02. + + facet_row_spacing: float between 0 and 1 + Spacing between facet rows created when ``facet_col_wrap`` is used, in + paper units. Default is 0.0.7. + + color_continuous_scale : str or list of str + colormap used to map scalar data to colors (for a 2D image). This parameter is + not used for RGB or RGBA images. If a string is provided, it should be the name + of a known color scale, and if a list is provided, it should be a list of CSS- + compatible colors. + + color_continuous_midpoint : number + If set, computes the bounds of the continuous color scale to have the desired + midpoint. Overridden by range_color or zmin and zmax. + + range_color : list of two numbers + If provided, overrides auto-scaling on the continuous color scale, including + overriding `color_continuous_midpoint`. Also overrides zmin and zmax. Used only + for single-channel images. + + title : str + The figure title. + + template : str or dict or plotly.graph_objects.layout.Template instance + The figure template name or definition. + + width : number + The figure width in pixels. + + height: number + The figure height in pixels. + + aspect: 'equal', 'auto', or None + - 'equal': Ensures an aspect ratio of 1 or pixels (square pixels) + - 'auto': The axes is kept fixed and the aspect ratio of pixels is + adjusted so that the data fit in the axes. In general, this will + result in non-square pixels. + - if None, 'equal' is used for numpy arrays and 'auto' for xarrays + (which have typically heterogeneous coordinates) + + contrast_rescaling: 'minmax', 'infer', or None + how to determine data values corresponding to the bounds of the color + range, when zmin or zmax are not passed. If `minmax`, the min and max + values of the image are used. If `infer`, a heuristic based on the image + data type is used. + + binary_string: bool, default None + if True, the image data are first rescaled and encoded as uint8 and + then passed to plotly.js as a b64 PNG string. If False, data are passed + unchanged as a numerical array. Setting to True may lead to performance + gains, at the cost of a loss of precision depending on the original data + type. If None, use_binary_string is set to True for multichannel (eg) RGB + arrays, and to False for single-channel (2D) arrays. 2D arrays are + represented as grayscale and with no colorbar if use_binary_string is + True. + + binary_backend: str, 'auto' (default), 'pil' or 'pypng' + Third-party package for the transformation of numpy arrays to + png b64 strings. If 'auto', Pillow is used if installed, otherwise + pypng. + + binary_compression_level: int, between 0 and 9 (default 4) + png compression level to be passed to the backend when transforming an + array to a png b64 string. Increasing `binary_compression` decreases the + size of the png string, but the compression step takes more time. For most + images it is not worth using levels greater than 5, but it's possible to + test `len(fig.data[0].source)` and to time the execution of `imshow` to + tune the level of compression. 0 means no compression (not recommended). + + binary_format: str, 'png' (default) or 'jpg' + compression format used to generate b64 string. 'png' is recommended + since it uses lossless compression, but 'jpg' (lossy) compression can + result if smaller binary strings for natural images. + + text_auto: bool or str (default `False`) + If `True` or a string, single-channel `img` values will be displayed as text. + A string like `'.2f'` will be interpreted as a `texttemplate` numeric formatting directive. + + Returns + ------- + fig : graph_objects.Figure containing the displayed image + + See also + -------- + + plotly.graph_objects.Image : image trace + plotly.graph_objects.Heatmap : heatmap trace + + Notes + ----- + + In order to update and customize the returned figure, use + `go.Figure.update_traces` or `go.Figure.update_layout`. + + If an xarray is passed, dimensions names and coordinates are used for + axes labels and ticks. + """ + args = locals() + apply_default_cascade(args) + labels = labels.copy() + nslices_facet = 1 + if facet_col is not None: + if isinstance(facet_col, str): + facet_col = img.dims.index(facet_col) + nslices_facet = img.shape[facet_col] + facet_slices = range(nslices_facet) + ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet + nrows = ( + nslices_facet // ncols + 1 + if nslices_facet % ncols + else nslices_facet // ncols + ) + else: + nrows = 1 + ncols = 1 + if animation_frame is not None: + if isinstance(animation_frame, str): + animation_frame = img.dims.index(animation_frame) + nslices_animation = img.shape[animation_frame] + animation_slices = range(nslices_animation) + slice_dimensions = (facet_col is not None) + ( + animation_frame is not None + ) # 0, 1, or 2 + facet_label = None + animation_label = None + img_is_xarray = False + # ----- Define x and y, set labels if img is an xarray ------------------- + if xarray_imported and isinstance(img, xarray.DataArray): + dims = list(img.dims) + img_is_xarray = True + pop_indexes = [] + if facet_col is not None: + facet_slices = img.coords[img.dims[facet_col]].values + pop_indexes.append(facet_col) + facet_label = img.dims[facet_col] + if animation_frame is not None: + animation_slices = img.coords[img.dims[animation_frame]].values + pop_indexes.append(animation_frame) + animation_label = img.dims[animation_frame] + # Remove indices in sorted order. + for index in sorted(pop_indexes, reverse=True): + _ = dims.pop(index) + y_label, x_label = dims[0], dims[1] + # np.datetime64 is not handled correctly by go.Heatmap + for ax in [x_label, y_label]: + if np.issubdtype(img.coords[ax].dtype, np.datetime64): + img.coords[ax] = img.coords[ax].astype(str) + if x is None: + x = img.coords[x_label].values + if y is None: + y = img.coords[y_label].values + if aspect is None: + aspect = "auto" + if labels.get("x", None) is None: + labels["x"] = x_label + if labels.get("y", None) is None: + labels["y"] = y_label + if labels.get("animation_frame", None) is None: + labels["animation_frame"] = animation_label + if labels.get("facet_col", None) is None: + labels["facet_col"] = facet_label + if labels.get("color", None) is None: + labels["color"] = xarray.plot.utils.label_from_attrs(img) + labels["color"] = labels["color"].replace("\n", "<br>") + else: + if hasattr(img, "columns") and hasattr(img.columns, "__len__"): + if x is None: + x = img.columns + if labels.get("x", None) is None and hasattr(img.columns, "name"): + labels["x"] = img.columns.name or "" + if hasattr(img, "index") and hasattr(img.index, "__len__"): + if y is None: + y = img.index + if labels.get("y", None) is None and hasattr(img.index, "name"): + labels["y"] = img.index.name or "" + + if labels.get("x", None) is None: + labels["x"] = "" + if labels.get("y", None) is None: + labels["y"] = "" + if labels.get("color", None) is None: + labels["color"] = "" + if aspect is None: + aspect = "equal" + + # --- Set the value of binary_string (forbidden for pandas) + img = nw.from_native(img, pass_through=True) + if isinstance(img, nw.DataFrame): + if binary_string: + raise ValueError("Binary strings cannot be used with pandas arrays") + is_dataframe = True + else: + is_dataframe = False + + # --------------- Starting from here img is always a numpy array -------- + img = np.asanyarray(img) + # Reshape array so that animation dimension comes first, then facets, then images + if facet_col is not None: + img = np.moveaxis(img, facet_col, 0) + if animation_frame is not None and animation_frame < facet_col: + animation_frame += 1 + facet_col = True + if animation_frame is not None: + img = np.moveaxis(img, animation_frame, 0) + animation_frame = True + args["animation_frame"] = ( + "animation_frame" + if labels.get("animation_frame") is None + else labels["animation_frame"] + ) + iterables = () + if animation_frame is not None: + iterables += (range(nslices_animation),) + if facet_col is not None: + iterables += (range(nslices_facet),) + + # Default behaviour of binary_string: True for RGB images, False for 2D + if binary_string is None: + binary_string = img.ndim >= (3 + slice_dimensions) and not is_dataframe + + # Cast bools to uint8 (also one byte) + if img.dtype == bool: + img = 255 * img.astype(np.uint8) + + if range_color is not None: + zmin = range_color[0] + zmax = range_color[1] + + # -------- Contrast rescaling: either minmax or infer ------------------ + if contrast_rescaling is None: + contrast_rescaling = "minmax" if img.ndim == (2 + slice_dimensions) else "infer" + + # We try to set zmin and zmax only if necessary, because traces have good defaults + if contrast_rescaling == "minmax": + # When using binary_string and minmax we need to set zmin and zmax to rescale the image + if (zmin is not None or binary_string) and zmax is None: + zmax = img.max() + if (zmax is not None or binary_string) and zmin is None: + zmin = img.min() + else: + # For uint8 data and infer we let zmin and zmax to be None if passed as None + if zmax is None and img.dtype != np.uint8: + zmax = _infer_zmax_from_type(img) + if zmin is None and zmax is not None: + zmin = 0 + + # For 2d data, use Heatmap trace, unless binary_string is True + if img.ndim == 2 + slice_dimensions and not binary_string: + y_index = slice_dimensions + if y is not None and img.shape[y_index] != len(y): + raise ValueError( + "The length of the y vector must match the length of the first " + + "dimension of the img matrix." + ) + x_index = slice_dimensions + 1 + if x is not None and img.shape[x_index] != len(x): + raise ValueError( + "The length of the x vector must match the length of the second " + + "dimension of the img matrix." + ) + + texttemplate = None + if text_auto is True: + texttemplate = "%{z}" + elif text_auto is not False: + texttemplate = "%{z:" + text_auto + "}" + + traces = [ + go.Heatmap( + x=x, + y=y, + z=img[index_tup], + coloraxis="coloraxis1", + name=str(i), + texttemplate=texttemplate, + ) + for i, index_tup in enumerate(itertools.product(*iterables)) + ] + autorange = True if origin == "lower" else "reversed" + layout = dict(yaxis=dict(autorange=autorange)) + if aspect == "equal": + layout["xaxis"] = dict(scaleanchor="y", constrain="domain") + layout["yaxis"]["constrain"] = "domain" + colorscale_validator = ColorscaleValidator("colorscale", "imshow") + layout["coloraxis1"] = dict( + colorscale=colorscale_validator.validate_coerce( + args["color_continuous_scale"] + ), + cmid=color_continuous_midpoint, + cmin=zmin, + cmax=zmax, + ) + if labels["color"]: + layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"]) + + # For 2D+RGB data, use Image trace + elif ( + img.ndim >= 3 + and (img.shape[-1] in [3, 4] or slice_dimensions and binary_string) + ) or (img.ndim == 2 and binary_string): + rescale_image = True # to check whether image has been modified + if zmin is not None and zmax is not None: + zmin, zmax = ( + _vectorize_zvalue(zmin, mode="min"), + _vectorize_zvalue(zmax, mode="max"), + ) + x0, y0, dx, dy = (None,) * 4 + error_msg_xarray = ( + "Non-numerical coordinates were passed with xarray `img`, but " + "the Image trace cannot handle it. Please use `binary_string=False` " + "for 2D data or pass instead the numpy array `img.values` to `px.imshow`." + ) + if x is not None: + x = np.asanyarray(x) + if np.issubdtype(x.dtype, np.number): + x0 = x[0] + dx = x[1] - x[0] + else: + error_msg = ( + error_msg_xarray + if img_is_xarray + else ( + "Only numerical values are accepted for the `x` parameter " + "when an Image trace is used." + ) + ) + raise ValueError(error_msg) + if y is not None: + y = np.asanyarray(y) + if np.issubdtype(y.dtype, np.number): + y0 = y[0] + dy = y[1] - y[0] + else: + error_msg = ( + error_msg_xarray + if img_is_xarray + else ( + "Only numerical values are accepted for the `y` parameter " + "when an Image trace is used." + ) + ) + raise ValueError(error_msg) + if binary_string: + if zmin is None and zmax is None: # no rescaling, faster + img_rescaled = img + rescale_image = False + elif img.ndim == 2 + slice_dimensions: # single-channel image + img_rescaled = rescale_intensity( + img, in_range=(zmin[0], zmax[0]), out_range=np.uint8 + ) + else: + img_rescaled = np.stack( + [ + rescale_intensity( + img[..., ch], + in_range=(zmin[ch], zmax[ch]), + out_range=np.uint8, + ) + for ch in range(img.shape[-1]) + ], + axis=-1, + ) + img_str = [ + image_array_to_data_uri( + img_rescaled[index_tup], + backend=binary_backend, + compression=binary_compression_level, + ext=binary_format, + ) + for index_tup in itertools.product(*iterables) + ] + + traces = [ + go.Image(source=img_str_slice, name=str(i), x0=x0, y0=y0, dx=dx, dy=dy) + for i, img_str_slice in enumerate(img_str) + ] + else: + colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" + traces = [ + go.Image( + z=img[index_tup], + zmin=zmin, + zmax=zmax, + colormodel=colormodel, + x0=x0, + y0=y0, + dx=dx, + dy=dy, + ) + for index_tup in itertools.product(*iterables) + ] + layout = {} + if origin == "lower" or (dy is not None and dy < 0): + layout["yaxis"] = dict(autorange=True) + if dx is not None and dx < 0: + layout["xaxis"] = dict(autorange="reversed") + else: + raise ValueError( + "px.imshow only accepts 2D single-channel, RGB or RGBA images. " + "An image of shape %s was provided. " + "Alternatively, 3- or 4-D single or multichannel datasets can be " + "visualized using the `facet_col` or/and `animation_frame` arguments." + % str(img.shape) + ) + + # Now build figure + col_labels = [] + if facet_col is not None: + slice_label = ( + "facet_col" if labels.get("facet_col") is None else labels["facet_col"] + ) + col_labels = [f"{slice_label}={i}" for i in facet_slices] + fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) + for attr_name in ["height", "width"]: + if args[attr_name]: + layout[attr_name] = args[attr_name] + if args["title"]: + layout["title_text"] = args["title"] + elif args["template"].layout.margin.t is None: + layout["margin"] = {"t": 60} + + frame_list = [] + for index, trace in enumerate(traces): + if (facet_col and index < nrows * ncols) or index == 0: + fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1) + if animation_frame is not None: + for i, index in zip(range(nslices_animation), animation_slices): + frame_list.append( + dict( + data=traces[nslices_facet * i : nslices_facet * (i + 1)], + layout=layout, + name=str(index), + ) + ) + if animation_frame: + fig.frames = frame_list + fig.update_layout(layout) + # Hover name, z or color + if binary_string and rescale_image and not np.all(img == img_rescaled): + # we rescaled the image, hence z is not displayed in hover since it does + # not correspond to img values + hovertemplate = "%s: %%{x}<br>%s: %%{y}<extra></extra>" % ( + labels["x"] or "x", + labels["y"] or "y", + ) + else: + if trace["type"] == "heatmap": + hover_name = "%{z}" + elif img.ndim == 2: + hover_name = "%{z[0]}" + elif img.ndim == 3 and img.shape[-1] == 3: + hover_name = "[%{z[0]}, %{z[1]}, %{z[2]}]" + else: + hover_name = "%{z}" + hovertemplate = "%s: %%{x}<br>%s: %%{y}<br>%s: %s<extra></extra>" % ( + labels["x"] or "x", + labels["y"] or "y", + labels["color"] or "color", + hover_name, + ) + fig.update_traces(hovertemplate=hovertemplate) + if labels["x"]: + fig.update_xaxes(title_text=labels["x"], row=1) + if labels["y"]: + fig.update_yaxes(title_text=labels["y"], col=1) + configure_animation_controls(args, go.Image, fig) + fig.update_layout(template=args["template"], overwrite=True) + return fig diff --git a/venv/lib/python3.8/site-packages/plotly/express/_special_inputs.py b/venv/lib/python3.8/site-packages/plotly/express/_special_inputs.py new file mode 100644 index 0000000..c1b3d4d --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/express/_special_inputs.py @@ -0,0 +1,40 @@ +class IdentityMap(object): + """ + `dict`-like object which acts as if the value for any key is the key itself. Objects + of this class can be passed in to arguments like `color_discrete_map` to + use the provided data values as colors, rather than mapping them to colors cycled + from `color_discrete_sequence`. This works for any `_map` argument to Plotly Express + functions, such as `line_dash_map` and `symbol_map`. + """ + + def __getitem__(self, key): + return key + + def __contains__(self, key): + return True + + def copy(self): + return self + + +class Constant(object): + """ + Objects of this class can be passed to Plotly Express functions that expect column + identifiers or list-like objects to indicate that this attribute should take on a + constant value. An optional label can be provided. + """ + + def __init__(self, value, label=None): + self.value = value + self.label = label + + +class Range(object): + """ + Objects of this class can be passed to Plotly Express functions that expect column + identifiers or list-like objects to indicate that this attribute should be mapped + onto integers starting at 0. An optional label can be provided. + """ + + def __init__(self, label=None): + self.label = label diff --git a/venv/lib/python3.8/site-packages/plotly/express/colors/__init__.py b/venv/lib/python3.8/site-packages/plotly/express/colors/__init__.py new file mode 100644 index 0000000..62cd3ca --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/express/colors/__init__.py @@ -0,0 +1,52 @@ +# ruff: noqa: F405 +"""For a list of colors available in `plotly.express.colors`, please see + +* the `tutorial on discrete color sequences <https://plotly.com/python/discrete-color/#color-sequences-in-plotly-express>`_ +* the `list of built-in continuous color scales <https://plotly.com/python/builtin-colorscales/>`_ +* the `tutorial on continuous colors <https://plotly.com/python/colorscales/>`_ + +Color scales are available within the following namespaces + +* cyclical +* diverging +* qualitative +* sequential +""" + +from plotly.colors import * # noqa: F403 + + +__all__ = [ + "named_colorscales", + "cyclical", + "diverging", + "sequential", + "qualitative", + "colorbrewer", + "colorbrewer", + "carto", + "cmocean", + "color_parser", + "colorscale_to_colors", + "colorscale_to_scale", + "convert_colors_to_same_type", + "convert_colorscale_to_rgb", + "convert_dict_colors_to_same_type", + "convert_to_RGB_255", + "find_intermediate_color", + "hex_to_rgb", + "label_rgb", + "make_colorscale", + "n_colors", + "unconvert_from_RGB_255", + "unlabel_rgb", + "validate_colors", + "validate_colors_dict", + "validate_colorscale", + "validate_scale_values", + "plotlyjs", + "DEFAULT_PLOTLY_COLORS", + "PLOTLY_SCALES", + "get_colorscale", + "sample_colorscale", +] diff --git a/venv/lib/python3.8/site-packages/plotly/express/data/__init__.py b/venv/lib/python3.8/site-packages/plotly/express/data/__init__.py new file mode 100644 index 0000000..5096f33 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/express/data/__init__.py @@ -0,0 +1,18 @@ +# ruff: noqa: F405 +"""Built-in datasets for demonstration, educational and test purposes.""" + +from plotly.data import * # noqa: F403 + +__all__ = [ + "carshare", + "election", + "election_geojson", + "experiment", + "gapminder", + "iris", + "medals_wide", + "medals_long", + "stocks", + "tips", + "wind", +] diff --git a/venv/lib/python3.8/site-packages/plotly/express/imshow_utils.py b/venv/lib/python3.8/site-packages/plotly/express/imshow_utils.py new file mode 100644 index 0000000..7f110ed --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/express/imshow_utils.py @@ -0,0 +1,247 @@ +"""Vendored code from scikit-image in order to limit the number of dependencies +Extracted from scikit-image/skimage/exposure/exposure.py +""" + +import numpy as np + +from warnings import warn + +_integer_types = ( + np.byte, + np.ubyte, # 8 bits + np.short, + np.ushort, # 16 bits + np.intc, + np.uintc, # 16 or 32 or 64 bits + np.int_, + np.uint, # 32 or 64 bits + np.longlong, + np.ulonglong, +) # 64 bits +_integer_ranges = {t: (np.iinfo(t).min, np.iinfo(t).max) for t in _integer_types} +dtype_range = { + np.bool_: (False, True), + np.float16: (-1, 1), + np.float32: (-1, 1), + np.float64: (-1, 1), +} +dtype_range.update(_integer_ranges) + + +DTYPE_RANGE = dtype_range.copy() +DTYPE_RANGE.update((d.__name__, limits) for d, limits in dtype_range.items()) +DTYPE_RANGE.update( + { + "uint10": (0, 2**10 - 1), + "uint12": (0, 2**12 - 1), + "uint14": (0, 2**14 - 1), + "bool": dtype_range[np.bool_], + "float": dtype_range[np.float64], + } +) + + +def intensity_range(image, range_values="image", clip_negative=False): + """Return image intensity range (min, max) based on desired value type. + + Parameters + ---------- + image : array + Input image. + range_values : str or 2-tuple, optional + The image intensity range is configured by this parameter. + The possible values for this parameter are enumerated below. + + 'image' + Return image min/max as the range. + 'dtype' + Return min/max of the image's dtype as the range. + dtype-name + Return intensity range based on desired `dtype`. Must be valid key + in `DTYPE_RANGE`. Note: `image` is ignored for this range type. + 2-tuple + Return `range_values` as min/max intensities. Note that there's no + reason to use this function if you just want to specify the + intensity range explicitly. This option is included for functions + that use `intensity_range` to support all desired range types. + + clip_negative : bool, optional + If True, clip the negative range (i.e. return 0 for min intensity) + even if the image dtype allows negative values. + """ + if range_values == "dtype": + range_values = image.dtype.type + + if range_values == "image": + i_min = np.min(image) + i_max = np.max(image) + elif range_values in DTYPE_RANGE: + i_min, i_max = DTYPE_RANGE[range_values] + if clip_negative: + i_min = 0 + else: + i_min, i_max = range_values + return i_min, i_max + + +def _output_dtype(dtype_or_range): + """Determine the output dtype for rescale_intensity. + + The dtype is determined according to the following rules: + - if ``dtype_or_range`` is a dtype, that is the output dtype. + - if ``dtype_or_range`` is a dtype string, that is the dtype used, unless + it is not a NumPy data type (e.g. 'uint12' for 12-bit unsigned integers), + in which case the data type that can contain it will be used + (e.g. uint16 in this case). + - if ``dtype_or_range`` is a pair of values, the output data type will be + float. + + Parameters + ---------- + dtype_or_range : type, string, or 2-tuple of int/float + The desired range for the output, expressed as either a NumPy dtype or + as a (min, max) pair of numbers. + + Returns + ------- + out_dtype : type + The data type appropriate for the desired output. + """ + if type(dtype_or_range) in [list, tuple, np.ndarray]: + # pair of values: always return float. + return np.float_ + if isinstance(dtype_or_range, type): + # already a type: return it + return dtype_or_range + if dtype_or_range in DTYPE_RANGE: + # string key in DTYPE_RANGE dictionary + try: + # if it's a canonical numpy dtype, convert + return np.dtype(dtype_or_range).type + except TypeError: # uint10, uint12, uint14 + # otherwise, return uint16 + return np.uint16 + else: + raise ValueError( + "Incorrect value for out_range, should be a valid image data " + "type or a pair of values, got %s." % str(dtype_or_range) + ) + + +def rescale_intensity(image, in_range="image", out_range="dtype"): + """Return image after stretching or shrinking its intensity levels. + + The desired intensity range of the input and output, `in_range` and + `out_range` respectively, are used to stretch or shrink the intensity range + of the input image. See examples below. + + Parameters + ---------- + image : array + Image array. + in_range, out_range : str or 2-tuple, optional + Min and max intensity values of input and output image. + The possible values for this parameter are enumerated below. + + 'image' + Use image min/max as the intensity range. + 'dtype' + Use min/max of the image's dtype as the intensity range. + dtype-name + Use intensity range based on desired `dtype`. Must be valid key + in `DTYPE_RANGE`. + 2-tuple + Use `range_values` as explicit min/max intensities. + + Returns + ------- + out : array + Image array after rescaling its intensity. This image is the same dtype + as the input image. + + Notes + ----- + .. versionchanged:: 0.17 + The dtype of the output array has changed to match the output dtype, or + float if the output range is specified by a pair of floats. + + See Also + -------- + equalize_hist + + Examples + -------- + By default, the min/max intensities of the input image are stretched to + the limits allowed by the image's dtype, since `in_range` defaults to + 'image' and `out_range` defaults to 'dtype': + + >>> image = np.array([51, 102, 153], dtype=np.uint8) + >>> rescale_intensity(image) + array([ 0, 127, 255], dtype=uint8) + + It's easy to accidentally convert an image dtype from uint8 to float: + + >>> 1.0 * image + array([ 51., 102., 153.]) + + Use `rescale_intensity` to rescale to the proper range for float dtypes: + + >>> image_float = 1.0 * image + >>> rescale_intensity(image_float) + array([0. , 0.5, 1. ]) + + To maintain the low contrast of the original, use the `in_range` parameter: + + >>> rescale_intensity(image_float, in_range=(0, 255)) + array([0.2, 0.4, 0.6]) + + If the min/max value of `in_range` is more/less than the min/max image + intensity, then the intensity levels are clipped: + + >>> rescale_intensity(image_float, in_range=(0, 102)) + array([0.5, 1. , 1. ]) + + If you have an image with signed integers but want to rescale the image to + just the positive range, use the `out_range` parameter. In that case, the + output dtype will be float: + + >>> image = np.array([-10, 0, 10], dtype=np.int8) + >>> rescale_intensity(image, out_range=(0, 127)) + array([ 0. , 63.5, 127. ]) + + To get the desired range with a specific dtype, use ``.astype()``: + + >>> rescale_intensity(image, out_range=(0, 127)).astype(np.int8) + array([ 0, 63, 127], dtype=int8) + + If the input image is constant, the output will be clipped directly to the + output range: + >>> image = np.array([130, 130, 130], dtype=np.int32) + >>> rescale_intensity(image, out_range=(0, 127)).astype(np.int32) + array([127, 127, 127], dtype=int32) + """ + if out_range in ["dtype", "image"]: + out_dtype = _output_dtype(image.dtype.type) + else: + out_dtype = _output_dtype(out_range) + + imin, imax = map(float, intensity_range(image, in_range)) + omin, omax = map( + float, intensity_range(image, out_range, clip_negative=(imin >= 0)) + ) + + if np.any(np.isnan([imin, imax, omin, omax])): + warn( + "One or more intensity levels are NaN. Rescaling will broadcast " + "NaN to the full image. Provide intensity levels yourself to " + "avoid this. E.g. with np.nanmin(image), np.nanmax(image).", + stacklevel=2, + ) + + image = np.clip(image, imin, imax) + + if imin != imax: + image = (image - imin) / (imax - imin) + return np.asarray(image * (omax - omin) + omin, dtype=out_dtype) + else: + return np.clip(image, omin, omax).astype(out_dtype) diff --git a/venv/lib/python3.8/site-packages/plotly/express/trendline_functions/__init__.py b/venv/lib/python3.8/site-packages/plotly/express/trendline_functions/__init__.py new file mode 100644 index 0000000..18ff219 --- /dev/null +++ b/venv/lib/python3.8/site-packages/plotly/express/trendline_functions/__init__.py @@ -0,0 +1,170 @@ +""" +The `trendline_functions` module contains functions which are called by Plotly Express +when the `trendline` argument is used. Valid values for `trendline` are the names of the +functions in this module, and the value of the `trendline_options` argument to PX +functions is passed in as the first argument to these functions when called. + +Note that the functions in this module are not meant to be called directly, and are +exposed as part of the public API for documentation purposes. +""" + +__all__ = ["ols", "lowess", "rolling", "ewm", "expanding"] + + +def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Ordinary Least Squares (OLS) trendline function + + Requires `statsmodels` to be installed. + + This trendline function causes fit results to be stored within the figure, + accessible via the `plotly.express.get_trendline_results` function. The fit results + are the output of the `statsmodels.api.OLS` function. + + Valid keys for the `trendline_options` dict are: + + - `add_constant` (`bool`, default `True`): if `False`, the trendline passes through + the origin but if `True` a y-intercept is fitted. + + - `log_x` and `log_y` (`bool`, default `False`): if `True` the OLS is computed with + respect to the base 10 logarithm of the input. Note that this means no zeros can + be present in the input. + """ + import numpy as np + + valid_options = ["add_constant", "log_x", "log_y"] + for k in trendline_options.keys(): + if k not in valid_options: + raise ValueError( + "OLS trendline_options keys must be one of [%s] but got '%s'" + % (", ".join(valid_options), k) + ) + + import statsmodels.api as sm + + add_constant = trendline_options.get("add_constant", True) + log_x = trendline_options.get("log_x", False) + log_y = trendline_options.get("log_y", False) + + if log_y: + if np.any(y <= 0): + raise ValueError( + "Can't do OLS trendline with `log_y=True` when `y` contains non-positive values." + ) + y = np.log10(y) + y_label = "log10(%s)" % y_label + if log_x: + if np.any(x <= 0): + raise ValueError( + "Can't do OLS trendline with `log_x=True` when `x` contains non-positive values." + ) + x = np.log10(x) + x_label = "log10(%s)" % x_label + if add_constant: + x = sm.add_constant(x) + fit_results = sm.OLS(y, x, missing="drop").fit() + y_out = fit_results.predict() + if log_y: + y_out = np.power(10, y_out) + hover_header = "<b>OLS trendline</b><br>" + if len(fit_results.params) == 2: + hover_header += "%s = %g * %s + %g<br>" % ( + y_label, + fit_results.params[1], + x_label, + fit_results.params[0], + ) + elif not add_constant: + hover_header += "%s = %g * %s<br>" % (y_label, fit_results.params[0], x_label) + else: + hover_header += "%s = %g<br>" % (y_label, fit_results.params[0]) + hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared + return y_out, hover_header, fit_results + + +def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """LOcally WEighted Scatterplot Smoothing (LOWESS) trendline function + + Requires `statsmodels` to be installed. + + Valid keys for the `trendline_options` dict are: + + - `frac` (`float`, default `0.6666666`): the `frac` parameter from the + `statsmodels.api.nonparametric.lowess` function + """ + + valid_options = ["frac"] + for k in trendline_options.keys(): + if k not in valid_options: + raise ValueError( + "LOWESS trendline_options keys must be one of [%s] but got '%s'" + % (", ".join(valid_options), k) + ) + + import statsmodels.api as sm + + frac = trendline_options.get("frac", 0.6666666) + y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1] + hover_header = "<b>LOWESS trendline</b><br><br>" + return y_out, hover_header, None + + +def _pandas(mode, trendline_options, x_raw, y, non_missing): + import numpy as np + + try: + import pandas as pd + except ImportError: + msg = "Trendline requires pandas to be installed" + raise ImportError(msg) + + modes = dict(rolling="Rolling", ewm="Exponentially Weighted", expanding="Expanding") + trendline_options = trendline_options.copy() + function_name = trendline_options.pop("function", "mean") + function_args = trendline_options.pop("function_args", dict()) + + series = pd.Series(np.copy(y), index=x_raw.to_pandas()) + + # TODO: Narwhals Series/DataFrame do not support rolling, ewm nor expanding, therefore + # it fallbacks to pandas Series independently of the original type. + # Plotly issue: https://github.com/plotly/plotly.py/issues/4834 + # Narwhals issue: https://github.com/narwhals-dev/narwhals/issues/1254 + agg = getattr(series, mode) # e.g. series.rolling + agg_obj = agg(**trendline_options) # e.g. series.rolling(**opts) + function = getattr(agg_obj, function_name) # e.g. series.rolling(**opts).mean + y_out = function(**function_args) # e.g. series.rolling(**opts).mean(**opts) + y_out = y_out[non_missing] + hover_header = "<b>%s %s trendline</b><br><br>" % (modes[mode], function_name) + return y_out, hover_header, None + + +def rolling(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Rolling trendline function + + The value of the `function` key of the `trendline_options` dict is the function to + use (defaults to `mean`) and the value of the `function_args` key are taken to be + its arguments as a dict. The remainder of the `trendline_options` dict is passed as + keyword arguments into the `pandas.Series.rolling` function. + """ + return _pandas("rolling", trendline_options, x_raw, y, non_missing) + + +def expanding(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Expanding trendline function + + The value of the `function` key of the `trendline_options` dict is the function to + use (defaults to `mean`) and the value of the `function_args` key are taken to be + its arguments as a dict. The remainder of the `trendline_options` dict is passed as + keyword arguments into the `pandas.Series.expanding` function. + """ + return _pandas("expanding", trendline_options, x_raw, y, non_missing) + + +def ewm(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Exponentially Weighted Moment (EWM) trendline function + + The value of the `function` key of the `trendline_options` dict is the function to + use (defaults to `mean`) and the value of the `function_args` key are taken to be + its arguments as a dict. The remainder of the `trendline_options` dict is passed as + keyword arguments into the `pandas.Series.ewm` function. + """ + return _pandas("ewm", trendline_options, x_raw, y, non_missing) |