aboutsummaryrefslogtreecommitdiff
path: root/venv/lib/python3.8/site-packages/plotly/express/trendline_functions
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.8/site-packages/plotly/express/trendline_functions')
-rw-r--r--venv/lib/python3.8/site-packages/plotly/express/trendline_functions/__init__.py170
1 files changed, 170 insertions, 0 deletions
diff --git a/venv/lib/python3.8/site-packages/plotly/express/trendline_functions/__init__.py b/venv/lib/python3.8/site-packages/plotly/express/trendline_functions/__init__.py
new file mode 100644
index 0000000..18ff219
--- /dev/null
+++ b/venv/lib/python3.8/site-packages/plotly/express/trendline_functions/__init__.py
@@ -0,0 +1,170 @@
+"""
+The `trendline_functions` module contains functions which are called by Plotly Express
+when the `trendline` argument is used. Valid values for `trendline` are the names of the
+functions in this module, and the value of the `trendline_options` argument to PX
+functions is passed in as the first argument to these functions when called.
+
+Note that the functions in this module are not meant to be called directly, and are
+exposed as part of the public API for documentation purposes.
+"""
+
+__all__ = ["ols", "lowess", "rolling", "ewm", "expanding"]
+
+
+def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
+ """Ordinary Least Squares (OLS) trendline function
+
+ Requires `statsmodels` to be installed.
+
+ This trendline function causes fit results to be stored within the figure,
+ accessible via the `plotly.express.get_trendline_results` function. The fit results
+ are the output of the `statsmodels.api.OLS` function.
+
+ Valid keys for the `trendline_options` dict are:
+
+ - `add_constant` (`bool`, default `True`): if `False`, the trendline passes through
+ the origin but if `True` a y-intercept is fitted.
+
+ - `log_x` and `log_y` (`bool`, default `False`): if `True` the OLS is computed with
+ respect to the base 10 logarithm of the input. Note that this means no zeros can
+ be present in the input.
+ """
+ import numpy as np
+
+ valid_options = ["add_constant", "log_x", "log_y"]
+ for k in trendline_options.keys():
+ if k not in valid_options:
+ raise ValueError(
+ "OLS trendline_options keys must be one of [%s] but got '%s'"
+ % (", ".join(valid_options), k)
+ )
+
+ import statsmodels.api as sm
+
+ add_constant = trendline_options.get("add_constant", True)
+ log_x = trendline_options.get("log_x", False)
+ log_y = trendline_options.get("log_y", False)
+
+ if log_y:
+ if np.any(y <= 0):
+ raise ValueError(
+ "Can't do OLS trendline with `log_y=True` when `y` contains non-positive values."
+ )
+ y = np.log10(y)
+ y_label = "log10(%s)" % y_label
+ if log_x:
+ if np.any(x <= 0):
+ raise ValueError(
+ "Can't do OLS trendline with `log_x=True` when `x` contains non-positive values."
+ )
+ x = np.log10(x)
+ x_label = "log10(%s)" % x_label
+ if add_constant:
+ x = sm.add_constant(x)
+ fit_results = sm.OLS(y, x, missing="drop").fit()
+ y_out = fit_results.predict()
+ if log_y:
+ y_out = np.power(10, y_out)
+ hover_header = "<b>OLS trendline</b><br>"
+ if len(fit_results.params) == 2:
+ hover_header += "%s = %g * %s + %g<br>" % (
+ y_label,
+ fit_results.params[1],
+ x_label,
+ fit_results.params[0],
+ )
+ elif not add_constant:
+ hover_header += "%s = %g * %s<br>" % (y_label, fit_results.params[0], x_label)
+ else:
+ hover_header += "%s = %g<br>" % (y_label, fit_results.params[0])
+ hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
+ return y_out, hover_header, fit_results
+
+
+def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
+ """LOcally WEighted Scatterplot Smoothing (LOWESS) trendline function
+
+ Requires `statsmodels` to be installed.
+
+ Valid keys for the `trendline_options` dict are:
+
+ - `frac` (`float`, default `0.6666666`): the `frac` parameter from the
+ `statsmodels.api.nonparametric.lowess` function
+ """
+
+ valid_options = ["frac"]
+ for k in trendline_options.keys():
+ if k not in valid_options:
+ raise ValueError(
+ "LOWESS trendline_options keys must be one of [%s] but got '%s'"
+ % (", ".join(valid_options), k)
+ )
+
+ import statsmodels.api as sm
+
+ frac = trendline_options.get("frac", 0.6666666)
+ y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1]
+ hover_header = "<b>LOWESS trendline</b><br><br>"
+ return y_out, hover_header, None
+
+
+def _pandas(mode, trendline_options, x_raw, y, non_missing):
+ import numpy as np
+
+ try:
+ import pandas as pd
+ except ImportError:
+ msg = "Trendline requires pandas to be installed"
+ raise ImportError(msg)
+
+ modes = dict(rolling="Rolling", ewm="Exponentially Weighted", expanding="Expanding")
+ trendline_options = trendline_options.copy()
+ function_name = trendline_options.pop("function", "mean")
+ function_args = trendline_options.pop("function_args", dict())
+
+ series = pd.Series(np.copy(y), index=x_raw.to_pandas())
+
+ # TODO: Narwhals Series/DataFrame do not support rolling, ewm nor expanding, therefore
+ # it fallbacks to pandas Series independently of the original type.
+ # Plotly issue: https://github.com/plotly/plotly.py/issues/4834
+ # Narwhals issue: https://github.com/narwhals-dev/narwhals/issues/1254
+ agg = getattr(series, mode) # e.g. series.rolling
+ agg_obj = agg(**trendline_options) # e.g. series.rolling(**opts)
+ function = getattr(agg_obj, function_name) # e.g. series.rolling(**opts).mean
+ y_out = function(**function_args) # e.g. series.rolling(**opts).mean(**opts)
+ y_out = y_out[non_missing]
+ hover_header = "<b>%s %s trendline</b><br><br>" % (modes[mode], function_name)
+ return y_out, hover_header, None
+
+
+def rolling(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
+ """Rolling trendline function
+
+ The value of the `function` key of the `trendline_options` dict is the function to
+ use (defaults to `mean`) and the value of the `function_args` key are taken to be
+ its arguments as a dict. The remainder of the `trendline_options` dict is passed as
+ keyword arguments into the `pandas.Series.rolling` function.
+ """
+ return _pandas("rolling", trendline_options, x_raw, y, non_missing)
+
+
+def expanding(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
+ """Expanding trendline function
+
+ The value of the `function` key of the `trendline_options` dict is the function to
+ use (defaults to `mean`) and the value of the `function_args` key are taken to be
+ its arguments as a dict. The remainder of the `trendline_options` dict is passed as
+ keyword arguments into the `pandas.Series.expanding` function.
+ """
+ return _pandas("expanding", trendline_options, x_raw, y, non_missing)
+
+
+def ewm(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
+ """Exponentially Weighted Moment (EWM) trendline function
+
+ The value of the `function` key of the `trendline_options` dict is the function to
+ use (defaults to `mean`) and the value of the `function_args` key are taken to be
+ its arguments as a dict. The remainder of the `trendline_options` dict is passed as
+ keyword arguments into the `pandas.Series.ewm` function.
+ """
+ return _pandas("ewm", trendline_options, x_raw, y, non_missing)