Source code for pymc_marketing.mmm.builders.factories

#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Generic recursive factory for the MMM YAML schema."""

from __future__ import annotations

import importlib
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any

from pymc_marketing.deserialize import deserialize

# Optional short-name registry -------------------------------------------------
REGISTRY: dict[str, Any] = {
    # "Prior": pymc_marketing.prior.Prior,   # <— example of a whitelisted alias
}

# -----------------------------------------------------------------------------


[docs] def locate(qualname: str) -> Any: """ Resolve *qualname* to a Python callable. Parameters ---------- qualname : str Either a dotted import path ('pkg.mod.Class') or a key in REGISTRY. """ # Check if qualname is a dictionary (which would cause the error) if not isinstance(qualname, str): raise TypeError( f"Expected string for qualname but got {type(qualname).__name__}: {qualname}" ) if qualname in REGISTRY: return REGISTRY[qualname] module, _, obj_name = qualname.rpartition(".") if not module: raise ValueError( f"Cannot locate '{qualname}'. " "Provide a fully-qualified name or add it to REGISTRY." ) module_obj = importlib.import_module(module) return getattr(module_obj, obj_name)
[docs] def build(spec: Mapping[str, Any]) -> Any: """ Instantiate the object described by *spec*. Notes ----- Recognised keys * class : str (mandatory) * kwargs : dict (optional) * args : list (optional positional arguments) """ if not isinstance(spec["class"], str): raise TypeError( f"Expected string for 'class' but got {type(spec['class']).__name__}: {spec['class']}" ) cls = locate(spec["class"]) raw_kwargs: MutableMapping[str, Any] = dict(spec.get("kwargs", {})) raw_args: Sequence[Any] = raw_kwargs.pop("args", spec.get("args", ())) # Handle specific kwargs that should be processed differently special_processing_keys = ["priors", "prior"] # Convert list dimensions to tuples for model or effect classes if "dims" in raw_kwargs and isinstance(raw_kwargs["dims"], list): try: raw_kwargs["dims"] = tuple(raw_kwargs["dims"]) except Exception as e: print(f"Warning: Could not convert dims to tuple: {e}") kwargs = {} for k, v in raw_kwargs.items(): if k in special_processing_keys: # Handle priors and prior differently if isinstance(v, dict): if k == "priors": # Create a dictionary of priors priors_dict = {} for prior_key, prior_value in v.items(): if isinstance(prior_value, dict): priors_dict[prior_key] = deserialize(prior_value) else: priors_dict[prior_key] = prior_value kwargs[k] = priors_dict elif k == "prior" and "distribution" in v: kwargs[k] = deserialize(v) else: kwargs[k] = resolve(v) else: kwargs[k] = resolve(v) else: # --- recurse into nested objects for other items ----------------------------------------- kwargs[k] = resolve(v) args = [resolve(v) for v in raw_args] return cls(*args, **kwargs)
[docs] def resolve(value): """ Resolve a value by recursively building nested objects. This is a helper function for build. """ if isinstance(value, Mapping) and "class" in value: return build(value) if ( isinstance(value, list) and value and isinstance(value[0], Mapping) and "class" in value[0] ): return [build(v) for v in value] return value