import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm

from pymc_marketing.bass.model import create_bass_model
from pymc_marketing.plot import plot_curve
from pymc_marketing.prior import Prior

# Create time points - 3 years of monthly data
n_dates = 12 * 3
dates = pd.date_range(start="2020-01-01", periods=n_dates, freq="MS")
t = np.arange(n_dates)

# Define coordinates for multiple products
coords = {"T": t, "product": ["A", "B", "C"]}

# Define priors
priors = {
    "m": Prior("DiracDelta", c=10_000),  # Market potential
    "p": Prior("Beta", alpha=13.85, beta=692.43, dims="product"),  # Innovation coefficient
    "q": Prior("Beta", alpha=36.2, beta=54.4),  # Imitation coefficient
    "likelihood": Prior("Poisson", dims=("T", "product")),
}

# Create the Bass model
model = create_bass_model(t, observed=None, priors=priors, coords=coords)

# Sample from the prior predictive distribution
with model:
    idata = pm.sample_prior_predictive()

# Plot the adoption curves
fig, axes = plt.subplots(1, 3, figsize=(10, 6))
idata.prior["y"].pipe(plot_curve, "T", axes=axes)
plt.suptitle("Bass Model Prior Predictive Adoption Curves")
plt.tight_layout()
plt.show()