MMM Multidimensional Example Notebook#
In this notebook, we present an new experimental media mix model class to create multidimensional and customized marketing mix models. To showcase its capabilities, we extend the MMM Example Notebook simulation to create a multidimensional hierarchical model.
Warning
Even though the new MMM
class is an experimental class, it is fully functional and can be used to create multidimensional marketing mix models. This model is under active development and will be further improved in the future (feedback welcome!).
Prepare Notebook#
import warnings
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns
from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.additive_effect import LinearTrendEffect
from pymc_marketing.mmm.linear_trend import LinearTrend
from pymc_marketing.mmm.multidimensional import (
MMM,
MultiDimensionalBudgetOptimizerWrapper,
)
from pymc_marketing.paths import data_dir
from pymc_marketing.prior import Prior
warnings.filterwarnings("ignore", category=UserWarning)
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["xtick.labelsize"] = 10
plt.rcParams["ytick.labelsize"] = 8
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
/Users/carlostrujillo/Documents/GitHub/pymc-marketing/pymc_marketing/mmm/multidimensional.py:68: FutureWarning: This functionality is experimental and subject to change. If you encounter any issues or have suggestions, please raise them at: https://github.com/pymc-labs/pymc-marketing/issues/new
warnings.warn(warning_msg, FutureWarning, stacklevel=1)
seed: int = sum(map(ord, "mmm_multidimensional"))
rng: np.random.Generator = np.random.default_rng(seed=seed)
Read Data#
We read the simulated data from the MMM Example Notebook.
data_path = data_dir / "mmm_example.csv"
raw_data_df = pd.read_csv(data_path, parse_dates=["date_week"])
raw_data_df = raw_data_df.rename(columns={"date_week": "date"})
raw_data_df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 179 entries, 0 to 178
Data columns (total 8 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 date 179 non-null datetime64[ns]
1 y 179 non-null float64
2 x1 179 non-null float64
3 x2 179 non-null float64
4 event_1 179 non-null float64
5 event_2 179 non-null float64
6 dayofyear 179 non-null int64
7 t 179 non-null int64
dtypes: datetime64[ns](1), float64(5), int64(2)
memory usage: 11.3 KB
To generate a multidimensional dataset, we simply duplicate and perturb the simulated target data with noise. We assign the different groups to different geography (geos) levels. We keep the same media variables for both geos.
# Create copies of the original data foo both geos
a_data_df = raw_data_df.copy().assign(geo="geo_a")
b_data_df = raw_data_df.copy().assign(geo="geo_b")
# Add noise to the target variable for the second geo
b_data_df["y"] = b_data_df["y"] + 500 * rng.normal(size=len(b_data_df))
# Concatenate the two datasets
data_df = pd.concat([a_data_df, b_data_df])
data_df.head()
date | y | x1 | x2 | event_1 | event_2 | dayofyear | t | geo | |
---|---|---|---|---|---|---|---|---|---|
0 | 2018-04-02 | 3984.662237 | 0.318580 | 0.0 | 0.0 | 0.0 | 92 | 0 | geo_a |
1 | 2018-04-09 | 3762.871794 | 0.112388 | 0.0 | 0.0 | 0.0 | 99 | 1 | geo_a |
2 | 2018-04-16 | 4466.967388 | 0.292400 | 0.0 | 0.0 | 0.0 | 106 | 2 | geo_a |
3 | 2018-04-23 | 3864.219373 | 0.071399 | 0.0 | 0.0 | 0.0 | 113 | 3 | geo_a |
4 | 2018-04-30 | 4441.625278 | 0.386745 | 0.0 | 0.0 | 0.0 | 120 | 4 | geo_a |
Let’s plot the target variable for each geo to visually inspect the difference.
g = sns.relplot(
data=data_df,
x="date",
y="y",
color="black",
col="geo",
col_wrap=1,
kind="line",
height=4,
aspect=3,
)

Overall, the targets follow a similar global pattern but the short term fluctuations are different.
Prior Specification#
Here is where we see the main power of the new class. We can now specify different priors for different geos and even build custom hierarchies.
Recall from the MMM Example Notebook that we wan to use the spend shares as a prior for the beta parameters in the saturation function. Let’s compute those shares across geos.
channel_columns = ["x1", "x2"]
n_channels = len(channel_columns)
sum_spend_geo_channel = data_df.groupby(["geo"]).agg({"x1": "sum", "x2": "sum"})
spend_share = (
sum_spend_geo_channel.to_numpy() / sum_spend_geo_channel.sum(axis=1).to_numpy()
)
prior_sigma = n_channels * spend_share
Now we are ready to specify the priors for the saturation function.
For the beta parameters, we use a half-normal distribution with a sigma parameter that is computed as a function of the spend shares. For illustrative purposes, we define the beta values to be independent across both channels and geos.
For the lambda parameters, we add a channel-level hierarchical structure through the hyperparameters of the gamma prior (mu and sigma). These hyperparameters vary by channel, while the lambda values themselves vary across both channel and geo. This setup enables partial pooling across geos within the same channel, allowing the model to share information across geos.
If you need an introduction on Bayesian hierarchical models, check out the comprehensive example “A Primer on Bayesian Methods for Multilevel Modeling” in the PyMC documentation. Please note the dims
argument in the priors is used to specify the dimensions of the distribution. Here you can control the dimensions along which the hierarchies are defined.
saturation = LogisticSaturation(
priors={
"beta": Prior("HalfNormal", sigma=prior_sigma, dims=("channel", "geo")),
"lam": Prior(
"Gamma",
mu=Prior("LogNormal", mu=np.log(3), sigma=np.log(1.5), dims="channel"),
sigma=Prior("LogNormal", mu=np.log(1), sigma=np.log(1.5), dims="channel"),
dims=("channel", "geo"),
),
}
)
saturation.model_config
{'saturation_lam': Prior("Gamma", mu=Prior("LogNormal", mu=1.0986122886681098, sigma=0.4054651081081644, dims="channel"), sigma=Prior("LogNormal", mu=0.0, sigma=0.4054651081081644, dims="channel"), dims=("channel", "geo")),
'saturation_beta': Prior("HalfNormal", sigma=[[1.31263903 0.68736097]
[1.31263903 0.68736097]], dims=("channel", "geo"))}
For the adstock parameters we do not add any hierarchical structure. We simply keep the same prior for all the geos.
adstock = GeometricAdstock(
priors={"alpha": Prior("Beta", alpha=2, beta=3, dims="channel")}, l_max=8
)
adstock.model_config
{'adstock_alpha': Prior("Beta", alpha=2, beta=3, dims="channel")}
We complete the model specification with similar priors as in the MMM Example Notebook. Please be aware on how to specify the priors dimensions.
model_config = {
"intercept": Prior("Normal", mu=0.5, sigma=0.5, dims="geo"),
"gamma_control": Prior("Normal", mu=0, sigma=0.5, dims="control"),
"gamma_fourier": Prior(
"Laplace", mu=0, b=Prior("HalfNormal", sigma=0.2), dims=("geo", "fourier_mode")
),
"likelihood": Prior(
"TruncatedNormal",
lower=0,
sigma=Prior("HalfNormal", sigma=Prior("HalfNormal", sigma=1.5)),
dims=("date", "geo"),
),
}
Model Definition#
We are now ready to define the model class. The API is very similar to the one in the MMM Example Notebook.
# Base MMM model specification
mmm = MMM(
date_column="date",
target_column="y",
channel_columns=["x1", "x2"],
control_columns=["event_1", "event_2"],
dims=("geo",),
scaling={
"channel": {"method": "max", "dims": ()},
"target": {"method": "max", "dims": ()},
},
adstock=adstock,
saturation=saturation,
yearly_seasonality=2,
model_config=model_config,
)
Tip
Observe we have the following two new arguments:
dims
: a tuple of strings that specify the dimensions of the model.scaling
: a dictionary that specifies the scaling method and dimensions for the target and media variables. In this case we leave the dimensions empty as we want to scale the target variable for each geo (see details below).
We can add additional components to the model mean component. Here, for example, we add a hierarchical linear trend component (with changepoints).
linear_trend = LinearTrend(
priors={
"delta": Prior(
"Laplace",
mu=0,
b=Prior("HalfNormal", sigma=0.2),
dims=("changepoint", "geo"),
),
},
n_changepoints=5,
include_intercept=False,
dims=("geo"),
)
linear_trend_effect = LinearTrendEffect(linear_trend, prefix="trend")
mmm.mu_effects.append(linear_trend_effect)
We can now prepare the training data.
x_train = data_df.drop(columns=["y"])
y_train = data_df["y"]
To build the model, we need to specify the training data and the target variables.
Tip
We do not need to build the model, we can simply fit the model. This is just to inspect the model structure.
mmm.build_model(X=x_train, y=y_train)
Let’s look into the model graph:
pm.model_to_graphviz(mmm.model)
It is great to see that the model automatically vectorizes and creates the expected hierarchies and dimensions 🚀!
As we are scaling our data internally, we can add deterministic terms to recover the component contributions in the original scale.
mmm.add_original_scale_contribution_variable(
var=[
"channel_contribution",
"control_contribution",
"intercept_contribution",
"yearly_seasonality_contribution",
"y",
]
)
pm.model_to_graphviz(mmm.model)
Coming back to the scalers, we can get them as an xarray dataset.
scalers = mmm.get_scales_as_xarray()
scalers
{'channel_scale': <xarray.DataArray '_channel' (geo: 2, channel: 2)> Size: 32B
array([[0.99665813, 0.99437431],
[0.99665813, 0.99437431]])
Coordinates:
* geo (geo) object 16B 'geo_a' 'geo_b'
* channel (channel) object 16B 'x1' 'x2',
'target_scale': <xarray.DataArray '_target' (geo: 2)> Size: 16B
array([8312.40754439, 8440.6617456 ])
Coordinates:
* geo (geo) object 16B 'geo_a' 'geo_b'}
As expected, from the model definition, we have scalers for the target and media variables across geos.
Prior Predictive Checks#
Before fitting the model, we can inspect the prior predictive distribution.
prior_predictive = mmm.sample_prior_predictive(X=x_train, y=y_train, samples=1_000)
Sampling: [adstock_alpha, delta, delta_b, gamma_control, gamma_fourier, gamma_fourier_b, intercept_contribution, saturation_beta, saturation_lam, saturation_lam_mu, saturation_lam_sigma, y, y_sigma, y_sigma_sigma]
g = sns.relplot(
data=data_df,
x="date",
y="y",
color="black",
col="geo",
col_wrap=1,
kind="line",
height=4,
aspect=3,
)
axes = g.axes.flatten()
for ax, geo in zip(axes, mmm.model.coords["geo"], strict=True):
az.plot_hdi(
x=mmm.model.coords["date"],
# We need to scale the prior predictive to the original scale to make it comparable to the data.
y=(
prior_predictive.sel(geo=geo)["y"].unstack().transpose(..., "date")
* scalers["target_scale"].sel(geo=geo).item()
),
smooth=False,
color="C0",
hdi_prob=0.94,
fill_kwargs={"alpha": 0.3, "label": "94% HDI"},
ax=ax,
)
az.plot_hdi(
x=mmm.model.coords["date"],
y=(
prior_predictive.sel(geo=geo)["y"].unstack().transpose(..., "date")
* scalers["target_scale"].sel(geo=geo).item()
),
smooth=False,
color="C0",
hdi_prob=0.5,
fill_kwargs={"alpha": 0.5, "label": "50% HDI"},
ax=ax,
)
ax.legend(loc="upper left")
g.figure.suptitle("Prior Predictive", fontsize=16, fontweight="bold", y=1.03);

The prior predictive distribution looks good and not too restrictive.
Model Fitting#
We can now fit the model and generate the posterior predictive distribution.
mmm.fit(
X=x_train,
y=y_train,
chains=4,
target_accept=0.8,
random_seed=rng,
)
mmm.sample_posterior_predictive(
X=x_train,
extend_idata=True,
combined=True,
random_seed=rng,
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept_contribution, adstock_alpha, saturation_lam_mu, saturation_lam_sigma, saturation_lam, saturation_beta, gamma_control, gamma_fourier_b, gamma_fourier, delta_b, delta, y_sigma_sigma, y_sigma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 43 seconds.
Sampling: [y]
<xarray.Dataset> Size: 23MB Dimensions: (date: 179, geo: 2, sample: 4000) Coordinates: * date (date) datetime64[ns] 1kB 2018-04-02 ... 2021-08-30 * geo (geo) <U5 40B 'geo_a' 'geo_b' * sample (sample) object 32kB MultiIndex * chain (sample) int64 32kB 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 * draw (sample) int64 32kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: y (date, geo, sample) float64 11MB 0.4939 0.4977 ... 0.6185 y_original_scale (date, geo, sample) float64 11MB 4.106e+03 ... 5.221e+03 Attributes: created_at: 2025-06-17T17:05:21.560445+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
The sampling looks good, no divergences and the r-hat values are close to \(1\).
mmm.idata.sample_stats.diverging.sum().item()

az.summary(
mmm.idata,
var_names=[
"adstock_alpha",
"delta",
"delta_b",
"gamma_control",
"gamma_fourier",
"intercept_contribution",
"saturation_beta",
"saturation_lam_mu",
"saturation_lam_sigma",
"saturation_lam",
"y_sigma_sigma",
"y_sigma",
],
)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
adstock_alpha[x1] | 0.398 | 0.034 | 0.336 | 0.463 | 0.001 | 0.000 | 3592.0 | 3027.0 | 1.00 |
adstock_alpha[x2] | 0.173 | 0.042 | 0.099 | 0.253 | 0.001 | 0.001 | 1975.0 | 2545.0 | 1.00 |
delta[0, geo_a] | 0.110 | 0.051 | 0.014 | 0.208 | 0.001 | 0.001 | 2546.0 | 1932.0 | 1.00 |
delta[0, geo_b] | 0.104 | 0.053 | 0.002 | 0.202 | 0.001 | 0.001 | 2803.0 | 1858.0 | 1.00 |
delta[1, geo_a] | 0.019 | 0.065 | -0.104 | 0.159 | 0.001 | 0.001 | 2444.0 | 2024.0 | 1.00 |
delta[1, geo_b] | 0.009 | 0.067 | -0.113 | 0.144 | 0.001 | 0.002 | 2879.0 | 1888.0 | 1.00 |
delta[2, geo_a] | -0.041 | 0.059 | -0.165 | 0.055 | 0.001 | 0.001 | 2871.0 | 2127.0 | 1.00 |
delta[2, geo_b] | -0.030 | 0.059 | -0.147 | 0.075 | 0.001 | 0.001 | 3289.0 | 2170.0 | 1.00 |
delta[3, geo_a] | 0.013 | 0.067 | -0.114 | 0.147 | 0.001 | 0.002 | 3406.0 | 2084.0 | 1.00 |
delta[3, geo_b] | 0.033 | 0.073 | -0.096 | 0.184 | 0.001 | 0.002 | 3031.0 | 1736.0 | 1.00 |
delta[4, geo_a] | 0.003 | 0.134 | -0.278 | 0.246 | 0.002 | 0.007 | 4169.0 | 1339.0 | 1.00 |
delta[4, geo_b] | -0.004 | 0.122 | -0.245 | 0.234 | 0.003 | 0.006 | 3185.0 | 1407.0 | 1.00 |
delta_b | 0.082 | 0.042 | 0.023 | 0.156 | 0.001 | 0.001 | 1218.0 | 2023.0 | 1.00 |
gamma_control[event_1] | 0.272 | 0.035 | 0.210 | 0.340 | 0.000 | 0.001 | 5785.0 | 2900.0 | 1.00 |
gamma_control[event_2] | 0.325 | 0.035 | 0.256 | 0.387 | 0.000 | 0.001 | 5581.0 | 3131.0 | 1.00 |
gamma_fourier[geo_a, sin_1] | 0.003 | 0.005 | -0.007 | 0.013 | 0.000 | 0.000 | 4970.0 | 2745.0 | 1.00 |
gamma_fourier[geo_a, sin_2] | -0.057 | 0.005 | -0.067 | -0.047 | 0.000 | 0.000 | 4868.0 | 2762.0 | 1.00 |
gamma_fourier[geo_a, cos_1] | 0.062 | 0.005 | 0.053 | 0.073 | 0.000 | 0.000 | 5674.0 | 2447.0 | 1.00 |
gamma_fourier[geo_a, cos_2] | 0.002 | 0.005 | -0.007 | 0.012 | 0.000 | 0.000 | 5195.0 | 2573.0 | 1.00 |
gamma_fourier[geo_b, sin_1] | 0.001 | 0.005 | -0.008 | 0.011 | 0.000 | 0.000 | 5444.0 | 2873.0 | 1.00 |
gamma_fourier[geo_b, sin_2] | -0.044 | 0.006 | -0.055 | -0.034 | 0.000 | 0.000 | 5443.0 | 2549.0 | 1.00 |
gamma_fourier[geo_b, cos_1] | 0.068 | 0.005 | 0.058 | 0.078 | 0.000 | 0.000 | 5193.0 | 3108.0 | 1.00 |
gamma_fourier[geo_b, cos_2] | 0.001 | 0.005 | -0.010 | 0.010 | 0.000 | 0.000 | 4752.0 | 2827.0 | 1.00 |
intercept_contribution[geo_a] | 0.353 | 0.020 | 0.315 | 0.390 | 0.000 | 0.000 | 3157.0 | 2901.0 | 1.00 |
intercept_contribution[geo_b] | 0.349 | 0.020 | 0.308 | 0.383 | 0.000 | 0.000 | 3484.0 | 2772.0 | 1.00 |
saturation_beta[x1, geo_a] | 0.366 | 0.030 | 0.313 | 0.424 | 0.001 | 0.001 | 3013.0 | 2716.0 | 1.00 |
saturation_beta[x1, geo_b] | 0.384 | 0.029 | 0.330 | 0.442 | 0.001 | 0.001 | 3077.0 | 2774.0 | 1.00 |
saturation_beta[x2, geo_a] | 0.290 | 0.157 | 0.184 | 0.488 | 0.009 | 0.025 | 774.0 | 356.0 | 1.01 |
saturation_beta[x2, geo_b] | 0.282 | 0.130 | 0.169 | 0.492 | 0.005 | 0.010 | 938.0 | 794.0 | 1.01 |
saturation_lam_mu[x1] | 3.832 | 0.732 | 2.462 | 5.273 | 0.013 | 0.013 | 3169.0 | 2299.0 | 1.00 |
saturation_lam_mu[x2] | 3.027 | 1.007 | 1.272 | 4.831 | 0.029 | 0.019 | 1060.0 | 1567.0 | 1.01 |
saturation_lam_sigma[x1] | 0.943 | 0.374 | 0.365 | 1.659 | 0.006 | 0.009 | 4164.0 | 2394.0 | 1.00 |
saturation_lam_sigma[x2] | 1.022 | 0.407 | 0.358 | 1.751 | 0.007 | 0.008 | 4013.0 | 2961.0 | 1.00 |
saturation_lam[x1, geo_a] | 3.949 | 0.563 | 2.917 | 5.019 | 0.010 | 0.010 | 3090.0 | 1950.0 | 1.00 |
saturation_lam[x1, geo_b] | 3.982 | 0.535 | 3.001 | 4.984 | 0.010 | 0.009 | 3097.0 | 2168.0 | 1.00 |
saturation_lam[x2, geo_a] | 3.107 | 1.313 | 0.630 | 5.427 | 0.045 | 0.030 | 709.0 | 367.0 | 1.01 |
saturation_lam[x2, geo_b] | 2.749 | 1.156 | 0.594 | 4.676 | 0.036 | 0.022 | 873.0 | 806.0 | 1.01 |
y_sigma_sigma | 0.510 | 0.592 | 0.016 | 1.688 | 0.009 | 0.012 | 4905.0 | 3207.0 | 1.00 |
y_sigma | 0.049 | 0.002 | 0.046 | 0.052 | 0.000 | 0.000 | 5258.0 | 3202.0 | 1.00 |
_ = az.plot_trace(
data=mmm.fit_result,
var_names=[
"adstock_alpha",
"delta",
"delta_b",
"gamma_control",
"gamma_fourier",
"intercept_contribution",
"saturation_beta",
"saturation_lam_mu",
"saturation_lam_sigma",
"saturation_lam",
"total_media_contribution_original_scale",
"y_sigma_sigma",
"y_sigma",
],
compact=True,
backend_kwargs={"figsize": (15, 15), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16, fontweight="bold", y=1.03);

Posterior Predictive Checks#
We can now inspect the posterior predictive distribution. As before, we need to scale the posterior predictive to the original scale to make it comparable to the data.
fig, axes = plt.subplots(
nrows=len(mmm.model.coords["geo"]),
figsize=(12, 9),
sharex=True,
sharey=True,
layout="constrained",
)
for i, geo in enumerate(mmm.model.coords["geo"]):
ax = axes[i]
az.plot_hdi(
x=mmm.model.coords["date"],
y=(mmm.idata["posterior_predictive"].y_original_scale.sel(geo=geo)),
color="C0",
smooth=False,
hdi_prob=0.94,
fill_kwargs={"alpha": 0.2, "label": "94% HDI"},
ax=ax,
)
az.plot_hdi(
x=mmm.model.coords["date"],
y=(mmm.idata["posterior_predictive"].y_original_scale.sel(geo=geo)),
color="C0",
smooth=False,
hdi_prob=0.5,
fill_kwargs={"alpha": 0.4, "label": "50% HDI"},
ax=ax,
)
sns.lineplot(
data=data_df.query("geo == @geo"),
x="date",
y="y",
color="black",
ax=ax,
)
ax.legend(loc="upper left")
ax.set(title=f"{geo}")
fig.suptitle("Posterior Predictive", fontsize=16, fontweight="bold", y=1.03);

The fit looks very good!
Model Components#
We can extract the contributions of each component of the model in the original scale thanks to the deterministic variables added to the model.
fig, axes = plt.subplots(
nrows=len(mmm.model.coords["geo"]),
figsize=(15, 10),
sharex=True,
sharey=True,
layout="constrained",
)
for i, geo in enumerate(mmm.model.coords["geo"]):
ax = axes[i]
for j, channel in enumerate(mmm.model.coords["channel"]):
az.plot_hdi(
x=mmm.model.coords["date"],
y=mmm.idata["posterior"]["channel_contribution_original_scale"].sel(
geo=geo, channel=channel
),
color=f"C{j}",
smooth=False,
hdi_prob=0.94,
fill_kwargs={"alpha": 0.5, "label": f"94% HDI ({channel})"},
ax=ax,
)
az.plot_hdi(
x=mmm.model.coords["date"],
y=mmm.idata["posterior"]["intercept_contribution_original_scale"]
.sel(geo=geo)
.expand_dims({"date": mmm.model.coords["date"]})
.transpose(..., "date"),
color="C2",
smooth=False,
hdi_prob=0.94,
fill_kwargs={"alpha": 0.5, "label": "94% HDI intercept"},
ax=ax,
)
az.plot_hdi(
x=mmm.model.coords["date"],
y=mmm.idata["posterior"]["yearly_seasonality_contribution_original_scale"].sel(
geo=geo,
),
color="C3",
smooth=False,
hdi_prob=0.94,
fill_kwargs={"alpha": 0.5, "label": "94% HDI Fourier"},
ax=ax,
)
az.plot_hdi(
x=mmm.model.coords["date"],
y=mmm.idata["posterior"]["trend_effect_contribution"].sel(
geo=geo,
)
* mmm.scalers["_target"].sel(geo=geo).item(),
color="C4",
smooth=False,
hdi_prob=0.94,
fill_kwargs={"alpha": 0.5, "label": "94% HDI Trend"},
ax=ax,
)
for k, control in enumerate(mmm.model.coords["control"]):
az.plot_hdi(
x=mmm.model.coords["date"],
y=mmm.idata["posterior"]["control_contribution_original_scale"].sel(
geo=geo, control=control
),
color=f"C{5 + k}",
smooth=False,
hdi_prob=0.94,
fill_kwargs={"alpha": 0.5, "label": f"94% HDI control ({control})"},
ax=ax,
)
sns.lineplot(
data=data_df.query("geo == @geo"),
x="date",
y="y",
color="black",
label="y",
ax=ax,
)
ax.legend(
loc="upper center",
bbox_to_anchor=(0.5, -0.1),
ncol=4,
)
ax.set(title=f"{geo}")
fig.suptitle(
"Posterior Predictive - Channel Contributions",
fontsize=16,
fontweight="bold",
y=1.03,
);

Media Deep Dive#
Next, we can look into the individual channel contributions across geos.
fig, axes = plt.subplots(
nrows=len(mmm.model.coords["geo"]),
ncols=len(mmm.model.coords["channel"]),
figsize=(12, 9),
layout="constrained",
)
for i, geo in enumerate(mmm.model.coords["geo"]):
for j, channel in enumerate(mmm.model.coords["channel"]):
ax = axes[i, j]
az.plot_hdi(
x=mmm.model.coords["date"],
y=mmm.idata["posterior"]["channel_contribution_original_scale"].sel(
geo=geo, channel=channel
),
color=f"C{j}",
smooth=False,
hdi_prob=0.94,
ax=ax,
)
ax.set_title(f"{geo} - {channel}")

Observe that all the vectorization and heavy lifting is done under the hood by the new class.
This new class has a new plot
name space that contains many plotting methods. For example, we can reproduce the plots above by simply calling:
fig, axes = mmm.plot.contributions_over_time(
var=["channel_contribution_original_scale"],
)
# Adjust figure size and layout to 2x2
fig.set_size_inches(14, 10)
fig.set_constrained_layout(True)
# Reshape axes to 2x2 grid
num_axes = len(axes.flatten())
if num_axes > 0:
# Create a new 2x2 grid
gs = fig.add_gridspec(2, 2)
# Move existing axes to the new grid
for i, ax in enumerate(axes.flatten()):
if i < 4: # Only handle up to 4 axes for 2x2 grid
ax.set_position(gs[i // 2, i % 2].get_position(fig))
axes = axes.flatten()
# Share x and y axes across all subplots
for ax in axes:
ax.legend().remove()
ax.tick_params(axis="both", which="major", labelsize=6)
ax.tick_params(axis="both", which="minor", labelsize=6)
# Share y axis limits
y_min = min(ax.get_ylim()[0] for ax in axes)
y_max = max(ax.get_ylim()[1] for ax in axes)
for ax in axes:
ax.set_ylim(y_min, y_max)
# Share x axis limits
x_min = min(ax.get_xlim()[0] for ax in axes)
x_max = max(ax.get_xlim()[1] for ax in axes)
for ax in axes:
ax.set_xlim(x_min, x_max)

We can also plot the saturation curves for each channel and geo.
mmm.plot.saturation_curves_scatter(
width_per_col=8, height_per_row=4, original_scale=True
);

Out of Sample Predictions#
It is very important to be able to make predictions out of the sample. This is key for model validation, forward looking scenario planning and business decision making. Similarly as in the MMM Example Notebook, we assume the future spends are the same as the last day in the training sample. This way we can create a new dataset with the future dates and channel spends and use the model to make predictions.
last_date = x_train["date"].max()
# New dates starting from last in dataset
n_new = 7
new_dates = pd.date_range(start=last_date, periods=1 + n_new, freq="W-MON")[1:]
x_out_of_sample_geo_a = pd.DataFrame({"date": new_dates, "geo": "geo_a"})
x_out_of_sample_geo_b = pd.DataFrame({"date": new_dates, "geo": "geo_b"})
# Same channel spends as last day
x_out_of_sample_geo_a["x1"] = x_train.query("geo == 'geo_a'")["x1"].iloc[-1]
x_out_of_sample_geo_a["x2"] = x_train.query("geo == 'geo_a'")["x2"].iloc[-1]
x_out_of_sample_geo_b["x1"] = x_train.query("geo == 'geo_b'")["x1"].iloc[-1]
x_out_of_sample_geo_b["x2"] = x_train.query("geo == 'geo_b'")["x2"].iloc[-1]
# Other features
## Event 1
x_out_of_sample_geo_a["event_1"] = 0.0
x_out_of_sample_geo_a["event_2"] = 0.0
## Event 2
x_out_of_sample_geo_b["event_1"] = 0.0
x_out_of_sample_geo_b["event_2"] = 0.0
x_out_of_sample = pd.concat([x_out_of_sample_geo_a, x_out_of_sample_geo_b])
# Final dataset to generate out of sample predictions.
x_out_of_sample
date | geo | x1 | x2 | event_1 | event_2 | |
---|---|---|---|---|---|---|
0 | 2021-09-06 | geo_a | 0.438857 | 0.0 | 0.0 | 0.0 |
1 | 2021-09-13 | geo_a | 0.438857 | 0.0 | 0.0 | 0.0 |
2 | 2021-09-20 | geo_a | 0.438857 | 0.0 | 0.0 | 0.0 |
3 | 2021-09-27 | geo_a | 0.438857 | 0.0 | 0.0 | 0.0 |
4 | 2021-10-04 | geo_a | 0.438857 | 0.0 | 0.0 | 0.0 |
5 | 2021-10-11 | geo_a | 0.438857 | 0.0 | 0.0 | 0.0 |
6 | 2021-10-18 | geo_a | 0.438857 | 0.0 | 0.0 | 0.0 |
0 | 2021-09-06 | geo_b | 0.438857 | 0.0 | 0.0 | 0.0 |
1 | 2021-09-13 | geo_b | 0.438857 | 0.0 | 0.0 | 0.0 |
2 | 2021-09-20 | geo_b | 0.438857 | 0.0 | 0.0 | 0.0 |
3 | 2021-09-27 | geo_b | 0.438857 | 0.0 | 0.0 | 0.0 |
4 | 2021-10-04 | geo_b | 0.438857 | 0.0 | 0.0 | 0.0 |
5 | 2021-10-11 | geo_b | 0.438857 | 0.0 | 0.0 | 0.0 |
6 | 2021-10-18 | geo_b | 0.438857 | 0.0 | 0.0 | 0.0 |
Using the same sample_posterior_predictive
method, we can now generate the forecast.
y_out_of_sample = mmm.sample_posterior_predictive(
x_out_of_sample,
extend_idata=False,
include_last_observations=True,
random_seed=rng,
var_names=["y_original_scale"],
)
y_out_of_sample
Sampling: [y]
<xarray.Dataset> Size: 544kB Dimensions: (date: 7, geo: 2, sample: 4000) Coordinates: * date (date) datetime64[ns] 56B 2021-09-06 ... 2021-10-18 * geo (geo) <U5 40B 'geo_a' 'geo_b' * sample (sample) object 32kB MultiIndex * chain (sample) int64 32kB 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 * draw (sample) int64 32kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: y_original_scale (date, geo, sample) float64 448kB 5.109e+03 ... 6.617e+03 Attributes: created_at: 2025-06-17T17:05:30.845296+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
fig, axes = plt.subplots(
nrows=2,
ncols=1,
figsize=(12, 10),
sharex=True,
sharey=True,
layout="constrained",
)
n_train_to_plot = 30
for ax, geo in zip(axes, mmm.model.coords["geo"], strict=True):
for hdi_prob in [0.94, 0.5]:
az.plot_hdi(
x=mmm.model.coords["date"][-n_train_to_plot:],
y=(
mmm.idata["posterior_predictive"].y_original_scale.sel(geo=geo)[
:, :, -n_train_to_plot:
]
),
color="C0",
smooth=False,
hdi_prob=hdi_prob,
fill_kwargs={"alpha": 0.4, "label": f"{hdi_prob: 0.0%} HDI"},
ax=ax,
)
az.plot_hdi(
x_out_of_sample.query("geo == @geo")["date"],
(
y_out_of_sample["y_original_scale"]
.sel(geo=geo)
.unstack()
.transpose(..., "date")
),
color="C1",
smooth=False,
hdi_prob=hdi_prob,
fill_kwargs={"alpha": 0.4, "label": f"{hdi_prob: 0.0%} HDI"},
ax=ax,
)
ax.plot(
x_out_of_sample.query("geo == @geo")["date"],
y_out_of_sample["y_original_scale"].sel(geo=geo).mean(dim="sample"),
marker="o",
color="C1",
label="posterior predictive mean",
)
sns.lineplot(
data=data_df.query("(geo == @geo)").tail(n_train_to_plot),
x="date",
y="y",
marker="o",
color="black",
label="observed",
ax=ax,
)
ax.axvline(x=last_date, color="gray", linestyle="--", label="last observation")
ax.legend(
loc="upper center",
bbox_to_anchor=(0.5, -0.15),
ncol=3,
)
ax.set(title=f"{geo}")
fig.suptitle(
"Posterior Predictive - Out of Sample", fontsize=16, fontweight="bold", y=1.03
);

Optimization#
If you want to run optimizations, then you need to use the MultiDimensionalBudgetOptimizerWrapper
.
optimizable_model = MultiDimensionalBudgetOptimizerWrapper(
model=mmm, start_date="2021-10-01", end_date="2021-12-31"
)
allocation_xarray, scipy_opt_result = optimizable_model.optimize_budget(
budget=10, # Total budget to allocate here is spend in Millions
)
sample_allocation = optimizable_model.sample_response_distribution(
allocation_strategy=allocation_xarray,
)
Sampling: [y]
This objects is an xarray dataset with the allocation and posterior predictive responses!
sample_allocation
<xarray.Dataset> Size: 3MB Dimensions: (date: 13, geo: 2, sample: 4000, channel: 2) Coordinates: * date (date) datetime64[ns] 104B 2021-10-0... * geo (geo) <U5 40B 'geo_a' 'geo_b' * channel (channel) <U2 16B 'x1' 'x2' * sample (sample) object 32kB MultiIndex * chain (sample) int64 32kB 0 0 0 0 ... 3 3 3 3 * draw (sample) int64 32kB 0 1 2 ... 998 999 Data variables: y (date, geo, sample) float64 832kB 1.... channel_contribution_original_scale (date, geo, channel, sample) float64 2MB ... allocation (geo, channel) float64 32B 2.291 ...... Attributes: created_at: 2025-06-17T17:05:36.066184+00:00 arviz_version: 0.21.0 inference_library: pymc inference_library_version: 5.22.0
Once you get the allocation, you can plot a the results 🚀
optimizable_model.plot.budget_allocation(
samples=sample_allocation,
);

The graph shows the optimal budget for each channel on each geo, next to their respective mean contribution given the optimal budget. The method identify automatically the number of dimensions an tries to create a plot based on them.
If you want to see the full uncertanty over time, you can use the plot suite and the method allocated_contribution_by_channel_over_time
.
optimizable_model.plot.allocated_contribution_by_channel_over_time(
samples=sample_allocation,
);

If you have a custom model, you can wrapped it into the model protocol, and use the optimizer after. If your model handle scales internally, you don’t need to modify anything. Otherwise, for the plots, you may want to use scale_factor=N
. E.g:
optimizable_model.plot.budget_allocation(
samples=sample_allocation,
scale_factor=120
);
Note
We are very excited about this new feature and the possibilities it opens up. We are looking forward to hearing your feedback!
%load_ext watermark
%watermark -n -u -v -iv -w -p pymc_marketing,pytensor,nutpie
Last updated: Tue Jun 17 2025
Python implementation: CPython
Python version : 3.10.17
IPython version : 8.35.0
pymc_marketing: 0.14.0
pytensor : 2.30.2+77.g8f2982d3d
nutpie : 0.15.1
arviz : 0.21.0
pymc_marketing: 0.14.0
pymc : 5.22.0
seaborn : 0.13.2
numpy : 1.26.4
pandas : 2.2.3
matplotlib : 3.10.1
Watermark: 2.5.0