import matplotlib.pyplot as plt
import numpy as np
import arviz as az
from pymc_marketing.mmm.transformers import geometric_adstock
plt.style.use('arviz-darkgrid')
l_max = 12
params = [
    (0.01, False),
    (0.5, False),
    (0.9, False),
    (0.5, True),
    (0.9, True),
]
spend = np.zeros(15)
spend[0] = 1
ax = plt.subplot(111)
x = np.arange(len(spend))
for a, normalize in params:
    y = geometric_adstock(spend, alpha=a, l_max=l_max, normalize=normalize).eval()
    plt.plot(x, y, label=f'alpha = {a}\nnormalize = {normalize}')
plt.xlabel('time since spend', fontsize=12)
plt.title(f'Geometric Adstock with l_max = {l_max}', fontsize=14)
plt.ylabel('f(time since spend)', fontsize=12)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.65, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()