Source code for pymc_marketing.mmm.fourier

#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Fourier seasonality transformations.

This modules provides Fourier seasonality transformations for use in
Marketing Mix Models. The Fourier seasonality is a set of sine and cosine
functions that can be used to model periodic patterns in the data.

There are two types of Fourier seasonality transformations available:

- Yearly Fourier: A yearly seasonality with a period of 365.25 days
- Monthly Fourier: A monthly seasonality with a period of 365.25 / 12 days
- Weekly Fourier: A weekly seasonality with a period of 7 days

.. plot::
    :context: close-figs

    import matplotlib.pyplot as plt
    import numpy as np
    import arviz as az
    from pymc_marketing.mmm import YearlyFourier
    from pymc_marketing.prior import Prior

    plt.style.use('arviz-darkgrid')

    prior = Prior(
        "Normal",
        mu=[0, 0, -1, 0],
        sigma=Prior("Gamma", mu=0.10, sigma=0.1, dims="fourier"),
        dims=("hierarchy", "fourier"),
    )
    yearly = YearlyFourier(n_order=2, prior=prior)
    coords = {"hierarchy": ["A", "B"]}
    prior = yearly.sample_prior(coords=coords)
    curve = yearly.sample_curve(prior)
    fig, _ = yearly.plot_curve(curve, subplot_kwargs={"ncols": 1})
    fig.suptitle("Yearly Fourier Seasonality")
    plt.show()

Examples
--------
Use yearly fourier seasonality for custom Marketing Mix Model.

.. code-block:: python

    import pandas as pd
    import pymc as pm

    from pymc_marketing.mmm import YearlyFourier

    yearly = YearlyFourier(n_order=3)

    dates = pd.date_range("2023-01-01", periods=52, freq="W-MON")

    dayofyear = dates.dayofyear.to_numpy()

    with pm.Model() as model:
        fourier_trend = yearly.apply(dayofyear)

Plot the prior fourier seasonality trend.

.. code-block:: python

    import matplotlib.pyplot as plt

    prior = yearly.sample_prior()
    curve = yearly.sample_curve(prior)
    yearly.plot_curve(curve)
    plt.show()

Change the prior distribution of the fourier seasonality.

.. code-block:: python

    from pymc_marketing.mmm import YearlyFourier
    from pymc_marketing.prior import Prior

    prior = Prior("Normal", mu=0, sigma=0.10)
    yearly = YearlyFourier(n_order=6, prior=prior)

Even make it hierarchical...

.. code-block:: python

    from pymc_marketing.mmm import YearlyFourier
    from pymc_marketing.prior import Prior

    # "fourier" is the default prefix!
    prior = Prior(
        "Laplace",
        mu=Prior("Normal", dims="fourier"),
        b=Prior("HalfNormal", sigma=0.1, dims="fourier"),
        dims=("fourier", "hierarchy"),
    )
    yearly = YearlyFourier(n_order=3, prior=prior)

All the plotting will still work! Just pass any coords.

.. code-block:: python

    import matplotlib.pyplot as plt

    coords = {"hierarchy": ["A", "B", "C"]}
    prior = yearly.sample_prior(coords=coords)
    curve = yearly.sample_curve(prior)
    yearly.plot_curve(curve)
    plt.show()

Out of sample predictions with fourier seasonality by changing the day of year
used in the model.

.. code-block:: python

    import pandas as pd
    import pymc as pm

    from pymc_marketing.mmm import YearlyFourier

    periods = 52 * 3
    dates = pd.date_range("2022-01-01", periods=periods, freq="W-MON")

    training_dates = dates[: 52 * 2]
    testing_dates = dates[52 * 2 :]

    yearly = YearlyFourier(n_order=3)

    coords = {
        "date": training_dates,
    }
    with pm.Model(coords=coords) as model:
        dayofyear = pm.Data(
            "dayofyear",
            training_dates.dayofyear.to_numpy(),
            dims="date",
        )

        trend = pm.Deterministic(
            "trend",
            yearly.apply(dayofyear),
            dims="date",
        )

        idata = pm.sample_prior_predictive().prior

    with model:
        pm.set_data(
            {"dayofyear": testing_dates.dayofyear.to_numpy()},
            coords={"date": testing_dates},
        )

        out_of_sample = pm.sample_posterior_predictive(
            idata,
            var_names=["trend"],
        ).posterior_predictive["trend"]


Use yearly and monthly fourier seasonality together.

By default, the prefix of the fourier seasonality is set to "fourier". However,
the prefix can be changed upon initialization in order to avoid variable name
conflicts.

.. code-block:: python

    import pandas as pd
    import pymc as pm

    from pymc_marketing.mmm import (
        MonthlyFourier,
        YearlyFourier,
    )

    yearly = YearlyFourier(n_order=6, prefix="yearly")
    monthly = MonthlyFourier(n_order=3, prefix="monthly")

    dates = pd.date_range("2023-01-01", periods=52, freq="W-MON")
    dayofyear = dates.dayofyear.to_numpy()

    coords = {
        "date": dates,
    }

    with pm.Model(coords=coords) as model:
        yearly_trend = yearly.apply(dayofyear)
        monthly_trend = monthly.apply(dayofyear)

        trend = pm.Deterministic(
            "trend",
            yearly_trend + monthly_trend,
            dims="date",
        )

    with model:
        prior_samples = pm.sample_prior_predictive().prior

"""

import datetime
from abc import abstractmethod
from collections.abc import Callable, Iterable
from typing import Any

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    InstanceOf,
    field_serializer,
    model_validator,
)
from typing_extensions import Self

from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_WEEK, DAYS_IN_YEAR
from pymc_marketing.deserialize import deserialize, register_deserialization
from pymc_marketing.plot import SelToString, plot_curve, plot_hdi, plot_samples
from pymc_marketing.prior import Prior, VariableFactory, create_dim_handler

X_NAME: str = "day"
NON_GRID_NAMES: frozenset[str] = frozenset({X_NAME})


[docs] def generate_fourier_modes( periods: pt.TensorLike, n_order: int, ) -> pt.TensorVariable: """Create fourier modes for a given period. Parameters ---------- periods : pt.TensorLike Periods to generate fourier modes for. n_order : int Number of fourier modes to generate. Returns ------- pt.TensorVariable Fourier modes. """ multiples = pt.arange(1, n_order + 1) x = 2 * pt.pi * periods values = x[:, None] * multiples return pt.concatenate( [ pt.sin(values), pt.cos(values), ], axis=1, )
[docs] class FourierBase(BaseModel): """Base class for Fourier seasonality transformations. Parameters ---------- n_order : int Number of fourier modes to use. days_in_period : float Number of days in a period. prefix : str, optional Alternative prefix for the fourier seasonality, by default None or "fourier" prior : Prior | VariableFactory, optional Prior distribution or VariableFactory for the fourier seasonality beta parameters, by default `Prior("Laplace", mu=0, b=1)` variable_name : str, optional Name of the variable that multiplies the fourier modes. By default None, in which case it is set to the `{prefix}_beta`. """ n_order: int = Field(..., gt=0) days_in_period: float = Field(..., gt=0) prefix: str = Field("fourier") prior: InstanceOf[Prior] | InstanceOf[VariableFactory] = Field( Prior("Laplace", mu=0, b=1) ) variable_name: str | None = Field(None) model_config = ConfigDict(extra="forbid")
[docs] def model_post_init(self, __context: Any) -> None: """Model post initialization for a Pydantic model.""" if self.variable_name is None: self.variable_name = f"{self.prefix}_beta" if not self.prior.dims and isinstance(self.prior, Prior): self.prior = self.prior.deepcopy() self.prior.dims = self.prefix elif not self.prior.dims: self.prior.dims = self.prefix
@model_validator(mode="after") def _check_variable_name(self) -> Self: if self.variable_name == self.prefix: raise ValueError("Variable name cannot be the same as the prefix") return self @model_validator(mode="after") def _check_prior_has_right_dimensions(self) -> Self: if self.prefix not in self.prior.dims: raise ValueError(f"Prior distribution must have dimension {self.prefix}") return self
[docs] @field_serializer("prior", when_used="json") def serialize_prior(prior: Any) -> dict[str, Any]: """Serialize the prior distribution. Parameters ---------- prior : VariableFactory | Prior The prior distribution to serialize. Returns ------- dict[str, Any] The serialized prior distribution. """ return prior.to_dict()
@property def nodes(self) -> list[str]: """Fourier node names for model coordinates.""" return [ f"{func}_{i}" for func in ["sin", "cos"] for i in range(1, self.n_order + 1) ]
[docs] def get_default_start_date( self, start_date: str | datetime.datetime | None = None, ) -> str | datetime.datetime: """Get the start date for the Fourier curve. If `start_date` is provided, validate its type. Otherwise, provide the default start date based on the subclass implementation. Parameters ---------- start_date : str or datetime.datetime, optional Provided start date. Can be a string or a datetime object. Returns ------- str or datetime.datetime The validated start date. Raises ------ TypeError If `start_date` is neither a string nor a datetime object. """ if start_date is None: return self._get_default_start_date() elif isinstance(start_date, str) | isinstance(start_date, datetime.datetime): return start_date else: raise TypeError( "start_date must be a datetime.datetime object, a string, or None" )
@abstractmethod def _get_default_start_date(self) -> datetime.datetime: """Provide the default start date. Must be implemented by subclasses. Returns ------- datetime.datetime The default start date. """ pass # pragma: no cover @abstractmethod def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index: """Return the relevant day within the characteristic periodicity. Returns ------- int or float The relevant period within the characteristic periodicity """ pass
[docs] def apply( self, dayofperiod: pt.TensorLike, result_callback: Callable[[pt.TensorVariable], None] | None = None, ) -> pt.TensorVariable: """Apply fourier seasonality to day of year. Must be used within a PyMC model context. Parameters ---------- dayofperiod : pt.TensorLike Day of year or weekday result_callback : Callable[[pt.TensorVariable], None], optional Callback function to apply to the result, by default None Returns ------- pt.TensorVariable Fourier seasonality Examples -------- Save off the result before summing through the prefix dimension. .. code-block:: python import pandas as pd import pymc as pm from pymc_marketing.mmm import YearlyFourier fourier = YearlyFourier(n_order=3) def callback(result): pm.Deterministic("fourier_trend", result, dims=("date", "fourier")) dates = pd.date_range("2023-01-01", periods=52, freq="W-MON") coords = { "date": dates, } with pm.Model(coords=coords) as model: dayofyear = dates.dayofyear.to_numpy() fourier.apply(dayofyear, result_callback=callback) """ periods = dayofperiod / self.days_in_period model = pm.modelcontext(None) model.add_coord(self.prefix, self.nodes) beta = self.prior.create_variable(self.variable_name) fourier_modes = generate_fourier_modes(periods=periods, n_order=self.n_order) DUMMY_DIM = "DATE" prefix_idx = self.prior.dims.index(self.prefix) result_dims = (DUMMY_DIM, *self.prior.dims) dim_handler = create_dim_handler(result_dims) result = dim_handler(fourier_modes, (DUMMY_DIM, self.prefix)) * dim_handler( beta, self.prior.dims ) if result_callback is not None: result_callback(result) return result.sum(axis=prefix_idx + 1)
[docs] def sample_prior(self, coords: dict | None = None, **kwargs) -> xr.Dataset: """Sample the prior distributions. Parameters ---------- coords : dict, optional Coordinates for the prior distribution, by default None kwargs Additional keywords for sample_prior_predictive Returns ------- xr.Dataset Prior distribution. """ coords = coords or {} coords[self.prefix] = self.nodes return self.prior.sample_prior(coords=coords, name=self.variable_name, **kwargs)
[docs] def sample_curve( self, parameters: az.InferenceData | xr.Dataset, use_dates: bool = False, start_date: str | datetime.datetime | None = None, ) -> xr.DataArray: """Create full period of the Fourier seasonality. Parameters ---------- parameters : az.InferenceData | xr.Dataset Inference data or dataset containing the Fourier parameters. Can be posterior or prior. use_dates : bool, optional If True, use datetime coordinates for the x-axis. Defaults to False. start_date : datetime.datetime, optional Starting date for the Fourier curve. If not provided and use_dates is True, it will be derived from the current year or month. Defaults to None. Returns ------- xr.DataArray Full period of the Fourier seasonality. """ full_period = np.arange(self.days_in_period + 1) coords = {} if use_dates: start_date = self.get_default_start_date(start_date=start_date) date_range = pd.date_range( start=start_date, periods=int(np.ceil(self.days_in_period) + 1), freq="D", ) coords["date"] = date_range.to_numpy() dayofperiod = self._get_days_in_period(date_range).to_numpy() else: coords["day"] = full_period dayofperiod = full_period for key, values in parameters[self.variable_name].coords.items(): if key in {"chain", "draw", self.prefix}: continue coords[key] = values.to_numpy() with pm.Model(coords=coords): name = f"{self.prefix}_trend" pm.Deterministic( name, self.apply(dayofperiod=dayofperiod), dims=tuple(coords.keys()), ) return pm.sample_posterior_predictive( parameters, var_names=[name], ).posterior_predictive[name]
[docs] def plot_curve( self, curve: xr.DataArray, n_samples: int = 10, hdi_probs: float | list[float] | None = None, random_seed: np.random.Generator | None = None, subplot_kwargs: dict | None = None, sample_kwargs: dict | None = None, hdi_kwargs: dict | None = None, axes: npt.NDArray[plt.Axes] | None = None, same_axes: bool = False, colors: Iterable[str] | None = None, legend: bool | None = None, sel_to_string: SelToString | None = None, ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: """Plot the seasonality for one full period. Parameters ---------- curve : xr.DataArray Sampled full period of the fourier seasonality. n_samples : int, optional Number of samples hdi_probs : float | list[float], optional HDI probabilities. Defaults to None which uses arviz default for stats.ci_prob which is 94% random_seed : int | random number generator, optional Random number generator. Defaults to None subplot_kwargs : dict, optional Keyword arguments for the subplot, by default None sample_kwargs : dict, optional Keyword arguments for the plot_full_period_samples method, by default None hdi_kwargs : dict, optional Keyword arguments for the plot_full_period_hdi method, by default None axes : npt.NDArray[plt.Axes], optional Matplotlib axes, by default None same_axes : bool, optional Use the same axes for all plots, by default False colors : Iterable[str], optional Colors for the different plots, by default None legend : bool, optional Show the legend, by default None sel_to_string : SelToString, optional Function to convert the selection to a string, by default None Returns ------- tuple[plt.Figure, npt.NDArray[plt.Axes]] Matplotlib figure and axes. """ if "date" in curve.coords: x_coord_name = "date" elif "day" in curve.coords: x_coord_name = "day" else: raise ValueError("Curve must have either 'day' or 'date' as a coordinate") return plot_curve( curve, non_grid_names={x_coord_name}, n_samples=n_samples, hdi_probs=hdi_probs, random_seed=random_seed, subplot_kwargs=subplot_kwargs, sample_kwargs=sample_kwargs, hdi_kwargs=hdi_kwargs, axes=axes, same_axes=same_axes, colors=colors, legend=legend, sel_to_string=sel_to_string, )
[docs] def plot_curve_hdi( self, curve: xr.DataArray, hdi_kwargs: dict | None = None, subplot_kwargs: dict[str, Any] | None = None, plot_kwargs: dict[str, Any] | None = None, axes: npt.NDArray[plt.Axes] | None = None, ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: """Plot full period of the fourier seasonality. Parameters ---------- curve : xr.DataArray The curve to plot. hdi_kwargs : dict, optional Keyword arguments for the az.hdi function. Defaults to None. plot_kwargs : dict, optional Keyword arguments for the fill_between function. Defaults to None. subplot_kwargs : dict, optional Keyword arguments for plt.subplots axes : npt.NDArray[plt.Axes], optional The exact axes to plot on. Overrides any subplot_kwargs Returns ------- tuple[plt.Figure, npt.NDArray[plt.Axes]] """ if "date" in curve.coords: x_coord_name = "date" elif "day" in curve.coords: x_coord_name = "day" else: raise ValueError("Curve must have either 'day' or 'date' as a coordinate") return plot_hdi( curve, non_grid_names={x_coord_name}, hdi_kwargs=hdi_kwargs, subplot_kwargs=subplot_kwargs, plot_kwargs=plot_kwargs, axes=axes, )
[docs] def plot_curve_samples( self, curve: xr.DataArray, n: int = 10, rng: np.random.Generator | None = None, plot_kwargs: dict[str, Any] | None = None, subplot_kwargs: dict[str, Any] | None = None, axes: npt.NDArray[plt.Axes] | None = None, ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: """Plot samples from the curve. Parameters ---------- curve : xr.DataArray Samples from the curve. n : int, optional Number of samples to plot, by default 10 rng : np.random.Generator, optional Random number generator, by default None plot_kwargs : dict, optional Keyword arguments for the plot function, by default None subplot_kwargs : dict, optional Keyword arguments for the subplot, by default None axes : npt.NDArray[plt.Axes], optional Matplotlib axes, by default None Returns ------- tuple[plt.Figure, npt.NDArray[plt.Axes]] Matplotlib figure and axes. """ if "date" in curve.coords: x_coord_name = "date" elif "day" in curve.coords: x_coord_name = "day" else: raise ValueError("Curve must have either 'day' or 'date' as a coordinate") return plot_samples( curve, non_grid_names={x_coord_name}, n=n, rng=rng, axes=axes, subplot_kwargs=subplot_kwargs, plot_kwargs=plot_kwargs, )
[docs] def to_dict(self) -> dict[str, Any]: """Serialize the Fourier seasonality. Returns ------- dict[str, Any] Serialized Fourier seasonality """ return { "class": self.__class__.__name__, "data": self.model_dump(mode="json"), }
[docs] @classmethod def from_dict(cls, data: dict[str, Any]) -> Self: """Deserialize the Fourier seasonality. Parameters ---------- data : dict[str, Any] Serialized Fourier seasonality Returns ------- FourierBase Deserialized Fourier seasonality """ data = data["data"] data["prior"] = deserialize(data["prior"]) return cls(**data)
[docs] class YearlyFourier(FourierBase): """Yearly fourier seasonality. .. plot:: :context: close-figs import arviz as az import matplotlib.pyplot as plt import numpy as np from pymc_marketing.mmm import YearlyFourier from pymc_marketing.prior import Prior az.style.use("arviz-white") seed = sum(map(ord, "Yearly")) rng = np.random.default_rng(seed) mu = np.array([0, 0, -1, 0]) b = 0.15 dist = Prior("Laplace", mu=mu, b=b, dims="fourier") yearly = YearlyFourier(n_order=2, prior=dist) prior = yearly.sample_prior(random_seed=rng) curve = yearly.sample_curve(prior) _, axes = yearly.plot_curve(curve) axes[0].set(title="Yearly Fourier Seasonality") plt.show() n_order : int Number of fourier modes to use. prefix : str, optional Alternative prefix for the fourier seasonality, by default None or "fourier" prior : Prior | VariableFactory, optional Prior distribution or VariableFactory for the fourier seasonality beta parameters, by default `Prior("Laplace", mu=0, b=1)` name : str, optional Name of the variable that multiplies the fourier modes, by default None variable_name : str, optional Name of the variable that multiplies the fourier modes, by default None """ days_in_period: float = DAYS_IN_YEAR def _get_default_start_date(self) -> datetime.datetime: """Get the default start date for yearly seasonality. Returns January 1st of the current year. """ current_year = datetime.datetime.now().year return datetime.datetime(year=current_year, month=1, day=1) def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index: """Return the dayofyear within the yearly periodicity. Returns ------- int or float The relevant period within the characteristic periodicity """ return dates.dayofyear
[docs] class MonthlyFourier(FourierBase): """Monthly fourier seasonality. .. plot:: :context: close-figs import arviz as az import matplotlib.pyplot as plt import numpy as np from pymc_marketing.mmm import MonthlyFourier from pymc_marketing.prior import Prior az.style.use("arviz-white") seed = sum(map(ord, "Monthly")) rng = np.random.default_rng(seed) mu = np.array([0, 0, 0.5, 0]) b = 0.075 dist = Prior("Laplace", mu=mu, b=b, dims="fourier") monthly = MonthlyFourier(n_order=2, prior=dist) prior = monthly.sample_prior(samples=100) curve = monthly.sample_curve(prior) _, axes = monthly.plot_curve(curve) axes[0].set(title="Monthly Fourier Seasonality") plt.show() n_order : int Number of fourier modes to use. prefix : str, optional Alternative prefix for the fourier seasonality, by default None or "fourier" prior : Prior | VariableFactory, optional Prior distribution or VariableFactory for the fourier seasonality beta parameters, by default `Prior("Laplace", mu=0, b=1)` name : str, optional Name of the variable that multiplies the fourier modes, by default None variable_name : str, optional Name of the variable that multiplies the fourier modes, by default None """ days_in_period: float = DAYS_IN_MONTH def _get_default_start_date(self) -> datetime.datetime: """Get the default start date for monthly seasonality. Returns the first day of the current month. """ now = datetime.datetime.now() return datetime.datetime(year=now.year, month=now.month, day=1) def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index: """Return the dayofyear within the yearly periodicity. Returns ------- int or float The relevant period within the characteristic periodicity """ return dates.dayofyear
[docs] class WeeklyFourier(FourierBase): """Weekly fourier seasonality. .. plot:: :context: close-figs import arviz as az import matplotlib.pyplot as plt import numpy as np from pymc_marketing.mmm import WeeklyFourier from pymc_marketing.prior import Prior az.style.use("arviz-white") seed = sum(map(ord, "Weekly")) rng = np.random.default_rng(seed) mu = np.array([0, 0, 0.5, 0]) b = 0.075 dist = Prior("Laplace", mu=mu, b=b, dims="fourier") weekly = WeeklyFourier(n_order=2, prior=dist) prior = weekly.sample_prior(samples=100) curve = weekly.sample_curve(prior) _, axes = weekly.plot_curve(curve) axes[0].set(title="Weekly Fourier Seasonality") plt.show() n_order : int Number of fourier modes to use. prefix : str, optional Alternative prefix for the fourier seasonality, by default None or "fourier" prior : Prior | VariableFactory, optional Prior distribution or VariableFactory for the fourier seasonality beta parameters, by default `Prior("Laplace", mu=0, b=1)` name : str, optional Name of the variable that multiplies the fourier modes, by default None variable_name : str, optional Name of the variable that multiplies the fourier modes, by default None """ days_in_period: float = DAYS_IN_WEEK def _get_default_start_date(self) -> datetime.datetime: """Get the default start date for weekly seasonality. Returns the first day of the current month. """ now = datetime.datetime.now() return datetime.datetime.fromisocalendar( year=now.year, week=now.isocalendar().week, day=1 ) def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index: """Return the weekday within the weekly periodicity. Returns ------- int or float The relevant period within the characteristic periodicity """ return dates.weekday
def _is_yearly_fourier(data: Any) -> bool: return data.get("class") == "YearlyFourier" def _is_monthly_fourier(data: Any) -> bool: return data.get("class") == "MonthlyFourier" def _is_weekly_fourier(data: Any) -> bool: return data.get("class") == "WeeklyFourier" register_deserialization( is_type=_is_yearly_fourier, deserialize=lambda data: YearlyFourier.from_dict(data), ) register_deserialization( is_type=_is_monthly_fourier, deserialize=lambda data: MonthlyFourier.from_dict(data), ) register_deserialization( is_type=_is_weekly_fourier, deserialize=lambda data: WeeklyFourier.from_dict(data) )