# Copyright 2022 - 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class that represents a prior distribution.
The `Prior` class is a wrapper around PyMC distributions that allows the user
to create outside of the PyMC model.
This is the alternative to using the dictionaries in PyMC-Marketing models.
Examples
--------
Create a normal prior.
.. code-block:: python
from pymc_marketing.prior import Prior
normal = Prior("Normal")
Create a hierarchical normal prior by using distributions for the parameters
and specifying the dims.
.. code-block:: python
hierarchical_normal = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
dims="channel",
)
Create a non-centered hierarchical normal prior with the `centered` parameter.
.. code-block:: python
non_centered_hierarchical_normal = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
dims="channel",
# Only change needed to make it non-centered
centered=False,
)
Create a hierarchical beta prior by using Beta distribution, distributions for
the parameters, and specifying the dims.
.. code-block:: python
hierarchical_beta = Prior(
"Beta",
alpha=Prior("HalfNormal"),
beta=Prior("HalfNormal"),
dims="channel",
)
Create a transformed hierarchical normal prior by using the `transform`
parameter. Here the "sigmoid" transformation comes from `pm.math`.
.. code-block:: python
transformed_hierarchical_normal = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
transform="sigmoid",
dims="channel",
)
Create a prior with a custom transform function by registering it with
`register_tensor_transform`.
.. code-block:: python
from pymc_marketing.prior import register_tensor_transform
def custom_transform(x):
return x**2
register_tensor_transform("square", custom_transform)
custom_distribution = Prior("Normal", transform="square")
"""
from __future__ import annotations
import copy
from collections.abc import Callable
from inspect import signature
from typing import Any, Protocol, runtime_checkable
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from pydantic import InstanceOf, validate_call
from pydantic.dataclasses import dataclass
from pymc.distributions.shape_utils import Dims
from pymc_marketing.deserialize import deserialize, register_deserialization
class UnsupportedShapeError(Exception):
"""Error for when the shapes from variables are not compatible."""
class UnsupportedDistributionError(Exception):
"""Error for when an unsupported distribution is used."""
class UnsupportedParameterizationError(Exception):
"""The follow parameterization is not supported."""
class MuAlreadyExistsError(Exception):
"""Error for when 'mu' is present in Prior."""
def __init__(self, distribution: Prior) -> None:
self.distribution = distribution
self.message = f"The mu parameter is already defined in {distribution}"
super().__init__(self.message)
class UnknownTransformError(Exception):
"""Error for when an unknown transform is used."""
def _remove_leading_xs(args: list[str | int]) -> list[str | int]:
"""Remove leading 'x' from the args."""
while args and args[0] == "x":
args.pop(0)
return args
[docs]
def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVariable:
"""Take a tensor of dims `dims` and align it to `desired_dims`.
Doesn't check for validity of the dims
Examples
--------
1D to 2D with new dim
.. code-block:: python
x = np.array([1, 2, 3])
dims = "channel"
desired_dims = ("channel", "group")
handle_dims(x, dims, desired_dims)
"""
x = pt.as_tensor_variable(x)
if np.ndim(x) == 0:
return x
dims = dims if isinstance(dims, tuple) else (dims,)
desired_dims = desired_dims if isinstance(desired_dims, tuple) else (desired_dims,)
if difference := set(dims).difference(desired_dims):
raise UnsupportedShapeError(
f"Dims {dims} of data are not a subset of the desired dims {desired_dims}. "
f"{difference} is missing from the desired dims."
)
aligned_dims = np.array(dims)[:, None] == np.array(desired_dims)
missing_dims = aligned_dims.sum(axis=0) == 0
new_idx = aligned_dims.argmax(axis=0)
args = [
"x" if missing else idx
for (idx, missing) in zip(new_idx, missing_dims, strict=False)
]
args = _remove_leading_xs(args)
return x.dimshuffle(*args)
DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike]
[docs]
def create_dim_handler(desired_dims: Dims) -> DimHandler:
"""Wrap the `handle_dims` function to act like the previous `create_dim_handler` function."""
def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable:
return handle_dims(x, dims, desired_dims)
return func
def _dims_to_str(obj: tuple[str, ...]) -> str:
if len(obj) == 1:
return f'"{obj[0]}"'
return (
"(" + ", ".join(f'"{i}"' if isinstance(i, str) else str(i) for i in obj) + ")"
)
def _get_pymc_distribution(name: str) -> type[pm.Distribution]:
if not hasattr(pm, name):
raise UnsupportedDistributionError(
f"PyMC doesn't have a distribution of name {name!r}"
)
return getattr(pm, name)
Transform = Callable[[pt.TensorLike], pt.TensorLike]
CUSTOM_TRANSFORMS: dict[str, Transform] = {}
def _get_transform(name: str):
if name in CUSTOM_TRANSFORMS:
return CUSTOM_TRANSFORMS[name]
for module in (pt, pm.math):
if hasattr(module, name):
break
else:
module = None
if not module:
msg = (
f"Neither pytensor.tensor nor pymc.math have the function {name!r}. "
"If this is a custom function, register it with the "
"`pymc_marketing.prior.register_tensor_transform` function before "
"previous function call."
)
raise UnknownTransformError(msg)
return getattr(module, name)
def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:
return set(signature(distribution.dist).parameters.keys()) - {"kwargs", "args"}
[docs]
@runtime_checkable
class VariableFactory(Protocol):
"""Protocol for something that works like a Prior class."""
dims: tuple[str, ...]
[docs]
def create_variable(self, name: str) -> pt.TensorVariable:
"""Create a TensorVariable."""
[docs]
def sample_prior(
factory: VariableFactory,
coords=None,
name: str = "var",
wrap: bool = False,
**sample_prior_predictive_kwargs,
) -> xr.Dataset:
"""Sample the prior for an arbitrary VariableFactory.
Parameters
----------
factory : VariableFactory
The factory to sample from.
coords : dict[str, list[str]], optional
The coordinates for the variable, by default None.
Only required if the dims are specified.
name : str, optional
The name of the variable, by default "var".
wrap : bool, optional
Whether to wrap the variable in a `pm.Deterministic` node, by default False.
sample_prior_predictive_kwargs : dict
Additional arguments to pass to `pm.sample_prior_predictive`.
Returns
-------
xr.Dataset
The dataset of the prior samples.
Example
-------
Sample from an arbitrary variable factory.
.. code-block:: python
import pymc as pm
import pytensor.tensor as pt
from pymc_marketing.prior import sample_prior
class CustomVariableDefinition:
def __init__(self, dims, n: int):
self.dims = dims
self.n = n
def create_variable(self, name: str) -> "TensorVariable":
x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims)
return pt.sum([x**n for n in range(1, self.n + 1)], axis=0)
cubic = CustomVariableDefinition(dims=("channel",), n=3)
coords = {"channel": ["C1", "C2", "C3"]}
# Doesn't include the return value
prior = sample_prior(cubic, coords=coords)
prior_with = sample_prior(cubic, coords=coords, wrap=True)
"""
coords = coords or {}
if isinstance(factory.dims, str):
dims = (factory.dims,)
else:
dims = factory.dims
if missing_keys := set(dims) - set(coords.keys()):
raise KeyError(f"Coords are missing the following dims: {missing_keys}")
with pm.Model(coords=coords) as model:
if wrap:
pm.Deterministic(name, factory.create_variable(name), dims=factory.dims)
else:
factory.create_variable(name)
return pm.sample_prior_predictive(
model=model,
**sample_prior_predictive_kwargs,
).prior
[docs]
class Prior:
"""A class to represent a prior distribution.
This is the alternative to using the dictionaries in PyMC-Marketing models
but provides added flexibility and functionality.
Make use of the various helper methods to understand the distributions
better.
- `preliz` attribute to get the equivalent distribution in `preliz`
- `sample_prior` method to sample from the prior
- `graph` get a dummy model graph with the distribution
- `constrain` to shift the distribution to a different range
Parameters
----------
distribution : str
The name of PyMC distribution.
dims : Dims, optional
The dimensions of the variable, by default None
centered : bool, optional
Whether the variable is centered or not, by default True.
Only allowed for Normal distribution.
transform : str, optional
The name of the transform to apply to the variable after it is
created, by default None or no transform. The transformation must
be registered with `register_tensor_transform` function or
be available in either `pytensor.tensor` or `pymc.math`.
"""
# Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
non_centered_distributions: dict[str, dict[str, float]] = {
"Normal": {"mu": 0, "sigma": 1},
"StudentT": {"mu": 0, "sigma": 1},
"ZeroSumNormal": {"sigma": 1},
}
pymc_distribution: type[pm.Distribution]
pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None
[docs]
@validate_call
def __init__(
self,
distribution: str,
*,
dims: Dims | None = None,
centered: bool = True,
transform: str | None = None,
**parameters,
) -> None:
self.distribution = distribution
self.parameters = parameters
self.dims = dims
self.centered = centered
self.transform = transform
self._checks()
@property
def distribution(self) -> str:
"""The name of the PyMC distribution."""
return self._distribution
@distribution.setter
def distribution(self, distribution: str) -> None:
if hasattr(self, "_distribution"):
raise AttributeError("Can't change the distribution")
self._distribution = distribution
self.pymc_distribution = _get_pymc_distribution(distribution)
@property
def transform(self) -> str | None:
"""The name of the transform to apply to the variable after it is created."""
return self._transform
@transform.setter
def transform(self, transform: str | None) -> None:
self._transform = transform
self.pytensor_transform = not transform or _get_transform(transform) # type: ignore
@property
def dims(self) -> Dims:
"""The dimensions of the variable."""
return self._dims
@dims.setter
def dims(self, dims) -> None:
if isinstance(dims, str):
dims = (dims,)
elif isinstance(dims, list):
dims = tuple(dims)
self._dims = dims or ()
self._param_dims_work()
self._unique_dims()
def __getitem__(self, key: str) -> Prior | Any:
"""Return the parameter of the prior."""
return self.parameters[key]
def _checks(self) -> None:
if not self.centered:
self._correct_non_centered_distribution()
self._parameters_are_at_least_subset_of_pymc()
self._convert_lists_to_numpy()
self._parameters_are_correct_type()
def _parameters_are_at_least_subset_of_pymc(self) -> None:
pymc_params = _get_pymc_parameters(self.pymc_distribution)
if not set(self.parameters.keys()).issubset(pymc_params):
msg = (
f"Parameters {set(self.parameters.keys())} "
"are not a subset of the pymc distribution "
f"parameters {set(pymc_params)}"
)
raise ValueError(msg)
def _convert_lists_to_numpy(self) -> None:
def convert(x):
if not isinstance(x, list):
return x
return np.array(x)
self.parameters = {
key: convert(value) for key, value in self.parameters.items()
}
def _parameters_are_correct_type(self) -> None:
supported_types = (
int,
float,
np.ndarray,
Prior,
pt.TensorVariable,
VariableFactory,
)
incorrect_types = {
param: type(value)
for param, value in self.parameters.items()
if not isinstance(value, supported_types)
}
if incorrect_types:
msg = (
"Parameters must be one of the following types: "
f"(int, float, np.array, Prior, pt.TensorVariable). Incorrect parameters: {incorrect_types}"
)
raise ValueError(msg)
def _correct_non_centered_distribution(self) -> None:
if (
not self.centered
and self.distribution not in self.non_centered_distributions
):
raise UnsupportedParameterizationError(
f"{self.distribution!r} is not supported for non-centered parameterization. "
f"Choose from {list(self.non_centered_distributions.keys())}"
)
required_parameters = set(
self.non_centered_distributions[self.distribution].keys()
)
if set(self.parameters.keys()) < required_parameters:
msg = " and ".join([f"{param!r}" for param in required_parameters])
raise ValueError(
f"Must have at least {msg} parameter for non-centered for {self.distribution!r}"
)
def _unique_dims(self) -> None:
if not self.dims:
return
if len(self.dims) != len(set(self.dims)):
raise ValueError("Dims must be unique")
def _param_dims_work(self) -> None:
other_dims = set()
for value in self.parameters.values():
if hasattr(value, "dims"):
other_dims.update(value.dims)
if not other_dims.issubset(self.dims):
raise UnsupportedShapeError(
f"Parameter dims {other_dims} are not a subset of the prior dims {self.dims}"
)
def __str__(self) -> str:
"""Return a string representation of the prior."""
param_str = ", ".join(
[f"{param}={value}" for param, value in self.parameters.items()]
)
param_str = "" if not param_str else f", {param_str}"
dim_str = f", dims={_dims_to_str(self.dims)}" if self.dims else ""
centered_str = f", centered={self.centered}" if not self.centered else ""
transform_str = f', transform="{self.transform}"' if self.transform else ""
return f'Prior("{self.distribution}"{param_str}{dim_str}{centered_str}{transform_str})'
def __repr__(self) -> str:
"""Return a string representation of the prior."""
return f"{self}"
def _create_parameter(self, param, value, name):
if not hasattr(value, "create_variable"):
return value
child_name = f"{name}_{param}"
return self.dim_handler(value.create_variable(child_name), value.dims)
def _create_centered_variable(self, name: str):
parameters = {
param: self._create_parameter(param, value, name)
for param, value in self.parameters.items()
}
return self.pymc_distribution(name, **parameters, dims=self.dims)
def _create_non_centered_variable(self, name: str) -> pt.TensorVariable:
def handle_variable(var_name: str):
parameter = self.parameters[var_name]
if not hasattr(parameter, "create_variable"):
return parameter
return self.dim_handler(
parameter.create_variable(f"{name}_{var_name}"),
parameter.dims,
)
defaults = self.non_centered_distributions[self.distribution]
other_parameters = {
param: handle_variable(param)
for param in self.parameters.keys()
if param not in defaults
}
offset = self.pymc_distribution(
f"{name}_offset",
**defaults,
**other_parameters,
dims=self.dims,
)
if "mu" in self.parameters:
mu = (
handle_variable("mu")
if isinstance(self.parameters["mu"], Prior)
else self.parameters["mu"]
)
else:
mu = 0
sigma = (
handle_variable("sigma")
if isinstance(self.parameters["sigma"], Prior)
else self.parameters["sigma"]
)
return pm.Deterministic(
name,
mu + sigma * offset,
dims=self.dims,
)
[docs]
def create_variable(self, name: str) -> pt.TensorVariable:
"""Create a PyMC variable from the prior.
Must be used in a PyMC model context.
Parameters
----------
name : str
The name of the variable.
Returns
-------
pt.TensorVariable
The PyMC variable.
Examples
--------
Create a hierarchical normal variable in larger PyMC model.
.. code-block:: python
dist = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
dims="channel",
)
coords = {"channel": ["C1", "C2", "C3"]}
with pm.Model(coords=coords):
var = dist.create_variable("var")
"""
self.dim_handler = create_dim_handler(self.dims)
if self.transform:
var_name = f"{name}_raw"
def transform(var):
return pm.Deterministic(
name, self.pytensor_transform(var), dims=self.dims
)
else:
var_name = name
def transform(var):
return var
create_variable = (
self._create_centered_variable
if self.centered
else self._create_non_centered_variable
)
var = create_variable(name=var_name)
return transform(var)
@property
def preliz(self):
"""Create an equivalent preliz distribution.
Helpful to visualize a distribution when it is univariate.
Returns
-------
preliz.distributions.Distribution
Examples
--------
Create a preliz distribution from a prior.
.. code-block:: python
from pymc_marketing.prior import Prior
dist = Prior("Gamma", alpha=5, beta=1)
dist.preliz.plot_pdf()
"""
import preliz as pz
return getattr(pz, self.distribution)(**self.parameters)
[docs]
def to_dict(self) -> dict[str, Any]:
"""Convert the prior to dictionary format.
This is equivalent to the older PyMC-Marketing dictionary format.
Returns
-------
dict[str, Any]
The dictionary format of the prior.
Examples
--------
Convert a prior to the dictionary format.
.. code-block:: python
from pymc_marketing.prior import Prior
dist = Prior("Normal", mu=0, sigma=1)
dist.to_dict()
# {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
Convert a hierarchical prior to the dictionary format.
.. code-block:: python
dist = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
dims="channel",
)
dist.to_dict()
# {
# "dist": "Normal",
# "kwargs": {
# "mu": {"dist": "Normal"},
# "sigma": {"dist": "HalfNormal"},
# },
# "dims": "channel",
# }
"""
data: dict[str, Any] = {
"dist": self.distribution,
}
if self.parameters:
def handle_value(value):
if isinstance(value, Prior):
return value.to_dict()
if isinstance(value, pt.TensorVariable):
value = value.eval()
if isinstance(value, np.ndarray):
return value.tolist()
if hasattr(value, "to_dict"):
return value.to_dict()
return value
data["kwargs"] = {
param: handle_value(value) for param, value in self.parameters.items()
}
if not self.centered:
data["centered"] = False
if self.dims:
data["dims"] = self.dims
if self.transform:
data["transform"] = self.transform
return data
[docs]
@classmethod
def from_dict(cls, data) -> Prior:
"""Create a Prior from the dictionary format.
Parameters
----------
data : dict[str, Any]
The dictionary format of the prior.
Returns
-------
Prior
The prior distribution.
Examples
--------
Convert prior in the dictionary format to a Prior instance.
.. code-block:: python
from pymc_marketing.prior import Prior
data = {
"dist": "Normal",
"kwargs": {"mu": 0, "sigma": 1},
}
dist = Prior.from_dict(data)
dist
# Prior("Normal", mu=0, sigma=1)
"""
if not isinstance(data, dict):
msg = (
"Must be a dictionary representation of a prior distribution. "
f"Not of type: {type(data)}"
)
raise ValueError(msg)
dist = data["dist"]
kwargs = data.get("kwargs", {})
def handle_value(value):
if isinstance(value, dict):
return deserialize(value)
if isinstance(value, list):
return np.array(value)
return value
kwargs = {param: handle_value(value) for param, value in kwargs.items()}
centered = data.get("centered", True)
dims = data.get("dims")
if isinstance(dims, list):
dims = tuple(dims)
transform = data.get("transform")
return cls(dist, dims=dims, centered=centered, transform=transform, **kwargs)
[docs]
def constrain(
self, lower: float, upper: float, mass: float = 0.95, kwargs=None
) -> Prior:
"""Create a new prior with a given mass constrained within the given bounds.
Wrapper around `preliz.maxent`.
Parameters
----------
lower : float
The lower bound.
upper : float
The upper bound.
mass: float = 0.95
The mass of the distribution to keep within the bounds.
kwargs : dict
Additional arguments to pass to `pz.maxent`.
Returns
-------
Prior
The maximum entropy prior with a mass constrained to the given bounds.
Examples
--------
Create a Beta distribution that is constrained to have 95% of the mass
between 0.5 and 0.8.
.. code-block:: python
dist = Prior(
"Beta",
).constrain(lower=0.5, upper=0.8)
Create a Beta distribution with mean 0.6, that is constrained to
have 95% of the mass between 0.5 and 0.8.
.. code-block:: python
dist = Prior(
"Beta",
mu=0.6,
).constrain(lower=0.5, upper=0.8)
"""
from preliz import maxent
if self.transform:
raise ValueError("Can't constrain a transformed variable")
if kwargs is None:
kwargs = {}
kwargs.setdefault("plot", False)
if kwargs["plot"]:
new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs)[
0
].params_dict
else:
new_parameters = maxent(
self.preliz, lower, upper, mass, **kwargs
).params_dict
return Prior(
self.distribution,
dims=self.dims,
transform=self.transform,
centered=self.centered,
**new_parameters,
)
def __eq__(self, other) -> bool:
"""Check if two priors are equal."""
if not isinstance(other, Prior):
return False
try:
np.testing.assert_equal(self.parameters, other.parameters)
except AssertionError:
return False
return (
self.distribution == other.distribution
and self.dims == other.dims
and self.centered == other.centered
and self.transform == other.transform
)
[docs]
def sample_prior(
self,
coords=None,
name: str = "var",
**sample_prior_predictive_kwargs,
) -> xr.Dataset:
"""Sample the prior distribution for the variable.
Parameters
----------
coords : dict[str, list[str]], optional
The coordinates for the variable, by default None.
Only required if the dims are specified.
name : str, optional
The name of the variable, by default "var".
sample_prior_predictive_kwargs : dict
Additional arguments to pass to `pm.sample_prior_predictive`.
Returns
-------
xr.Dataset
The dataset of the prior samples.
Example
-------
Sample from a hierarchical normal distribution.
.. code-block:: python
dist = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
dims="channel",
)
coords = {"channel": ["C1", "C2", "C3"]}
prior = dist.sample_prior(coords=coords)
"""
return sample_prior(
factory=self,
coords=coords,
name=name,
**sample_prior_predictive_kwargs,
)
def __deepcopy__(self, memo) -> Prior:
"""Return a deep copy of the prior."""
if id(self) in memo:
return memo[id(self)]
copy_obj = Prior(
self.distribution,
dims=copy.copy(self.dims),
centered=self.centered,
transform=self.transform,
**copy.deepcopy(self.parameters),
)
memo[id(self)] = copy_obj
return copy_obj
[docs]
def deepcopy(self) -> Prior:
"""Return a deep copy of the prior."""
return copy.deepcopy(self)
[docs]
def to_graph(self):
"""Generate a graph of the variables.
Examples
--------
Create the graph for a 2D transformed hierarchical distribution.
.. code-block:: python
from pymc_marketing.prior import Prior
mu = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
dims="channel",
)
sigma = Prior("HalfNormal", dims="channel")
dist = Prior(
"Normal",
mu=mu,
sigma=sigma,
dims=("channel", "geo"),
centered=False,
transform="sigmoid",
)
dist.to_graph()
.. image:: /_static/example-graph.png
:alt: Example graph
"""
coords = {name: ["DUMMY"] for name in self.dims}
with pm.Model(coords=coords) as model:
self.create_variable("var")
return pm.model_to_graphviz(model)
[docs]
def create_likelihood_variable(
self,
name: str,
mu: pt.TensorLike,
observed: pt.TensorLike,
) -> pt.TensorVariable:
"""Create a likelihood variable from the prior.
Will require that the distribution has a `mu` parameter
and that it has not been set in the parameters.
Parameters
----------
name : str
The name of the variable.
mu : pt.TensorLike
The mu parameter for the likelihood.
observed : pt.TensorLike
The observed data.
Returns
-------
pt.TensorVariable
The PyMC variable.
Examples
--------
Create a likelihood variable in a larger PyMC model.
.. code-block:: python
import pymc as pm
dist = Prior("Normal", sigma=Prior("HalfNormal"))
with pm.Model():
# Create the likelihood variable
mu = pm.Normal("mu", mu=0, sigma=1)
dist.create_likelihood_variable("y", mu=mu, observed=observed)
"""
if "mu" not in _get_pymc_parameters(self.pymc_distribution):
raise UnsupportedDistributionError(
f"Likelihood distribution {self.distribution!r} is not supported."
)
if "mu" in self.parameters:
raise MuAlreadyExistsError(self)
distribution = self.deepcopy()
distribution.parameters["mu"] = mu
distribution.parameters["observed"] = observed
return distribution.create_variable(name)
[docs]
def is_alternative_prior(data: Any) -> bool:
"""Check if the data is a dictionary representing a Prior (alternative check)."""
return isinstance(data, dict) and "distribution" in data
[docs]
def deserialize_alternative_prior(data: dict[str, Any]) -> Prior:
"""Alternative deserializer that recursively handles all nested parameters.
This implementation is more general and handles cases where any parameter
might be a nested prior, and also extracts centered and transform parameters.
Examples
--------
This handles cases like:
.. code-block:: yaml
distribution: Gamma
alpha: 1
beta:
distribution: HalfNormal
sigma: 1
dims: channel
dims: [brand, channel]
"""
data = copy.deepcopy(data)
distribution = data.pop("distribution")
dims = data.pop("dims", None)
centered = data.pop("centered", True)
transform = data.pop("transform", None)
parameters = data
# Recursively deserialize any nested parameters
parameters = {
key: value if not isinstance(value, dict) else deserialize(value)
for key, value in parameters.items()
}
return Prior(
distribution,
transform=transform,
centered=centered,
dims=dims,
**parameters,
)
# Register the alternative prior deserializer for more complex nested cases
register_deserialization(is_alternative_prior, deserialize_alternative_prior)
class VariableNotFound(Exception):
"""Variable is not found."""
def _remove_random_variable(var: pt.TensorVariable) -> None:
if var.name is None:
raise ValueError("This isn't removable")
name: str = var.name
model = pm.modelcontext(None)
for idx, free_rv in enumerate(model.free_RVs):
if var == free_rv:
index_to_remove = idx
break
else:
raise VariableNotFound(f"Variable {var.name!r} not found")
var.name = None
model.free_RVs.pop(index_to_remove)
model.named_vars.pop(name)
[docs]
@dataclass
class Censored:
"""Create censored random variable.
Examples
--------
Create a censored Normal distribution:
.. code-block:: python
from pymc_marketing.prior import Prior, Censored
normal = Prior("Normal")
censored_normal = Censored(normal, lower=0)
Create hierarchical censored Normal distribution:
.. code-block:: python
from pymc_marketing.prior import Prior, Censored
normal = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
dims="channel",
)
censored_normal = Censored(normal, lower=0)
coords = {"channel": range(3)}
samples = censored_normal.sample_prior(coords=coords)
"""
distribution: InstanceOf[Prior]
lower: float | InstanceOf[pt.TensorVariable] = -np.inf
upper: float | InstanceOf[pt.TensorVariable] = np.inf
def __post_init__(self) -> None:
"""Check validity at initialization."""
if not self.distribution.centered:
raise ValueError(
"Censored distribution must be centered so that .dist() API can be used on distribution."
)
if self.distribution.transform is not None:
raise ValueError(
"Censored distribution can't have a transform so that .dist() API can be used on distribution."
)
@property
def dims(self) -> tuple[str, ...]:
"""The dims from the distribution to censor."""
return self.distribution.dims
@dims.setter
def dims(self, dims) -> None:
self.distribution.dims = dims
[docs]
def create_variable(self, name: str) -> pt.TensorVariable:
"""Create censored random variable."""
dist = self.distribution.create_variable(name)
_remove_random_variable(var=dist)
return pm.Censored(
name,
dist,
lower=self.lower,
upper=self.upper,
dims=self.dims,
)
[docs]
def to_dict(self) -> dict[str, Any]:
"""Convert the censored distribution to a dictionary."""
def handle_value(value):
if isinstance(value, pt.TensorVariable):
return value.eval().tolist()
return value
return {
"class": "Censored",
"data": {
"dist": self.distribution.to_dict(),
"lower": handle_value(self.lower),
"upper": handle_value(self.upper),
},
}
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Censored:
"""Create a censored distribution from a dictionary."""
data = data["data"]
return cls( # type: ignore
distribution=deserialize(data["dist"]),
lower=data["lower"],
upper=data["upper"],
)
[docs]
def sample_prior(
self,
coords=None,
name: str = "variable",
**sample_prior_predictive_kwargs,
) -> xr.Dataset:
"""Sample the prior distribution for the variable.
Parameters
----------
coords : dict[str, list[str]], optional
The coordinates for the variable, by default None.
Only required if the dims are specified.
name : str, optional
The name of the variable, by default "var".
sample_prior_predictive_kwargs : dict
Additional arguments to pass to `pm.sample_prior_predictive`.
Returns
-------
xr.Dataset
The dataset of the prior samples.
Example
-------
Sample from a censored Gamma distribution.
.. code-block:: python
gamma = Prior("Gamma", mu=1, sigma=1, dims="channel")
dist = Censored(gamma, lower=0.5)
coords = {"channel": ["C1", "C2", "C3"]}
prior = dist.sample_prior(coords=coords)
"""
return sample_prior(
factory=self,
coords=coords,
name=name,
**sample_prior_predictive_kwargs,
)
[docs]
def to_graph(self):
"""Generate a graph of the variables.
Examples
--------
Create graph for a censored Normal distribution
.. code-block:: python
from pymc_marketing.prior import Prior, Censored
normal = Prior("Normal")
censored_normal = Censored(normal, lower=0)
censored_normal.to_graph()
"""
coords = {name: ["DUMMY"] for name in self.dims}
with pm.Model(coords=coords) as model:
self.create_variable("var")
return pm.model_to_graphviz(model)
[docs]
def create_likelihood_variable(
self,
name: str,
mu: pt.TensorLike,
observed: pt.TensorLike,
) -> pt.TensorVariable:
"""Create observed censored variable.
Will require that the distribution has a `mu` parameter
and that it has not been set in the parameters.
Parameters
----------
name : str
The name of the variable.
mu : pt.TensorLike
The mu parameter for the likelihood.
observed : pt.TensorLike
The observed data.
Returns
-------
pt.TensorVariable
The PyMC variable.
Examples
--------
Create a censored likelihood variable in a larger PyMC model.
.. code-block:: python
import pymc as pm
from pymc_marketing.prior import Prior, Censored
normal = Prior("Normal", sigma=Prior("HalfNormal"))
dist = Censored(normal, lower=0)
observed = 1
with pm.Model():
# Create the likelihood variable
mu = pm.HalfNormal("mu", sigma=1)
dist.create_likelihood_variable("y", mu=mu, observed=observed)
"""
if "mu" not in _get_pymc_parameters(self.distribution.pymc_distribution):
raise UnsupportedDistributionError(
f"Likelihood distribution {self.distribution.distribution!r} is not supported."
)
if "mu" in self.distribution.parameters:
raise MuAlreadyExistsError(self.distribution)
distribution = self.distribution.deepcopy()
distribution.parameters["mu"] = mu
dist = distribution.create_variable(name)
_remove_random_variable(var=dist)
return pm.Censored(
name,
dist,
observed=observed,
lower=self.lower,
upper=self.upper,
dims=self.dims,
)
def _is_prior_type(data: dict) -> bool:
return "dist" in data
def _is_censored_type(data: dict) -> bool:
return data.keys() == {"class", "data"} and data["class"] == "Censored"
register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_dict)
register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict)
[docs]
class Scaled:
"""Scaled distribution for numerical stability."""
[docs]
def __init__(self, dist: Prior, factor: float | pt.TensorVariable) -> None:
self.dist = dist
self.factor = factor
@property
def dims(self) -> Dims:
"""The dimensions of the scaled distribution."""
return self.dist.dims
[docs]
def create_variable(self, name: str) -> pt.TensorVariable:
"""Create a scaled variable.
Parameters
----------
name : str
The name of the variable.
Returns
-------
pt.TensorVariable
The scaled variable.
"""
var = self.dist.create_variable(f"{name}_unscaled")
return pm.Deterministic(name, var * self.factor, dims=self.dims)