Source code for pymc_marketing.mmm.media_transformation
# Copyright 2022 - 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for applying media transformations to media data.
Examples
--------
Create a media transformation for online and offline media channels:
.. code-block:: python
from pymc_marketing.mmm import (
GeometricAdstock,
HillSaturation,
MediaTransformation,
MichaelisMentenSaturation,
)
# Shared media transformation for all offline media channels
offline_media_transform = MediaTransformation(
adstock=GeometricAdstock(l_max=15),
saturation=HillSaturation(),
adstock_first=True,
)
# Shared media transformation for all online media channels
online_media_transform = MediaTransformation(
adstock=GeometricAdstock(l_max=10),
saturation=MichaelisMentenSaturation(),
adstock_first=False,
),
)
Create a combined media configuration for offline and online media channels:
.. code-block:: python
from pymc_marketing.mmm import (
MediaConfig,
MediaConfigList,
)
media_configs: MediaConfigList(
[
MediaConfig(
name="offline",
columns=["TV", "Radio"],
media_transformation=offline_media_transform,
),
MediaConfig(
name="online",
columns=["Facebook", "Instagram", "YouTube", "TikTok"],
media_transformation=online_media_transform,
),
]
)
Apply the media transformation to media data in PyMC model:
.. code-block:: python
import pymc as pm
import pandas as pd
df: pd.DataFrame = ...
media_columns = media_configs.media_values
coords = {
"date": df["week"],
"media": media_columns,
}
with pm.Model(coords=coords) as model:
media_data = pm.Data(
"media_data", df.loc[:, media_columns].to_numpy(), dims=("date", "media")
)
transformed_media_data = media_configs(media_data)
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import cast
import pymc as pm
import pytensor.tensor as pt
from pymc.distributions.shape_utils import Dims
from pymc_marketing.deserialize import register_deserialization
from pymc_marketing.mmm.components.adstock import (
AdstockTransformation,
adstock_from_dict,
)
from pymc_marketing.mmm.components.saturation import (
SaturationTransformation,
saturation_from_dict,
)
[docs]
@dataclass
class MediaTransformation:
"""Wrapper for applying adstock and saturation transformation to media data.
Parameters
----------
adstock : AdstockTransformation
The adstock transformation to apply.
saturation : SaturationTransformation
The saturation transformation to apply.
adstock_first : bool
Flag to apply the adstock transformation first.
dims : Dims
The dimensions of the parameters.
Attributes
----------
first : AdstockTransformation | SaturationTransformation
The first transformation to apply.
second : AdstockTransformation | SaturationTransformation
The second transformation to apply.
"""
adstock: AdstockTransformation
saturation: SaturationTransformation
adstock_first: bool
dims: Dims | None = None
def __post_init__(self):
"""Set the first and second transformations based on the adstock_first flag."""
self.first, self.second = (
(self.adstock, self.saturation)
if self.adstock_first
else (self.saturation, self.adstock)
)
if isinstance(self.dims, str):
self.dims = (self.dims,)
self.dims = self.dims or ()
self._check_compatible_dims()
def _check_compatible_dims(self):
self.dims = cast(Dims, self.dims)
if not set(self.adstock.combined_dims).issubset(self.dims):
raise ValueError(
f"Adstock dimensions {self.adstock.combined_dims} are not a subset of {self.dims}"
)
if not set(self.saturation.combined_dims).issubset(self.dims):
raise ValueError(
f"Saturation dimensions {self.saturation.combined_dims} are not a subset of {self.dims}"
)
def __call__(self, x):
"""Apply adstock and saturation transformation to media data.
Parameters
----------
x : pt.TensorLike
The media data to transform.
dim : str
The dimension of the parameters.
Returns
-------
pt.TensorVariable
The transformed media data.
Examples
--------
Apply the media transformation to media data:
.. code-block:: python
from pymc_marketing.mmm import (
GeometricAdstock,
HillSaturation,
MediaTransformation,
)
media_data = ...
media_transformation = MediaTransformation(
adstock=GeometricAdstock(l_max=15),
saturation=HillSaturation(),
adstock_first=True,
)
coords = {
"date": ...,
"media": ...,
}
with pm.Model(coords=coords) as model:
transformed_media_data = media_transformation(
media_data,
dim="media",
)
"""
return self.second.apply(self.first.apply(x, self.dims), self.dims)
[docs]
def to_dict(self) -> dict:
"""Convert the media transformation to a dictionary.
Returns
-------
dict
The media transformation as a dictionary.
"""
return {
"adstock": self.adstock.to_dict(),
"saturation": self.saturation.to_dict(),
"adstock_first": self.adstock_first,
"dims": self.dims,
}
[docs]
@classmethod
def from_dict(cls, data) -> MediaTransformation:
"""Create a media transformation from a dictionary.
Parameters
----------
data : dict
The data to create the media transformation from.
Returns
-------
MediaTransformation
The media transformation created from the dictionary.
"""
return cls(
adstock=adstock_from_dict(data["adstock"]),
saturation=saturation_from_dict(data["saturation"]),
adstock_first=data["adstock_first"],
dims=data.get("dims"),
)
def _is_media_transformation(data):
return (
isinstance(data, dict)
and "adstock" in data
and "saturation" in data
and "adstock_first" in data
)
register_deserialization(
is_type=_is_media_transformation,
deserialize=MediaTransformation.from_dict,
)
[docs]
@dataclass
class MediaConfig:
"""Configuration for a media transformation to certain media channels.
Parameters
----------
name : str
The name of the media transformation and prefix of all media variables.
columns : list[str]
The media channels to apply the transformation to.
media_transformation : MediaTransformation
The media transformation to apply to the media channels.
"""
name: str
columns: list[str]
media_transformation: MediaTransformation
[docs]
def to_dict(self) -> dict:
"""Convert the media configuration to a dictionary.
Returns
-------
dict
The media configuration as a dictionary.
"""
return {
"name": self.name,
"columns": self.columns,
"media_transformation": self.media_transformation.to_dict(),
}
[docs]
@classmethod
def from_dict(cls, data) -> MediaConfig:
"""Create a media configuration from a dictionary.
Parameters
----------
data : dict
The data to create the media configuration from.
Returns
-------
MediaConfig
The media configuration created from the dictionary.
"""
return cls(
name=data["name"],
columns=data["columns"],
media_transformation=MediaTransformation.from_dict(
data["media_transformation"]
),
)
def _is_media_config(data):
return (
isinstance(data, dict)
and "name" in data
and "columns" in data
and "media_transformation" in data
and _is_media_transformation(data["media_transformation"])
)
[docs]
class MediaConfigList:
"""Wrapper for a list of media configurations to apply to media data.
Parameters
----------
media_configs : list[MediaConfig]
The media configurations to apply to the media data.
Examples
--------
Different order of media transformations for online and offline media channels:
.. code-block:: python
from pymc_marketing.mmm import (
GeometricAdstock,
LogisticSaturation,
MediaTransformation,
MediaConfig,
MediaConfigList,
)
online = MediaConfig(
name="online",
columns=["Facebook", "Instagram", "YouTube", "TikTok"],
media_transformation=MediaTransformation(
adstock=GeometricAdstock(l_max=10).set_dims_for_all_priors("online"),
saturation=LogisticSaturation().set_dims_for_all_priors("online"),
adstock_first=True,
),
)
offline = MediaConfig(
name="offline",
columns=["TV", "Radio"],
media_transformation=MediaTransformation(
adstock=GeometricAdstock(
l_max=10,
).set_dims_for_all_priors("offline"),
saturation=LogisticSaturation().set_dims_for_all_priors("offline"),
adstock_first=False,
),
)
media_configs = MediaConfigList([online, offline])
"""
[docs]
def __init__(self, media_configs: list[MediaConfig]) -> None:
self.media_configs = media_configs
def __eq__(self, other) -> bool:
"""Check if the media configuration lists are equal.
Parameters
----------
other : MediaConfigList
The other media configuration list to compare.
Returns
-------
bool
True if the media configuration lists are equal, False otherwise.
"""
return self.media_configs == other.media_configs
def __getitem__(self, key: int) -> MediaConfig:
"""Get the media configuration at the specified index.
Parameters
----------
key : int
The index of the media configuration to get.
Returns
-------
MediaConfig
The media configuration at the specified index.
"""
return self.media_configs[key]
@property
def media_values(self) -> list[str]:
"""Get the media values from the media configurations.
Returns
-------
list[str]
The media values from the media configurations in the order they appear.
"""
result = []
for config in self.media_configs:
result.extend(config.columns)
return result
[docs]
def to_dict(self) -> list[dict]:
"""Convert the media configuration list to a dictionary.
Returns
-------
list[dict]
The media configuration list as a dictionary.
"""
return [config.to_dict() for config in self.media_configs]
[docs]
@classmethod
def from_dict(cls, data: list[dict]) -> MediaConfigList:
"""Create a media configuration list from a dictionary.
Parameters
----------
data : list[dict]
The data to create the media configuration list from.
Returns
-------
MediaConfigList
The media configuration list created from the dictionary.
"""
return cls([MediaConfig.from_dict(config) for config in data])
def __call__(self, x) -> pt.TensorVariable:
"""Apply media transformation to media data.
Assumes that the columns in the data correspond to the media channels
in the media_configs.
Parameters
----------
x : pt.TensorLike
The media data to transform.
Returns
-------
pt.TensorVariable
The transformed media data.
"""
model = pm.modelcontext(None)
transformed_data = []
start_idx = 0
for config in self.media_configs:
config.media_transformation.dims = config.name
model.add_coord(config.name, config.columns)
end_idx = start_idx + len(config.columns)
media_data = x[:, start_idx:end_idx]
adstock = config.media_transformation.adstock
saturation = config.media_transformation.saturation
adstock.prefix = f"{config.name}_{adstock.prefix}"
saturation.prefix = f"{config.name}_{saturation.prefix}"
media_transformation_data = config.media_transformation(
media_data,
)
transformed_data.append(media_transformation_data)
start_idx = end_idx
return pt.concatenate(transformed_data, axis=1)
def _is_media_config_list(data):
return isinstance(data, list) and all(_is_media_config(config) for config in data)
register_deserialization(
is_type=_is_media_config_list,
deserialize=MediaConfigList.from_dict,
)