turbo_broccoli.custom.tensorflow
Tensorflow (de)serialization utilities.
1"""Tensorflow (de)serialization utilities.""" 2 3from typing import Any, Callable, Tuple 4 5import tensorflow as tf 6from safetensors import tensorflow as st 7 8from turbo_broccoli.context import Context 9from turbo_broccoli.exceptions import DeserializationError, TypeNotSupported 10 11 12def _json_to_sparse_tensor(dct: dict, ctx: Context) -> tf.Tensor: 13 decoders = { 14 2: _json_to_sparse_tensor_v2, 15 } 16 return decoders[dct["__version__"]](dct, ctx) 17 18 19def _json_to_sparse_tensor_v2(dct: dict, ctx: Context) -> tf.Tensor: 20 return tf.SparseTensor( 21 dense_shape=dct["shape"], 22 indices=dct["indices"], 23 values=dct["values"], 24 ) 25 26 27def _json_to_tensor(dct: dict, ctx: Context) -> tf.Tensor: 28 ctx.raise_if_nodecode("bytes") 29 decoders = { 30 4: _json_to_tensor_v4, 31 } 32 return decoders[dct["__version__"]](dct, ctx) 33 34 35def _json_to_tensor_v4(dct: dict, ctx: Context) -> tf.Tensor: 36 return st.load(dct["data"])["data"] 37 38 39def _json_to_variable(dct: dict, ctx: Context) -> tf.Variable: 40 decoders = { 41 3: _json_to_variable_v3, 42 } 43 return decoders[dct["__version__"]](dct, ctx) 44 45 46def _json_to_variable_v3(dct: dict, ctx: Context) -> tf.Variable: 47 return tf.Variable( 48 initial_value=dct["value"], 49 name=dct["name"], 50 trainable=dct["trainable"], 51 ) 52 53 54def _ragged_tensor_to_json(obj: tf.Tensor, ctx: Context) -> dict: 55 raise NotImplementedError( 56 "Serialization of ragged tensors is not supported" 57 ) 58 59 60def _sparse_tensor_to_json(obj: tf.SparseTensor, ctx: Context) -> dict: 61 return { 62 "__type__": "tensorflow.sparse_tensor", 63 "__version__": 2, 64 "indices": obj.indices, 65 "shape": list(obj.dense_shape), 66 "values": obj.values, 67 } 68 69 70def _tensor_to_json(obj: tf.Tensor, ctx: Context) -> dict: 71 return { 72 "__type__": "tensorflow.tensor", 73 "__version__": 4, 74 "data": st.save({"data": obj}), 75 } 76 77 78def _variable_to_json(var: tf.Variable, ctx: Context) -> dict: 79 return { 80 "__type__": "tensorflow.variable", 81 "__version__": 3, 82 "name": var.name, 83 "value": var.value(), 84 "trainable": var.trainable, 85 } 86 87 88# pylint: disable=missing-function-docstring 89def from_json(dct: dict, ctx: Context) -> Any: 90 decoders = { 91 "tensorflow.sparse_tensor": _json_to_sparse_tensor, 92 "tensorflow.tensor": _json_to_tensor, 93 "tensorflow.variable": _json_to_variable, 94 } 95 try: 96 type_name = dct["__type__"] 97 return decoders[type_name](dct, ctx) 98 except KeyError as exc: 99 raise DeserializationError() from exc 100 101 102def to_json(obj: Any, ctx: Context) -> dict: 103 """ 104 Serializes a tensorflow object into JSON by cases. See the README for the 105 precise list of supported types. The return dict has the following 106 structure: 107 108 - `tf.RaggedTensor`: Not supported. 109 110 - `tf.SparseTensor`: 111 112 ```py 113 { 114 "__type__": "tensorflow.sparse_tensor", 115 "__version__": 2, 116 "indices": {...}, 117 "values": {...}, 118 "shape": {...}, 119 } 120 ``` 121 122 where the first two `{...}` placeholders result in the serialization of 123 `tf.Tensor` (see below). 124 125 - other `tf.Tensor` subtypes: 126 127 ```py 128 { 129 "__type__": "tensorflow.tensor", 130 "__version__": 4, 131 "data": { 132 "__type__": "bytes", 133 ... 134 }, 135 } 136 ``` 137 138 see `turbo_broccoli.custom.bytes.to_json`. 139 140 - `tf.Variable`: 141 142 ```py 143 { 144 "__type__": "tensorflow.tensor", 145 "__version__": 3, 146 "name": <str>, 147 "value": {...}, 148 "trainable": <bool>, 149 } 150 ``` 151 152 where `{...}` is the document produced by serializing the value tensor of 153 the variable, see above. 154 155 """ 156 encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [ 157 (tf.RaggedTensor, _ragged_tensor_to_json), 158 (tf.SparseTensor, _sparse_tensor_to_json), 159 (tf.Tensor, _tensor_to_json), 160 (tf.Variable, _variable_to_json), 161 ] 162 for t, f in encoders: 163 if isinstance(obj, t): 164 return f(obj, ctx) 165 raise TypeNotSupported()
90def from_json(dct: dict, ctx: Context) -> Any: 91 decoders = { 92 "tensorflow.sparse_tensor": _json_to_sparse_tensor, 93 "tensorflow.tensor": _json_to_tensor, 94 "tensorflow.variable": _json_to_variable, 95 } 96 try: 97 type_name = dct["__type__"] 98 return decoders[type_name](dct, ctx) 99 except KeyError as exc: 100 raise DeserializationError() from exc
103def to_json(obj: Any, ctx: Context) -> dict: 104 """ 105 Serializes a tensorflow object into JSON by cases. See the README for the 106 precise list of supported types. The return dict has the following 107 structure: 108 109 - `tf.RaggedTensor`: Not supported. 110 111 - `tf.SparseTensor`: 112 113 ```py 114 { 115 "__type__": "tensorflow.sparse_tensor", 116 "__version__": 2, 117 "indices": {...}, 118 "values": {...}, 119 "shape": {...}, 120 } 121 ``` 122 123 where the first two `{...}` placeholders result in the serialization of 124 `tf.Tensor` (see below). 125 126 - other `tf.Tensor` subtypes: 127 128 ```py 129 { 130 "__type__": "tensorflow.tensor", 131 "__version__": 4, 132 "data": { 133 "__type__": "bytes", 134 ... 135 }, 136 } 137 ``` 138 139 see `turbo_broccoli.custom.bytes.to_json`. 140 141 - `tf.Variable`: 142 143 ```py 144 { 145 "__type__": "tensorflow.tensor", 146 "__version__": 3, 147 "name": <str>, 148 "value": {...}, 149 "trainable": <bool>, 150 } 151 ``` 152 153 where `{...}` is the document produced by serializing the value tensor of 154 the variable, see above. 155 156 """ 157 encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [ 158 (tf.RaggedTensor, _ragged_tensor_to_json), 159 (tf.SparseTensor, _sparse_tensor_to_json), 160 (tf.Tensor, _tensor_to_json), 161 (tf.Variable, _variable_to_json), 162 ] 163 for t, f in encoders: 164 if isinstance(obj, t): 165 return f(obj, ctx) 166 raise TypeNotSupported()
Serializes a tensorflow object into JSON by cases. See the README for the precise list of supported types. The return dict has the following structure:
tf.RaggedTensor
: Not supported.tf.SparseTensor
:{ "__type__": "tensorflow.sparse_tensor", "__version__": 2, "indices": {...}, "values": {...}, "shape": {...}, }
where the first two
{...}
placeholders result in the serialization oftf.Tensor
(see below).other
tf.Tensor
subtypes:{ "__type__": "tensorflow.tensor", "__version__": 4, "data": { "__type__": "bytes", ... }, }
tf.Variable
:{ "__type__": "tensorflow.tensor", "__version__": 3, "name": <str>, "value": {...}, "trainable": <bool>, }
where
{...}
is the document produced by serializing the value tensor of the variable, see above.