Source code for pymc_marketing.plot

#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Plot distributions stored in xarray.DataArray across coordinates.

Used to plot the prior and posterior of the various MMM components.

See the :func:`plot_curve` function for more information.

"""

import warnings
from collections.abc import Callable, Generator, Iterable, MutableMapping, Sequence
from itertools import product, repeat
from typing import Any, Concatenate, ParamSpec, cast

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd
import xarray as xr
from matplotlib.axes import Axes
from matplotlib.lines import Line2D
from matplotlib.patches import Patch

Values = Sequence[Any] | npt.NDArray[Any]
Coords = dict[str, Values]


[docs] def get_plot_coords(coords: Coords, non_grid_names: set[str]) -> Coords: """Get the plot coordinates. Parameters ---------- coords : Coords The coordinates to get the plot coordinates from. non_grid_names : set[str] The names to exclude from the grid. Returns ------- Coords The plot coordinates. """ plot_coord_names = list(key for key in coords.keys() if key not in non_grid_names) return {name: np.array(coords[name]) for name in plot_coord_names}
[docs] def drop_scalar_coords(curve: xr.DataArray) -> xr.DataArray: """Remove scalar coordinates from an xarray DataArray. This function identifies and removes scalar coordinates from the given DataArray. Scalar coordinates are those with a single value that are not part of the DataArray's indexes. The function returns a new DataArray with the scalar coordinates removed. Parameters ---------- curve : xr.DataArray The input DataArray from which scalar coordinates will be removed. Returns ------- xr.DataArray A new DataArray with the identified scalar coordinates removed. """ scalar_coords_to_drop = [] for coord, values in curve.coords.items(): if values.size == 1 and coord not in curve.indexes: scalar_coords_to_drop.append(coord) return curve.reset_coords(scalar_coords_to_drop, drop=True)
[docs] def get_total_coord_size(coords: Coords) -> int: """Get the total size of the coordinates. Parameters ---------- coords : Coords The coordinates to get the total size of. Returns ------- int The total size of the coordinates. """ total_size: int = ( 1 if coords == {} else np.prod([len(values) for values in coords.values()]) # type: ignore ) if total_size >= 12: warnings.warn("Large number of coordinates!", stacklevel=2) return total_size
[docs] def create_legend_handles( colors: Iterable[str], alpha: float = 0.5, line: bool = True, patch: bool = True, ) -> list[Line2D | Patch | tuple[Line2D, Patch]]: """Create the legend handles for the given colors. Parameters ---------- colors : Iterable[str] The colors to create the legend handles. alpha : float, optional The alpha value for the patches, by default 0.5. line : bool, optional Whether to include the line, by default True. patch : bool, optional Whether to include the patch, by default True. Returns ------- list[Line2D | Patch | tuple[Line2D, Patch]] The legend handles. """ if not line and not patch: raise ValueError("At least one of line or patch must be True") def create_handle( color: str, alpha: float ) -> Line2D | Patch | tuple[Line2D, Patch]: if line and patch: return Line2D([0], [0], color=color), Patch(color=color, alpha=alpha) if line: return Line2D([0], [0], color=color) return Patch(color=color, alpha=alpha) return [create_handle(color, alpha) for color in colors]
[docs] def set_subplot_kwargs_defaults( subplot_kwargs: MutableMapping[str, Any], total_size: int, ) -> None: """Set the defaults for the subplot kwargs. Parameters ---------- subplot_kwargs : MutableMapping[str, Any] The subplot kwargs to set the defaults for. total_size : int The total size of the coordinates. Raises ------ ValueError If both `ncols` and `nrows` are specified. """ if "ncols" in subplot_kwargs and "nrows" in subplot_kwargs: raise ValueError("Only specify one") if "ncols" not in subplot_kwargs and "nrows" not in subplot_kwargs: subplot_kwargs["ncols"] = total_size if "ncols" in subplot_kwargs: subplot_kwargs["nrows"] = total_size // subplot_kwargs["ncols"] elif "nrows" in subplot_kwargs: subplot_kwargs["ncols"] = total_size // subplot_kwargs["nrows"]
Selection = dict[str, Any]
[docs] def selections( coords: Coords, ) -> Generator[Selection, None, None]: """Create generator of selections. Parameters ---------- coords : Coords The coordinates to create the selections from. Yields ------ dict[str, Any] The selections. """ coord_names = coords.keys() for values in product(*coords.values()): yield {name: value for name, value in zip(coord_names, values, strict=True)}
P = ParamSpec("P") GetPlotData = Callable[[xr.DataArray], xr.DataArray] MakeSelection = Callable[[xr.DataArray, Selection], pd.DataFrame] PlotSelection = Callable[Concatenate[pd.DataFrame, Axes, str, P], Axes] def _get_sample_plot_data(data): return data def _create_make_sample_selection( rng, n: int, n_chains: int, n_draws: int, ) -> MakeSelection: rng = rng or np.random.default_rng() idx = random_samples( rng, n=n, n_chains=n_chains, n_draws=n_draws, ) def make_sample_selection(data, sel): return data.sel(sel).to_series().unstack().loc[idx, :].T return make_sample_selection def _plot_sample_selection(df, ax: Axes, color: str, **plot_kwargs) -> Axes: return df.plot(ax=ax, color=color, **plot_kwargs) def _create_get_hdi_plot_data(hdi_kwargs) -> GetPlotData: def get_plot_data(data: xr.DataArray) -> xr.DataArray: hdi: xr.Dataset = az.hdi(data, **hdi_kwargs) return hdi[data.name] return get_plot_data def _make_hdi_selection(data: xr.DataArray, sel: dict[str, Any]) -> pd.DataFrame: return data.sel(sel).to_series().unstack() def _plot_hdi_selection( df: pd.DataFrame, ax: Axes, color: str, **plot_kwargs, ) -> Axes: ax.fill_between( x=df.index, y1=df["lower"], y2=df["higher"], color=color, **plot_kwargs, ) return ax SelToString = Callable[[Selection], str]
[docs] def random_samples( rng: np.random.Generator, n: int, n_chains: int, n_draws: int, ) -> list[tuple[int, int]]: """Generate random samples from the chains and draws. Parameters ---------- rng : np.random.Generator Random number generator n : int Number of samples to generate n_chains : int Number of chains n_draws : int Number of draws Returns ------- list[tuple[int, int]] The random samples """ combinations = list(product(range(n_chains), range(n_draws))) return [ tuple(pair) for pair in list(rng.choice(combinations, size=n, replace=False)) ]
[docs] def generate_colors(n: int, start: int = 0) -> list[str]: """Generate list of colors. Parameters ---------- n : int Number of colors to generate start : int, optional Starting index, by default 0 Returns ------- list[str] List of colors Examples -------- Generate 5 colors starting from index 1 .. code-block:: python colors = generate_colors(5, start=1) print(colors) # ['C1', 'C2', 'C3', 'C4', 'C5'] """ return [f"C{i}" for i in range(start, start + n)]
def _plot_across_coord( curve: xr.DataArray, non_grid_names: set[str], get_plot_data: GetPlotData, make_selection: MakeSelection, plot_selection: PlotSelection, subplot_kwargs: dict | None = None, axes: npt.NDArray[Axes] | None = None, same_axes: bool = False, colors: Iterable[str] | None = None, legend: bool = False, plot_kwargs: dict[str, Any] | None = None, patch: bool = True, line: bool = True, sel_to_string: SelToString | None = None, ) -> tuple[plt.Figure, npt.NDArray[Axes]]: """Plot data array across coords. Commonality used for the `plot_samples` and `plot_hdi` functions. Differences depending on the `get_plot_data`, `make_selection` and `plot_selection` functions passed. Allows for plotting each coordinate combination on a separate axis or on the same axis. """ if sel_to_string is None: def sel_to_string(sel): return ", ".join(f"{key}={value}" for key, value in sel.items()) curve = drop_scalar_coords(curve) data = get_plot_data(curve) plot_coords = get_plot_coords( data.coords, non_grid_names=non_grid_names.union({"chain", "draw", "hdi"}), ) total_size = get_total_coord_size(plot_coords) if axes is None and not same_axes: subplot_kwargs = subplot_kwargs or {} subplot_kwargs = {**{"sharey": True, "sharex": True}, **subplot_kwargs} set_subplot_kwargs_defaults(subplot_kwargs, total_size) fig, axes = plt.subplots(**subplot_kwargs) axes_iter = np.ravel(axes) return_axes = axes create_title = sel_to_string create_legend_label = None elif axes is not None and same_axes: fig = plt.gcf() axes_iter = repeat(axes[0], total_size) # type: ignore return_axes = np.array([axes]) if not isinstance(axes, np.ndarray) else axes def create_title(sel): return "" create_legend_label = sel_to_string elif axes is None and same_axes: fig, ax = plt.subplots(ncols=1, nrows=1) axes_iter = repeat(ax, total_size) # type: ignore return_axes = np.array([ax]) def create_title(sel): return "" create_legend_label = sel_to_string else: fig = plt.gcf() axes_iter = np.ravel(axes) # type: ignore return_axes = np.array([axes]) if not isinstance(axes, np.ndarray) else axes create_title = sel_to_string # type: ignore create_legend_label = None colors = cast(Iterable[str], colors or generate_colors(n=total_size, start=0)) for color, ax, sel in zip(colors, axes_iter, selections(plot_coords), strict=False): ax = data.pipe(make_selection, sel=sel).pipe( plot_selection, ax=ax, color=color, **plot_kwargs, ) title = create_title(sel) ax.set_title(title) if same_axes and legend and create_legend_label is not None: handles = create_legend_handles(colors, patch=patch, line=line) labels = [create_legend_label(sel) for sel in selections(plot_coords)] ax.legend(handles=handles, labels=labels) return fig, return_axes
[docs] def plot_hdi( curve: xr.DataArray, non_grid_names: str | set[str], hdi_prob: float | None = None, hdi_kwargs: dict | None = None, subplot_kwargs: dict[str, Any] | None = None, plot_kwargs: dict[str, Any] | None = None, axes: npt.NDArray[Axes] | None = None, same_axes: bool = False, colors: Iterable[str] | None = None, legend: bool = False, sel_to_string: SelToString | None = None, ) -> tuple[plt.Figure, npt.NDArray[Axes]]: """Plot hdi of the curve across coords. Parameters ---------- curve : xr.DataArray Curve to plot non_grid_names : str | set[str] The names to exclude from the grid. chain and draw are excluded automatically n : int, optional Number of samples to plot rng : np.random.Generator, optional Random number generator axes : npt.NDArray[plt.Axes], optional Axes to plot on subplot_kwargs : dict, optional Additional kwargs to while creating the fig and axes plot_kwargs : dict, optional Kwargs for the plot function Returns ------- tuple[plt.Figure, npt.NDArray[plt.Axes]] Figure and the axes """ hdi_kwargs = hdi_kwargs or {} hdi_kwargs = {**dict(hdi_prob=hdi_prob), **hdi_kwargs} get_plot_data = _create_get_hdi_plot_data(hdi_kwargs) make_selection = _make_hdi_selection plot_selection = _plot_hdi_selection if isinstance(non_grid_names, str): non_grid_names = {non_grid_names} plot_kwargs = plot_kwargs or {} plot_kwargs = {**{"alpha": 0.25}, **plot_kwargs} return _plot_across_coord( curve=curve, non_grid_names=non_grid_names, get_plot_data=get_plot_data, make_selection=make_selection, plot_selection=plot_selection, subplot_kwargs=subplot_kwargs, same_axes=same_axes, axes=axes, colors=colors, legend=legend, plot_kwargs=plot_kwargs, patch=True, line=False, sel_to_string=sel_to_string, )
[docs] def plot_samples( curve: xr.DataArray, non_grid_names: str | set[str], n: int = 10, rng: np.random.Generator | None = None, axes: npt.NDArray[Axes] | None = None, subplot_kwargs: dict[str, Any] | None = None, plot_kwargs: dict[str, Any] | None = None, same_axes: bool = False, colors: Iterable[str] | None = None, legend: bool = False, sel_to_string: SelToString | None = None, ) -> tuple[plt.Figure, npt.NDArray[Axes]]: """Plot n samples of the curve across coords. Parameters ---------- curve : xr.DataArray Curve to plot non_grid_names : str | set[str] The names to exclude from the grid. chain and draw are excluded automatically n : int, optional Number of samples to plot rng : np.random.Generator, optional Random number generator axes : npt.NDArray[plt.Axes], optional Axes to plot on subplot_kwargs : dict, optional Additional kwargs to while creating the fig and axes plot_kwargs : dict, optional Kwargs for the plot function same_axes : bool All of the plots in the same axis Returns ------- tuple[plt.Figure, npt.NDArray[plt.Axes]] Figure and the axes """ get_plot_data = _get_sample_plot_data if isinstance(non_grid_names, str): non_grid_names = {non_grid_names} n_chains = curve.sizes["chain"] n_draws = curve.sizes["draw"] make_selection = _create_make_sample_selection( rng=rng, n=n, n_chains=n_chains, n_draws=n_draws, ) plot_selection = _plot_sample_selection plot_kwargs = plot_kwargs or {} plot_kwargs = { **{"alpha": 0.3, "legend": False}, **plot_kwargs, } return _plot_across_coord( curve=curve, non_grid_names=non_grid_names, get_plot_data=get_plot_data, make_selection=make_selection, plot_selection=plot_selection, subplot_kwargs=subplot_kwargs, plot_kwargs=plot_kwargs, same_axes=same_axes, axes=axes, colors=colors, legend=legend, patch=False, line=True, sel_to_string=sel_to_string, )
[docs] def plot_curve( curve: xr.DataArray, non_grid_names: str | set[str], n_samples: int = 10, hdi_probs: float | list[float] | None = None, random_seed: np.random.Generator | None = None, subplot_kwargs: dict | None = None, sample_kwargs: dict | None = None, hdi_kwargs: dict | None = None, axes: npt.NDArray[Axes] | None = None, same_axes: bool = False, colors: Iterable[str] | None = None, legend: bool | None = None, sel_to_string: SelToString | None = None, ) -> tuple[plt.Figure, npt.NDArray[Axes]]: """Plot HDI with samples of the curve across coords. Parameters ---------- curve : xr.DataArray Curve to plot non_grid_names : str | set[str] The names to exclude from the grid. HDI and samples both have defaults of hdi and chain, draw, respectively n_samples : int, optional Number of samples hdi_probs : float | list[float], optional HDI probabilities. Defaults to None which uses arviz default for stats.ci_prob which is 94% random_seed : np.random.Generator, optional Random number generator. Defaults to None which uses np.random.default_rng() subplot_kwargs : dict, optional Additional kwargs to while creating the fig and axes sample_kwargs : dict, optional Kwargs for the :func:`plot_samples` function hdi_kwargs : dict, optional Kwargs for the :func:`plot_hdi` function same_axes : bool If all of the plots are on the same axis colors : Iterable[str], optional Colors for the plots legend : bool, optional If to include a legend. Defaults to True if same_axes sel_to_string : Callable[[Selection], str], optional Function to convert selection to a string. Defaults to ", ".join(f"{key}={value}" for key, value in sel.items()) Returns ------- tuple[plt.Figure, npt.NDArray[plt.Axes]] Figure and the axes Examples -------- Plot prior for arbitrary Deterministic in PyMC model .. plot:: :include-source: True :context: reset import numpy as np import pandas as pd import pymc as pm import matplotlib.pyplot as plt from pymc_marketing.plot import plot_curve seed = sum(map(ord, "Arbitrary curve")) rng = np.random.default_rng(seed) dates = pd.date_range("2024-01-01", periods=52, freq="W") coords = {"date": dates, "product": ["A", "B"]} with pm.Model(coords=coords) as model: data = pm.Normal( "data", mu=[-0.5, 0.5], sigma=1, dims=("date", "product"), ) cumsum = pm.Deterministic( "cumsum", data.cumsum(axis=0), dims=("date", "product"), ) idata = pm.sample_prior_predictive(random_seed=rng) curve = idata.prior["cumsum"] fig, axes = plot_curve( curve, "date", subplot_kwargs={"figsize": (15, 5)}, random_seed=rng, ) plt.show() Choose the HDI intervals and number of samples .. plot:: :include-source: True :context: reset fig, axes = plot_curve( curve, "date", n_samples=3, hdi_probs=[0.5, 0.95], random_seed=rng, ) fig.suptitle("Same data but fewer lines and more HDIs") plt.show() Plot same curve on same axes with custom colors .. plot:: :include-source: True :context: close-figs colors = ["red", "blue"] fig, axes = plot_curve( curve, "date", same_axes=True, colors=colors, random_seed=rng, ) axes[0].set(title="Same data but on same axes and custom colors") plt.show() """ curve = drop_scalar_coords(curve) hdi_probs = hdi_probs or None if not isinstance(hdi_probs, list): hdi_probs = [hdi_probs] # type: ignore hdi_kwargs = hdi_kwargs or {} sample_kwargs = sample_kwargs or {} sample_kwargs = {**dict(n=n_samples, rng=random_seed), **sample_kwargs} if "subplot_kwargs" not in sample_kwargs: sample_kwargs["subplot_kwargs"] = subplot_kwargs if "axes" not in sample_kwargs: sample_kwargs["axes"] = axes if same_axes: sample_kwargs["same_axes"] = True sample_kwargs["legend"] = False hdi_kwargs["same_axes"] = True hdi_kwargs["legend"] = legend if isinstance(legend, bool) else True if colors is not None: sample_kwargs["colors"] = colors hdi_kwargs["colors"] = colors if sel_to_string is not None: sample_kwargs["sel_to_string"] = sel_to_string hdi_kwargs["sel_to_string"] = sel_to_string fig, axes = plot_samples( curve, non_grid_names=non_grid_names, **sample_kwargs, ) for hdi_prob in hdi_probs: fig, axes = plot_hdi( curve, hdi_prob=hdi_prob, non_grid_names=non_grid_names, axes=axes, **hdi_kwargs, ) return fig, axes