import numpy as np
import pandas as pd

import pymc as pm

import matplotlib.pyplot as plt

from pymc_marketing.mmm import HSGPPeriodic
from pymc_marketing.prior import Prior

seed = sum(map(ord, "Higher dimensional HSGP with periodic data"))
rng = np.random.default_rng(seed)

n = 52 * 3
dates = pd.date_range("2023-01-01", periods=n, freq="W-MON")
X = np.arange(n)

scale = Prior("HalfNormal", sigma=1)
ls = Prior("InverseGamma", alpha=2, beta=1)

hsgp = HSGPPeriodic(
    X=X,
    scale=scale,
    ls=ls,
    m=20,
    cov_func="periodic",
    period=52,
    dims=("time", "channel", "product"),
)

coords = {
    "time": dates,
    "channel": ["A", "B"],
    "product": ["X", "Y", "Z"],
}
prior = hsgp.sample_prior(coords=coords, random_seed=rng)
curve = prior["f"]
fig, axes = hsgp.plot_curve(
    curve,
    n_samples=3,
    random_seed=rng,
    subplot_kwargs={"figsize": (12, 8), "ncols": 3},
)
plt.show()