Source code for pymc_marketing.customer_choice.synthetic_data
# Copyright 2022 - 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data generation functions for consumer choice models."""
import numpy as np
import pandas as pd
[docs]
def generate_saturated_data(
total_sales_mu: int,
total_sales_sigma: float,
treatment_time: int,
n_observations: int,
market_shares_before,
market_shares_after,
market_share_labels,
random_seed: int | np.random.Generator | None = None,
):
"""Generate synthetic data for the MVITS model, assuming market is saturated.
This function generates synthetic data for the MVITS model, assuming that the market is
saturated. This makes the assumption that the total sales are normally distributed around
some average level of sales, and that the market shares are constant over time.
Parameters
----------
total_sales_mu: int
The average level of sales in the market.
total_sales_sigma: float
The standard deviation of sales in the market.
treatment_time: int
The time at which the new model is introduced.
n_observations: int
The number of observations to generate.
market_shares_before: list[float]
The market shares before the introduction of the new model.
market_shares_after: list[float]
The market shares after the introduction of the new model.
market_share_labels: list[str]
The labels for the market shares.
random_seed: np.random.Generator | int, optional
The random number generator to use.
Returns
-------
data: pd.DataFrame
The synthetic data generated.
"""
rng: np.random.Generator = (
random_seed
if isinstance(random_seed, np.random.Generator)
else np.random.default_rng(random_seed)
)
rates = np.array(
treatment_time * market_shares_before
+ (n_observations - treatment_time) * market_shares_after
)
# Generate total demand (sales) as normally distributed around some average level of sales
total = (
rng.normal(loc=total_sales_mu, scale=total_sales_sigma, size=n_observations)
).astype(int)
# Ensure total sales are never negative
total[total < 0] = 0
# Generate sales counts
counts = rng.multinomial(total, rates)
# Convert to DataFrame
data = pd.DataFrame(counts)
data.columns = market_share_labels
data.columns.name = "product"
data.index.name = "day"
data["pre"] = data.index < treatment_time
return data
[docs]
def generate_unsaturated_data(
total_sales_before: list[int],
total_sales_after: list[int],
total_sales_sigma: float,
treatment_time: int,
n_observations: int,
market_shares_before: list[list[float]],
market_shares_after: list[list[float]],
market_share_labels: list[str],
random_seed: np.random.Generator | int | None = None,
):
"""Generate synthetic data for the MVITS model.
Notably, we can define different total sales levels before and after the
introduction of the new model.
This function generates synthetic data for the MVITS model, assuming that the market is
unsaturated meaning that there are new sales to be made.
This makes the assumption that the total sales are normally distributed around
some average level of sales, and that the market shares are constant over time.
Parameters
----------
total_sales_mu: int
The average level of sales in the market.
total_sales_sigma: float
The standard deviation of sales in the market.
treatment_time: int
The time at which the new model is introduced.
n_observations: int
The number of observations to generate.
market_shares_before: list[float]
The market shares before the introduction of the new model.
market_shares_after: list[float]
The market shares after the introduction of the new model.
market_share_labels: list[str]
The labels for the market shares.
random_seed: np.random.Generator | int, optional
The random number generator to use.
Returns
-------
data: pd.DataFrame
The synthetic data generated.
"""
rng: np.random.Generator = (
random_seed
if isinstance(random_seed, np.random.Generator)
else np.random.default_rng(random_seed)
)
rates = np.array(
treatment_time * market_shares_before
+ (n_observations - treatment_time) * market_shares_after
)
total_sales_mu = np.array(
treatment_time * total_sales_before
+ (n_observations - treatment_time) * total_sales_after
)
total = (
rng.normal(loc=total_sales_mu, scale=total_sales_sigma, size=n_observations)
).astype(int)
# Ensure total sales are never negative
total[total < 0] = 0
# Generate sales counts
counts = rng.multinomial(total, rates)
# Convert to DataFrame
data = pd.DataFrame(counts)
data.columns = pd.Index(market_share_labels)
data.columns.name = "product"
data.index.name = "day"
data["pre"] = data.index < treatment_time
return data