MMM Multidimensional Example Notebook#

In this notebook, we present an new experimental media mix model class to create multidimensional and customized marketing mix models. To showcase its capabilities, we extend the MMM Example Notebook simulation to create a multidimensional hierarchical model.

Warning

Even though the new MMM class is an experimental class, it is fully functional and can be used to create multidimensional marketing mix models. This model is under active development and will be further improved in the future (feedback welcome!).

Prepare Notebook#

import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns

from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.additive_effect import LinearTrendEffect
from pymc_marketing.mmm.linear_trend import LinearTrend
from pymc_marketing.mmm.multidimensional import (
    MMM,
    MultiDimensionalBudgetOptimizerWrapper,
)
from pymc_marketing.paths import data_dir
from pymc_marketing.prior import Prior

warnings.filterwarnings("ignore", category=UserWarning)

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["xtick.labelsize"] = 10
plt.rcParams["ytick.labelsize"] = 8

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
/Users/carlostrujillo/Documents/GitHub/pymc-marketing/pymc_marketing/mmm/multidimensional.py:68: FutureWarning: This functionality is experimental and subject to change. If you encounter any issues or have suggestions, please raise them at: https://github.com/pymc-labs/pymc-marketing/issues/new
  warnings.warn(warning_msg, FutureWarning, stacklevel=1)
seed: int = sum(map(ord, "mmm_multidimensional"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

Read Data#

We read the simulated data from the MMM Example Notebook.

data_path = data_dir / "mmm_example.csv"

raw_data_df = pd.read_csv(data_path, parse_dates=["date_week"])

raw_data_df = raw_data_df.rename(columns={"date_week": "date"})

raw_data_df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 179 entries, 0 to 178
Data columns (total 8 columns):
 #   Column     Non-Null Count  Dtype         
---  ------     --------------  -----         
 0   date       179 non-null    datetime64[ns]
 1   y          179 non-null    float64       
 2   x1         179 non-null    float64       
 3   x2         179 non-null    float64       
 4   event_1    179 non-null    float64       
 5   event_2    179 non-null    float64       
 6   dayofyear  179 non-null    int64         
 7   t          179 non-null    int64         
dtypes: datetime64[ns](1), float64(5), int64(2)
memory usage: 11.3 KB

To generate a multidimensional dataset, we simply duplicate and perturb the simulated target data with noise. We assign the different groups to different geography (geos) levels. We keep the same media variables for both geos.

# Create copies of the original data foo both geos
a_data_df = raw_data_df.copy().assign(geo="geo_a")
b_data_df = raw_data_df.copy().assign(geo="geo_b")

# Add noise to the target variable for the second geo
b_data_df["y"] = b_data_df["y"] + 500 * rng.normal(size=len(b_data_df))

# Concatenate the two datasets
data_df = pd.concat([a_data_df, b_data_df])

data_df.head()
date y x1 x2 event_1 event_2 dayofyear t geo
0 2018-04-02 3984.662237 0.318580 0.0 0.0 0.0 92 0 geo_a
1 2018-04-09 3762.871794 0.112388 0.0 0.0 0.0 99 1 geo_a
2 2018-04-16 4466.967388 0.292400 0.0 0.0 0.0 106 2 geo_a
3 2018-04-23 3864.219373 0.071399 0.0 0.0 0.0 113 3 geo_a
4 2018-04-30 4441.625278 0.386745 0.0 0.0 0.0 120 4 geo_a

Let’s plot the target variable for each geo to visually inspect the difference.

g = sns.relplot(
    data=data_df,
    x="date",
    y="y",
    color="black",
    col="geo",
    col_wrap=1,
    kind="line",
    height=4,
    aspect=3,
)
../../_images/9c2b79472ff9f40548ba2d192440b267ef4b49f296d9d621b50235f79bc457af.png

Overall, the targets follow a similar global pattern but the short term fluctuations are different.

Prior Specification#

Here is where we see the main power of the new class. We can now specify different priors for different geos and even build custom hierarchies.

Recall from the MMM Example Notebook that we wan to use the spend shares as a prior for the beta parameters in the saturation function. Let’s compute those shares across geos.

channel_columns = ["x1", "x2"]
n_channels = len(channel_columns)

sum_spend_geo_channel = data_df.groupby(["geo"]).agg({"x1": "sum", "x2": "sum"})

spend_share = (
    sum_spend_geo_channel.to_numpy() / sum_spend_geo_channel.sum(axis=1).to_numpy()
)

prior_sigma = n_channels * spend_share

Now we are ready to specify the priors for the saturation function.

For the beta parameters, we use a half-normal distribution with a sigma parameter that is computed as a function of the spend shares. For illustrative purposes, we define the beta values to be independent across both channels and geos.

For the lambda parameters, we add a channel-level hierarchical structure through the hyperparameters of the gamma prior (mu and sigma). These hyperparameters vary by channel, while the lambda values themselves vary across both channel and geo. This setup enables partial pooling across geos within the same channel, allowing the model to share information across geos.

If you need an introduction on Bayesian hierarchical models, check out the comprehensive example “A Primer on Bayesian Methods for Multilevel Modeling” in the PyMC documentation. Please note the dims argument in the priors is used to specify the dimensions of the distribution. Here you can control the dimensions along which the hierarchies are defined.

saturation = LogisticSaturation(
    priors={
        "beta": Prior("HalfNormal", sigma=prior_sigma, dims=("channel", "geo")),
        "lam": Prior(
            "Gamma",
            mu=Prior("LogNormal", mu=np.log(3), sigma=np.log(1.5), dims="channel"),
            sigma=Prior("LogNormal", mu=np.log(1), sigma=np.log(1.5), dims="channel"),
            dims=("channel", "geo"),
        ),
    }
)

saturation.model_config
{'saturation_lam': Prior("Gamma", mu=Prior("LogNormal", mu=1.0986122886681098, sigma=0.4054651081081644, dims="channel"), sigma=Prior("LogNormal", mu=0.0, sigma=0.4054651081081644, dims="channel"), dims=("channel", "geo")),
 'saturation_beta': Prior("HalfNormal", sigma=[[1.31263903 0.68736097]
  [1.31263903 0.68736097]], dims=("channel", "geo"))}

For the adstock parameters we do not add any hierarchical structure. We simply keep the same prior for all the geos.

adstock = GeometricAdstock(
    priors={"alpha": Prior("Beta", alpha=2, beta=3, dims="channel")}, l_max=8
)

adstock.model_config
{'adstock_alpha': Prior("Beta", alpha=2, beta=3, dims="channel")}

We complete the model specification with similar priors as in the MMM Example Notebook. Please be aware on how to specify the priors dimensions.

model_config = {
    "intercept": Prior("Normal", mu=0.5, sigma=0.5, dims="geo"),
    "gamma_control": Prior("Normal", mu=0, sigma=0.5, dims="control"),
    "gamma_fourier": Prior(
        "Laplace", mu=0, b=Prior("HalfNormal", sigma=0.2), dims=("geo", "fourier_mode")
    ),
    "likelihood": Prior(
        "TruncatedNormal",
        lower=0,
        sigma=Prior("HalfNormal", sigma=Prior("HalfNormal", sigma=1.5)),
        dims=("date", "geo"),
    ),
}

Model Definition#

We are now ready to define the model class. The API is very similar to the one in the MMM Example Notebook.

# Base MMM model specification
mmm = MMM(
    date_column="date",
    target_column="y",
    channel_columns=["x1", "x2"],
    control_columns=["event_1", "event_2"],
    dims=("geo",),
    scaling={
        "channel": {"method": "max", "dims": ()},
        "target": {"method": "max", "dims": ()},
    },
    adstock=adstock,
    saturation=saturation,
    yearly_seasonality=2,
    model_config=model_config,
)

Tip

Observe we have the following two new arguments:

  • dims: a tuple of strings that specify the dimensions of the model.

  • scaling: a dictionary that specifies the scaling method and dimensions for the target and media variables. In this case we leave the dimensions empty as we want to scale the target variable for each geo (see details below).

We can add additional components to the model mean component. Here, for example, we add a hierarchical linear trend component (with changepoints).

linear_trend = LinearTrend(
    priors={
        "delta": Prior(
            "Laplace",
            mu=0,
            b=Prior("HalfNormal", sigma=0.2),
            dims=("changepoint", "geo"),
        ),
    },
    n_changepoints=5,
    include_intercept=False,
    dims=("geo"),
)

linear_trend_effect = LinearTrendEffect(linear_trend, prefix="trend")

mmm.mu_effects.append(linear_trend_effect)

We can now prepare the training data.

x_train = data_df.drop(columns=["y"])

y_train = data_df["y"]

To build the model, we need to specify the training data and the target variables.

Tip

We do not need to build the model, we can simply fit the model. This is just to inspect the model structure.

mmm.build_model(X=x_train, y=y_train)

Let’s look into the model graph:

pm.model_to_graphviz(mmm.model)
../../_images/1220a47d154e168d244fc8c90730138c2192c29678f7d2a27e30b48f5d625aa1.svg

It is great to see that the model automatically vectorizes and creates the expected hierarchies and dimensions 🚀!

As we are scaling our data internally, we can add deterministic terms to recover the component contributions in the original scale.

mmm.add_original_scale_contribution_variable(
    var=[
        "channel_contribution",
        "control_contribution",
        "intercept_contribution",
        "yearly_seasonality_contribution",
        "y",
    ]
)

pm.model_to_graphviz(mmm.model)
../../_images/6ee147246b327da70e49ed41c556c549957b78ee125ad13f735e97d91eb315d0.svg

Coming back to the scalers, we can get them as an xarray dataset.

scalers = mmm.get_scales_as_xarray()

scalers
{'channel_scale': <xarray.DataArray '_channel' (geo: 2, channel: 2)> Size: 32B
 array([[0.99665813, 0.99437431],
        [0.99665813, 0.99437431]])
 Coordinates:
   * geo      (geo) object 16B 'geo_a' 'geo_b'
   * channel  (channel) object 16B 'x1' 'x2',
 'target_scale': <xarray.DataArray '_target' (geo: 2)> Size: 16B
 array([8312.40754439, 8440.6617456 ])
 Coordinates:
   * geo      (geo) object 16B 'geo_a' 'geo_b'}

As expected, from the model definition, we have scalers for the target and media variables across geos.

Prior Predictive Checks#

Before fitting the model, we can inspect the prior predictive distribution.

prior_predictive = mmm.sample_prior_predictive(X=x_train, y=y_train, samples=1_000)
Sampling: [adstock_alpha, delta, delta_b, gamma_control, gamma_fourier, gamma_fourier_b, intercept_contribution, saturation_beta, saturation_lam, saturation_lam_mu, saturation_lam_sigma, y, y_sigma, y_sigma_sigma]
g = sns.relplot(
    data=data_df,
    x="date",
    y="y",
    color="black",
    col="geo",
    col_wrap=1,
    kind="line",
    height=4,
    aspect=3,
)

axes = g.axes.flatten()

for ax, geo in zip(axes, mmm.model.coords["geo"], strict=True):
    az.plot_hdi(
        x=mmm.model.coords["date"],
        # We need to scale the prior predictive to the original scale to make it comparable to the data.
        y=(
            prior_predictive.sel(geo=geo)["y"].unstack().transpose(..., "date")
            * scalers["target_scale"].sel(geo=geo).item()
        ),
        smooth=False,
        color="C0",
        hdi_prob=0.94,
        fill_kwargs={"alpha": 0.3, "label": "94% HDI"},
        ax=ax,
    )
    az.plot_hdi(
        x=mmm.model.coords["date"],
        y=(
            prior_predictive.sel(geo=geo)["y"].unstack().transpose(..., "date")
            * scalers["target_scale"].sel(geo=geo).item()
        ),
        smooth=False,
        color="C0",
        hdi_prob=0.5,
        fill_kwargs={"alpha": 0.5, "label": "50% HDI"},
        ax=ax,
    )
    ax.legend(loc="upper left")

g.figure.suptitle("Prior Predictive", fontsize=16, fontweight="bold", y=1.03);
../../_images/c81136208b59aaac7ebbb360b3b8318f5cbd0ad495d5126089c1a00b28e7dfa0.png

The prior predictive distribution looks good and not too restrictive.

Model Fitting#

We can now fit the model and generate the posterior predictive distribution.

mmm.fit(
    X=x_train,
    y=y_train,
    chains=4,
    target_accept=0.8,
    random_seed=rng,
)

mmm.sample_posterior_predictive(
    X=x_train,
    extend_idata=True,
    combined=True,
    random_seed=rng,
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept_contribution, adstock_alpha, saturation_lam_mu, saturation_lam_sigma, saturation_lam, saturation_beta, gamma_control, gamma_fourier_b, gamma_fourier, delta_b, delta, y_sigma_sigma, y_sigma]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 43 seconds.
Sampling: [y]

<xarray.Dataset> Size: 23MB
Dimensions:           (date: 179, geo: 2, sample: 4000)
Coordinates:
  * date              (date) datetime64[ns] 1kB 2018-04-02 ... 2021-08-30
  * geo               (geo) <U5 40B 'geo_a' 'geo_b'
  * sample            (sample) object 32kB MultiIndex
  * chain             (sample) int64 32kB 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3
  * draw              (sample) int64 32kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables:
    y                 (date, geo, sample) float64 11MB 0.4939 0.4977 ... 0.6185
    y_original_scale  (date, geo, sample) float64 11MB 4.106e+03 ... 5.221e+03
Attributes:
    created_at:                 2025-06-17T17:05:21.560445+00:00
    arviz_version:              0.21.0
    inference_library:          pymc
    inference_library_version:  5.22.0

The sampling looks good, no divergences and the r-hat values are close to \(1\).

mmm.idata.sample_stats.diverging.sum().item()
../../_images/dbf0b6f80da8feba700ef7646b6fbd12215ad4f45e20abc474597724cf3687c5.png
az.summary(
    mmm.idata,
    var_names=[
        "adstock_alpha",
        "delta",
        "delta_b",
        "gamma_control",
        "gamma_fourier",
        "intercept_contribution",
        "saturation_beta",
        "saturation_lam_mu",
        "saturation_lam_sigma",
        "saturation_lam",
        "y_sigma_sigma",
        "y_sigma",
    ],
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
adstock_alpha[x1] 0.398 0.034 0.336 0.463 0.001 0.000 3592.0 3027.0 1.00
adstock_alpha[x2] 0.173 0.042 0.099 0.253 0.001 0.001 1975.0 2545.0 1.00
delta[0, geo_a] 0.110 0.051 0.014 0.208 0.001 0.001 2546.0 1932.0 1.00
delta[0, geo_b] 0.104 0.053 0.002 0.202 0.001 0.001 2803.0 1858.0 1.00
delta[1, geo_a] 0.019 0.065 -0.104 0.159 0.001 0.001 2444.0 2024.0 1.00
delta[1, geo_b] 0.009 0.067 -0.113 0.144 0.001 0.002 2879.0 1888.0 1.00
delta[2, geo_a] -0.041 0.059 -0.165 0.055 0.001 0.001 2871.0 2127.0 1.00
delta[2, geo_b] -0.030 0.059 -0.147 0.075 0.001 0.001 3289.0 2170.0 1.00
delta[3, geo_a] 0.013 0.067 -0.114 0.147 0.001 0.002 3406.0 2084.0 1.00
delta[3, geo_b] 0.033 0.073 -0.096 0.184 0.001 0.002 3031.0 1736.0 1.00
delta[4, geo_a] 0.003 0.134 -0.278 0.246 0.002 0.007 4169.0 1339.0 1.00
delta[4, geo_b] -0.004 0.122 -0.245 0.234 0.003 0.006 3185.0 1407.0 1.00
delta_b 0.082 0.042 0.023 0.156 0.001 0.001 1218.0 2023.0 1.00
gamma_control[event_1] 0.272 0.035 0.210 0.340 0.000 0.001 5785.0 2900.0 1.00
gamma_control[event_2] 0.325 0.035 0.256 0.387 0.000 0.001 5581.0 3131.0 1.00
gamma_fourier[geo_a, sin_1] 0.003 0.005 -0.007 0.013 0.000 0.000 4970.0 2745.0 1.00
gamma_fourier[geo_a, sin_2] -0.057 0.005 -0.067 -0.047 0.000 0.000 4868.0 2762.0 1.00
gamma_fourier[geo_a, cos_1] 0.062 0.005 0.053 0.073 0.000 0.000 5674.0 2447.0 1.00
gamma_fourier[geo_a, cos_2] 0.002 0.005 -0.007 0.012 0.000 0.000 5195.0 2573.0 1.00
gamma_fourier[geo_b, sin_1] 0.001 0.005 -0.008 0.011 0.000 0.000 5444.0 2873.0 1.00
gamma_fourier[geo_b, sin_2] -0.044 0.006 -0.055 -0.034 0.000 0.000 5443.0 2549.0 1.00
gamma_fourier[geo_b, cos_1] 0.068 0.005 0.058 0.078 0.000 0.000 5193.0 3108.0 1.00
gamma_fourier[geo_b, cos_2] 0.001 0.005 -0.010 0.010 0.000 0.000 4752.0 2827.0 1.00
intercept_contribution[geo_a] 0.353 0.020 0.315 0.390 0.000 0.000 3157.0 2901.0 1.00
intercept_contribution[geo_b] 0.349 0.020 0.308 0.383 0.000 0.000 3484.0 2772.0 1.00
saturation_beta[x1, geo_a] 0.366 0.030 0.313 0.424 0.001 0.001 3013.0 2716.0 1.00
saturation_beta[x1, geo_b] 0.384 0.029 0.330 0.442 0.001 0.001 3077.0 2774.0 1.00
saturation_beta[x2, geo_a] 0.290 0.157 0.184 0.488 0.009 0.025 774.0 356.0 1.01
saturation_beta[x2, geo_b] 0.282 0.130 0.169 0.492 0.005 0.010 938.0 794.0 1.01
saturation_lam_mu[x1] 3.832 0.732 2.462 5.273 0.013 0.013 3169.0 2299.0 1.00
saturation_lam_mu[x2] 3.027 1.007 1.272 4.831 0.029 0.019 1060.0 1567.0 1.01
saturation_lam_sigma[x1] 0.943 0.374 0.365 1.659 0.006 0.009 4164.0 2394.0 1.00
saturation_lam_sigma[x2] 1.022 0.407 0.358 1.751 0.007 0.008 4013.0 2961.0 1.00
saturation_lam[x1, geo_a] 3.949 0.563 2.917 5.019 0.010 0.010 3090.0 1950.0 1.00
saturation_lam[x1, geo_b] 3.982 0.535 3.001 4.984 0.010 0.009 3097.0 2168.0 1.00
saturation_lam[x2, geo_a] 3.107 1.313 0.630 5.427 0.045 0.030 709.0 367.0 1.01
saturation_lam[x2, geo_b] 2.749 1.156 0.594 4.676 0.036 0.022 873.0 806.0 1.01
y_sigma_sigma 0.510 0.592 0.016 1.688 0.009 0.012 4905.0 3207.0 1.00
y_sigma 0.049 0.002 0.046 0.052 0.000 0.000 5258.0 3202.0 1.00
_ = az.plot_trace(
    data=mmm.fit_result,
    var_names=[
        "adstock_alpha",
        "delta",
        "delta_b",
        "gamma_control",
        "gamma_fourier",
        "intercept_contribution",
        "saturation_beta",
        "saturation_lam_mu",
        "saturation_lam_sigma",
        "saturation_lam",
        "total_media_contribution_original_scale",
        "y_sigma_sigma",
        "y_sigma",
    ],
    compact=True,
    backend_kwargs={"figsize": (15, 15), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16, fontweight="bold", y=1.03);
../../_images/58a9cde9dea695636d304f1e08d801089377bf955edfc869ee5e90584416dcdb.png

Posterior Predictive Checks#

We can now inspect the posterior predictive distribution. As before, we need to scale the posterior predictive to the original scale to make it comparable to the data.

fig, axes = plt.subplots(
    nrows=len(mmm.model.coords["geo"]),
    figsize=(12, 9),
    sharex=True,
    sharey=True,
    layout="constrained",
)

for i, geo in enumerate(mmm.model.coords["geo"]):
    ax = axes[i]
    az.plot_hdi(
        x=mmm.model.coords["date"],
        y=(mmm.idata["posterior_predictive"].y_original_scale.sel(geo=geo)),
        color="C0",
        smooth=False,
        hdi_prob=0.94,
        fill_kwargs={"alpha": 0.2, "label": "94% HDI"},
        ax=ax,
    )

    az.plot_hdi(
        x=mmm.model.coords["date"],
        y=(mmm.idata["posterior_predictive"].y_original_scale.sel(geo=geo)),
        color="C0",
        smooth=False,
        hdi_prob=0.5,
        fill_kwargs={"alpha": 0.4, "label": "50% HDI"},
        ax=ax,
    )

    sns.lineplot(
        data=data_df.query("geo == @geo"),
        x="date",
        y="y",
        color="black",
        ax=ax,
    )

    ax.legend(loc="upper left")
    ax.set(title=f"{geo}")

fig.suptitle("Posterior Predictive", fontsize=16, fontweight="bold", y=1.03);
../../_images/04ff9bd348214b394eb3d7674bb54fad87d98da8a6f8c0e343412dacbd0fee29.png

The fit looks very good!

Model Components#

We can extract the contributions of each component of the model in the original scale thanks to the deterministic variables added to the model.

fig, axes = plt.subplots(
    nrows=len(mmm.model.coords["geo"]),
    figsize=(15, 10),
    sharex=True,
    sharey=True,
    layout="constrained",
)

for i, geo in enumerate(mmm.model.coords["geo"]):
    ax = axes[i]

    for j, channel in enumerate(mmm.model.coords["channel"]):
        az.plot_hdi(
            x=mmm.model.coords["date"],
            y=mmm.idata["posterior"]["channel_contribution_original_scale"].sel(
                geo=geo, channel=channel
            ),
            color=f"C{j}",
            smooth=False,
            hdi_prob=0.94,
            fill_kwargs={"alpha": 0.5, "label": f"94% HDI ({channel})"},
            ax=ax,
        )

    az.plot_hdi(
        x=mmm.model.coords["date"],
        y=mmm.idata["posterior"]["intercept_contribution_original_scale"]
        .sel(geo=geo)
        .expand_dims({"date": mmm.model.coords["date"]})
        .transpose(..., "date"),
        color="C2",
        smooth=False,
        hdi_prob=0.94,
        fill_kwargs={"alpha": 0.5, "label": "94% HDI intercept"},
        ax=ax,
    )

    az.plot_hdi(
        x=mmm.model.coords["date"],
        y=mmm.idata["posterior"]["yearly_seasonality_contribution_original_scale"].sel(
            geo=geo,
        ),
        color="C3",
        smooth=False,
        hdi_prob=0.94,
        fill_kwargs={"alpha": 0.5, "label": "94% HDI Fourier"},
        ax=ax,
    )

    az.plot_hdi(
        x=mmm.model.coords["date"],
        y=mmm.idata["posterior"]["trend_effect_contribution"].sel(
            geo=geo,
        )
        * mmm.scalers["_target"].sel(geo=geo).item(),
        color="C4",
        smooth=False,
        hdi_prob=0.94,
        fill_kwargs={"alpha": 0.5, "label": "94% HDI Trend"},
        ax=ax,
    )

    for k, control in enumerate(mmm.model.coords["control"]):
        az.plot_hdi(
            x=mmm.model.coords["date"],
            y=mmm.idata["posterior"]["control_contribution_original_scale"].sel(
                geo=geo, control=control
            ),
            color=f"C{5 + k}",
            smooth=False,
            hdi_prob=0.94,
            fill_kwargs={"alpha": 0.5, "label": f"94% HDI control ({control})"},
            ax=ax,
        )

    sns.lineplot(
        data=data_df.query("geo == @geo"),
        x="date",
        y="y",
        color="black",
        label="y",
        ax=ax,
    )
    ax.legend(
        loc="upper center",
        bbox_to_anchor=(0.5, -0.1),
        ncol=4,
    )
    ax.set(title=f"{geo}")

fig.suptitle(
    "Posterior Predictive - Channel Contributions",
    fontsize=16,
    fontweight="bold",
    y=1.03,
);
../../_images/70ed9f858e5d70103a439f06eba64a8b9de766d93251a2fb19c9a8cf78732dbc.png

Media Deep Dive#

Next, we can look into the individual channel contributions across geos.

fig, axes = plt.subplots(
    nrows=len(mmm.model.coords["geo"]),
    ncols=len(mmm.model.coords["channel"]),
    figsize=(12, 9),
    layout="constrained",
)

for i, geo in enumerate(mmm.model.coords["geo"]):
    for j, channel in enumerate(mmm.model.coords["channel"]):
        ax = axes[i, j]
        az.plot_hdi(
            x=mmm.model.coords["date"],
            y=mmm.idata["posterior"]["channel_contribution_original_scale"].sel(
                geo=geo, channel=channel
            ),
            color=f"C{j}",
            smooth=False,
            hdi_prob=0.94,
            ax=ax,
        )
        ax.set_title(f"{geo} - {channel}")
../../_images/222c37a9869db3f03e2f4f2d921fa94fa060264846d97425f3aee26af0e28545.png

Observe that all the vectorization and heavy lifting is done under the hood by the new class.

This new class has a new plot name space that contains many plotting methods. For example, we can reproduce the plots above by simply calling:

fig, axes = mmm.plot.contributions_over_time(
    var=["channel_contribution_original_scale"],
)

# Adjust figure size and layout to 2x2
fig.set_size_inches(14, 10)
fig.set_constrained_layout(True)

# Reshape axes to 2x2 grid
num_axes = len(axes.flatten())
if num_axes > 0:
    # Create a new 2x2 grid
    gs = fig.add_gridspec(2, 2)

    # Move existing axes to the new grid
    for i, ax in enumerate(axes.flatten()):
        if i < 4:  # Only handle up to 4 axes for 2x2 grid
            ax.set_position(gs[i // 2, i % 2].get_position(fig))

axes = axes.flatten()

# Share x and y axes across all subplots
for ax in axes:
    ax.legend().remove()
    ax.tick_params(axis="both", which="major", labelsize=6)
    ax.tick_params(axis="both", which="minor", labelsize=6)

# Share y axis limits
y_min = min(ax.get_ylim()[0] for ax in axes)
y_max = max(ax.get_ylim()[1] for ax in axes)
for ax in axes:
    ax.set_ylim(y_min, y_max)

# Share x axis limits
x_min = min(ax.get_xlim()[0] for ax in axes)
x_max = max(ax.get_xlim()[1] for ax in axes)
for ax in axes:
    ax.set_xlim(x_min, x_max)
../../_images/732e8dd2f40be794ceb05d39013dc730e9167c603fec8a759a4c3b8b670ff520.png

We can also plot the saturation curves for each channel and geo.

mmm.plot.saturation_curves_scatter(
    width_per_col=8, height_per_row=4, original_scale=True
);
../../_images/63c646f0efcc7b018c795fb5cb45e754b4616e411635b50908aaabdf9550cca2.png

Out of Sample Predictions#

It is very important to be able to make predictions out of the sample. This is key for model validation, forward looking scenario planning and business decision making. Similarly as in the MMM Example Notebook, we assume the future spends are the same as the last day in the training sample. This way we can create a new dataset with the future dates and channel spends and use the model to make predictions.

last_date = x_train["date"].max()

# New dates starting from last in dataset
n_new = 7
new_dates = pd.date_range(start=last_date, periods=1 + n_new, freq="W-MON")[1:]

x_out_of_sample_geo_a = pd.DataFrame({"date": new_dates, "geo": "geo_a"})
x_out_of_sample_geo_b = pd.DataFrame({"date": new_dates, "geo": "geo_b"})

# Same channel spends as last day
x_out_of_sample_geo_a["x1"] = x_train.query("geo == 'geo_a'")["x1"].iloc[-1]
x_out_of_sample_geo_a["x2"] = x_train.query("geo == 'geo_a'")["x2"].iloc[-1]

x_out_of_sample_geo_b["x1"] = x_train.query("geo == 'geo_b'")["x1"].iloc[-1]
x_out_of_sample_geo_b["x2"] = x_train.query("geo == 'geo_b'")["x2"].iloc[-1]

# Other features
## Event 1
x_out_of_sample_geo_a["event_1"] = 0.0
x_out_of_sample_geo_a["event_2"] = 0.0
## Event 2
x_out_of_sample_geo_b["event_1"] = 0.0
x_out_of_sample_geo_b["event_2"] = 0.0

x_out_of_sample = pd.concat([x_out_of_sample_geo_a, x_out_of_sample_geo_b])

# Final dataset to generate out of sample predictions.
x_out_of_sample
date geo x1 x2 event_1 event_2
0 2021-09-06 geo_a 0.438857 0.0 0.0 0.0
1 2021-09-13 geo_a 0.438857 0.0 0.0 0.0
2 2021-09-20 geo_a 0.438857 0.0 0.0 0.0
3 2021-09-27 geo_a 0.438857 0.0 0.0 0.0
4 2021-10-04 geo_a 0.438857 0.0 0.0 0.0
5 2021-10-11 geo_a 0.438857 0.0 0.0 0.0
6 2021-10-18 geo_a 0.438857 0.0 0.0 0.0
0 2021-09-06 geo_b 0.438857 0.0 0.0 0.0
1 2021-09-13 geo_b 0.438857 0.0 0.0 0.0
2 2021-09-20 geo_b 0.438857 0.0 0.0 0.0
3 2021-09-27 geo_b 0.438857 0.0 0.0 0.0
4 2021-10-04 geo_b 0.438857 0.0 0.0 0.0
5 2021-10-11 geo_b 0.438857 0.0 0.0 0.0
6 2021-10-18 geo_b 0.438857 0.0 0.0 0.0

Using the same sample_posterior_predictive method, we can now generate the forecast.

y_out_of_sample = mmm.sample_posterior_predictive(
    x_out_of_sample,
    extend_idata=False,
    include_last_observations=True,
    random_seed=rng,
    var_names=["y_original_scale"],
)

y_out_of_sample
Sampling: [y]

<xarray.Dataset> Size: 544kB
Dimensions:           (date: 7, geo: 2, sample: 4000)
Coordinates:
  * date              (date) datetime64[ns] 56B 2021-09-06 ... 2021-10-18
  * geo               (geo) <U5 40B 'geo_a' 'geo_b'
  * sample            (sample) object 32kB MultiIndex
  * chain             (sample) int64 32kB 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3
  * draw              (sample) int64 32kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables:
    y_original_scale  (date, geo, sample) float64 448kB 5.109e+03 ... 6.617e+03
Attributes:
    created_at:                 2025-06-17T17:05:30.845296+00:00
    arviz_version:              0.21.0
    inference_library:          pymc
    inference_library_version:  5.22.0
fig, axes = plt.subplots(
    nrows=2,
    ncols=1,
    figsize=(12, 10),
    sharex=True,
    sharey=True,
    layout="constrained",
)

n_train_to_plot = 30

for ax, geo in zip(axes, mmm.model.coords["geo"], strict=True):
    for hdi_prob in [0.94, 0.5]:
        az.plot_hdi(
            x=mmm.model.coords["date"][-n_train_to_plot:],
            y=(
                mmm.idata["posterior_predictive"].y_original_scale.sel(geo=geo)[
                    :, :, -n_train_to_plot:
                ]
            ),
            color="C0",
            smooth=False,
            hdi_prob=hdi_prob,
            fill_kwargs={"alpha": 0.4, "label": f"{hdi_prob: 0.0%} HDI"},
            ax=ax,
        )

        az.plot_hdi(
            x_out_of_sample.query("geo == @geo")["date"],
            (
                y_out_of_sample["y_original_scale"]
                .sel(geo=geo)
                .unstack()
                .transpose(..., "date")
            ),
            color="C1",
            smooth=False,
            hdi_prob=hdi_prob,
            fill_kwargs={"alpha": 0.4, "label": f"{hdi_prob: 0.0%} HDI"},
            ax=ax,
        )

        ax.plot(
            x_out_of_sample.query("geo == @geo")["date"],
            y_out_of_sample["y_original_scale"].sel(geo=geo).mean(dim="sample"),
            marker="o",
            color="C1",
            label="posterior predictive mean",
        )

    sns.lineplot(
        data=data_df.query("(geo == @geo)").tail(n_train_to_plot),
        x="date",
        y="y",
        marker="o",
        color="black",
        label="observed",
        ax=ax,
    )

    ax.axvline(x=last_date, color="gray", linestyle="--", label="last observation")
    ax.legend(
        loc="upper center",
        bbox_to_anchor=(0.5, -0.15),
        ncol=3,
    )
    ax.set(title=f"{geo}")

fig.suptitle(
    "Posterior Predictive - Out of Sample", fontsize=16, fontweight="bold", y=1.03
);
../../_images/881c99f2edf498d5e86a6323455c0de77d62e9b06da7cd95a69f138bdf56d704.png

Optimization#

If you want to run optimizations, then you need to use the MultiDimensionalBudgetOptimizerWrapper.

optimizable_model = MultiDimensionalBudgetOptimizerWrapper(
    model=mmm, start_date="2021-10-01", end_date="2021-12-31"
)

allocation_xarray, scipy_opt_result = optimizable_model.optimize_budget(
    budget=10,  # Total budget to allocate here is spend in Millions
)

sample_allocation = optimizable_model.sample_response_distribution(
    allocation_strategy=allocation_xarray,
)
Sampling: [y]

This objects is an xarray dataset with the allocation and posterior predictive responses!

sample_allocation
<xarray.Dataset> Size: 3MB
Dimensions:                              (date: 13, geo: 2, sample: 4000,
                                          channel: 2)
Coordinates:
  * date                                 (date) datetime64[ns] 104B 2021-10-0...
  * geo                                  (geo) <U5 40B 'geo_a' 'geo_b'
  * channel                              (channel) <U2 16B 'x1' 'x2'
  * sample                               (sample) object 32kB MultiIndex
  * chain                                (sample) int64 32kB 0 0 0 0 ... 3 3 3 3
  * draw                                 (sample) int64 32kB 0 1 2 ... 998 999
Data variables:
    y                                    (date, geo, sample) float64 832kB 1....
    channel_contribution_original_scale  (date, geo, channel, sample) float64 2MB ...
    allocation                           (geo, channel) float64 32B 2.291 ......
Attributes:
    created_at:                 2025-06-17T17:05:36.066184+00:00
    arviz_version:              0.21.0
    inference_library:          pymc
    inference_library_version:  5.22.0

Once you get the allocation, you can plot a the results 🚀

optimizable_model.plot.budget_allocation(
    samples=sample_allocation,
);
../../_images/cb0e868390487e42a73bd13c6eabff0a61563f48fb911656a7fbc918a6a71aea.png

The graph shows the optimal budget for each channel on each geo, next to their respective mean contribution given the optimal budget. The method identify automatically the number of dimensions an tries to create a plot based on them.

If you want to see the full uncertanty over time, you can use the plot suite and the method allocated_contribution_by_channel_over_time.

optimizable_model.plot.allocated_contribution_by_channel_over_time(
    samples=sample_allocation,
);
../../_images/c54cedf59cda9224e4181c82323475426bedfa27a98f2fc40047bec412aebfac.png

If you have a custom model, you can wrapped it into the model protocol, and use the optimizer after. If your model handle scales internally, you don’t need to modify anything. Otherwise, for the plots, you may want to use scale_factor=N. E.g:

optimizable_model.plot.budget_allocation(
    samples=sample_allocation,
    scale_factor=120
);

Note

We are very excited about this new feature and the possibilities it opens up. We are looking forward to hearing your feedback!

%load_ext watermark
%watermark -n -u -v -iv -w -p pymc_marketing,pytensor,nutpie
Last updated: Tue Jun 17 2025

Python implementation: CPython
Python version       : 3.10.17
IPython version      : 8.35.0

pymc_marketing: 0.14.0
pytensor      : 2.30.2+77.g8f2982d3d
nutpie        : 0.15.1

arviz         : 0.21.0
pymc_marketing: 0.14.0
pymc          : 5.22.0
seaborn       : 0.13.2
numpy         : 1.26.4
pandas        : 2.2.3
matplotlib    : 3.10.1

Watermark: 2.5.0