turbo_broccoli.custom.tensorflow

Tensorflow (de)serialization utilities.

  1"""Tensorflow (de)serialization utilities."""
  2
  3from typing import Any, Callable, Tuple
  4
  5import tensorflow as tf
  6from safetensors import tensorflow as st
  7
  8from turbo_broccoli.context import Context
  9from turbo_broccoli.exceptions import DeserializationError, TypeNotSupported
 10
 11
 12def _json_to_sparse_tensor(dct: dict, ctx: Context) -> tf.Tensor:
 13    decoders = {
 14        2: _json_to_sparse_tensor_v2,
 15    }
 16    return decoders[dct["__version__"]](dct, ctx)
 17
 18
 19def _json_to_sparse_tensor_v2(dct: dict, ctx: Context) -> tf.Tensor:
 20    return tf.SparseTensor(
 21        dense_shape=dct["shape"],
 22        indices=dct["indices"],
 23        values=dct["values"],
 24    )
 25
 26
 27def _json_to_tensor(dct: dict, ctx: Context) -> tf.Tensor:
 28    ctx.raise_if_nodecode("bytes")
 29    decoders = {
 30        4: _json_to_tensor_v4,
 31    }
 32    return decoders[dct["__version__"]](dct, ctx)
 33
 34
 35def _json_to_tensor_v4(dct: dict, ctx: Context) -> tf.Tensor:
 36    return st.load(dct["data"])["data"]
 37
 38
 39def _json_to_variable(dct: dict, ctx: Context) -> tf.Variable:
 40    decoders = {
 41        3: _json_to_variable_v3,
 42    }
 43    return decoders[dct["__version__"]](dct, ctx)
 44
 45
 46def _json_to_variable_v3(dct: dict, ctx: Context) -> tf.Variable:
 47    return tf.Variable(
 48        initial_value=dct["value"],
 49        name=dct["name"],
 50        trainable=dct["trainable"],
 51    )
 52
 53
 54def _ragged_tensor_to_json(obj: tf.Tensor, ctx: Context) -> dict:
 55    raise NotImplementedError(
 56        "Serialization of ragged tensors is not supported"
 57    )
 58
 59
 60def _sparse_tensor_to_json(obj: tf.SparseTensor, ctx: Context) -> dict:
 61    return {
 62        "__type__": "tensorflow.sparse_tensor",
 63        "__version__": 2,
 64        "indices": obj.indices,
 65        "shape": list(obj.dense_shape),
 66        "values": obj.values,
 67    }
 68
 69
 70def _tensor_to_json(obj: tf.Tensor, ctx: Context) -> dict:
 71    return {
 72        "__type__": "tensorflow.tensor",
 73        "__version__": 4,
 74        "data": st.save({"data": obj}),
 75    }
 76
 77
 78def _variable_to_json(var: tf.Variable, ctx: Context) -> dict:
 79    return {
 80        "__type__": "tensorflow.variable",
 81        "__version__": 3,
 82        "name": var.name,
 83        "value": var.value(),
 84        "trainable": var.trainable,
 85    }
 86
 87
 88# pylint: disable=missing-function-docstring
 89def from_json(dct: dict, ctx: Context) -> Any:
 90    decoders = {
 91        "tensorflow.sparse_tensor": _json_to_sparse_tensor,
 92        "tensorflow.tensor": _json_to_tensor,
 93        "tensorflow.variable": _json_to_variable,
 94    }
 95    try:
 96        type_name = dct["__type__"]
 97        return decoders[type_name](dct, ctx)
 98    except KeyError as exc:
 99        raise DeserializationError() from exc
100
101
102def to_json(obj: Any, ctx: Context) -> dict:
103    """
104    Serializes a tensorflow object into JSON by cases. See the README for the
105    precise list of supported types. The return dict has the following
106    structure:
107
108    - `tf.RaggedTensor`: Not supported.
109
110    - `tf.SparseTensor`:
111
112        ```py
113        {
114            "__type__": "tensorflow.sparse_tensor",
115            "__version__": 2,
116            "indices": {...},
117            "values": {...},
118            "shape": {...},
119        }
120        ```
121
122      where the first two `{...}` placeholders result in the serialization of
123      `tf.Tensor` (see below).
124
125    - other `tf.Tensor` subtypes:
126
127        ```py
128        {
129            "__type__": "tensorflow.tensor",
130            "__version__": 4,
131            "data": {
132                "__type__": "bytes",
133                ...
134            },
135        }
136        ```
137
138      see `turbo_broccoli.custom.bytes.to_json`.
139
140    - `tf.Variable`:
141
142        ```py
143        {
144            "__type__": "tensorflow.tensor",
145            "__version__": 3,
146            "name": <str>,
147            "value": {...},
148            "trainable": <bool>,
149        }
150        ```
151
152      where `{...}` is the document produced by serializing the value tensor of
153      the variable, see above.
154
155    """
156    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
157        (tf.RaggedTensor, _ragged_tensor_to_json),
158        (tf.SparseTensor, _sparse_tensor_to_json),
159        (tf.Tensor, _tensor_to_json),
160        (tf.Variable, _variable_to_json),
161    ]
162    for t, f in encoders:
163        if isinstance(obj, t):
164            return f(obj, ctx)
165    raise TypeNotSupported()
def from_json(dct: dict, ctx: turbo_broccoli.context.Context) -> Any:
 90def from_json(dct: dict, ctx: Context) -> Any:
 91    decoders = {
 92        "tensorflow.sparse_tensor": _json_to_sparse_tensor,
 93        "tensorflow.tensor": _json_to_tensor,
 94        "tensorflow.variable": _json_to_variable,
 95    }
 96    try:
 97        type_name = dct["__type__"]
 98        return decoders[type_name](dct, ctx)
 99    except KeyError as exc:
100        raise DeserializationError() from exc
def to_json(obj: Any, ctx: turbo_broccoli.context.Context) -> dict:
103def to_json(obj: Any, ctx: Context) -> dict:
104    """
105    Serializes a tensorflow object into JSON by cases. See the README for the
106    precise list of supported types. The return dict has the following
107    structure:
108
109    - `tf.RaggedTensor`: Not supported.
110
111    - `tf.SparseTensor`:
112
113        ```py
114        {
115            "__type__": "tensorflow.sparse_tensor",
116            "__version__": 2,
117            "indices": {...},
118            "values": {...},
119            "shape": {...},
120        }
121        ```
122
123      where the first two `{...}` placeholders result in the serialization of
124      `tf.Tensor` (see below).
125
126    - other `tf.Tensor` subtypes:
127
128        ```py
129        {
130            "__type__": "tensorflow.tensor",
131            "__version__": 4,
132            "data": {
133                "__type__": "bytes",
134                ...
135            },
136        }
137        ```
138
139      see `turbo_broccoli.custom.bytes.to_json`.
140
141    - `tf.Variable`:
142
143        ```py
144        {
145            "__type__": "tensorflow.tensor",
146            "__version__": 3,
147            "name": <str>,
148            "value": {...},
149            "trainable": <bool>,
150        }
151        ```
152
153      where `{...}` is the document produced by serializing the value tensor of
154      the variable, see above.
155
156    """
157    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
158        (tf.RaggedTensor, _ragged_tensor_to_json),
159        (tf.SparseTensor, _sparse_tensor_to_json),
160        (tf.Tensor, _tensor_to_json),
161        (tf.Variable, _variable_to_json),
162    ]
163    for t, f in encoders:
164        if isinstance(obj, t):
165            return f(obj, ctx)
166    raise TypeNotSupported()

Serializes a tensorflow object into JSON by cases. See the README for the precise list of supported types. The return dict has the following structure:

  • tf.RaggedTensor: Not supported.

  • tf.SparseTensor:

    {
        "__type__": "tensorflow.sparse_tensor",
        "__version__": 2,
        "indices": {...},
        "values": {...},
        "shape": {...},
    }
    

    where the first two {...} placeholders result in the serialization of tf.Tensor (see below).

  • other tf.Tensor subtypes:

    {
        "__type__": "tensorflow.tensor",
        "__version__": 4,
        "data": {
            "__type__": "bytes",
            ...
        },
    }
    

    see turbo_broccoli.custom.bytes.to_json.

  • tf.Variable:

    {
        "__type__": "tensorflow.tensor",
        "__version__": 3,
        "name": <str>,
        "value": {...},
        "trainable": <bool>,
    }
    

    where {...} is the document produced by serializing the value tensor of the variable, see above.