import numpy as np
import matplotlib.pyplot as plt
from pymc_marketing.mmm.transformers import hill_saturation_sigmoid
x = np.linspace(0, 10, 100)
# Varying sigma
sigmas = [0.5, 1, 1.5]
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i, sigma in enumerate(sigmas):
    plt.subplot(1, 3, i+1)
    y = hill_saturation_sigmoid(x, sigma, 2, 5).eval()
    plt.plot(x, y)
    plt.xlabel('x')
    plt.title(f'Sigma = {sigma}')
plt.subplot(1,3,1)
plt.ylabel('Hill Saturation Sigmoid')
plt.tight_layout()
plt.show()
# Varying beta
betas = [1, 2, 3]
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i, beta in enumerate(betas):
    plt.subplot(1, 3, i+1)
    y = hill_saturation_sigmoid(x, 1, beta, 5).eval()
    plt.plot(x, y)
    plt.xlabel('x')
    plt.title(f'Beta = {beta}')
plt.subplot(1,3,1)
plt.ylabel('Hill Saturation Sigmoid')
plt.tight_layout()
plt.show()
# Varying lam
lams = [3, 5, 7]
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i, lam in enumerate(lams):
    plt.subplot(1, 3, i+1)
    y = hill_saturation_sigmoid(x, 1, 2, lam).eval()
    plt.plot(x, y)
    plt.xlabel('x')
    plt.title(f'Lambda = {lam}')
plt.subplot(1,3,1)
plt.ylabel('Hill Saturation Sigmoid')
plt.tight_layout()
plt.show()