import arviz as az
import matplotlib.pyplot as plt
import numpy as np

from pymc_marketing.mmm import YearlyFourier
from pymc_marketing.prior import Prior

az.style.use("arviz-white")

seed = sum(map(ord, "Yearly"))
rng = np.random.default_rng(seed)

mu = np.array([0, 0, -1, 0])
b = 0.15
dist = Prior("Laplace", mu=mu, b=b, dims="fourier")
yearly = YearlyFourier(n_order=2, prior=dist)
prior = yearly.sample_prior(random_seed=rng)
curve = yearly.sample_curve(prior)

_, axes = yearly.plot_curve(curve)
axes[0].set(title="Yearly Fourier Seasonality")
plt.show()