import matplotlib.pyplot as plt
import numpy as np
import arviz as az
from pymc_marketing.mmm.transformers import inverse_scaled_logistic_saturation
plt.style.use('arviz-darkgrid')
lam = np.array([0.25, 0.5, 1, 2, 4])
x = np.linspace(0, 5, 100)
ax = plt.subplot(111)
for l in lam:
    y = inverse_scaled_logistic_saturation(x, lam=l).eval()
    plt.plot(x, y, label=f'lam = {l}')
plt.xlabel('spend', fontsize=12)
plt.ylabel('f(spend)', fontsize=12)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()