Source code for pymc_marketing.mmm.constraints
# Copyright 2022 - 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constraints for the BudgetOptimizer."""
from collections.abc import Callable
from typing import Any, Literal
import pytensor.tensor as pt
from pymc.pytensorf import rewrite_pregrad
from pytensor import function
[docs]
class Constraint:
"""
Represents a constraint for the BudgetOptimizer.
Attributes
----------
key (str): Identifier for the constraint.
constraint_type (Literal["eq", "ineq"]): Type of the constraint ("eq" for equality, "ineq" for inequality).
constraint_fun (Callable[[pt.TensorVariable, pt.TensorVariable, Any], pt.TensorVariable]):
Function that computes the symbolic constraint, taking `budgets_sym`, `total_budget_sym`, and `optimizer`.
"""
[docs]
def __init__(
self,
key: str,
constraint_type: Literal["eq", "ineq"],
constraint_fun: Callable[
[pt.TensorVariable, pt.TensorVariable, Any], pt.TensorVariable
],
):
self.key = key
self.constraint_type = constraint_type
self.constraint_fun = constraint_fun
[docs]
def build_default_sum_constraint(key: str = "default") -> Constraint:
"""Return a Constraint enforcing sum(budgets) == total_budget."""
def _constraint_fun(
budgets_sym: pt.TensorVariable, total_budget_sym: pt.TensorVariable, optimizer
) -> pt.TensorVariable:
return pt.sum(budgets_sym) - total_budget_sym
return Constraint(
key=key,
constraint_type="eq",
constraint_fun=_constraint_fun,
)
[docs]
def compile_constraints_for_scipy(constraints: list[Constraint] | dict, optimizer):
"""Compile constraints for scipy."""
compiled_constraints = []
budgets = optimizer._budgets
budgets_flat = optimizer._budgets_flat
total_budget = optimizer._total_budget
if isinstance(constraints, dict):
constraints = list(constraints.values())
if not constraints:
raise ValueError("No constraints provided for compilation.")
for constraint in constraints:
if not isinstance(constraint, Constraint):
raise TypeError(
f"Expected an instance of Constraint, but received {type(constraint)}. "
"Ensure all constraints are created using the Constraint class."
)
# Pass the required arguments to constraint_fun
constraint_fun_output = constraint.constraint_fun(
budgets, total_budget, optimizer
)
sym_jac_output = pt.grad(rewrite_pregrad(constraint_fun_output), budgets_flat)
# Compile symbolic => python callables
compiled_fun = function(
inputs=[budgets_flat],
outputs=constraint_fun_output,
on_unused_input="ignore",
)
compiled_jac = function(
inputs=[budgets_flat],
outputs=sym_jac_output,
on_unused_input="ignore",
)
compiled_constraints.append(
{
"type": constraint.constraint_type,
"fun": compiled_fun,
"jac": compiled_jac,
}
)
return compiled_constraints