import numpy as np
import pandas as pd

import pymc as pm

import matplotlib.pyplot as plt

from pymc_marketing.mmm import SoftPlusHSGP
from pymc_marketing.model_graph import deterministics_to_flat
from pymc_marketing.prior import Prior

seed = sum(map(ord, "New data predictions with SoftPlusHSGP"))
rng = np.random.default_rng(seed)

eta = Prior("Exponential", lam=1)
ls = Prior("InverseGamma", alpha=2, beta=1)
hsgp = SoftPlusHSGP(
    eta=eta,
    ls=ls,
    m=20,
    L=150,
    dims=("time", "channel"),
)

n = 52
X = np.arange(n)

channels = ["A", "B", "C"]
dates = pd.date_range("2022-01-01", periods=n, freq="W-MON")
coords = {"time": dates, "channel": channels}
with pm.Model(coords=coords) as model:
    data = pm.Data("data", X, dims="time")
    hsgp.register_data(data).create_variable("f")
    idata = pm.sample_prior_predictive(random_seed=rng)

prior = idata.prior

n_new = 10
X_new = np.arange(n - 1 , n + n_new)
last_date = dates[-1]
new_dates = pd.date_range(last_date, periods=n_new + 1, freq="W-MON")

with deterministics_to_flat(model, hsgp.deterministics_to_replace("f")):
    pm.set_data(
        new_data={
            "data": X_new,
        },
        coords={"time": new_dates},
    )
    post = pm.sample_posterior_predictive(
        prior,
        var_names=["f"],
        random_seed=rng,
    )

chain, draw = 0, rng.choice(prior.sizes["draw"])
colors = [f"C{i}" for i in range(len(channels))]

def get_sample(curve):
    return curve.loc[chain, draw].to_series().unstack()

ax = prior["f"].pipe(get_sample).plot(color=colors)
post.posterior_predictive["f"].pipe(get_sample).plot(
    ax=ax, color=colors, linestyle="--", legend=False
)
ax.set(xlabel="time", ylabel="f", title="New data predictions")
plt.show()