import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from pymc_marketing.mmm import HSGPPeriodic
from pymc_marketing.prior import Prior

seed = sum(map(ord, "Periodic GP"))
rng = np.random.default_rng(seed)

n = 52 * 3
dates = pd.date_range("2023-01-01", periods=n, freq="W-MON")
X = np.arange(n)
coords = {
    "time": dates,
}
scale = Prior("Gamma", mu=0.25, sigma=0.1)
ls = Prior("InverseGamma", alpha=2, beta=1)

hsgp = HSGPPeriodic(
    scale=scale,
    m=20,
    cov_func="periodic",
    ls=ls,
    period=52,
    dims="time",
    transform="exp",
)
hsgp.register_data(X)

prior = hsgp.sample_prior(coords=coords, random_seed=rng)
curve = prior["f"]
fig, axes = hsgp.plot_curve(
    curve,
    n_samples=3,
    random_seed=rng,
)
ax = axes[0]
ax.set(xlabel="Date", ylabel="f", title="HSGP with period of 52 weeks")
plt.show()