Source code for pymc_marketing.mmm.linear_trend

#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Linear trend using change points.

Examples
--------
Define a linear trend with 8 changepoints:

.. code-block:: python

    from pymc_marketing.mmm import LinearTrend

    trend = LinearTrend(n_changepoints=8)

Sample the prior for the trend parameters and curve:

.. code-block:: python

    import numpy as np

    seed = sum(map(ord, "Linear Trend"))
    rng = np.random.default_rng(seed)

    prior = trend.sample_prior(random_seed=rng)
    curve = trend.sample_curve(prior)

Plot the curve samples:

.. code-block:: python

    _, axes = trend.plot_curve(curve, random_seed=rng)
    ax = axes[0]
    ax.set(
        xlabel="Time",
        ylabel="Trend",
        title=f"Linear Trend with {trend.n_changepoints} Change Points",
    )

.. image:: /_static/linear-trend-prior.png
    :alt: LinearTrend prior

"""

from collections.abc import Iterable
from typing import Any, cast

import numpy as np
import numpy.typing as npt
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from pydantic import BaseModel, ConfigDict, Field, InstanceOf, model_validator
from pymc.distributions.shape_utils import Dims
from pytensor.tensor.variable import TensorVariable
from typing_extensions import Self

from pymc_marketing.plot import SelToString, plot_curve
from pymc_marketing.prior import Prior, create_dim_handler


[docs] class LinearTrend(BaseModel): r"""LinearTrend class. Linear trend component using change points. The trend is defined as: .. math:: f(t) = k + \sum_{m=0}^{M-1} \delta_m I(t > s_m) where: - :math:`t \ge 0`, - :math:`k` is the base intercept, - :math:`\delta_m` is the change in the trend at change point :math:`m`, - :math:`I` is the indicator function, - :math:`s_m` is the change point. The change points are defined as: .. math:: s_m = \frac{m}{M-1} T, 0 \le m \le M-1 where :math:`M` is the number of change points (:math:`M>1`) and :math:`T` is the time of the last observed data point. The priors for the trend parameters are: - :math:`k \sim \text{Normal}(0, 0.05)` - :math:`\delta_m \sim \text{Laplace}(0, 0.25)` Parameters ---------- priors : dict[str, Prior], optional Dictionary with the priors for the trend parameters. The dictionary must have 'delta' key. If `include_intercept` is True, the 'k' key is also required. By default None, or the default priors. dims : Dims, optional Dimensions of the parameters, by default None or empty. n_changepoints : int, optional Number of changepoints, by default 10. include_intercept : bool, optional Include an intercept in the trend, by default False Examples -------- Linear trend with 10 changepoints: .. code-block:: python from pymc_marketing.mmm import LinearTrend trend = LinearTrend(n_changepoints=10) Use the trend in a model: .. code-block:: python import pymc as pm import numpy as np import pandas as pd n_years = 3 n_dates = 52 * n_years first_date = "2020-01-01" dates = pd.date_range(first_date, periods=n_dates, freq="W-MON") dayofyear = dates.dayofyear.to_numpy() t = (dates - dates[0]).days.to_numpy() t = t / 365.25 coords = {"date": dates} with pm.Model(coords=coords) as model: intercept = pm.Normal("intercept", mu=0, sigma=1) mu = intercept + trend.apply(t) sigma = pm.Gamma("sigma", mu=0.1, sigma=0.025) pm.Normal("obs", mu=mu, sigma=sigma, dims="date") Hierarchical LinearTrend via hierarchical prior: .. code-block:: python from pymc_marketing.prior import Prior hierarchical_delta = Prior( "Laplace", mu=Prior("Normal", dims="changepoint"), b=Prior("HalfNormal", dims="changepoint"), dims=("changepoint", "geo"), ) priors = dict(delta=hierarchical_delta) hierarchical_trend = LinearTrend( priors=priors, n_changepoints=10, dims="geo", ) Sample the hierarchical trend: .. code-block:: python seed = sum(map(ord, "Hierarchical LinearTrend")) rng = np.random.default_rng(seed) coords = {"geo": ["A", "B"]} prior = hierarchical_trend.sample_prior( coords=coords, random_seed=rng, ) curve = hierarchical_trend.sample_curve(prior) Plot the curve HDI and samples: .. code-block:: python fig, axes = hierarchical_trend.plot_curve( curve, n_samples=3, random_seed=rng, ) fig.suptitle("Hierarchical Linear Trend") axes[0].set(ylabel="Trend", xlabel="Time") axes[1].set(xlabel="Time") .. image:: /_static/hierarchical-linear-trend-prior.png :alt: Hierarchical LinearTrend prior References ---------- Adapted from MBrouns/timeseers package: https://github.com/MBrouns/timeseers/blob/master/src/timeseers/linear_trend.py """ priors: InstanceOf[dict[str, Prior]] = Field( None, description="Priors for the trend parameters.", ) dims: tuple[str, ...] | InstanceOf[Dims] | str | None = Field( None, description="The additional dimensions for the trend.", ) n_changepoints: int = Field( 10, description="Number of changepoints.", ge=1, ) include_intercept: bool = Field( False, description="Include an intercept in the trend.", ) model_config = ConfigDict(extra="forbid") @model_validator(mode="after") def _dims_is_tuple(self) -> Self: dims = self.dims if isinstance(dims, str): self.dims = (dims,) self.dims: tuple[str] = self.dims or () return self @model_validator(mode="after") def _priors_are_set(self) -> Self: self.priors = self.priors or self.default_priors.copy() return self @model_validator(mode="after") def _check_parameters(self) -> Self: required_parameters = set(self.default_priors.keys()) if set(self.priors.keys()) > required_parameters: msg = f"Invalid priors. The required parameters are {required_parameters}." raise ValueError(msg) return self @model_validator(mode="after") def _check_dims_are_subsets(self) -> Self: allowed_dims = {"changepoint"}.union(cast(Dims, self.dims)) if not all(set(prior.dims) <= allowed_dims for prior in self.priors.values()): msg = "Invalid dimensions in the priors." raise ValueError(msg) return self @property def default_priors(self) -> dict[str, Prior]: """Default priors for the trend parameters. Returns ------- dict[str, Prior] Dictionary with the default priors. """ priors = { "delta": Prior( "Laplace", mu=0, b=0.25, dims="changepoint", ), } if self.include_intercept: priors["k"] = Prior("Normal", mu=0, sigma=0.05) return priors @property def non_broadcastable_dims(self) -> tuple[str, ...]: """Get the dimensions of the trend that are not just broadcastable. Returns ------- tuple[str, ...] Tuple with the dimensions of the trend. """ dims = set() for prior in self.priors.values(): dims.update(prior.dims) dims = dims.difference({"changepoint"}) return tuple(dim for dim in cast(tuple[str, ...], self.dims) if dim in dims)
[docs] def apply(self, t: pt.TensorLike) -> TensorVariable: """Create the linear trend for the given x values. Parameters ---------- t : pt.TensorLike 1D array of strictly increasing time values for the trend starting from 0. Returns ------- pt.TensorVariable TensorVariable with the trend values. """ dims = cast(Dims, self.dims) model = pm.modelcontext(None) model.add_coord("changepoint", range(self.n_changepoints)) DUMMY_DIM = "DATE" out_dims = (DUMMY_DIM, "changepoint", *dims) dim_handler = create_dim_handler(desired_dims=out_dims) # (changepoints, ) s = pt.linspace(0, pt.max(t).eval(), self.n_changepoints) s.type.shape = (self.n_changepoints,) s = dim_handler( s, ("changepoint",), ) # (dates, changepoints) A = (dim_handler(t, (DUMMY_DIM,)) > s) * 1.0 delta_dist = self.priors["delta"] delta = dim_handler( delta_dist.create_variable("delta"), delta_dist.dims, ) k_dim_handler = create_dim_handler((DUMMY_DIM, *dims)) first = (A * delta).sum(axis=1) * k_dim_handler(t, (DUMMY_DIM,)) if self.include_intercept: # (additional_groups) k_dist = self.priors["k"] k = k_dim_handler( k_dist.create_variable("k"), k_dist.dims, ) first += k gamma = -s * delta second = (A * gamma).sum(axis=1) return first + second
[docs] def sample_prior( self, coords=None, **sample_prior_predictive_kwargs, ) -> xr.Dataset: """Sample the prior for the parameters used in the trend. Parameters ---------- coords : dict, optional Coordinates in the priors, by default includes the changepoints. sample_prior_predictive_kwargs : dict, optional Keyword arguments for the `pm.sample_prior_predictive` function. Returns ------- xr.Dataset Dataset with the prior samples. """ coords = coords or {} coords["changepoint"] = range(self.n_changepoints) with pm.Model(coords=coords): for key, param in self.priors.items(): param.create_variable(key) return pm.sample_prior_predictive(**sample_prior_predictive_kwargs).prior
[docs] def sample_curve( self, parameters: xr.Dataset, max_value: float = 1.0, ) -> xr.DataArray: """Sample the curve given parameters. Parameters ---------- parameters : xr.Dataset Dataset with the parameters to condition on. Would be either the prior or the posterior. Returns ------- xr.DataArray DataArray with the curve samples. """ t = np.linspace(0, max_value, 100) coords: dict[str, Any] = {"t": t} for name in self.priors.keys(): for key, values in parameters[name].coords.items(): if key in {"chain", "draw"}: continue coords[key] = values.to_numpy() with pm.Model(coords=coords): name = "trend" pm.Deterministic( name, self.apply(t), dims=("t", *cast(Dims, self.dims)), ) return pm.sample_posterior_predictive( parameters, var_names=[name], ).posterior_predictive[name]
[docs] def plot_curve( self, curve: xr.DataArray, n_samples: int = 10, hdi_probs: float | list[float] | None = None, random_seed: np.random.Generator | None = None, subplot_kwargs: dict | None = None, sample_kwargs: dict | None = None, hdi_kwargs: dict | None = None, include_changepoints: bool = True, axes: npt.NDArray[Axes] | None = None, same_axes: bool = False, colors: Iterable[str] | None = None, legend: bool | None = None, sel_to_string: SelToString | None = None, ) -> tuple[Figure, npt.NDArray[Axes]]: """Plot the curve samples from the trend. Parameters ---------- curve : xr.DataArray DataArray with the curve samples. n_samples : int, optional Number of samples hdi_probs : float | list[float], optional HDI probabilities. Defaults to None which uses arviz default for stats.ci_prob which is 94% random_seed : int | random number generator, optional Random number generator. Defaults to None subplot_kwargs : dict, optional Keyword arguments for the subplots, by default None. sample_kwargs : dict, optional Keyword arguments for the samples, by default None. hdi_kwargs : dict, optional Keyword arguments for the HDI, by default None. include_changepoints : bool, optional Include the change points in the plot, by default True. axes : npt.NDArray[plt.Axes], optional Axes to plot the curve, by default None. same_axes : bool, optional Use the same axes for the samples, by default False. colors : Iterable[str], optional Colors for the samples, by default None. legend : bool, optional Include a legend in the plot, by default None. sel_to_string : SelToString, optional Function to convert the selection to a string, by default None. Returns ------- tuple[plt.Figure, npt.NDArray[plt.Axes]] Tuple with the figure and the axes. """ fig, axes = plot_curve( curve, {"t"}, n_samples=n_samples, hdi_probs=hdi_probs, random_seed=random_seed, subplot_kwargs=subplot_kwargs, sample_kwargs=sample_kwargs, hdi_kwargs=hdi_kwargs, axes=axes, same_axes=same_axes, colors=colors, legend=legend, sel_to_string=sel_to_string, ) if not include_changepoints: return fig, axes max_value = curve.coords["t"].max().item() for ax in np.ravel(axes): for i in range(0, self.n_changepoints): ax.axvline( max_value * i / (self.n_changepoints - 1), color="gray", linestyle="--", ) return fig, axes