import numpy as np
import matplotlib.pyplot as plt
import arviz as az

from pymc_marketing.mmm.utils import create_new_spend_data
az.style.use("arviz-white")

spend = np.array([1, 2])
adstock_max_lag = 3
one_time = True
spend_leading_up = np.array([4, 3])
channel_spend = create_new_spend_data(spend, adstock_max_lag, one_time, spend_leading_up)

time_since_spend = np.arange(-adstock_max_lag, adstock_max_lag + 1)

ax = plt.subplot()
ax.plot(
    time_since_spend,
    channel_spend,
    "o",
    label=["Channel 1", "Channel 2"]
)
ax.legend()
ax.set(
    xticks=time_since_spend,
    yticks=np.arange(0, channel_spend.max() + 1),
    xlabel="Time since spend",
    ylabel="Spend",
    title="One time spend with spends leading up",
)
plt.show()