Source code for pymc_marketing.mmm.causal
# Copyright 2022 - 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Causal identification class."""
import warnings
import pandas as pd
try:
from dowhy import CausalModel
except ImportError:
class LazyCausalModel:
"""Lazy import of dowhy's CausalModel."""
def __init__(self, *args, **kwargs):
msg = (
"To use Causal Graph functionality, please install the optional dependencies with: "
"pip install pymc-marketing[dag]"
)
raise ImportError(msg)
CausalModel = LazyCausalModel
[docs]
class CausalGraphModel:
"""Represent a causal model based on a Directed Acyclic Graph (DAG).
Provides methods to analyze causal relationships and determine the minimal adjustment set
for backdoor adjustment between treatment and outcome variables.
Parameters
----------
causal_model : CausalModel
An instance of dowhy's CausalModel, representing the causal graph and its relationships.
treatment : list[str]
A list of treatment variable names.
outcome : str
The outcome variable name.
References
----------
.. [1] https://github.com/microsoft/dowhy
"""
[docs]
def __init__(
self, causal_model: CausalModel, treatment: list[str] | tuple[str], outcome: str
) -> None:
self.causal_model = causal_model
self.treatment = treatment
self.outcome = outcome
[docs]
@classmethod
def build_graphical_model(
cls, graph: str, treatment: list[str] | tuple[str], outcome: str
) -> "CausalGraphModel":
"""Create a CausalGraphModel from a string representation of a graph.
Parameters
----------
graph : str
A string representation of the graph (e.g., String in DOT format).
treatment : list[str]
A list of treatment variable names.
outcome : str
The outcome variable name.
Returns
-------
CausalGraphModel
An instance of CausalGraphModel constructed from the given graph string.
"""
causal_model = CausalModel(
data=pd.DataFrame(), graph=graph, treatment=treatment, outcome=outcome
)
return cls(causal_model, treatment, outcome)
[docs]
def get_backdoor_paths(self) -> list[list[str]]:
"""Find all backdoor paths between the combined treatment and outcome variables.
Returns
-------
list[list[str]]
A list of backdoor paths, where each path is represented as a list of variable names.
References
----------
.. [1] Causal Inference in Statistics: A Primer
By Judea Pearl, Madelyn Glymour, Nicholas P. Jewell · 2016
"""
# Use DoWhy's internal method to get backdoor paths for all treatments combined
return self.causal_model._graph.get_backdoor_paths(
nodes1=self.treatment, nodes2=[self.outcome]
)
[docs]
def get_unique_adjustment_nodes(self) -> list[str]:
"""Compute the minimal adjustment set required for backdoor adjustment across all treatments.
Returns
-------
list[str]
A list of unique adjustment variables needed to block all backdoor paths.
"""
paths = self.get_backdoor_paths()
# Flatten paths and exclude treatments and outcome from adjustment set
adjustment_nodes = set(
node
for path in paths
for node in path
if node not in self.treatment and node != self.outcome
)
return list(adjustment_nodes)
[docs]
def compute_adjustment_sets(
self,
channel_columns: list[str] | tuple[str],
control_columns: list[str] | None = None,
) -> list[str] | None:
"""Compute minimal adjustment sets and handle warnings."""
channel_columns = list(channel_columns)
if control_columns is None:
return control_columns
self.adjustment_set = self.get_unique_adjustment_nodes()
common_controls = set(control_columns).intersection(self.adjustment_set)
unique_controls = set(control_columns) - set(self.adjustment_set)
if unique_controls:
warnings.warn(
f"Columns {unique_controls} are not in the adjustment set. Controls are being modified.",
stacklevel=2,
)
control_columns = list(common_controls - set(channel_columns))
self.minimal_adjustment_set = control_columns + list(channel_columns)
for column in self.adjustment_set:
if column not in control_columns and column not in channel_columns:
warnings.warn(
f"""Column {column} in adjustment set not found in data.
Not controlling for this may induce bias in treatment effect estimates.""",
stacklevel=2,
)
return control_columns