Source code for pymc_marketing.mmm.validating

#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Validating methods for MMM classes."""

from collections.abc import Callable
from warnings import warn

import pandas as pd

__all__ = [
    "ValidateChannelColumns",
    "ValidateControlColumns",
    "ValidateDateColumn",
    "ValidateTargetColumn",
    "validation_method_X",
    "validation_method_y",
]


[docs] def validation_method_y(method: Callable) -> Callable: """Tag a method as a validation method for the target column.""" if not hasattr(method, "_tags"): method._tags = {} # type: ignore method._tags["validation_y"] = True # type: ignore return method
[docs] def validation_method_X(method: Callable) -> Callable: """Tag a method as a validation method for the predictor columns.""" if not hasattr(method, "_tags"): method._tags = {} # type: ignore method._tags["validation_X"] = True # type: ignore return method
[docs] class ValidateTargetColumn: """Validate the target column."""
[docs] @validation_method_y def validate_target(self, data: pd.Series) -> None: """Validate the target column. Parameters ---------- data : pd.Series The data to validate. Raises ------ ValueError: If the target column is not valid. """ if len(data) == 0: raise ValueError("y must have at least one element")
[docs] class ValidateDateColumn: """Validate the date column.""" date_column: str
[docs] @validation_method_X def validate_date_col(self, data: pd.DataFrame) -> None: """Validate the date column. Parameters ---------- data : pd.DataFrame The data to validate. Raises ------ ValueError: If the date column is not valid. """ if self.date_column not in data.columns: raise ValueError(f"date_col {self.date_column} not in data") if not data[self.date_column].is_unique: raise ValueError(f"date_col {self.date_column} has repeated values")
[docs] class ValidateChannelColumns: """Validate the channel columns.""" channel_columns: list[str] | tuple[str]
[docs] @validation_method_X def validate_channel_columns(self, data: pd.DataFrame) -> None: """Validate the channel columns. Parameters ---------- data : pd.DataFrame The data to validate. Raises ------ ValueError: If the channel columns are not valid. """ if not isinstance(self.channel_columns, list | tuple): raise ValueError("channel_columns must be a list or tuple") if len(self.channel_columns) == 0: raise ValueError("channel_columns must not be empty") if not set(self.channel_columns).issubset(data.columns): raise ValueError(f"channel_columns {self.channel_columns} not in data") if len(set(self.channel_columns)) != len(self.channel_columns): raise ValueError( f"channel_columns {self.channel_columns} contains duplicates" ) if (data.filter(list(self.channel_columns)) < 0).any().any(): warn( f"channel_columns {self.channel_columns} contains negative values", UserWarning, stacklevel=2, )
[docs] class ValidateControlColumns: """Validate the control columns.""" control_columns: list[str] | None
[docs] @validation_method_X def validate_control_columns(self, data: pd.DataFrame) -> None: """Validate the control columns. Parameters ---------- data : pd.DataFrame The data to validate. Raises ------ ValueError: If the control columns are not valid. """ if self.control_columns is None: return None if not isinstance(self.control_columns, list | tuple): raise ValueError("control_columns must be None, a list or tuple") if len(self.control_columns) == 0: raise ValueError( "If control_columns is not None, then it must not be empty" ) if not set(self.control_columns).issubset(data.columns): raise ValueError(f"control_columns {self.control_columns} not in data") if len(set(self.control_columns)) != len(self.control_columns): raise ValueError( f"control_columns {self.control_columns} contains duplicates" )