import numpy as np
import matplotlib.pyplot as plt
from pymc_marketing.mmm.transformers import hill_function
x = np.linspace(0, 10, 100)
# Varying slope
slopes = [0.3, 0.7, 1.2]
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i, slope in enumerate(slopes):
    plt.subplot(1, 3, i+1)
    y = hill_function(x, slope, 2).eval()
    plt.plot(x, y)
    plt.xlabel('x')
    plt.title(f'Slope = {slope}')
plt.subplot(1,3,1)
plt.ylabel('Hill Saturation Sigmoid')
plt.tight_layout()
plt.show()
# Varying kappa
kappas = [1, 5, 10]
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i, kappa in enumerate(kappas):
    plt.subplot(1, 3, i+1)
    y = hill_function(x, 1, kappa).eval()
    plt.plot(x, y)
    plt.xlabel('x')
    plt.title(f'Kappa = {kappa}')
plt.subplot(1,3,1)
plt.ylabel('Hill Saturation Sigmoid')
plt.tight_layout()
plt.show()