import numpy as np
import pandas as pd
import xarray as xr

import matplotlib.pyplot as plt

from pymc_marketing.mmm import HSGP
from pymc_marketing.plot import plot_curve

seed = sum(map(ord, "Out of the box GP"))
rng = np.random.default_rng(seed)

n = 52
X = np.arange(n)

kwargs = dict(X=X, ls=25, eta=1, dims="time", m=200, L=150, drop_first=False)

hsgp = HSGP(demeaned_basis=False, **kwargs)
hsgp_demeaned = HSGP(demeaned_basis=True, **kwargs)

dates = pd.date_range("2022-01-01", periods=n, freq="W-MON")
coords = {"time": dates}

def sample_curve(hsgp):
    return hsgp.sample_prior(coords=coords, random_seed=rng)["f"]

non_demeaned = sample_curve(hsgp).rename("False")
demeaned = sample_curve(hsgp_demeaned).rename("True")

combined = xr.merge([non_demeaned, demeaned]).to_array("demeaned")
_, axes = combined.pipe(plot_curve, "time", same_axes=True)
axes[0].set(title="Demeaned the intercepty first basis")
plt.show()