{ "cells": [ { "cell_type": "markdown", "id": "7fb27b941602401d91542211134fc71a", "metadata": {}, "source": [ "(model_deployment)=\n", "# Model deployment" ] }, { "cell_type": "markdown", "id": "acae54e37e7d407bbb7b55eff062a284", "metadata": {}, "source": [ "One of the main goals of PyMC-Marketing is to facilitate the deployment of its models. " ] }, { "cell_type": "markdown", "id": "9a63283cbaf04dbcab1f6479b197f3a8", "metadata": {}, "source": [ "This is achieved by building our models on top of [ModelBuilder](https://www.pymc-marketing.io/en/stable/api/generated/pymc_marketing.model_builder.ModelBuilder.html#pymc_marketing.model_builder.ModelBuilder) that offers a scikit-learn-like API and makes PyMC models easy to deploy.\n", "\n", "PyMC-marketing models inherit 2 easy-to-use methods: `save` and `load` that can be used after the model has been fitted. All models can be configured with two standard dictionaries: `model_config` and `sampler_config` that are serialized during `save` and persisted after `load`, allowing model reuse across workflows." ] }, { "cell_type": "markdown", "id": "8dd0d8092fe74a7c96281538738b07e2", "metadata": {}, "source": [ "We will illustrate this functionality with the example model described in the [MMM Example Notebook](https://www.pymc-marketing.io/en/stable/notebooks/mmm/mmm_example.html). For sake of generality, we ommit most technical details here." ] }, { "cell_type": "code", "execution_count": 1, "id": "72eea5119410473aa328ad9291626812", "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation\n", "from pymc_marketing.prior import Prior\n", "\n", "az.style.use(\"arviz-darkgrid\")\n", "plt.rcParams[\"figure.figsize\"] = [12, 7]\n", "plt.rcParams[\"figure.dpi\"] = 100\n", "\n", "%config InlineBackend.figure_format = \"retina\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "8edb47106e1a46a883d545849b8ab81b", "metadata": {}, "outputs": [], "source": [ "seed = sum(map(ord, \"mmm\"))\n", "rng = np.random.default_rng(seed=seed)" ] }, { "cell_type": "markdown", "id": "10185d26023b46108eb7d9f57d49d2b3", "metadata": {}, "source": [ "Let's load the dataset:" ] }, { "cell_type": "code", "execution_count": 3, "id": "8763a12b2bbd4a93a75aff182afb95dc", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", " | date_week | \n", "y | \n", "x1 | \n", "x2 | \n", "event_1 | \n", "event_2 | \n", "dayofyear | \n", "t | \n", "
---|---|---|---|---|---|---|---|---|
0 | \n", "2018-04-02 | \n", "3984.662237 | \n", "0.318580 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "92 | \n", "0 | \n", "
1 | \n", "2018-04-09 | \n", "3762.871794 | \n", "0.112388 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "99 | \n", "1 | \n", "
2 | \n", "2018-04-16 | \n", "4466.967388 | \n", "0.292400 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "106 | \n", "2 | \n", "
3 | \n", "2018-04-23 | \n", "3864.219373 | \n", "0.071399 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "113 | \n", "3 | \n", "
4 | \n", "2018-04-30 | \n", "4441.625278 | \n", "0.386745 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "120 | \n", "4 | \n", "
<xarray.Dataset> Size: 63MB\n", "Dimensions: (chain: 4, draw: 1000, control: 3,\n", " fourier_mode: 4, channel: 2, date: 179)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 ... 997 998 999\n", " * control (control) <U7 84B 'event_1' 'event_2' 't'\n", " * fourier_mode (fourier_mode) <U5 80B 'sin_1' ... 'cos_2'\n", " * channel (channel) <U2 16B 'x1' 'x2'\n", " * date (date) datetime64[ns] 1kB 2018-04-02 ......\n", "Data variables:\n", " intercept (chain, draw) float64 32kB 0.3241 ... 0....\n", " gamma_control (chain, draw, control) float64 96kB 0.24...\n", " gamma_fourier (chain, draw, fourier_mode) float64 128kB ...\n", " adstock_alpha (chain, draw, channel) float64 64kB 0.44...\n", " saturation_lam (chain, draw, channel) float64 64kB 3.86...\n", " saturation_beta (chain, draw, channel) float64 64kB 0.41...\n", " y_sigma (chain, draw) float64 32kB 0.0307 ... 0....\n", " channel_contributions (chain, draw, date, channel) float64 11MB ...\n", " control_contributions (chain, draw, date, control) float64 17MB ...\n", " fourier_contributions (chain, draw, date, fourier_mode) float64 23MB ...\n", " yearly_seasonality_contribution (chain, draw, date) float64 6MB 0.003468...\n", " mu (chain, draw, date) float64 6MB 0.4647 ....\n", "Attributes:\n", " created_at: 2024-11-14T13:56:50.170234\n", " arviz_version: 0.17.1
<xarray.Dataset> Size: 204kB\n", "Dimensions: (chain: 4, draw: 1000)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n", "Data variables:\n", " acceptance_rate (chain, draw) float64 32kB 0.9682 0.9951 ... 0.9999 0.9654\n", " step_size (chain, draw) float64 32kB 0.00572 0.00572 ... 0.005936\n", " diverging (chain, draw) bool 4kB False False False ... False False\n", " energy (chain, draw) float64 32kB -337.8 -349.0 ... -340.3 -345.5\n", " n_steps (chain, draw) int64 32kB 1023 511 511 511 ... 1023 511 511\n", " tree_depth (chain, draw) int64 32kB 10 9 9 9 9 9 9 ... 10 9 9 10 9 9\n", " lp (chain, draw) float64 32kB -352.0 -355.9 ... -352.2 -352.1\n", "Attributes:\n", " created_at: 2024-11-14T13:56:50.174899\n", " arviz_version: 0.17.1
<xarray.Dataset> Size: 3kB\n", "Dimensions: (date: 179)\n", "Coordinates:\n", " * date (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30\n", "Data variables:\n", " y (date) float64 1kB 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n", "Attributes:\n", " created_at: 2024-11-14T13:56:50.176001\n", " arviz_version: 0.17.1\n", " inference_library: numpyro\n", " inference_library_version: 0.15.2\n", " sampling_time: 14.669591
<xarray.Dataset> Size: 9kB\n", "Dimensions: (date: 179, channel: 2, control: 3)\n", "Coordinates:\n", " * date (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30\n", " * channel (channel) <U2 16B 'x1' 'x2'\n", " * control (control) <U7 84B 'event_1' 'event_2' 't'\n", "Data variables:\n", " channel_data (date, channel) float64 3kB 0.3196 0.0 0.1128 ... 0.4403 0.0\n", " control_data (date, control) float64 4kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0\n", " dayofyear (date) int32 716B 92 99 106 113 120 ... 214 221 228 235 242\n", "Attributes:\n", " created_at: 2024-11-14T13:56:50.177660\n", " arviz_version: 0.17.1\n", " inference_library: numpyro\n", " inference_library_version: 0.15.2\n", " sampling_time: 14.669591
<xarray.Dataset> Size: 13kB\n", "Dimensions: (index: 179)\n", "Coordinates:\n", " * index (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178\n", "Data variables:\n", " date_week (index) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30\n", " x1 (index) float64 1kB 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n", " x2 (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.8633 0.0 0.0 0.0\n", " event_1 (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n", " event_2 (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n", " dayofyear (index) int64 1kB 92 99 106 113 120 127 ... 214 221 228 235 242\n", " t (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178\n", " y (index) float64 1kB 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03
<xarray.Dataset> Size: 63MB\n", "Dimensions: (chain: 4, draw: 1000, control: 3,\n", " fourier_mode: 4, channel: 2, date: 179)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 ... 997 998 999\n", " * control (control) <U7 84B 'event_1' 'event_2' 't'\n", " * fourier_mode (fourier_mode) <U5 80B 'sin_1' ... 'cos_2'\n", " * channel (channel) <U2 16B 'x1' 'x2'\n", " * date (date) datetime64[ns] 1kB 2018-04-02 ......\n", "Data variables:\n", " intercept (chain, draw) float64 32kB 0.3241 ... 0....\n", " gamma_control (chain, draw, control) float64 96kB 0.24...\n", " gamma_fourier (chain, draw, fourier_mode) float64 128kB ...\n", " adstock_alpha (chain, draw, channel) float64 64kB 0.44...\n", " saturation_lam (chain, draw, channel) float64 64kB 3.86...\n", " saturation_beta (chain, draw, channel) float64 64kB 0.41...\n", " y_sigma (chain, draw) float64 32kB 0.0307 ... 0....\n", " channel_contributions (chain, draw, date, channel) float64 11MB ...\n", " control_contributions (chain, draw, date, control) float64 17MB ...\n", " fourier_contributions (chain, draw, date, fourier_mode) float64 23MB ...\n", " yearly_seasonality_contribution (chain, draw, date) float64 6MB 0.003468...\n", " mu (chain, draw, date) float64 6MB 0.4647 ....\n", "Attributes:\n", " created_at: 2024-11-14T13:56:50.170234\n", " arviz_version: 0.17.1
<xarray.Dataset> Size: 63MB\n", "Dimensions: (chain: 4, draw: 1000, control: 3,\n", " fourier_mode: 4, channel: 2, date: 179)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 ... 997 998 999\n", " * control (control) <U7 84B 'event_1' 'event_2' 't'\n", " * fourier_mode (fourier_mode) <U5 80B 'sin_1' ... 'cos_2'\n", " * channel (channel) <U2 16B 'x1' 'x2'\n", " * date (date) datetime64[ns] 1kB 2018-04-02 ......\n", "Data variables:\n", " intercept (chain, draw) float64 32kB 0.3241 ... 0....\n", " gamma_control (chain, draw, control) float64 96kB 0.24...\n", " gamma_fourier (chain, draw, fourier_mode) float64 128kB ...\n", " adstock_alpha (chain, draw, channel) float64 64kB 0.44...\n", " saturation_lam (chain, draw, channel) float64 64kB 3.86...\n", " saturation_beta (chain, draw, channel) float64 64kB 0.41...\n", " y_sigma (chain, draw) float64 32kB 0.0307 ... 0....\n", " channel_contributions (chain, draw, date, channel) float64 11MB ...\n", " control_contributions (chain, draw, date, control) float64 17MB ...\n", " fourier_contributions (chain, draw, date, fourier_mode) float64 23MB ...\n", " yearly_seasonality_contribution (chain, draw, date) float64 6MB 0.003468...\n", " mu (chain, draw, date) float64 6MB 0.4647 ....\n", "Attributes:\n", " created_at: 2024-11-14T13:56:50.170234\n", " arviz_version: 0.17.1
<xarray.Dataset> Size: 204kB\n", "Dimensions: (chain: 4, draw: 1000)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n", "Data variables:\n", " acceptance_rate (chain, draw) float64 32kB 0.9682 0.9951 ... 0.9999 0.9654\n", " step_size (chain, draw) float64 32kB 0.00572 0.00572 ... 0.005936\n", " diverging (chain, draw) bool 4kB False False False ... False False\n", " energy (chain, draw) float64 32kB -337.8 -349.0 ... -340.3 -345.5\n", " n_steps (chain, draw) int64 32kB 1023 511 511 511 ... 1023 511 511\n", " tree_depth (chain, draw) int64 32kB 10 9 9 9 9 9 9 ... 10 9 9 10 9 9\n", " lp (chain, draw) float64 32kB -352.0 -355.9 ... -352.2 -352.1\n", "Attributes:\n", " created_at: 2024-11-14T13:56:50.174899\n", " arviz_version: 0.17.1
<xarray.Dataset> Size: 3kB\n", "Dimensions: (date: 179)\n", "Coordinates:\n", " * date (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30\n", "Data variables:\n", " y (date) float64 1kB 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n", "Attributes:\n", " created_at: 2024-11-14T13:56:50.176001\n", " arviz_version: 0.17.1\n", " inference_library: numpyro\n", " inference_library_version: 0.15.2\n", " sampling_time: 14.669591
<xarray.Dataset> Size: 9kB\n", "Dimensions: (date: 179, channel: 2, control: 3)\n", "Coordinates:\n", " * date (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30\n", " * channel (channel) <U2 16B 'x1' 'x2'\n", " * control (control) <U7 84B 'event_1' 'event_2' 't'\n", "Data variables:\n", " channel_data (date, channel) float64 3kB 0.3196 0.0 0.1128 ... 0.4403 0.0\n", " control_data (date, control) float64 4kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0\n", " dayofyear (date) int32 716B 92 99 106 113 120 ... 214 221 228 235 242\n", "Attributes:\n", " created_at: 2024-11-14T13:56:50.177660\n", " arviz_version: 0.17.1\n", " inference_library: numpyro\n", " inference_library_version: 0.15.2\n", " sampling_time: 14.669591
<xarray.Dataset> Size: 13kB\n", "Dimensions: (index: 179)\n", "Coordinates:\n", " * index (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178\n", "Data variables:\n", " date_week (index) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30\n", " x1 (index) float64 1kB 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n", " x2 (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.8633 0.0 0.0 0.0\n", " event_1 (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n", " event_2 (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n", " dayofyear (index) int64 1kB 92 99 106 113 120 127 ... 214 221 228 235 242\n", " t (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178\n", " y (index) float64 1kB 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03
<xarray.Dataset> Size: 63MB\n", "Dimensions: (chain: 4, draw: 1000, control: 3,\n", " fourier_mode: 4, channel: 2, date: 179)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 ... 997 998 999\n", " * control (control) <U7 84B 'event_1' 'event_2' 't'\n", " * fourier_mode (fourier_mode) <U5 80B 'sin_1' ... 'cos_2'\n", " * channel (channel) <U2 16B 'x1' 'x2'\n", " * date (date) datetime64[ns] 1kB 2018-04-02 ......\n", "Data variables:\n", " intercept (chain, draw) float64 32kB ...\n", " gamma_control (chain, draw, control) float64 96kB ...\n", " gamma_fourier (chain, draw, fourier_mode) float64 128kB ...\n", " adstock_alpha (chain, draw, channel) float64 64kB ...\n", " saturation_lam (chain, draw, channel) float64 64kB ...\n", " saturation_beta (chain, draw, channel) float64 64kB ...\n", " y_sigma (chain, draw) float64 32kB ...\n", " channel_contributions (chain, draw, date, channel) float64 11MB ...\n", " control_contributions (chain, draw, date, control) float64 17MB ...\n", " fourier_contributions (chain, draw, date, fourier_mode) float64 23MB ...\n", " yearly_seasonality_contribution (chain, draw, date) float64 6MB ...\n", " mu (chain, draw, date) float64 6MB ...\n", "Attributes:\n", " created_at: 2024-11-14T13:56:50.170234\n", " arviz_version: 0.17.1
<xarray.Dataset> Size: 204kB\n", "Dimensions: (chain: 4, draw: 1000)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n", "Data variables:\n", " acceptance_rate (chain, draw) float64 32kB ...\n", " step_size (chain, draw) float64 32kB ...\n", " diverging (chain, draw) bool 4kB ...\n", " energy (chain, draw) float64 32kB ...\n", " n_steps (chain, draw) int64 32kB ...\n", " tree_depth (chain, draw) int64 32kB ...\n", " lp (chain, draw) float64 32kB ...\n", "Attributes:\n", " created_at: 2024-11-14T13:56:50.174899\n", " arviz_version: 0.17.1
<xarray.Dataset> Size: 3kB\n", "Dimensions: (date: 179)\n", "Coordinates:\n", " * date (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30\n", "Data variables:\n", " y (date) float64 1kB ...\n", "Attributes:\n", " created_at: 2024-11-14T13:56:50.176001\n", " arviz_version: 0.17.1\n", " inference_library: numpyro\n", " inference_library_version: 0.15.2\n", " sampling_time: 14.669591
<xarray.Dataset> Size: 9kB\n", "Dimensions: (date: 179, channel: 2, control: 3)\n", "Coordinates:\n", " * date (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30\n", " * channel (channel) <U2 16B 'x1' 'x2'\n", " * control (control) <U7 84B 'event_1' 'event_2' 't'\n", "Data variables:\n", " channel_data (date, channel) float64 3kB ...\n", " control_data (date, control) float64 4kB ...\n", " dayofyear (date) int32 716B ...\n", "Attributes:\n", " created_at: 2024-11-14T13:56:50.177660\n", " arviz_version: 0.17.1\n", " inference_library: numpyro\n", " inference_library_version: 0.15.2\n", " sampling_time: 14.669591
<xarray.Dataset> Size: 13kB\n", "Dimensions: (index: 179)\n", "Coordinates:\n", " * index (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178\n", "Data variables:\n", " date_week (index) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30\n", " x1 (index) float64 1kB 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n", " x2 (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.8633 0.0 0.0 0.0\n", " event_1 (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n", " event_2 (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n", " dayofyear (index) int64 1kB 92 99 106 113 120 127 ... 214 221 228 235 242\n", " t (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178\n", " y (index) float64 1kB 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
<xarray.Dataset> Size: 6MB\n", "Dimensions: (chain: 4, draw: 1000, date: 179)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999\n", " * date (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30\n", "Data variables:\n", " y (chain, draw, date) float64 6MB 3.945e+03 3.424e+03 ... 5.061e+03\n", "Attributes:\n", " created_at: 2024-11-14T13:56:53.715774\n", " arviz_version: 0.17.1\n", " inference_library: pymc\n", " inference_library_version: 5.15.1