Model deployment#

One of the main goals of PyMC-Marketing is to facilitate the deployment of its models.

This is achieved by building our models on top of ModelBuilder that offers a scikit-learn-like API and makes PyMC models easy to deploy.

PyMC-marketing models inherit 2 easy-to-use methods: save and load that can be used after the model has been fitted. All models can be configured with two standard dictionaries: model_config and sampler_config that are serialized during save and persisted after load, allowing model reuse across workflows.

We will illustrate this functionality with the example model described in the MMM Example Notebook. For sake of generality, we ommit most technical details here.

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation
from pymc_marketing.prior import Prior

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100

%config InlineBackend.figure_format = "retina"
seed = sum(map(ord, "mmm"))
rng = np.random.default_rng(seed=seed)

Let’s load the dataset:

url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv"
df = pd.read_csv(url, parse_dates=["date_week"])

columns_to_keep = [
    "date_week",
    "y",
    "x1",
    "x2",
    "event_1",
    "event_2",
    "dayofyear",
]

data = df[columns_to_keep].copy()
data["t"] = np.arange(df.shape[0])
data.head()
date_week y x1 x2 event_1 event_2 dayofyear t
0 2018-04-02 3984.662237 0.318580 0.0 0.0 0.0 92 0
1 2018-04-09 3762.871794 0.112388 0.0 0.0 0.0 99 1
2 2018-04-16 4466.967388 0.292400 0.0 0.0 0.0 106 2
3 2018-04-23 3864.219373 0.071399 0.0 0.0 0.0 113 3
4 2018-04-30 4441.625278 0.386745 0.0 0.0 0.0 120 4

But for our model we need much smaller dataset, many of the previous features were contributing to generation of others, now as our target variable is computed we can filter out not needed columns:

Model and sampling configuration#

Model configuration#

We first illustrate the use of model_config to define custom priors within the model.

Because there are potentially many variables that can be configured, each model provides a default_model_config attribute. This will allow you to see which settings are available by default and only define the ones you need to change.

We need to create a dummy model to be able to see the configuration dictionary.

adstock = GeometricAdstock(l_max=8)
saturation = LogisticSaturation()

dummy_model = MMM(
    date_column="date_week",
    channel_columns=["x1", "x2"],
    adstock=adstock,
    saturation=saturation,
    control_columns=[
        "event_1",
        "event_2",
        "t",
    ],
    yearly_seasonality=2,
)
dummy_model.default_model_config
{'intercept': Prior("Normal", mu=0, sigma=2),
 'likelihood': Prior("Normal", sigma=Prior("HalfNormal", sigma=2)),
 'gamma_control': Prior("Normal", mu=0, sigma=2, dims="control"),
 'gamma_fourier': Prior("Laplace", mu=0, b=1, dims="fourier_mode"),
 'adstock_alpha': Prior("Beta", alpha=1, beta=3, dims="channel"),
 'saturation_lam': Prior("Gamma", alpha=3, beta=1, dims="channel"),
 'saturation_beta': Prior("HalfNormal", sigma=2, dims="channel")}

We can change the parameters that go into the distribution of each term. In this case we’ll just simply replace the sigma for saturation_beta with a custom one:

n_channels = 2

total_spend_per_channel = data[["x1", "x2"]].sum(axis=0)
spend_share = total_spend_per_channel / total_spend_per_channel.sum()

# The scale necessary to make a HalfNormal distribution have unit variance
HALFNORMAL_SCALE = 1 / np.sqrt(1 - 2 / np.pi)
prior_sigma = HALFNORMAL_SCALE * n_channels * spend_share.to_numpy()
prior_sigma
array([2.1775326 , 1.14026088])
saturation_beta = Prior("HalfNormal", sigma=prior_sigma, dims="channel")
my_model_config = {"saturation_beta": saturation_beta}
my_model_config
{'saturation_beta': Prior("HalfNormal", sigma=[2.1775326  1.14026088], dims="channel")}

As mentioned in the original notebook: “For the prior specification there is no right or wrong answer. It all depends on the data, the context and the assumptions you are willing to make. It is always recommended to do some prior predictive sampling and sensitivity analysis to check the impact of the priors on the posterior. We skip this here for the sake of simplicity. If you are not sure about specific priors, the MMM class has some default priors that you can use as a starting point.”

Sampling configuration#

The second feature we can customize is sampler_config. Similar to model_config, it’s a dictionary that gets saved and contains things you would usually pass to the fit() kwargs. It’s not mandatory to create your own sampler_config. The default MMM.sampler_config is empty because the default sampling parameters usually prove sufficient for a start.

dummy_model.default_sampler_config
{}
my_sampler_config = {
    "tune": 1000,
    "draws": 1000,
    "chains": 4,
    "target_accept": 0.91,
    "nuts_sampler": "numpyro",
}

Let’s finally assemble our model!

mmm = MMM(
    model_config=my_model_config,
    sampler_config=my_sampler_config,
    date_column="date_week",
    channel_columns=["x1", "x2"],
    adstock=adstock,
    saturation=saturation,
    control_columns=[
        "event_1",
        "event_2",
        "t",
    ],
    yearly_seasonality=2,
)

We can confirm our settings are being used

mmm.model_config["saturation_beta"]
Prior("HalfNormal", sigma=[2.1775326  1.14026088], dims="channel")
mmm.sampler_config
{'tune': 1000,
 'draws': 1000,
 'chains': 4,
 'target_accept': 0.91,
 'nuts_sampler': 'numpyro'}

Model Fitting#

Note that we didn’t pass the dataset to the class constructor itself. This is done to mimick the scikit-learn API, and make it easier to get started on PyMC-Marketing models.

# Split X, and y
X = data.drop("y", axis=1)
y = data["y"]

All that’s left now is to finally fit the model:

As you can see below, you can still pass the sampler kwargs directly to fit() method. However, only those kwargs passed using sampler_config will be saved and reused after loading the model.

mmm.fit(X=X, y=y, random_seed=rng)
arviz.InferenceData
    • <xarray.Dataset> Size: 63MB
      Dimensions:                          (chain: 4, draw: 1000, control: 3,
                                            fourier_mode: 4, channel: 2, date: 179)
      Coordinates:
        * chain                            (chain) int64 32B 0 1 2 3
        * draw                             (draw) int64 8kB 0 1 2 3 ... 997 998 999
        * control                          (control) <U7 84B 'event_1' 'event_2' 't'
        * fourier_mode                     (fourier_mode) <U5 80B 'sin_1' ... 'cos_2'
        * channel                          (channel) <U2 16B 'x1' 'x2'
        * date                             (date) datetime64[ns] 1kB 2018-04-02 ......
      Data variables:
          intercept                        (chain, draw) float64 32kB 0.3241 ... 0....
          gamma_control                    (chain, draw, control) float64 96kB 0.24...
          gamma_fourier                    (chain, draw, fourier_mode) float64 128kB ...
          adstock_alpha                    (chain, draw, channel) float64 64kB 0.44...
          saturation_lam                   (chain, draw, channel) float64 64kB 3.86...
          saturation_beta                  (chain, draw, channel) float64 64kB 0.41...
          y_sigma                          (chain, draw) float64 32kB 0.0307 ... 0....
          channel_contributions            (chain, draw, date, channel) float64 11MB ...
          control_contributions            (chain, draw, date, control) float64 17MB ...
          fourier_contributions            (chain, draw, date, fourier_mode) float64 23MB ...
          yearly_seasonality_contribution  (chain, draw, date) float64 6MB 0.003468...
          mu                               (chain, draw, date) float64 6MB 0.4647 ....
      Attributes:
          created_at:     2024-11-14T13:56:50.170234
          arviz_version:  0.17.1

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (chain, draw) float64 32kB 0.9682 0.9951 ... 0.9999 0.9654
          step_size        (chain, draw) float64 32kB 0.00572 0.00572 ... 0.005936
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB -337.8 -349.0 ... -340.3 -345.5
          n_steps          (chain, draw) int64 32kB 1023 511 511 511 ... 1023 511 511
          tree_depth       (chain, draw) int64 32kB 10 9 9 9 9 9 9 ... 10 9 9 10 9 9
          lp               (chain, draw) float64 32kB -352.0 -355.9 ... -352.2 -352.1
      Attributes:
          created_at:     2024-11-14T13:56:50.174899
          arviz_version:  0.17.1

    • <xarray.Dataset> Size: 3kB
      Dimensions:  (date: 179)
      Coordinates:
        * date     (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
      Data variables:
          y        (date) float64 1kB 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625
      Attributes:
          created_at:                 2024-11-14T13:56:50.176001
          arviz_version:              0.17.1
          inference_library:          numpyro
          inference_library_version:  0.15.2
          sampling_time:              14.669591

    • <xarray.Dataset> Size: 9kB
      Dimensions:       (date: 179, channel: 2, control: 3)
      Coordinates:
        * date          (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
        * channel       (channel) <U2 16B 'x1' 'x2'
        * control       (control) <U7 84B 'event_1' 'event_2' 't'
      Data variables:
          channel_data  (date, channel) float64 3kB 0.3196 0.0 0.1128 ... 0.4403 0.0
          control_data  (date, control) float64 4kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0
          dayofyear     (date) int32 716B 92 99 106 113 120 ... 214 221 228 235 242
      Attributes:
          created_at:                 2024-11-14T13:56:50.177660
          arviz_version:              0.17.1
          inference_library:          numpyro
          inference_library_version:  0.15.2
          sampling_time:              14.669591

    • <xarray.Dataset> Size: 13kB
      Dimensions:    (index: 179)
      Coordinates:
        * index      (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178
      Data variables:
          date_week  (index) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
          x1         (index) float64 1kB 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389
          x2         (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.8633 0.0 0.0 0.0
          event_1    (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          event_2    (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          dayofyear  (index) int64 1kB 92 99 106 113 120 127 ... 214 221 228 235 242
          t          (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178
          y          (index) float64 1kB 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

The fit() method automatically builds the model using the priors from model_config, and assigns the created model to our instance. You can access it as a normal attribute.

type(mmm.model)
pymc.model.core.Model
mmm.graphviz()
../../_images/2a195546f6ea34a2935320b4cf0d291fec2ba4370d4a50da38104b0dcf6d7f95.svg

posterior trace can be accessed by fit_result attribute

mmm.fit_result
<xarray.Dataset> Size: 63MB
Dimensions:                          (chain: 4, draw: 1000, control: 3,
                                      fourier_mode: 4, channel: 2, date: 179)
Coordinates:
  * chain                            (chain) int64 32B 0 1 2 3
  * draw                             (draw) int64 8kB 0 1 2 3 ... 997 998 999
  * control                          (control) <U7 84B 'event_1' 'event_2' 't'
  * fourier_mode                     (fourier_mode) <U5 80B 'sin_1' ... 'cos_2'
  * channel                          (channel) <U2 16B 'x1' 'x2'
  * date                             (date) datetime64[ns] 1kB 2018-04-02 ......
Data variables:
    intercept                        (chain, draw) float64 32kB 0.3241 ... 0....
    gamma_control                    (chain, draw, control) float64 96kB 0.24...
    gamma_fourier                    (chain, draw, fourier_mode) float64 128kB ...
    adstock_alpha                    (chain, draw, channel) float64 64kB 0.44...
    saturation_lam                   (chain, draw, channel) float64 64kB 3.86...
    saturation_beta                  (chain, draw, channel) float64 64kB 0.41...
    y_sigma                          (chain, draw) float64 32kB 0.0307 ... 0....
    channel_contributions            (chain, draw, date, channel) float64 11MB ...
    control_contributions            (chain, draw, date, control) float64 17MB ...
    fourier_contributions            (chain, draw, date, fourier_mode) float64 23MB ...
    yearly_seasonality_contribution  (chain, draw, date) float64 6MB 0.003468...
    mu                               (chain, draw, date) float64 6MB 0.4647 ....
Attributes:
    created_at:     2024-11-14T13:56:50.170234
    arviz_version:  0.17.1

If you wish to inspect the entire inference data, use the idata attribute. Within idata, you can find the entire dataset passed to the model under fit_data.

mmm.idata
arviz.InferenceData
    • <xarray.Dataset> Size: 63MB
      Dimensions:                          (chain: 4, draw: 1000, control: 3,
                                            fourier_mode: 4, channel: 2, date: 179)
      Coordinates:
        * chain                            (chain) int64 32B 0 1 2 3
        * draw                             (draw) int64 8kB 0 1 2 3 ... 997 998 999
        * control                          (control) <U7 84B 'event_1' 'event_2' 't'
        * fourier_mode                     (fourier_mode) <U5 80B 'sin_1' ... 'cos_2'
        * channel                          (channel) <U2 16B 'x1' 'x2'
        * date                             (date) datetime64[ns] 1kB 2018-04-02 ......
      Data variables:
          intercept                        (chain, draw) float64 32kB 0.3241 ... 0....
          gamma_control                    (chain, draw, control) float64 96kB 0.24...
          gamma_fourier                    (chain, draw, fourier_mode) float64 128kB ...
          adstock_alpha                    (chain, draw, channel) float64 64kB 0.44...
          saturation_lam                   (chain, draw, channel) float64 64kB 3.86...
          saturation_beta                  (chain, draw, channel) float64 64kB 0.41...
          y_sigma                          (chain, draw) float64 32kB 0.0307 ... 0....
          channel_contributions            (chain, draw, date, channel) float64 11MB ...
          control_contributions            (chain, draw, date, control) float64 17MB ...
          fourier_contributions            (chain, draw, date, fourier_mode) float64 23MB ...
          yearly_seasonality_contribution  (chain, draw, date) float64 6MB 0.003468...
          mu                               (chain, draw, date) float64 6MB 0.4647 ....
      Attributes:
          created_at:     2024-11-14T13:56:50.170234
          arviz_version:  0.17.1

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (chain, draw) float64 32kB 0.9682 0.9951 ... 0.9999 0.9654
          step_size        (chain, draw) float64 32kB 0.00572 0.00572 ... 0.005936
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB -337.8 -349.0 ... -340.3 -345.5
          n_steps          (chain, draw) int64 32kB 1023 511 511 511 ... 1023 511 511
          tree_depth       (chain, draw) int64 32kB 10 9 9 9 9 9 9 ... 10 9 9 10 9 9
          lp               (chain, draw) float64 32kB -352.0 -355.9 ... -352.2 -352.1
      Attributes:
          created_at:     2024-11-14T13:56:50.174899
          arviz_version:  0.17.1

    • <xarray.Dataset> Size: 3kB
      Dimensions:  (date: 179)
      Coordinates:
        * date     (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
      Data variables:
          y        (date) float64 1kB 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625
      Attributes:
          created_at:                 2024-11-14T13:56:50.176001
          arviz_version:              0.17.1
          inference_library:          numpyro
          inference_library_version:  0.15.2
          sampling_time:              14.669591

    • <xarray.Dataset> Size: 9kB
      Dimensions:       (date: 179, channel: 2, control: 3)
      Coordinates:
        * date          (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
        * channel       (channel) <U2 16B 'x1' 'x2'
        * control       (control) <U7 84B 'event_1' 'event_2' 't'
      Data variables:
          channel_data  (date, channel) float64 3kB 0.3196 0.0 0.1128 ... 0.4403 0.0
          control_data  (date, control) float64 4kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0
          dayofyear     (date) int32 716B 92 99 106 113 120 ... 214 221 228 235 242
      Attributes:
          created_at:                 2024-11-14T13:56:50.177660
          arviz_version:              0.17.1
          inference_library:          numpyro
          inference_library_version:  0.15.2
          sampling_time:              14.669591

    • <xarray.Dataset> Size: 13kB
      Dimensions:    (index: 179)
      Coordinates:
        * index      (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178
      Data variables:
          date_week  (index) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
          x1         (index) float64 1kB 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389
          x2         (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.8633 0.0 0.0 0.0
          event_1    (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          event_2    (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          dayofyear  (index) int64 1kB 92 99 106 113 120 127 ... 214 221 228 235 242
          t          (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178
          y          (index) float64 1kB 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

Saving and loading a fitted model#

All the data passed to the model on initialization is stored in idata.attrs. This will be used later in the save() method to convert both this data and all the fit data into the netCDF format. You can read more about this format here.

The save and load method only require a path to inform where the model should be saved and loaded from.

mmm.save("my_saved_model.nc")
loaded_model = MMM.load("my_saved_model.nc")
loaded_model.model_config["saturation_beta"]
Prior("HalfNormal", sigma=[2.1775326  1.14026088], dims="channel")
loaded_model.idata.attrs
{'id': 'cbf06a279ecf0af6',
 'model_type': 'MMM',
 'version': '0.0.2',
 'sampler_config': '{"tune": 1000, "draws": 1000, "chains": 4, "target_accept": 0.91, "nuts_sampler": "numpyro"}',
 'model_config': '{"intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}}, "likelihood": {"dist": "Normal", "kwargs": {"sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}}}, "dims": ["date"]}, "gamma_control": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}, "dims": ["control"]}, "gamma_fourier": {"dist": "Laplace", "kwargs": {"mu": 0, "b": 1}, "dims": ["fourier_mode"]}, "adstock_alpha": {"dist": "Beta", "kwargs": {"alpha": 1, "beta": 3}, "dims": ["channel"]}, "saturation_lam": {"dist": "Gamma", "kwargs": {"alpha": 3, "beta": 1}, "dims": ["channel"]}, "saturation_beta": {"dist": "HalfNormal", "kwargs": {"sigma": [2.1775326025486734, 1.140260877391939]}, "dims": ["channel"]}}',
 'date_column': '"date_week"',
 'adstock': '{"lookup_name": "geometric", "prefix": "adstock", "priors": {"alpha": {"dist": "Beta", "kwargs": {"alpha": 1, "beta": 3}, "dims": ["channel"]}}, "l_max": 8, "normalize": true, "mode": "After"}',
 'saturation': '{"lookup_name": "logistic", "prefix": "saturation", "priors": {"lam": {"dist": "Gamma", "kwargs": {"alpha": 3, "beta": 1}, "dims": ["channel"]}, "beta": {"dist": "HalfNormal", "kwargs": {"sigma": [2.1775326025486734, 1.140260877391939]}, "dims": ["channel"]}}}',
 'adstock_first': 'true',
 'control_columns': '["event_1", "event_2", "t"]',
 'channel_columns': '["x1", "x2"]',
 'validate_data': 'true',
 'yearly_seasonality': '2',
 'time_varying_intercept': 'false',
 'time_varying_media': 'false'}
loaded_model.graphviz()
../../_images/2a195546f6ea34a2935320b4cf0d291fec2ba4370d4a50da38104b0dcf6d7f95.svg
loaded_model.idata
arviz.InferenceData
    • <xarray.Dataset> Size: 63MB
      Dimensions:                          (chain: 4, draw: 1000, control: 3,
                                            fourier_mode: 4, channel: 2, date: 179)
      Coordinates:
        * chain                            (chain) int64 32B 0 1 2 3
        * draw                             (draw) int64 8kB 0 1 2 3 ... 997 998 999
        * control                          (control) <U7 84B 'event_1' 'event_2' 't'
        * fourier_mode                     (fourier_mode) <U5 80B 'sin_1' ... 'cos_2'
        * channel                          (channel) <U2 16B 'x1' 'x2'
        * date                             (date) datetime64[ns] 1kB 2018-04-02 ......
      Data variables:
          intercept                        (chain, draw) float64 32kB ...
          gamma_control                    (chain, draw, control) float64 96kB ...
          gamma_fourier                    (chain, draw, fourier_mode) float64 128kB ...
          adstock_alpha                    (chain, draw, channel) float64 64kB ...
          saturation_lam                   (chain, draw, channel) float64 64kB ...
          saturation_beta                  (chain, draw, channel) float64 64kB ...
          y_sigma                          (chain, draw) float64 32kB ...
          channel_contributions            (chain, draw, date, channel) float64 11MB ...
          control_contributions            (chain, draw, date, control) float64 17MB ...
          fourier_contributions            (chain, draw, date, fourier_mode) float64 23MB ...
          yearly_seasonality_contribution  (chain, draw, date) float64 6MB ...
          mu                               (chain, draw, date) float64 6MB ...
      Attributes:
          created_at:     2024-11-14T13:56:50.170234
          arviz_version:  0.17.1

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (chain, draw) float64 32kB ...
          step_size        (chain, draw) float64 32kB ...
          diverging        (chain, draw) bool 4kB ...
          energy           (chain, draw) float64 32kB ...
          n_steps          (chain, draw) int64 32kB ...
          tree_depth       (chain, draw) int64 32kB ...
          lp               (chain, draw) float64 32kB ...
      Attributes:
          created_at:     2024-11-14T13:56:50.174899
          arviz_version:  0.17.1

    • <xarray.Dataset> Size: 3kB
      Dimensions:  (date: 179)
      Coordinates:
        * date     (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
      Data variables:
          y        (date) float64 1kB ...
      Attributes:
          created_at:                 2024-11-14T13:56:50.176001
          arviz_version:              0.17.1
          inference_library:          numpyro
          inference_library_version:  0.15.2
          sampling_time:              14.669591

    • <xarray.Dataset> Size: 9kB
      Dimensions:       (date: 179, channel: 2, control: 3)
      Coordinates:
        * date          (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
        * channel       (channel) <U2 16B 'x1' 'x2'
        * control       (control) <U7 84B 'event_1' 'event_2' 't'
      Data variables:
          channel_data  (date, channel) float64 3kB ...
          control_data  (date, control) float64 4kB ...
          dayofyear     (date) int32 716B ...
      Attributes:
          created_at:                 2024-11-14T13:56:50.177660
          arviz_version:              0.17.1
          inference_library:          numpyro
          inference_library_version:  0.15.2
          sampling_time:              14.669591

    • <xarray.Dataset> Size: 13kB
      Dimensions:    (index: 179)
      Coordinates:
        * index      (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178
      Data variables:
          date_week  (index) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
          x1         (index) float64 1kB 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389
          x2         (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.8633 0.0 0.0 0.0
          event_1    (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          event_2    (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          dayofyear  (index) int64 1kB 92 99 106 113 120 127 ... 214 221 228 235 242
          t          (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178
          y          (index) float64 1kB 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

A loaded model is ready to be used for sampling and prediction, making use of the previous fitting results and data if needed.

loaded_model.sample_posterior_predictive(
    X, extend_idata=True, combined=False, random_seed=rng
)
Sampling: [y]


<xarray.Dataset> Size: 6MB
Dimensions:  (chain: 4, draw: 1000, date: 179)
Coordinates:
  * chain    (chain) int64 32B 0 1 2 3
  * draw     (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
  * date     (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
Data variables:
    y        (chain, draw, date) float64 6MB 3.945e+03 3.424e+03 ... 5.061e+03
Attributes:
    created_at:                 2024-11-14T13:56:53.715774
    arviz_version:              0.17.1
    inference_library:          pymc
    inference_library_version:  5.15.1
az.plot_ppc(loaded_model.idata);
/Users/juanitorduz/Documents/envs/pymc-marketing-env/lib/python3.12/site-packages/arviz/stats/density_utils.py:487: UserWarning: Your data appears to have a single value or no finite values
  warnings.warn("Your data appears to have a single value or no finite values")
../../_images/19dcceaf033f6c2ddb32108ecf831451e6edb63bc6b0dc726eb355410f8af46e.png

Other models#

Even though this introduction is using MMM, all other PyMC-Marketing models (MMM and CLV) provide these functionalities as well.

Summary#

The PyMC-Marketing functionalities described here are intended to facilitate model sharing among data science teams without demanding extensive modelling technical knowledge for everyone involved. We are still iterating on our API and would love to hear more feedback from our users!

%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor
Last updated: Thu Nov 14 2024

Python implementation: CPython
Python version       : 3.12.4
IPython version      : 8.27.0

pytensor: 2.22.1

numpy     : 1.26.4
matplotlib: 3.9.2
arviz     : 0.17.1
pandas    : 2.2.2

Watermark: 2.4.3