Source code for pymc_marketing.mmm.evaluation

#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Evaluation and diagnostics for MMM models."""

from typing import cast

import arviz as az
import numpy as np
import numpy.typing as npt
import pandas as pd
import xarray as xr
from sklearn.metrics import (
    mean_absolute_error,
    mean_absolute_percentage_error,
    root_mean_squared_error,
)

from pymc_marketing.metrics import nmae, nrmse


[docs] def calculate_metric_distributions( y_true: npt.NDArray | pd.Series, y_pred: npt.NDArray | xr.DataArray, metrics_to_calculate: list[str] | None = None, ) -> dict[str, npt.NDArray]: """Calculate distributions of evaluation metrics for posterior samples. Parameters ---------- y_true : npt.NDArray | pd.Series True values for the dataset. Shape: (date,) y_pred : npt.NDArray | xr.DataArray Posterior predictive samples. Shape: (date, sample) metrics_to_calculate : list of str or None, optional List of metrics to calculate. Options include: * `r_squared`: Bayesian R-squared. * `rmse`: Root Mean Squared Error. * `nrmse`: Normalized Root Mean Squared Error. * `mae`: Mean Absolute Error. * `nmae`: Normalized Mean Absolute Error. * `mape`: Mean Absolute Percentage Error. Defaults to all metrics if None. Returns ------- dict of str to npt.NDArray A dictionary containing calculated metric distributions. """ if isinstance(y_true, pd.Series): y_true = cast(np.ndarray, y_true.to_numpy()) if isinstance(y_pred, xr.DataArray): y_pred = y_pred.values metric_functions = { "r_squared": lambda y_true, y_pred: az.r2_score(y_true, y_pred.T)["r2"], "rmse": root_mean_squared_error, "nrmse": nrmse, "mae": mean_absolute_error, "nmae": nmae, "mape": mean_absolute_percentage_error, } if metrics_to_calculate is None: metrics_to_calculate = list(metric_functions.keys()) invalid_metrics = set(metrics_to_calculate) - set(metric_functions.keys()) if invalid_metrics: raise ValueError( f"Invalid metrics: {invalid_metrics}. " f"Valid options are: {list(metric_functions.keys())}" ) results = {} for metric in metrics_to_calculate: metric_values = np.array( [ metric_functions[metric]( y_true, y_pred[:, i] ) # Calculate along date dimension for i in range(y_pred.shape[1]) ] ) results[metric] = metric_values return results
[docs] def summarize_metric_distributions( metric_distributions: dict[str, npt.NDArray], hdi_prob: float = 0.94, ) -> dict[str, dict[str, float]]: """Summarize metric distributions with point estimates and HDIs. Parameters ---------- metric_distributions : dict of str to npt.NDArray Dictionary of metric distributions as returned by calculate_metric_distributions. hdi_prob : float, optional The probability mass of the highest density interval. Defaults to 0.94. Returns ------- dict of str to dict A dictionary containing summary statistics for each metric. List of summary statistics calculated for each metric: * `mean`: Mean of the metric distribution. * `median`: Median of the metric distribution. * `std`: Standard deviation of the metric distribution. * `min`: Minimum value of the metric distribution. * `max`: Maximum value of the metric distribution. * `hdi_lower`: Lower bound of the Highest Density Interval. * `hdi_upper`: Upper bound of the Highest Density Interval. """ metric_summaries = {} for metric, distribution in metric_distributions.items(): hdi = az.hdi(distribution, hdi_prob=hdi_prob) metric_summaries[metric] = { "mean": np.mean(distribution), "median": np.median(distribution), "std": np.std(distribution), "min": np.min(distribution), "max": np.max(distribution), f"{hdi_prob:.0%}_hdi_lower": hdi[0], f"{hdi_prob:.0%}_hdi_upper": hdi[1], } return metric_summaries
[docs] def compute_summary_metrics( y_true: npt.NDArray | pd.Series, y_pred: npt.NDArray | xr.DataArray, metrics_to_calculate: list[str] | None = None, hdi_prob: float = 0.94, ) -> dict[str, dict[str, float]]: """Evaluate the model by calculating metric distributions and summarizing them. This method combines the functionality of `calculate_metric_distributions` and `summarize_metric_distributions`. Parameters ---------- y_true : npt.NDArray | pd.Series The true values of the target variable. y_pred : npt.NDArray | xr.DataArray The predicted values of the target variable. metrics_to_calculate : list of str or None, optional List of metrics to calculate. Options include: * `r_squared`: Bayesian R-squared. * `rmse`: Root Mean Squared Error. * `nrmse`: Normalized Root Mean Squared Error. * `mae`: Mean Absolute Error. * `nmae`: Normalized Mean Absolute Error. * `mape`: Mean Absolute Percentage Error. Defaults to all metrics if None. hdi_prob : float, optional The probability mass of the highest density interval. Defaults to 0.94. Returns ------- dict of str to dict A dictionary containing summary statistics for each metric. List of summary statistics calculated for each metric: * `mean`: Mean of the metric distribution. * `median`: Median of the metric distribution. * `std`: Standard deviation of the metric distribution. * `min`: Minimum value of the metric distribution. * `max`: Maximum value of the metric distribution. * `hdi_lower`: Lower bound of the Highest Density Interval. * `hdi_upper`: Upper bound of the Highest Density Interval. Examples -------- Evaluation (error and model metrics) for a PyMC-Marketing MMM. .. code-block:: python import pandas as pd from pymc_marketing.mmm import ( GeometricAdstock, LogisticSaturation, MMM, ) from pymc_marketing.paths import data_dir from pymc_marketing.mmm.evaluation import compute_summary_metrics # Usual PyMC-Marketing demo model code file_path = data_dir / "mmm_example.csv" data = pd.read_csv(file_path, parse_dates=["date_week"]) X = data.drop("y", axis=1) y = data["y"] mmm = MMM( adstock=GeometricAdstock(l_max=8), saturation=LogisticSaturation(), date_column="date_week", channel_columns=["x1", "x2"], control_columns=[ "event_1", "event_2", "t", ], yearly_seasonality=2, ) mmm.fit(X, y) # Generate posterior predictive samples posterior_preds = mmm.sample_posterior_predictive(X) # Evaluate the model results = compute_summary_metrics( y_true=mmm.y, y_pred=posterior_preds.y, metrics_to_calculate=["r_squared", "rmse", "mae"], hdi_prob=0.89, ) # Print the results neatly for metric, stats in results.items(): print(f"{metric}:") for stat, value in stats.items(): print(f" {stat}: {value:.4f}") print() # r_squared: # mean: 0.9055 # median: 0.9061 # std: 0.0098 # min: 0.8669 # max: 0.9371 # 89%_hdi_lower: 0.8891 # 89%_hdi_upper: 0.9198 # # rmse: # mean: 351.9120 # median: 351.0219 # std: 19.4732 # min: 290.6544 # max: 418.0821 # 89%_hdi_lower: 317.0673 # 89%_hdi_upper: 378.1048 # # mae: # mean: 281.6953 # median: 281.2757 # std: 16.3375 # min: 234.1462 # max: 337.9461 # 89%_hdi_lower: 255.7273 # 89%_hdi_upper: 307.2391 """ metric_distributions = calculate_metric_distributions( y_true, y_pred, metrics_to_calculate, ) metric_summaries = summarize_metric_distributions(metric_distributions, hdi_prob) return metric_summaries