import numpy as np
import pymc as pm

import matplotlib.pyplot as plt

from pymc_marketing.mmm import HSGP

seed = sum(map(ord, "Higher dimensional HSGP"))
rng = np.random.default_rng(seed)

n = 52
X = np.arange(n)

hsgp = HSGP.parameterize_from_data(
    X=X,
    dims=("time", "channel", "product"),
)

coords = {
    "time": range(n),
    "channel": ["A", "B"],
    "product": ["X", "Y", "Z"],
}
prior = hsgp.sample_prior(coords=coords, random_seed=rng)
curve = prior["f"]
fig, _ = hsgp.plot_curve(
    curve,
    random_seed=rng,
    subplot_kwargs={"figsize": (12, 8), "ncols": 3},
)
fig.suptitle("Higher dimensional HSGP prior")
plt.show()