Source code for pymc_marketing.model_graph

#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Functions to manipulate PyMC models as graphs."""

import pymc as pm
from pymc.model.fgraph import (
    extract_dims,
    fgraph_from_model,
    model_free_rv,
    model_from_fgraph,
)
from pymc.pytensorf import toposort_replace
from pytensor.graph import rewrite_graph
from pytensor.tensor.basic import infer_shape_db


[docs] def deterministics_to_flat(model: pm.Model, names: list[str]) -> pm.Model: """Replace all specified Deterministic nodes in a pm.Model with Flat. This is useful to capture some state from a model and to then sample from the model using that state. For example, capturing the mean of a distribution or a value of a deterministic variable. See :class:`pymc_marketing.mmm.hsgp.SoftPlusHSGP` for an example of how this is used to keep a variable centered around 1.0 during sampling but stay continuous with new values. Parameters ---------- model : pm.Model PyMC model to be transformed names : list[str] Names of the deterministic variables to be replaced by flat Returns ------- new_model : pm.Model New model with all priors replaced by flat priors Examples -------- Replace single Deterministic with Flat and sample as if it were zeros. .. code-block:: python import pymc as pm import numpy as np import xarray as xr from pymc_marketing.model_graph import deterministics_to_flat with pm.Model() as model: x = pm.Normal("x", mu=0, sigma=1) y = pm.Deterministic("y", x**2) z = pm.Deterministic("z", x + y) new_model = deterministics_to_flat(model, ["y"]) chains, draws = 2, 100 mock_posterior = xr.Dataset( { "y": (("chain", "draw"), np.zeros((chains, draws))), }, coords={"chain": np.arange(chains), "draw": np.arange(draws)}, ) x_z_given_y = pm.sample_posterior_predictive( mock_posterior, model=new_model, var_names=["x", "z"], ).posterior_predictive np.testing.assert_allclose( x_z_given_y["x"], x_z_given_y["z"], ) """ fg, memo = fgraph_from_model(model, inlined_views=True) model_variables = [x for x in set(model.deterministics) if x.name in names] replacements = {} for variable in model_variables: model_var = memo[variable] dims = extract_dims(model_var) new_rv = pm.Flat.dist(shape=model_var.shape) new_rv.name = model_var.name replacements[model_var] = model_free_rv( new_rv, new_rv.type(name=model_var.name), None, *dims, ) toposort_replace(fg, replacements=tuple(replacements.items())) fg = rewrite_graph( fg, include=("ShapeOpt",), custom_rewrite=infer_shape_db.default_query, clone=False, ) new_model = model_from_fgraph(fg, mutate_fgraph=True) return new_model