turbo_broccoli.custom.numpy

numpy (de)serialization utilities.

Todo: Handle numpy's generic type (which supersedes the number type).

  1"""
  2numpy (de)serialization utilities.
  3
  4Todo:
  5    Handle numpy's `generic` type (which supersedes the `number` type).
  6"""
  7
  8from typing import Any, Callable, Tuple
  9
 10import joblib
 11import numpy as np
 12from safetensors import numpy as st
 13
 14from turbo_broccoli.context import Context
 15from turbo_broccoli.exceptions import DeserializationError, TypeNotSupported
 16
 17
 18def _json_to_dtype(dct: dict, ctx: Context) -> np.dtype:
 19    decoders = {
 20        2: _json_to_dtype_v2,
 21    }
 22    return decoders[dct["__version__"]](dct, ctx)
 23
 24
 25def _json_to_dtype_v2(dct: dict, ctx: Context) -> np.dtype:
 26    return np.lib.format.descr_to_dtype(dct["dtype"])
 27
 28
 29def _json_to_ndarray(dct: dict, ctx: Context) -> np.ndarray:
 30    ctx.raise_if_nodecode("bytes")
 31    decoders = {
 32        5: _json_to_ndarray_v5,
 33    }
 34    return decoders[dct["__version__"]](dct, ctx)
 35
 36
 37def _json_to_ndarray_v5(dct: dict, ctx: Context) -> np.ndarray:
 38    return st.load(dct["data"])["data"]
 39
 40
 41def _json_to_number(dct: dict, ctx: Context) -> np.number:
 42    decoders = {
 43        3: _json_to_number_v3,
 44    }
 45    return decoders[dct["__version__"]](dct, ctx)
 46
 47
 48def _json_to_number_v3(dct: dict, ctx: Context) -> np.number:
 49    return np.frombuffer(dct["value"], dtype=dct["dtype"])[0]
 50
 51
 52def _json_to_random_state(dct: dict, ctx: Context) -> np.number:
 53    decoders = {
 54        3: _json_to_random_state_v3,
 55    }
 56    return decoders[dct["__version__"]](dct, ctx)
 57
 58
 59def _json_to_random_state_v3(dct: dict, ctx: Context) -> np.number:
 60    return joblib.load(ctx.id_to_artifact_path(dct["data"]))
 61
 62
 63def _dtype_to_json(d: np.dtype, ctx: Context) -> dict:
 64    return {
 65        "__type__": "numpy.dtype",
 66        "__version__": 2,
 67        "dtype": np.lib.format.dtype_to_descr(d),
 68    }
 69
 70
 71def _ndarray_to_json(arr: np.ndarray, ctx: Context) -> dict:
 72    return {
 73        "__type__": "numpy.ndarray",
 74        "__version__": 5,
 75        "data": st.save({"data": arr}),
 76    }
 77
 78
 79def _number_to_json(num: np.number, ctx: Context) -> dict:
 80    return {
 81        "__type__": "numpy.number",
 82        "__version__": 3,
 83        "value": bytes(np.array(num).data),
 84        "dtype": num.dtype,
 85    }
 86
 87
 88def _random_state_to_json(obj: np.random.RandomState, ctx: Context) -> dict:
 89    path, name = ctx.new_artifact_path()
 90    with path.open(mode="wb") as fp:
 91        joblib.dump(obj, fp)
 92    return {
 93        "__type__": "numpy.random_state",
 94        "__version__": 3,
 95        "data": name,
 96    }
 97
 98
 99# pylint: disable=missing-function-docstring
100def from_json(dct: dict, ctx: Context) -> Any:
101    """
102    Deserializes a dict into a numpy object. See `to_json` for the
103    specification `dct` is expected to follow.
104    """
105    decoders = {
106        "numpy.ndarray": _json_to_ndarray,
107        "numpy.number": _json_to_number,
108        "numpy.dtype": _json_to_dtype,
109        "numpy.random_state": _json_to_random_state,
110    }
111    try:
112        type_name = dct["__type__"]
113        return decoders[type_name](dct, ctx)
114    except KeyError as exc:
115        raise DeserializationError() from exc
116
117
118def to_json(obj: Any, ctx: Context) -> dict:
119    """
120    Serializes a `numpy` object into JSON by cases. See the README for the
121    precise list of supported types. The return dict has the following
122    structure:
123
124    - `numpy.ndarray`: An array is processed differently depending on its size
125      and on the `TB_MAX_NBYTES` environment variable. If the array is
126      small, i.e. `arr.nbytes <= TB_MAX_NBYTES`, then it is directly
127      stored in the resulting JSON document as
128
129        ```py
130        {
131            "__type__": "numpy.ndarray",
132            "__version__": 5,
133            "data": {
134                "__type__": "bytes",
135                ...
136            }
137        }
138        ```
139
140      see `turbo_broccoli.custom.bytes.to_json`.
141
142    - `numpy.number`:
143
144        ```py
145        {
146            "__type__": "numpy.number",
147            "__version__": 3,
148            "value": <float>,
149            "dtype": {...},
150        }
151        ```
152
153        where the `dtype` document follows the specification below.
154
155    - `numpy.dtype`:
156
157        ```py
158        {
159            "__type__": "numpy.dtype",
160            "__version__": 2,
161            "dtype": <dtype_to_descr string>,
162        }
163        ```
164
165    - `numpy.random.RandomState`:
166
167        ```py
168        {
169            "__type__": "numpy.random_state",
170            "__version__": 3,
171            "data": <uuid4>,
172        }
173        ```
174
175    """
176    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
177        (np.ndarray, _ndarray_to_json),
178        (np.number, _number_to_json),
179        (np.dtype, _dtype_to_json),
180        (np.random.RandomState, _random_state_to_json),
181    ]
182    for t, f in encoders:
183        if isinstance(obj, t):
184            return f(obj, ctx)
185    raise TypeNotSupported()
def from_json(dct: dict, ctx: turbo_broccoli.context.Context) -> Any:
101def from_json(dct: dict, ctx: Context) -> Any:
102    """
103    Deserializes a dict into a numpy object. See `to_json` for the
104    specification `dct` is expected to follow.
105    """
106    decoders = {
107        "numpy.ndarray": _json_to_ndarray,
108        "numpy.number": _json_to_number,
109        "numpy.dtype": _json_to_dtype,
110        "numpy.random_state": _json_to_random_state,
111    }
112    try:
113        type_name = dct["__type__"]
114        return decoders[type_name](dct, ctx)
115    except KeyError as exc:
116        raise DeserializationError() from exc

Deserializes a dict into a numpy object. See to_json for the specification dct is expected to follow.

def to_json(obj: Any, ctx: turbo_broccoli.context.Context) -> dict:
119def to_json(obj: Any, ctx: Context) -> dict:
120    """
121    Serializes a `numpy` object into JSON by cases. See the README for the
122    precise list of supported types. The return dict has the following
123    structure:
124
125    - `numpy.ndarray`: An array is processed differently depending on its size
126      and on the `TB_MAX_NBYTES` environment variable. If the array is
127      small, i.e. `arr.nbytes <= TB_MAX_NBYTES`, then it is directly
128      stored in the resulting JSON document as
129
130        ```py
131        {
132            "__type__": "numpy.ndarray",
133            "__version__": 5,
134            "data": {
135                "__type__": "bytes",
136                ...
137            }
138        }
139        ```
140
141      see `turbo_broccoli.custom.bytes.to_json`.
142
143    - `numpy.number`:
144
145        ```py
146        {
147            "__type__": "numpy.number",
148            "__version__": 3,
149            "value": <float>,
150            "dtype": {...},
151        }
152        ```
153
154        where the `dtype` document follows the specification below.
155
156    - `numpy.dtype`:
157
158        ```py
159        {
160            "__type__": "numpy.dtype",
161            "__version__": 2,
162            "dtype": <dtype_to_descr string>,
163        }
164        ```
165
166    - `numpy.random.RandomState`:
167
168        ```py
169        {
170            "__type__": "numpy.random_state",
171            "__version__": 3,
172            "data": <uuid4>,
173        }
174        ```
175
176    """
177    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
178        (np.ndarray, _ndarray_to_json),
179        (np.number, _number_to_json),
180        (np.dtype, _dtype_to_json),
181        (np.random.RandomState, _random_state_to_json),
182    ]
183    for t, f in encoders:
184        if isinstance(obj, t):
185            return f(obj, ctx)
186    raise TypeNotSupported()

Serializes a numpy object into JSON by cases. See the README for the precise list of supported types. The return dict has the following structure:

  • numpy.ndarray: An array is processed differently depending on its size and on the TB_MAX_NBYTES environment variable. If the array is small, i.e. arr.nbytes <= TB_MAX_NBYTES, then it is directly stored in the resulting JSON document as

    {
        "__type__": "numpy.ndarray",
        "__version__": 5,
        "data": {
            "__type__": "bytes",
            ...
        }
    }
    

    see turbo_broccoli.custom.bytes.to_json.

  • numpy.number:

    {
        "__type__": "numpy.number",
        "__version__": 3,
        "value": <float>,
        "dtype": {...},
    }
    

    where the dtype document follows the specification below.

  • numpy.dtype:

    {
        "__type__": "numpy.dtype",
        "__version__": 2,
        "dtype": <dtype_to_descr string>,
    }
    
  • numpy.random.RandomState:

    {
        "__type__": "numpy.random_state",
        "__version__": 3,
        "data": <uuid4>,
    }