import numpy as np
import pandas as pd

import pymc as pm

import matplotlib.pyplot as plt

from pymc_marketing.plot import plot_curve

seed = sum(map(ord, "Arbitrary curve"))
rng = np.random.default_rng(seed)

dates = pd.date_range("2024-01-01", periods=52, freq="W")

coords = {"date": dates, "product": ["A", "B"]}
with pm.Model(coords=coords) as model:
    data = pm.Normal(
        "data",
        mu=[-0.5, 0.5],
        sigma=1,
        dims=("date", "product"),
    )
    cumsum = pm.Deterministic(
        "cumsum",
        data.cumsum(axis=0),
        dims=("date", "product"),
    )
    idata = pm.sample_prior_predictive(random_seed=rng)

curve = idata.prior["cumsum"]

fig, axes = plot_curve(
    curve,
    "date",
    subplot_kwargs={"figsize": (15, 5)},
    random_seed=rng,
)
plt.show()