import numpy as np
import pandas as pd
import pymc as pm

import matplotlib.pyplot as plt

from pymc_marketing.mmm.events import EventEffect, GaussianBasis
from pymc_marketing.plot import plot_curve
from pymc_marketing.prior import Prior

seed = sum(map(ord, "Events"))
rng = np.random.default_rng(seed)

df_events = pd.DataFrame(
    {
        "event": ["single day", "multi day"],
        "start_date": pd.to_datetime(["2025-01-01", "2025-01-20"]),
        "end_date": pd.to_datetime(["2025-01-02", "2025-01-25"]),
    }
)

def difference_in_days(model_dates, event_dates):
    if hasattr(model_dates, "to_numpy"):
        model_dates = model_dates.to_numpy()
    if hasattr(event_dates, "to_numpy"):
        event_dates = event_dates.to_numpy()

    one_day = np.timedelta64(1, "D")
    return (model_dates[:, None] - event_dates) / one_day


def create_basis_matrix(df_events: pd.DataFrame, model_dates: np.ndarray):
    start_dates = df_events["start_date"]
    end_dates = df_events["end_date"]

    start_ref = difference_in_days(model_dates, start_dates)
    end_ref = difference_in_days(model_dates, end_dates)

    return np.where(
        (start_ref >= 0) & (end_ref <= 0),
        0,
        np.where(np.abs(start_ref) < np.abs(end_ref), start_ref, end_ref),
    )


gaussian = GaussianBasis(
    priors={
        "sigma": Prior("Gamma", mu=7, sigma=1, dims="event"),
    }
)
effect_size = Prior("Normal", mu=1, sigma=1, dims="event")
effect = EventEffect(basis=gaussian, effect_size=effect_size, dims=("event",))

dates = pd.date_range("2024-12-01", periods=3 * 31, freq="D")

X = create_basis_matrix(df_events, model_dates=dates)

coords = {"date": dates, "event": df_events["event"].to_numpy()}
with pm.Model(coords=coords) as model:
    pm.Deterministic("effect", effect.apply(X), dims=("date", "event"))

    idata = pm.sample_prior_predictive(random_seed=rng)

fig, axes = idata.prior.effect.pipe(
    plot_curve,
    "date",
    random_seed=rng,
    subplot_kwargs={"ncols": 1},
)
fig.suptitle("Gaussian Event Effect")
plt.show()