Source code for pymc_marketing.deserialize

#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Deserialize into a PyMC-Marketing object.

This is a two step process:

1. Determine if the data is of the correct type.
2. Deserialize the data into a python object for PyMC-Marketing.

This is used to deserialize JSON data into PyMC-Marketing objects
throughout the package.

Examples
--------
Make use of the already registered PyMC-Marketing deserializers:

.. code-block:: python

    from pymc_marketing.deserialize import deserialize

    prior_class_data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
    prior = deserialize(prior_class_data)
    # Prior("Normal", mu=0, sigma=1)

Register custom class deserialization:

.. code-block:: python

    from pymc_marketing.deserialize import register_deserialization


    class MyClass:
        def __init__(self, value: int):
            self.value = value

        def to_dict(self) -> dict:
            # Example of what the to_dict method might look like.
            return {"value": self.value}


    register_deserialization(
        is_type=lambda data: data.keys() == {"value"}
        and isinstance(data["value"], int),
        deserialize=lambda data: MyClass(value=data["value"]),
    )

Deserialize data into that custom class:

.. code-block:: python

    from pymc_marketing.deserialize import deserialize

    data = {"value": 42}
    obj = deserialize(data)
    assert isinstance(obj, MyClass)


"""

from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

IsType = Callable[[Any], bool]
Deserialize = Callable[[Any], Any]


[docs] @dataclass class Deserializer: """Object to store information required for deserialization. All deserializers should be stored via the :func:`register_deserialization` function instead of creating this object directly. Attributes ---------- is_type : IsType Function to determine if the data is of the correct type. deserialize : Deserialize Function to deserialize the data. Examples -------- .. code-block:: python from typing import Any class MyClass: def __init__(self, value: int): self.value = value from pymc_marketing.deserialize import Deserializer def is_type(data: Any) -> bool: return data.keys() == {"value"} and isinstance(data["value"], int) def deserialize(data: dict) -> MyClass: return MyClass(value=data["value"]) deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize) """ is_type: IsType deserialize: Deserialize
DESERIALIZERS: list[Deserializer] = [] class DeserializableError(Exception): """Error raised when data cannot be deserialized.""" def __init__(self, data: Any): self.data = data super().__init__( f"Couldn't deserialize {data}. Use register_deserialization to add a deserialization mapping." )
[docs] def deserialize(data: Any) -> Any: """Deserialize a dictionary into a Python object. Use the :func:`register_deserialization` function to add custom deserializations. Deserialization is a two step process due to the dynamic nature of the data: 1. Determine if the data is of the correct type. 2. Deserialize the data into a Python object. Each registered deserialization is checked in order until one is found that can deserialize the data. If no deserialization is found, a :class:`DeserializableError` is raised. A :class:`DeserializableError` is raised when the data fails to be deserialized by any of the registered deserializers. Parameters ---------- data : Any The data to deserialize. Returns ------- Any The deserialized object. Raises ------ DeserializableError Raised when the data doesn't match any registered deserializations or fails to be deserialized. Examples -------- Deserialize a :class:`pymc_marketing.prior.Prior` object: .. code-block:: python from pymc_marketing.deserialize import deserialize data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}} prior = deserialize(data) # Prior("Normal", mu=0, sigma=1) """ for mapping in DESERIALIZERS: try: is_type = mapping.is_type(data) except Exception: is_type = False if not is_type: continue try: return mapping.deserialize(data) except Exception as e: raise DeserializableError(data) from e else: raise DeserializableError(data)
[docs] def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None: """Register an arbitrary deserialization. Use the :func:`deserialize` function to then deserialize data using all registered deserialize functions. Classes from PyMC-Marketing have their deserialization mappings registered automatically. However, custom classes will need to be registered manually using this function before they can be deserialized. Parameters ---------- is_type : Callable[[Any], bool] Function to determine if the data is of the correct type. deserialize : Callable[[dict], Any] Function to deserialize the data of that type. Examples -------- Register a custom class deserialization: .. code-block:: python from pymc_marketing.deserialize import register_deserialization class MyClass: def __init__(self, value: int): self.value = value def to_dict(self) -> dict: # Example of what the to_dict method might look like. return {"value": self.value} register_deserialization( is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int), deserialize=lambda data: MyClass(value=data["value"]), ) Use that custom class deserialization: .. code-block:: python from pymc_marketing.deserialize import deserialize data = {"value": 42} obj = deserialize(data) assert isinstance(obj, MyClass) """ mapping = Deserializer(is_type=is_type, deserialize=deserialize) DESERIALIZERS.append(mapping)