import numpy as np
import pandas as pd
import xarray as xr

import matplotlib.pyplot as plt

from pymc_marketing.mmm import HSGPPeriodic
from pymc_marketing.plot import plot_curve

seed = sum(map(ord, "Periodic GP"))
rng = np.random.default_rng(seed)

scale = 0.25
ls = 1
kwargs = dict(ls=ls, scale=scale, period=52, cov_func="periodic", dims="time", m=20)

n = 52 * 3
dates = pd.date_range("2023-01-01", periods=n, freq="W-MON")
X = np.arange(n)
coords = {"time": dates}

hsgp = HSGPPeriodic(demeaned_basis=False, **kwargs).register_data(X)
hsgp_demeaned = HSGPPeriodic(demeaned_basis=True, **kwargs).register_data(X)

def sample_curve(hsgp):
    return hsgp.sample_prior(coords=coords, random_seed=rng)["f"]

non_demeaned = sample_curve(hsgp).rename("False")
demeaned = sample_curve(hsgp_demeaned).rename("True")

combined = xr.merge([non_demeaned, demeaned]).to_array("demeaned")
_, axes = combined.pipe(plot_curve, "time", same_axes=True)
axes[0].set(title="Demeaned the intercepty first basis")
plt.show()