Source code for pymc_marketing.clv.plotting

#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Plotting functions for the CLV module."""

import warnings
from collections.abc import Sequence

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
from matplotlib.lines import Line2D

from pymc_marketing.clv import BetaGeoModel, ParetoNBDModel
from pymc_marketing.clv.utils import _expected_cumulative_transactions

__all__ = [
    "plot_customer_exposure",
    "plot_expected_purchases_over_time",
    "plot_expected_purchases_ppc",
    "plot_frequency_recency_matrix",
    "plot_probability_alive_matrix",
]


[docs] def plot_customer_exposure( df: pd.DataFrame, linewidth: float | None = None, size: float | None = None, labels: Sequence[str] | None = None, colors: Sequence[str] | None = None, padding: float = 0.25, ax: plt.Axes | None = None, ) -> plt.Axes: """Plot the recency and T of DataFrame of customers. Plots customers as horizontal lines with markers representing their recency and T starting. Order is the same as the DataFrame and plotted from the bottom up. The lines are colored by recency and T. Parameters ---------- df : pd.DataFrame A DataFrame with columns "recency" and "T" representing the recency and age of customers. linewidth : float, optional The width of the horizontal lines in the plot. size : float, optional The size of the markers in the plot. labels : Sequence[str], optional A sequence of labels for the legend. Default is ["Recency", "T"]. colors : Sequence[str], optional A sequence of colors for the legend. Default is ["C0", "C1"]. padding : float, optional The padding around the plot. Default is 0.25. ax : plt.Axes, optional A matplotlib axes instance to plot on. If None, a new figure and axes is created. Returns ------- plt.Axes The matplotlib axes instance. Examples -------- Plot customer exposure .. code-block:: python df = pd.DataFrame({"recency": [0, 1, 2, 3, 4], "T": [5, 5, 5, 5, 5]}) plot_customer_exposure(df) Plot customer exposure ordered by recency and T .. code-block:: python (df.sort_values(["recency", "T"]).pipe(plot_customer_exposure)) Plot exposure for only those with time until last purchase is less than 3 .. code-block:: python (df.query("T - recency < 3").pipe(plot_customer_exposure)) """ if padding < 0: raise ValueError("padding must be non-negative") if size is not None and size < 0: raise ValueError("size must be non-negative") if linewidth is not None and linewidth < 0: raise ValueError("linewidth must be non-negative") if ax is None: ax = plt.gca() n = len(df) customer_idx = np.arange(1, n + 1) recency = df["recency"].to_numpy() T = df["T"].to_numpy() if colors is None: colors = ["C0", "C1"] if len(colors) != 2: raise ValueError("colors must be a sequence of length 2") recency_color, T_color = colors ax.hlines( y=customer_idx, xmin=0, xmax=recency, linewidth=linewidth, color=recency_color ) ax.hlines(y=customer_idx, xmin=recency, xmax=T, linewidth=linewidth, color=T_color) ax.scatter(x=recency, y=customer_idx, linewidth=linewidth, s=size, c=recency_color) ax.scatter(x=T, y=customer_idx, linewidth=linewidth, s=size, c=T_color) ax.set( xlabel="Time since first purchase", ylabel="Customer", xlim=(0 - padding, T.max() + padding), ylim=(1 - padding, n + padding), title="Customer Exposure", ) if labels is None: labels = ["Recency", "T"] if len(labels) != 2: raise ValueError("labels must be a sequence of length 2") recency_label, T_label = labels legend_elements = [ Line2D([0], [0], color=recency_color, label=recency_label), Line2D([0], [0], color=T_color, label=T_label), ] ax.legend(handles=legend_elements, loc="best") return ax
def _create_frequency_recency_meshes( max_frequency: int, max_recency: int, ) -> tuple[np.ndarray, np.ndarray]: frequency = np.arange(max_frequency + 1) recency = np.arange(max_recency + 1) mesh_frequency, mesh_recency = np.meshgrid(frequency, recency) return mesh_frequency, mesh_recency
[docs] def plot_frequency_recency_matrix( model: BetaGeoModel | ParetoNBDModel, future_t: int = 1, max_frequency: int | None = None, max_recency: int | None = None, title: str | None = None, xlabel: str = "Customer's Historical Frequency", ylabel: str = "Customer's Recency", ax: plt.Axes | None = None, **kwargs, ) -> plt.Axes: """Plot expected purchases in *future_t* time periods as a heatmap based on customer population *frequency* and *recency*. Parameters ---------- model: CLV model A fitted CLV model. future_t: float, optional Future time periods over which to run predictions. max_frequency: int, optional The maximum *frequency* to plot. Defaults to max observed *frequency*. max_recency: int, optional The maximum *recency* to plot. This also determines the age of the customer. Defaults to max observed *recency*. title: str, optional Figure title xlabel: str, optional Figure xlabel ylabel: str, optional Figure ylabel ax: plt.Axes, optional A matplotlib axes instance. Creates new axes instance by default. kwargs Passed into the matplotlib.imshow command. Returns ------- axes: matplotlib.AxesSubplot """ # noqa: E501 if max_frequency is None: max_frequency = int(model.data["frequency"].max()) if max_recency is None: max_recency = int(model.data["recency"].max()) mesh_frequency, mesh_recency = _create_frequency_recency_meshes( max_frequency=max_frequency, max_recency=max_recency, ) # create dataframe for model input transaction_data = pd.DataFrame( { "customer_id": np.arange(mesh_recency.size), # placeholder "frequency": mesh_frequency.ravel(), "recency": mesh_recency.ravel(), "T": max_recency, } ) # run model predictions to create heatmap values Z = ( model.expected_purchases( data=transaction_data, future_t=future_t, ) .mean(("draw", "chain")) .values.reshape(mesh_recency.shape) ) if ax is None: ax = plt.subplot(111) pcm = ax.imshow(Z, **kwargs) if title is None: title = ( "Expected Number of Future Purchases for {} Unit{} of Time,".format( future_t, "s"[future_t == 1 :] ) + "\nby Frequency and Recency of a Customer" ) ax.set( xlabel=xlabel, ylabel=ylabel, title=title, ) _force_aspect(ax) # plot colorbar beside matrix plt.colorbar(pcm, ax=ax) return ax
[docs] def plot_probability_alive_matrix( model: BetaGeoModel | ParetoNBDModel, max_frequency: int | None = None, max_recency: int | None = None, title: str = "Probability Customer is Alive,\nby Frequency and Recency of a Customer", xlabel: str = "Customer's Historical Frequency", ylabel: str = "Customer's Recency", ax: plt.Axes | None = None, **kwargs, ) -> plt.Axes: """Plot probability alive matrix as a heatmap based on customer population *frequency* and *recency*. Parameters ---------- model: CLV model A fitted CLV model. max_frequency: int, optional The maximum *frequency* to plot. Defaults to max observed *frequency*. max_recency: int, optional The maximum *recency* to plot. This also determines the age of the customer. Defaults to max observed *recency*. title: str, optional Figure title xlabel: str, optional Figure xlabel ylabel: str, optional Figure ylabel ax: plt.Axes, optional A matplotlib axes instance. Creates new axes instance by default. kwargs Passed into the matplotlib.imshow command. Returns ------- axes: matplotlib.AxesSubplot """ if max_frequency is None: max_frequency = int(model.data["frequency"].max()) if max_recency is None: max_recency = int(model.data["recency"].max()) mesh_frequency, mesh_recency = _create_frequency_recency_meshes( max_frequency=max_frequency, max_recency=max_recency, ) # create dataframe for model input transaction_data = pd.DataFrame( { "customer_id": np.arange(mesh_recency.size), # placeholder "frequency": mesh_frequency.ravel(), "recency": mesh_recency.ravel(), "T": max_recency, } ) # run model predictions to create heatmap values Z = ( model.expected_probability_alive(data=transaction_data) .mean(("draw", "chain")) .values.reshape(mesh_recency.shape) ) interpolation = kwargs.pop("interpolation", "none") if ax is None: ax = plt.subplot(111) pcm = ax.imshow(Z, interpolation=interpolation, **kwargs) ax.set( xlabel=xlabel, ylabel=ylabel, title=title, ) _force_aspect(ax) # plot colorbar beside matrix plt.colorbar(pcm, ax=ax) return ax
[docs] def plot_expected_purchases_over_time( model, purchase_history: pd.DataFrame, customer_id_col: str, datetime_col: str, t: int, plot_cumulative: bool = True, t_start_eval: int | None = None, datetime_format: str | None = None, time_unit: str = "D", time_scaler: float | None = 1, sort_purchases: bool | None = True, set_index_date: bool | None = False, title: str | None = None, xlabel: str = "Time Periods", ylabel: str = "Purchases", ax: plt.Axes | None = None, t_unobserved: int | None = None, **kwargs, ) -> plt.Axes: """Plot actual and expected purchases over time for a fitted ``BetaGeoModel`` or ``ParetoNBDModel``. This function is based on the formulation on page 8 of [1]_. Specifically, we take only customers who have made their first purchase before the specified number of ``t`` time periods, and run ``expected_purchases_new_customer()`` for all remaining time periods. Results can be either cumulative or incremental. Adapted from the legacy ``lifetimes`` library: https://github.com/CamDavidsonPilon/lifetimes/blob/master/lifetimes/plotting.py#L392 Parameters ---------- model : A fitted ``BetaGeoModel`` or ``ParetoNBDModel``. purchase_history : ~pandas.DataFrame A Pandas DataFrame containing *customer_id_col* and *datetime_col*. customer_id_col : string Column in the *purchases* DataFrame denoting the *customer_id*. datetime_col : string Column in the *purchases* DataFrame denoting datetimes purchase were made. t : int Number of time units since earliest purchase to include in plot. plot_cumulative : bool Default: *True* Plot cumulative purchases over time. Set to *False* to plot incremental purchases. t_start_eval : int, optional If testing model on unobserved data, specify number of time units in training data to add an indicator for the start of the testing period. datetime_format : string, optional A string that represents the timestamp format. Useful if Pandas doesn't recognize the provided format. time_unit : string, optional Time granularity for study. Default: 'D' for days. Possible values listed here: https://numpy.org/devdocs/reference/arrays.datetime.html#datetime-units time_scaler : int, optional Default: 1. Scales *recency* & *T* to a different time granularity. This is useful for datasets spanning many years, and running predictions in different time scales. sort_purchases : bool, optional Default: *True* If *purchase_history* DataFrame is already sorted in chronological order, set to *False* to improve computational efficiency. set_index_date : bool, optional Set to True to return a dataframe with a datetime index. title : str, optional Figure title xlabel : str, optional Figure xlabel ylabel : str, optional Figure ylabel ax : matplotlib.Axes, optional A matplotlib Axes instance. Creates new axes instance by default. kwargs Additional arguments to pass into the pandas.DataFrame.plot command. Returns ------- axes: matplotlib.AxesSubplot References ---------- .. [1] Fader, Peter S., Bruce G.S. Hardie, and Ka Lok Lee (2005), A Note on Implementing the Pareto/NBD Model in MATLAB. http://brucehardie.com/notes/008/ """ if ax is None: ax = plt.subplot(111) df_cum_purchases = _expected_cumulative_transactions( model=model, transactions=purchase_history, customer_id_col=customer_id_col, datetime_col=datetime_col, t=t, datetime_format=datetime_format, time_unit=time_unit, time_scaler=time_scaler, sort_transactions=sort_purchases, set_index_date=set_index_date, ) if not plot_cumulative: df_cum_purchases = df_cum_purchases.diff() if title is None: title = "Tracking Incremental Transactions" else: if title is None: title = "Tracking Cumulative Transactions" # TODO: After utility func supports xarrays, refactor this for matplotlib API. ax = df_cum_purchases.plot(ax=ax, title=title, **kwargs) if t_unobserved: warnings.warn( "t_unobserved is deprecated and will be removed in a future release. " "Use t_start_eval instead.", DeprecationWarning, stacklevel=1, ) t_start_eval = t_unobserved if t_start_eval: if set_index_date: x_vline = df_cum_purchases.index[int(t_start_eval)] else: x_vline = t_start_eval ax.axvline(x=x_vline, color="r", linestyle="--") ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) return ax
[docs] def plot_expected_purchases_ppc( model, ppc: str = "posterior", max_purchases: int = 10, samples: int = 1000, random_seed: int = 45, ax: plt.Axes | None = None, **kwargs, ) -> plt.Axes: """Plot a prior or posterior predictive check for the customer purchase frequency distribution. ``ParetoNBDModel``, ``BetaGeoBetaBinomModel``, ``BetaGeoModel`` and ``ModifiedBetaGeoModel`` are supported. Adapted from legacy ``lifetimes`` library: https://github.com/CamDavidsonPilon/lifetimes/blob/master/lifetimes/plotting.py#L25 Parameters ---------- model : CLVModel Prior predictive checks can be performed before or after a model is fit. Posterior predictive checks require a fitted model. ppc : string, optional Type of predictive check to perform. Options are 'prior' or 'posterior'; defaults to 'posterior'. max_purchases : int, optional Cutoff for bars of purchase counts to plot. Default is 10. samples : int, optional Number of samples to draw for prior predictive checks. This is not used for posterior predictive checks. random_seed : int, optional Random seed to fix sampling results ax : matplotlib.Axes, optional A matplotlib Axes instance. Creates new axes instance by default. **kwargs Additional arguments to pass into the pandas.DataFrame.plot command. Returns ------- axes : matplotlib.AxesSubplot """ if ax is None: ax = plt.subplot(111) match ppc: case "prior": # build model if it has not been fit yet model.build_model() prior_idata = pm.sample_prior_predictive( draws=samples, model=model.model, random_seed=random_seed, ) # obs_var must be retrieved from prior_idata if model has not been fit obs_freq = prior_idata.observed_data["recency_frequency"].sel( obs_var="frequency" ) ppc_freq = prior_idata.prior_predictive["recency_frequency"].sel( obs_var="frequency" ) title = "Prior Predictive Check for Customer Frequency" case "posterior": obs_freq = model.idata.observed_data["recency_frequency"].sel( obs_var="frequency" ) # Keep samples at 1 here because (chain * draw * customer) samples are already being drawn ppc_freq = model.distribution_new_customer_recency_frequency( random_seed=random_seed, n_samples=1, ).sel(obs_var="frequency") title = "Posterior Predictive Check for Customer Frequency" case _: raise NameError("Specify 'prior' or 'posterior' for 'ppc' parameter.") # convert estimated and observed xarrays into dataframes for plotting estimated = ppc_freq.to_dataframe().value_counts(normalize=True).sort_index() observed = obs_freq.to_dataframe().value_counts(normalize=True).sort_index() # PPC histogram plot ax = pd.DataFrame( { "Estimated": estimated.reset_index()["proportion"].head(max_purchases), "Observed": observed.reset_index()["proportion"].head(max_purchases), }, ).plot( kind="bar", ax=ax, title=title, xlabel="Repeat Purchases", ylabel="% of Customer Population", rot=0.0, **kwargs, ) return ax
def _force_aspect(ax: plt.Axes, aspect=1): im = ax.get_images() extent = im[0].get_extent() ax.set_aspect(abs((extent[1] - extent[0]) / (extent[3] - extent[2])) / aspect)