import matplotlib.pyplot as plt
import numpy as np
import arviz as az
from pymc_marketing.mmm.transformers import (
    tanh_saturation_baselined,
    tanh_saturation,
    TanhSaturationBaselinedParameters,
)

gain = 1
overspend_fraction = 0.7
x_baseline = 400

params = TanhSaturationBaselinedParameters(x_baseline, gain, overspend_fraction)

x = np.linspace(0, 1000)
y = tanh_saturation_baselined(x, *params).eval()

saturation, cac0 = params.debaseline()
cac0 = cac0.eval()
saturated_ref = tanh_saturation(x_baseline, saturation, cac0).eval()

plt.plot(x, y);
plt.axvline(x_baseline, linestyle="dashed", color="red", label="baseline")
plt.plot(x, x * gain, linestyle="dashed", label="gain (slope)");
plt.axhline(saturated_ref, linestyle="dashed", label="f(reference)")
plt.plot(x, x / cac0, linestyle="dotted", label="1/cac (slope)");
plt.axhline(saturation, linestyle="dotted", label="saturation")
plt.fill_between(x, saturated_ref, saturation, alpha=0.1, label="underspend fraction")
plt.fill_between(x, saturated_ref, alpha=0.1, label="overspend fraction")
plt.legend()
plt.show()