import numpy as np
import pandas as pd

import pymc as pm

import matplotlib.pyplot as plt

from pymc_marketing.mmm import HSGP
from pymc_marketing.prior import Prior

seed = sum(map(ord, "New data predictions"))
rng = np.random.default_rng(seed)

eta = Prior("Exponential", lam=1)
ls = Prior("InverseGamma", alpha=2, beta=1)
hsgp = HSGP(
    eta=eta,
    ls=ls,
    m=20,
    L=150,
    dims=("time", "channel"),
)

n = 52
X = np.arange(n)

dates = pd.date_range("2022-01-01", periods=n, freq="W-MON")
coords = {"time": dates, "channel": ["A", "B"]}
with pm.Model(coords=coords) as model:
    data = pm.Data("data", X, dims="time")
    hsgp.register_data(data).create_variable("f")
    idata = pm.sample_prior_predictive(random_seed=rng)

prior = idata.prior

n_new = 10
X_new = np.arange(n, n + n_new)
new_dates = pd.date_range("2023-01-01", periods=n_new, freq="W-MON")

with model:
    pm.set_data(
        new_data={
            "data": X_new,
        },
        coords={"time": new_dates},
    )
    post = pm.sample_posterior_predictive(
        prior,
        var_names=["f"],
        random_seed=rng,
    )

chain, draw = 0, 50
colors = ["C0", "C1"]

def get_sample(curve):
    return curve.loc[chain, draw].to_series().unstack()

ax = prior["f"].pipe(get_sample).plot(color=colors)
post.posterior_predictive["f"].pipe(get_sample).plot(
    ax=ax, color=colors, linestyle="--", legend=False
)
ax.set(xlabel="time", ylabel="f", title="New data predictions")
plt.show()