import matplotlib.pyplot as plt
import numpy as np
import arviz as az
from pymc_marketing.mmm import YearlyFourier
from pymc_marketing.prior import Prior

plt.style.use('arviz-darkgrid')

prior = Prior(
    "Normal",
    mu=[0, 0, -1, 0],
    sigma=Prior("Gamma", mu=0.10, sigma=0.1, dims="fourier"),
    dims=("hierarchy", "fourier"),
)
yearly = YearlyFourier(n_order=2, prior=prior)
coords = {"hierarchy": ["A", "B"]}
prior = yearly.sample_prior(coords=coords)
curve = yearly.sample_curve(prior)
fig, _ = yearly.plot_curve(curve, subplot_kwargs={"ncols": 1})
fig.suptitle("Yearly Fourier Seasonality")
plt.show()