import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import HillSaturation

rng = np.random.default_rng(0)

adstock = HillSaturation()
prior = adstock.sample_prior(random_seed=rng)
curve = adstock.sample_curve(prior)
adstock.plot_curve(curve, random_seed=rng)
plt.show()