import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

import pymc as pm

from pymc_marketing.mmm.linear_trend import LinearTrend
from pymc_marketing.mmm.additive_effect import LinearTrendEffect

seed = sum(map(ord, "LinearTrend out of sample"))
rng = np.random.default_rng(seed)

class MockMMM:
    pass

dates = pd.date_range("2025-01-01", periods=52, freq="W")
coords = {"date": dates}
model = pm.Model(coords=coords)

mock_mmm = MockMMM()
mock_mmm.dims = ()
mock_mmm.model = model

effect = LinearTrendEffect(
    trend=LinearTrend(n_changepoints=8),
    prefix="trend",
)

with mock_mmm.model:
    effect.create_data(mock_mmm)
    pm.Deterministic(
        "effect",
        effect.create_effect(mock_mmm),
        dims="date",
    )

    idata = pm.sample_prior_predictive(random_seed=rng)

idata["posterior"] = idata.prior

n_new = 10 + 1
new_dates = pd.date_range(
    dates.max(),
    periods=n_new,
    freq="W",
)

with mock_mmm.model:
    mock_mmm.model.set_dim("date", n_new, new_dates)

    effect.set_data(mock_mmm, mock_mmm.model, None)

    pm.sample_posterior_predictive(
        idata,
        var_names=["effect"],
        random_seed=rng,
        extend_inferencedata=True,
    )

draw = rng.choice(range(idata.posterior.sizes["draw"]))
sel = dict(chain=0, draw=draw)

before = idata.posterior.effect.sel(sel).to_series()
after = idata.posterior_predictive.effect.sel(sel).to_series()

ax = before.plot(color="C0")
after.plot(color="C0", linestyle="dashed", ax=ax)
plt.show()