turbo_broccoli.custom.keras

keras (de)serialization utilities.

  1"""keras (de)serialization utilities."""
  2
  3from functools import partial
  4from typing import Any, Callable, Tuple
  5
  6from tensorflow import keras  # pylint: disable=no-name-in-module
  7
  8from turbo_broccoli.context import Context
  9from turbo_broccoli.exceptions import DeserializationError, TypeNotSupported
 10
 11KERAS_LAYERS = {
 12    "Activation": keras.layers.Activation,
 13    "ActivityRegularization": keras.layers.ActivityRegularization,
 14    "Add": keras.layers.Add,
 15    "AdditiveAttention": keras.layers.AdditiveAttention,
 16    "AlphaDropout": keras.layers.AlphaDropout,
 17    "Attention": keras.layers.Attention,
 18    "Average": keras.layers.Average,
 19    "AveragePooling1D": keras.layers.AveragePooling1D,
 20    "AveragePooling2D": keras.layers.AveragePooling2D,
 21    "AveragePooling3D": keras.layers.AveragePooling3D,
 22    "AvgPool1D": keras.layers.AvgPool1D,
 23    "AvgPool2D": keras.layers.AvgPool2D,
 24    "AvgPool3D": keras.layers.AvgPool3D,
 25    "BatchNormalization": keras.layers.BatchNormalization,
 26    "Bidirectional": keras.layers.Bidirectional,
 27    "CategoryEncoding": keras.layers.CategoryEncoding,
 28    "CenterCrop": keras.layers.CenterCrop,
 29    "Concatenate": keras.layers.Concatenate,
 30    "Conv1D": keras.layers.Conv1D,
 31    "Conv1DTranspose": keras.layers.Conv1DTranspose,
 32    "Conv2D": keras.layers.Conv2D,
 33    "Conv2DTranspose": keras.layers.Conv2DTranspose,
 34    "Conv3D": keras.layers.Conv3D,
 35    "Conv3DTranspose": keras.layers.Conv3DTranspose,
 36    "ConvLSTM1D": keras.layers.ConvLSTM1D,
 37    "ConvLSTM2D": keras.layers.ConvLSTM2D,
 38    "ConvLSTM3D": keras.layers.ConvLSTM3D,
 39    "Convolution1D": keras.layers.Convolution1D,
 40    "Convolution1DTranspose": keras.layers.Convolution1DTranspose,
 41    "Convolution2D": keras.layers.Convolution2D,
 42    "Convolution2DTranspose": keras.layers.Convolution2DTranspose,
 43    "Convolution3D": keras.layers.Convolution3D,
 44    "Convolution3DTranspose": keras.layers.Convolution3DTranspose,
 45    "Cropping1D": keras.layers.Cropping1D,
 46    "Cropping2D": keras.layers.Cropping2D,
 47    "Cropping3D": keras.layers.Cropping3D,
 48    "Dense": keras.layers.Dense,
 49    "DepthwiseConv1D": keras.layers.DepthwiseConv1D,
 50    "DepthwiseConv2D": keras.layers.DepthwiseConv2D,
 51    "Discretization": keras.layers.Discretization,
 52    "Dot": keras.layers.Dot,
 53    "Dropout": keras.layers.Dropout,
 54    "ELU": keras.layers.ELU,
 55    "EinsumDense": keras.layers.EinsumDense,
 56    "Embedding": keras.layers.Embedding,
 57    "Flatten": keras.layers.Flatten,
 58    "GRU": keras.layers.GRU,
 59    "GRUCell": keras.layers.GRUCell,
 60    "GaussianDropout": keras.layers.GaussianDropout,
 61    "GaussianNoise": keras.layers.GaussianNoise,
 62    "GlobalAveragePooling1D": keras.layers.GlobalAveragePooling1D,
 63    "GlobalAveragePooling2D": keras.layers.GlobalAveragePooling2D,
 64    "GlobalAveragePooling3D": keras.layers.GlobalAveragePooling3D,
 65    "GlobalAvgPool1D": keras.layers.GlobalAvgPool1D,
 66    "GlobalAvgPool2D": keras.layers.GlobalAvgPool2D,
 67    "GlobalAvgPool3D": keras.layers.GlobalAvgPool3D,
 68    "GlobalMaxPool1D": keras.layers.GlobalMaxPool1D,
 69    "GlobalMaxPool2D": keras.layers.GlobalMaxPool2D,
 70    "GlobalMaxPool3D": keras.layers.GlobalMaxPool3D,
 71    "GlobalMaxPooling1D": keras.layers.GlobalMaxPooling1D,
 72    "GlobalMaxPooling2D": keras.layers.GlobalMaxPooling2D,
 73    "GlobalMaxPooling3D": keras.layers.GlobalMaxPooling3D,
 74    "GroupNormalization": keras.layers.GroupNormalization,
 75    "GroupQueryAttention": keras.layers.GroupQueryAttention,
 76    "HashedCrossing": keras.layers.HashedCrossing,
 77    "Hashing": keras.layers.Hashing,
 78    "Identity": keras.layers.Identity,
 79    "Input": keras.layers.Input,
 80    "InputLayer": keras.layers.InputLayer,
 81    "InputSpec": keras.layers.InputSpec,
 82    "IntegerLookup": keras.layers.IntegerLookup,
 83    "LSTM": keras.layers.LSTM,
 84    "LSTMCell": keras.layers.LSTMCell,
 85    "Lambda": keras.layers.Lambda,
 86    "Layer": keras.layers.Layer,
 87    "LayerNormalization": keras.layers.LayerNormalization,
 88    "LeakyReLU": keras.layers.LeakyReLU,
 89    "Masking": keras.layers.Masking,
 90    "MaxPool1D": keras.layers.MaxPool1D,
 91    "MaxPool2D": keras.layers.MaxPool2D,
 92    "MaxPool3D": keras.layers.MaxPool3D,
 93    "MaxPooling1D": keras.layers.MaxPooling1D,
 94    "MaxPooling2D": keras.layers.MaxPooling2D,
 95    "MaxPooling3D": keras.layers.MaxPooling3D,
 96    "Maximum": keras.layers.Maximum,
 97    "MelSpectrogram": keras.layers.MelSpectrogram,
 98    "Minimum": keras.layers.Minimum,
 99    "MultiHeadAttention": keras.layers.MultiHeadAttention,
100    "Multiply": keras.layers.Multiply,
101    "Normalization": keras.layers.Normalization,
102    "PReLU": keras.layers.PReLU,
103    "Permute": keras.layers.Permute,
104    "RNN": keras.layers.RNN,
105    "RandomBrightness": keras.layers.RandomBrightness,
106    "RandomContrast": keras.layers.RandomContrast,
107    "RandomCrop": keras.layers.RandomCrop,
108    "RandomFlip": keras.layers.RandomFlip,
109    "RandomHeight": keras.layers.RandomHeight,
110    "RandomRotation": keras.layers.RandomRotation,
111    "RandomTranslation": keras.layers.RandomTranslation,
112    "RandomWidth": keras.layers.RandomWidth,
113    "RandomZoom": keras.layers.RandomZoom,
114    "ReLU": keras.layers.ReLU,
115    "RepeatVector": keras.layers.RepeatVector,
116    "Rescaling": keras.layers.Rescaling,
117    "Reshape": keras.layers.Reshape,
118    "Resizing": keras.layers.Resizing,
119    "SeparableConv1D": keras.layers.SeparableConv1D,
120    "SeparableConv2D": keras.layers.SeparableConv2D,
121    "SeparableConvolution1D": keras.layers.SeparableConvolution1D,
122    "SeparableConvolution2D": keras.layers.SeparableConvolution2D,
123    "SimpleRNN": keras.layers.SimpleRNN,
124    "SimpleRNNCell": keras.layers.SimpleRNNCell,
125    "Softmax": keras.layers.Softmax,
126    "SpatialDropout1D": keras.layers.SpatialDropout1D,
127    "SpatialDropout2D": keras.layers.SpatialDropout2D,
128    "SpatialDropout3D": keras.layers.SpatialDropout3D,
129    "SpectralNormalization": keras.layers.SpectralNormalization,
130    "StackedRNNCells": keras.layers.StackedRNNCells,
131    "StringLookup": keras.layers.StringLookup,
132    "Subtract": keras.layers.Subtract,
133    "TFSMLayer": keras.layers.TFSMLayer,
134    "TextVectorization": keras.layers.TextVectorization,
135    "ThresholdedReLU": keras.layers.ThresholdedReLU,
136    "TimeDistributed": keras.layers.TimeDistributed,
137    "TorchModuleWrapper": keras.layers.TorchModuleWrapper,
138    "UnitNormalization": keras.layers.UnitNormalization,
139    "UpSampling1D": keras.layers.UpSampling1D,
140    "UpSampling2D": keras.layers.UpSampling2D,
141    "UpSampling3D": keras.layers.UpSampling3D,
142    "Wrapper": keras.layers.Wrapper,
143    "ZeroPadding1D": keras.layers.ZeroPadding1D,
144    "ZeroPadding2D": keras.layers.ZeroPadding2D,
145    "ZeroPadding3D": keras.layers.ZeroPadding3D,
146}
147
148KERAS_LOSSES = {
149    "BinaryCrossentropy": keras.losses.BinaryCrossentropy,
150    "BinaryFocalCrossentropy": keras.losses.BinaryFocalCrossentropy,
151    "CTC": keras.losses.CTC,
152    "CategoricalCrossentropy": keras.losses.CategoricalCrossentropy,
153    "CategoricalFocalCrossentropy": keras.losses.CategoricalFocalCrossentropy,
154    "CategoricalHinge": keras.losses.CategoricalHinge,
155    "CosineSimilarity": keras.losses.CosineSimilarity,
156    "Hinge": keras.losses.Hinge,
157    "Huber": keras.losses.Huber,
158    "KLD": keras.losses.KLD,
159    "KLDivergence": keras.losses.KLDivergence,
160    "LogCosh": keras.losses.LogCosh,
161    "Loss": keras.losses.Loss,
162    "MAE": keras.losses.MAE,
163    "MAPE": keras.losses.MAPE,
164    "MSE": keras.losses.MSE,
165    "MSLE": keras.losses.MSLE,
166    "MeanAbsoluteError": keras.losses.MeanAbsoluteError,
167    "MeanAbsolutePercentageError": keras.losses.MeanAbsolutePercentageError,
168    "MeanSquaredError": keras.losses.MeanSquaredError,
169    "MeanSquaredLogarithmicError": keras.losses.MeanSquaredLogarithmicError,
170    "Poisson": keras.losses.Poisson,
171    "Reduction": keras.losses.Reduction,
172    "SparseCategoricalCrossentropy": keras.losses.SparseCategoricalCrossentropy,
173    "SquaredHinge": keras.losses.SquaredHinge,
174}
175
176KERAS_METRICS = {
177    "AUC": keras.metrics.AUC,
178    "Accuracy": keras.metrics.Accuracy,
179    "BinaryAccuracy": keras.metrics.BinaryAccuracy,
180    "BinaryCrossentropy": keras.metrics.BinaryCrossentropy,
181    "BinaryIoU": keras.metrics.BinaryIoU,
182    "CategoricalAccuracy": keras.metrics.CategoricalAccuracy,
183    "CategoricalCrossentropy": keras.metrics.CategoricalCrossentropy,
184    "CategoricalHinge": keras.metrics.CategoricalHinge,
185    "CosineSimilarity": keras.metrics.CosineSimilarity,
186    "F1Score": keras.metrics.F1Score,
187    "FBetaScore": keras.metrics.FBetaScore,
188    "FalseNegatives": keras.metrics.FalseNegatives,
189    "FalsePositives": keras.metrics.FalsePositives,
190    "Hinge": keras.metrics.Hinge,
191    "IoU": keras.metrics.IoU,
192    "KLDivergence": keras.metrics.KLDivergence,
193    "LogCoshError": keras.metrics.LogCoshError,
194    "Mean": keras.metrics.Mean,
195    "MeanAbsoluteError": keras.metrics.MeanAbsoluteError,
196    "MeanAbsolutePercentageError": keras.metrics.MeanAbsolutePercentageError,
197    "MeanIoU": keras.metrics.MeanIoU,
198    "MeanMetricWrapper": keras.metrics.MeanMetricWrapper,
199    "MeanSquaredError": keras.metrics.MeanSquaredError,
200    "MeanSquaredLogarithmicError": keras.metrics.MeanSquaredLogarithmicError,
201    "Metric": keras.metrics.Metric,
202    "OneHotIoU": keras.metrics.OneHotIoU,
203    "OneHotMeanIoU": keras.metrics.OneHotMeanIoU,
204    "Poisson": keras.metrics.Poisson,
205    "Precision": keras.metrics.Precision,
206    "PrecisionAtRecall": keras.metrics.PrecisionAtRecall,
207    "R2Score": keras.metrics.R2Score,
208    "Recall": keras.metrics.Recall,
209    "RecallAtPrecision": keras.metrics.RecallAtPrecision,
210    "RootMeanSquaredError": keras.metrics.RootMeanSquaredError,
211    "SensitivityAtSpecificity": keras.metrics.SensitivityAtSpecificity,
212    "SparseCategoricalAccuracy": keras.metrics.SparseCategoricalAccuracy,
213    "SparseCategoricalCrossentropy": keras.metrics.SparseCategoricalCrossentropy,
214    "SparseTopKCategoricalAccuracy": keras.metrics.SparseTopKCategoricalAccuracy,
215    "SpecificityAtSensitivity": keras.metrics.SpecificityAtSensitivity,
216    "SquaredHinge": keras.metrics.SquaredHinge,
217    "Sum": keras.metrics.Sum,
218    "TopKCategoricalAccuracy": keras.metrics.TopKCategoricalAccuracy,
219    "TrueNegatives": keras.metrics.TrueNegatives,
220    "TruePositives": keras.metrics.TruePositives,
221}
222
223KERAS_OPTIMIZERS = {
224    "Adadelta": keras.optimizers.Adadelta,
225    "Adafactor": keras.optimizers.Adafactor,
226    "Adagrad": keras.optimizers.Adagrad,
227    "Adam": keras.optimizers.Adam,
228    "AdamW": keras.optimizers.AdamW,
229    "Adamax": keras.optimizers.Adamax,
230    "Ftrl": keras.optimizers.Ftrl,
231    "Lion": keras.optimizers.Lion,
232    "LossScaleOptimizer": keras.optimizers.LossScaleOptimizer,
233    "Nadam": keras.optimizers.Nadam,
234    "Optimizer": keras.optimizers.Optimizer,
235    "RMSprop": keras.optimizers.RMSprop,
236    "SGD": keras.optimizers.SGD,
237}
238
239KERAS_LEGACY_OPTIMIZERS = {
240    "Adagrad": keras.optimizers.legacy.Adagrad,
241    "Adam": keras.optimizers.legacy.Adam,
242    "Ftrl": keras.optimizers.legacy.Ftrl,
243    "Optimizer": keras.optimizers.legacy.Optimizer,
244    "RMSprop": keras.optimizers.legacy.RMSprop,
245    "SGD": keras.optimizers.legacy.SGD,
246}
247
248
249def _json_to_layer(dct: dict, ctx: Context) -> Any:
250    decoders = {
251        2: _json_to_layer_v2,
252    }
253    return decoders[dct["__version__"]](dct, ctx)
254
255
256def _json_to_layer_v2(dct: dict, ctx: Context) -> Any:
257    return keras.utils.deserialize_keras_object(
258        dct["data"],
259        module_objects=KERAS_LAYERS,
260    )
261
262
263def _json_to_loss(dct: dict, ctx: Context) -> Any:
264    decoders = {
265        2: _json_to_loss_v2,
266    }
267    return decoders[dct["__version__"]](dct, ctx)
268
269
270def _json_to_loss_v2(dct: dict, ctx: Context) -> Any:
271    return keras.utils.deserialize_keras_object(
272        dct["data"],
273        module_objects=KERAS_LOSSES,
274    )
275
276
277def _json_to_metric(dct: dict, ctx: Context) -> Any:
278    decoders = {
279        2: _json_to_metric_v2,
280    }
281    return decoders[dct["__version__"]](dct, ctx)
282
283
284def _json_to_metric_v2(dct: dict, ctx: Context) -> Any:
285    return keras.utils.deserialize_keras_object(
286        dct["data"],
287        module_objects=KERAS_METRICS,
288    )
289
290
291def _json_to_model(dct: dict, ctx: Context) -> Any:
292    decoders = {
293        5: _json_to_model_v5,
294    }
295    return decoders[dct["__version__"]](dct, ctx)
296
297
298def _json_to_model_v5(dct: dict, ctx: Context) -> Any:
299    if "model" in dct:
300        model = keras.models.model_from_config(dct["model"])
301        model.set_weights(dct["weights"])
302        kwargs = {"metrics": dct["metrics"]}
303        for k in ["loss", "optimizer"]:
304            if dct.get(k) is not None:
305                kwargs[k] = dct[k]
306        model.compile(**kwargs)
307        return model
308    path = (
309        ctx.id_to_artifact_path(dct["id"], extension="keras")
310        if ctx.keras_format == "keras"
311        else ctx.id_to_artifact_path(dct["id"])
312    )
313    return keras.models.load_model(path)
314
315
316def _json_to_optimizer(dct: dict, ctx: Context) -> Any:
317    decoders = {
318        2: _json_to_optimizer_v2,
319        3: _json_to_optimizer_v3,
320    }
321    return decoders[dct["__version__"]](dct, ctx)
322
323
324def _json_to_optimizer_v2(dct: dict, ctx: Context) -> Any:
325    return keras.utils.deserialize_keras_object(
326        dct["data"],
327        module_objects=KERAS_OPTIMIZERS,
328    )
329
330
331def _json_to_optimizer_v3(dct: dict, ctx: Context) -> Any:
332    return keras.utils.deserialize_keras_object(
333        dct["data"],
334        module_objects=(
335            KERAS_LEGACY_OPTIMIZERS if dct["legacy"] else KERAS_OPTIMIZERS
336        ),
337    )
338
339
340def _generic_to_json(
341    obj: Any,
342    ctx: Context,
343    *,
344    type_: str,
345) -> dict:
346    return {
347        "__type__": "keras." + type_,
348        "__version__": 2,
349        "data": keras.utils.serialize_keras_object(obj),
350    }
351
352
353def _model_to_json(model: keras.Model, ctx: Context) -> dict:
354    if ctx.keras_format == "json":
355        return {
356            "__type__": "keras.model",
357            "__version__": 5,
358            "loss": getattr(model, "loss", None),
359            "metrics": getattr(model, "metrics", []),
360            "model": keras.utils.serialize_keras_object(model),
361            "optimizer": getattr(model, "optimizer", None),
362            "weights": model.weights,
363        }
364    if ctx.keras_format == "keras":
365        path, name = ctx.new_artifact_path(extension="keras")
366    else:
367        path, name = ctx.new_artifact_path()
368    model.save(path, save_format=ctx.keras_format)
369    return {
370        "__type__": "keras.model",
371        "__version__": 5,
372        "format": ctx.keras_format,
373        "id": name,
374    }
375
376
377def _optimizer_to_json(obj: Any, ctx: Context) -> dict:
378    return {
379        "__type__": "keras.optimizer",
380        "__version__": 3,
381        "data": keras.utils.serialize_keras_object(obj),
382        "legacy": isinstance(obj, keras.optimizers.legacy.Optimizer),
383    }
384
385
386# pylint: disable=missing-function-docstring
387def from_json(dct: dict, ctx: Context) -> Any:
388    decoders = {
389        "keras.model": _json_to_model,  # must be first!
390        "keras.layer": _json_to_layer,
391        "keras.loss": _json_to_loss,
392        "keras.metric": _json_to_metric,
393        "keras.optimizer": _json_to_optimizer,
394    }
395    try:
396        type_name = dct["__type__"]
397        return decoders[type_name](dct, ctx)
398    except KeyError as exc:
399        raise DeserializationError() from exc
400
401
402def to_json(obj: Any, ctx: Context) -> dict:
403    """
404    Serializes a tensorflow object into JSON by cases. See the README for the
405    precise list of supported types. Most keras object will simply be
406    serialized using `keras.utils.serialize_keras_object`. Here are the
407    exceptions:
408
409    - `keras.Model` (the model must have weights). If `TB_KERAS_FORMAT` is
410      `json`, the document will look like
411
412        ```py
413        {
414
415            "__type__": "keras.model",
416            "__version__": 5,
417            "loss": {...} or null,
418            "metrics": [...],
419            "model": {...},
420            "optimizer": {...} or null,
421            "weights": [...],
422        }
423        ```
424
425      if `TB_KERAS_FORMAT` is `h5` or `tf`, the document will look like
426
427        ```py
428        {
429
430            "__type__": "keras.model",
431            "__version__": 5,
432            "format": <str>,
433            "id": <uuid4>
434        }
435        ```
436
437      where `id` points to an artifact. Note that if the keras saving format is
438      `keras`, the artifact will have the `.keras` extension instead of the
439      usual `.tb`. Tensorflow/keras [forces this
440      behaviour](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model).
441
442    """
443    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
444        (keras.Model, _model_to_json),  # must be first
445        (keras.metrics.Metric, partial(_generic_to_json, type_="metric")),
446        (keras.layers.Layer, partial(_generic_to_json, type_="layer")),
447        (keras.losses.Loss, partial(_generic_to_json, type_="loss")),
448        (keras.optimizers.Optimizer, _optimizer_to_json),
449        (keras.optimizers.legacy.Optimizer, _optimizer_to_json),
450    ]
451    for t, f in encoders:
452        if isinstance(obj, t):
453            return f(obj, ctx)
454    raise TypeNotSupported()
KERAS_LAYERS = {'Activation': <class 'keras.src.layers.activations.activation.Activation'>, 'ActivityRegularization': <class 'keras.src.layers.regularization.activity_regularization.ActivityRegularization'>, 'Add': <class 'keras.src.layers.merging.add.Add'>, 'AdditiveAttention': <class 'keras.src.layers.attention.additive_attention.AdditiveAttention'>, 'AlphaDropout': <class 'keras.src.legacy.layers.AlphaDropout'>, 'Attention': <class 'keras.src.layers.attention.attention.Attention'>, 'Average': <class 'keras.src.layers.merging.average.Average'>, 'AveragePooling1D': <class 'keras.src.layers.pooling.average_pooling1d.AveragePooling1D'>, 'AveragePooling2D': <class 'keras.src.layers.pooling.average_pooling2d.AveragePooling2D'>, 'AveragePooling3D': <class 'keras.src.layers.pooling.average_pooling3d.AveragePooling3D'>, 'AvgPool1D': <class 'keras.src.layers.pooling.average_pooling1d.AveragePooling1D'>, 'AvgPool2D': <class 'keras.src.layers.pooling.average_pooling2d.AveragePooling2D'>, 'AvgPool3D': <class 'keras.src.layers.pooling.average_pooling3d.AveragePooling3D'>, 'BatchNormalization': <class 'keras.src.layers.normalization.batch_normalization.BatchNormalization'>, 'Bidirectional': <class 'keras.src.layers.rnn.bidirectional.Bidirectional'>, 'CategoryEncoding': <class 'keras.src.layers.preprocessing.category_encoding.CategoryEncoding'>, 'CenterCrop': <class 'keras.src.layers.preprocessing.center_crop.CenterCrop'>, 'Concatenate': <class 'keras.src.layers.merging.concatenate.Concatenate'>, 'Conv1D': <class 'keras.src.layers.convolutional.conv1d.Conv1D'>, 'Conv1DTranspose': <class 'keras.src.layers.convolutional.conv1d_transpose.Conv1DTranspose'>, 'Conv2D': <class 'keras.src.layers.convolutional.conv2d.Conv2D'>, 'Conv2DTranspose': <class 'keras.src.layers.convolutional.conv2d_transpose.Conv2DTranspose'>, 'Conv3D': <class 'keras.src.layers.convolutional.conv3d.Conv3D'>, 'Conv3DTranspose': <class 'keras.src.layers.convolutional.conv3d_transpose.Conv3DTranspose'>, 'ConvLSTM1D': <class 'keras.src.layers.rnn.conv_lstm1d.ConvLSTM1D'>, 'ConvLSTM2D': <class 'keras.src.layers.rnn.conv_lstm2d.ConvLSTM2D'>, 'ConvLSTM3D': <class 'keras.src.layers.rnn.conv_lstm3d.ConvLSTM3D'>, 'Convolution1D': <class 'keras.src.layers.convolutional.conv1d.Conv1D'>, 'Convolution1DTranspose': <class 'keras.src.layers.convolutional.conv1d_transpose.Conv1DTranspose'>, 'Convolution2D': <class 'keras.src.layers.convolutional.conv2d.Conv2D'>, 'Convolution2DTranspose': <class 'keras.src.layers.convolutional.conv2d_transpose.Conv2DTranspose'>, 'Convolution3D': <class 'keras.src.layers.convolutional.conv3d.Conv3D'>, 'Convolution3DTranspose': <class 'keras.src.layers.convolutional.conv3d_transpose.Conv3DTranspose'>, 'Cropping1D': <class 'keras.src.layers.reshaping.cropping1d.Cropping1D'>, 'Cropping2D': <class 'keras.src.layers.reshaping.cropping2d.Cropping2D'>, 'Cropping3D': <class 'keras.src.layers.reshaping.cropping3d.Cropping3D'>, 'Dense': <class 'keras.src.layers.core.dense.Dense'>, 'DepthwiseConv1D': <class 'keras.src.layers.convolutional.depthwise_conv1d.DepthwiseConv1D'>, 'DepthwiseConv2D': <class 'keras.src.layers.convolutional.depthwise_conv2d.DepthwiseConv2D'>, 'Discretization': <class 'keras.src.layers.preprocessing.discretization.Discretization'>, 'Dot': <class 'keras.src.layers.merging.dot.Dot'>, 'Dropout': <class 'keras.src.layers.regularization.dropout.Dropout'>, 'ELU': <class 'keras.src.layers.activations.elu.ELU'>, 'EinsumDense': <class 'keras.src.layers.core.einsum_dense.EinsumDense'>, 'Embedding': <class 'keras.src.layers.core.embedding.Embedding'>, 'Flatten': <class 'keras.src.layers.reshaping.flatten.Flatten'>, 'GRU': <class 'keras.src.layers.rnn.gru.GRU'>, 'GRUCell': <class 'keras.src.layers.rnn.gru.GRUCell'>, 'GaussianDropout': <class 'keras.src.layers.regularization.gaussian_dropout.GaussianDropout'>, 'GaussianNoise': <class 'keras.src.layers.regularization.gaussian_noise.GaussianNoise'>, 'GlobalAveragePooling1D': <class 'keras.src.layers.pooling.global_average_pooling1d.GlobalAveragePooling1D'>, 'GlobalAveragePooling2D': <class 'keras.src.layers.pooling.global_average_pooling2d.GlobalAveragePooling2D'>, 'GlobalAveragePooling3D': <class 'keras.src.layers.pooling.global_average_pooling3d.GlobalAveragePooling3D'>, 'GlobalAvgPool1D': <class 'keras.src.layers.pooling.global_average_pooling1d.GlobalAveragePooling1D'>, 'GlobalAvgPool2D': <class 'keras.src.layers.pooling.global_average_pooling2d.GlobalAveragePooling2D'>, 'GlobalAvgPool3D': <class 'keras.src.layers.pooling.global_average_pooling3d.GlobalAveragePooling3D'>, 'GlobalMaxPool1D': <class 'keras.src.layers.pooling.global_max_pooling1d.GlobalMaxPooling1D'>, 'GlobalMaxPool2D': <class 'keras.src.layers.pooling.global_max_pooling2d.GlobalMaxPooling2D'>, 'GlobalMaxPool3D': <class 'keras.src.layers.pooling.global_max_pooling3d.GlobalMaxPooling3D'>, 'GlobalMaxPooling1D': <class 'keras.src.layers.pooling.global_max_pooling1d.GlobalMaxPooling1D'>, 'GlobalMaxPooling2D': <class 'keras.src.layers.pooling.global_max_pooling2d.GlobalMaxPooling2D'>, 'GlobalMaxPooling3D': <class 'keras.src.layers.pooling.global_max_pooling3d.GlobalMaxPooling3D'>, 'GroupNormalization': <class 'keras.src.layers.normalization.group_normalization.GroupNormalization'>, 'GroupQueryAttention': <class 'keras.src.layers.attention.grouped_query_attention.GroupedQueryAttention'>, 'HashedCrossing': <class 'keras.src.layers.preprocessing.hashed_crossing.HashedCrossing'>, 'Hashing': <class 'keras.src.layers.preprocessing.hashing.Hashing'>, 'Identity': <class 'keras.src.layers.core.identity.Identity'>, 'Input': <function Input>, 'InputLayer': <class 'keras.src.layers.core.input_layer.InputLayer'>, 'InputSpec': <class 'keras.src.layers.input_spec.InputSpec'>, 'IntegerLookup': <class 'keras.src.layers.preprocessing.integer_lookup.IntegerLookup'>, 'LSTM': <class 'keras.src.layers.rnn.lstm.LSTM'>, 'LSTMCell': <class 'keras.src.layers.rnn.lstm.LSTMCell'>, 'Lambda': <class 'keras.src.layers.core.lambda_layer.Lambda'>, 'Layer': <class 'keras.src.layers.layer.Layer'>, 'LayerNormalization': <class 'keras.src.layers.normalization.layer_normalization.LayerNormalization'>, 'LeakyReLU': <class 'keras.src.layers.activations.leaky_relu.LeakyReLU'>, 'Masking': <class 'keras.src.layers.core.masking.Masking'>, 'MaxPool1D': <class 'keras.src.layers.pooling.max_pooling1d.MaxPooling1D'>, 'MaxPool2D': <class 'keras.src.layers.pooling.max_pooling2d.MaxPooling2D'>, 'MaxPool3D': <class 'keras.src.layers.pooling.max_pooling3d.MaxPooling3D'>, 'MaxPooling1D': <class 'keras.src.layers.pooling.max_pooling1d.MaxPooling1D'>, 'MaxPooling2D': <class 'keras.src.layers.pooling.max_pooling2d.MaxPooling2D'>, 'MaxPooling3D': <class 'keras.src.layers.pooling.max_pooling3d.MaxPooling3D'>, 'Maximum': <class 'keras.src.layers.merging.maximum.Maximum'>, 'MelSpectrogram': <class 'keras.src.layers.preprocessing.audio_preprocessing.MelSpectrogram'>, 'Minimum': <class 'keras.src.layers.merging.minimum.Minimum'>, 'MultiHeadAttention': <class 'keras.src.layers.attention.multi_head_attention.MultiHeadAttention'>, 'Multiply': <class 'keras.src.layers.merging.multiply.Multiply'>, 'Normalization': <class 'keras.src.layers.preprocessing.normalization.Normalization'>, 'PReLU': <class 'keras.src.layers.activations.prelu.PReLU'>, 'Permute': <class 'keras.src.layers.reshaping.permute.Permute'>, 'RNN': <class 'keras.src.layers.rnn.rnn.RNN'>, 'RandomBrightness': <class 'keras.src.layers.preprocessing.random_brightness.RandomBrightness'>, 'RandomContrast': <class 'keras.src.layers.preprocessing.random_contrast.RandomContrast'>, 'RandomCrop': <class 'keras.src.layers.preprocessing.random_crop.RandomCrop'>, 'RandomFlip': <class 'keras.src.layers.preprocessing.random_flip.RandomFlip'>, 'RandomHeight': <class 'keras.src.legacy.layers.RandomHeight'>, 'RandomRotation': <class 'keras.src.layers.preprocessing.random_rotation.RandomRotation'>, 'RandomTranslation': <class 'keras.src.layers.preprocessing.random_translation.RandomTranslation'>, 'RandomWidth': <class 'keras.src.legacy.layers.RandomWidth'>, 'RandomZoom': <class 'keras.src.layers.preprocessing.random_zoom.RandomZoom'>, 'ReLU': <class 'keras.src.layers.activations.relu.ReLU'>, 'RepeatVector': <class 'keras.src.layers.reshaping.repeat_vector.RepeatVector'>, 'Rescaling': <class 'keras.src.layers.preprocessing.rescaling.Rescaling'>, 'Reshape': <class 'keras.src.layers.reshaping.reshape.Reshape'>, 'Resizing': <class 'keras.src.layers.preprocessing.resizing.Resizing'>, 'SeparableConv1D': <class 'keras.src.layers.convolutional.separable_conv1d.SeparableConv1D'>, 'SeparableConv2D': <class 'keras.src.layers.convolutional.separable_conv2d.SeparableConv2D'>, 'SeparableConvolution1D': <class 'keras.src.layers.convolutional.separable_conv1d.SeparableConv1D'>, 'SeparableConvolution2D': <class 'keras.src.layers.convolutional.separable_conv2d.SeparableConv2D'>, 'SimpleRNN': <class 'keras.src.layers.rnn.simple_rnn.SimpleRNN'>, 'SimpleRNNCell': <class 'keras.src.layers.rnn.simple_rnn.SimpleRNNCell'>, 'Softmax': <class 'keras.src.layers.activations.softmax.Softmax'>, 'SpatialDropout1D': <class 'keras.src.layers.regularization.spatial_dropout.SpatialDropout1D'>, 'SpatialDropout2D': <class 'keras.src.layers.regularization.spatial_dropout.SpatialDropout2D'>, 'SpatialDropout3D': <class 'keras.src.layers.regularization.spatial_dropout.SpatialDropout3D'>, 'SpectralNormalization': <class 'keras.src.layers.normalization.spectral_normalization.SpectralNormalization'>, 'StackedRNNCells': <class 'keras.src.layers.rnn.stacked_rnn_cells.StackedRNNCells'>, 'StringLookup': <class 'keras.src.layers.preprocessing.string_lookup.StringLookup'>, 'Subtract': <class 'keras.src.layers.merging.subtract.Subtract'>, 'TFSMLayer': <class 'keras.src.export.export_lib.TFSMLayer'>, 'TextVectorization': <class 'keras.src.layers.preprocessing.text_vectorization.TextVectorization'>, 'ThresholdedReLU': <class 'keras.src.legacy.layers.ThresholdedReLU'>, 'TimeDistributed': <class 'keras.src.layers.rnn.time_distributed.TimeDistributed'>, 'TorchModuleWrapper': <class 'keras.src.utils.torch_utils.TorchModuleWrapper'>, 'UnitNormalization': <class 'keras.src.layers.normalization.unit_normalization.UnitNormalization'>, 'UpSampling1D': <class 'keras.src.layers.reshaping.up_sampling1d.UpSampling1D'>, 'UpSampling2D': <class 'keras.src.layers.reshaping.up_sampling2d.UpSampling2D'>, 'UpSampling3D': <class 'keras.src.layers.reshaping.up_sampling3d.UpSampling3D'>, 'Wrapper': <class 'keras.src.layers.core.wrapper.Wrapper'>, 'ZeroPadding1D': <class 'keras.src.layers.reshaping.zero_padding1d.ZeroPadding1D'>, 'ZeroPadding2D': <class 'keras.src.layers.reshaping.zero_padding2d.ZeroPadding2D'>, 'ZeroPadding3D': <class 'keras.src.layers.reshaping.zero_padding3d.ZeroPadding3D'>}
KERAS_LOSSES = {'BinaryCrossentropy': <class 'keras.src.losses.losses.BinaryCrossentropy'>, 'BinaryFocalCrossentropy': <class 'keras.src.losses.losses.BinaryFocalCrossentropy'>, 'CTC': <class 'keras.src.losses.losses.CTC'>, 'CategoricalCrossentropy': <class 'keras.src.losses.losses.CategoricalCrossentropy'>, 'CategoricalFocalCrossentropy': <class 'keras.src.losses.losses.CategoricalFocalCrossentropy'>, 'CategoricalHinge': <class 'keras.src.losses.losses.CategoricalHinge'>, 'CosineSimilarity': <class 'keras.src.losses.losses.CosineSimilarity'>, 'Hinge': <class 'keras.src.losses.losses.Hinge'>, 'Huber': <class 'keras.src.losses.losses.Huber'>, 'KLD': <function kl_divergence>, 'KLDivergence': <class 'keras.src.losses.losses.KLDivergence'>, 'LogCosh': <class 'keras.src.losses.losses.LogCosh'>, 'Loss': <class 'keras.src.losses.loss.Loss'>, 'MAE': <function mean_absolute_error>, 'MAPE': <function mean_absolute_percentage_error>, 'MSE': <function mean_squared_error>, 'MSLE': <function mean_squared_logarithmic_error>, 'MeanAbsoluteError': <class 'keras.src.losses.losses.MeanAbsoluteError'>, 'MeanAbsolutePercentageError': <class 'keras.src.losses.losses.MeanAbsolutePercentageError'>, 'MeanSquaredError': <class 'keras.src.losses.losses.MeanSquaredError'>, 'MeanSquaredLogarithmicError': <class 'keras.src.losses.losses.MeanSquaredLogarithmicError'>, 'Poisson': <class 'keras.src.losses.losses.Poisson'>, 'Reduction': <class 'keras.src.legacy.losses.Reduction'>, 'SparseCategoricalCrossentropy': <class 'keras.src.losses.losses.SparseCategoricalCrossentropy'>, 'SquaredHinge': <class 'keras.src.losses.losses.SquaredHinge'>}
KERAS_METRICS = {'AUC': <class 'keras.src.metrics.confusion_metrics.AUC'>, 'Accuracy': <class 'keras.src.metrics.accuracy_metrics.Accuracy'>, 'BinaryAccuracy': <class 'keras.src.metrics.accuracy_metrics.BinaryAccuracy'>, 'BinaryCrossentropy': <class 'keras.src.metrics.probabilistic_metrics.BinaryCrossentropy'>, 'BinaryIoU': <class 'keras.src.metrics.iou_metrics.BinaryIoU'>, 'CategoricalAccuracy': <class 'keras.src.metrics.accuracy_metrics.CategoricalAccuracy'>, 'CategoricalCrossentropy': <class 'keras.src.metrics.probabilistic_metrics.CategoricalCrossentropy'>, 'CategoricalHinge': <class 'keras.src.metrics.hinge_metrics.CategoricalHinge'>, 'CosineSimilarity': <class 'keras.src.metrics.regression_metrics.CosineSimilarity'>, 'F1Score': <class 'keras.src.metrics.f_score_metrics.F1Score'>, 'FBetaScore': <class 'keras.src.metrics.f_score_metrics.FBetaScore'>, 'FalseNegatives': <class 'keras.src.metrics.confusion_metrics.FalseNegatives'>, 'FalsePositives': <class 'keras.src.metrics.confusion_metrics.FalsePositives'>, 'Hinge': <class 'keras.src.metrics.hinge_metrics.Hinge'>, 'IoU': <class 'keras.src.metrics.iou_metrics.IoU'>, 'KLDivergence': <class 'keras.src.metrics.probabilistic_metrics.KLDivergence'>, 'LogCoshError': <class 'keras.src.metrics.regression_metrics.LogCoshError'>, 'Mean': <class 'keras.src.metrics.reduction_metrics.Mean'>, 'MeanAbsoluteError': <class 'keras.src.metrics.regression_metrics.MeanAbsoluteError'>, 'MeanAbsolutePercentageError': <class 'keras.src.metrics.regression_metrics.MeanAbsolutePercentageError'>, 'MeanIoU': <class 'keras.src.metrics.iou_metrics.MeanIoU'>, 'MeanMetricWrapper': <class 'keras.src.metrics.reduction_metrics.MeanMetricWrapper'>, 'MeanSquaredError': <class 'keras.src.metrics.regression_metrics.MeanSquaredError'>, 'MeanSquaredLogarithmicError': <class 'keras.src.metrics.regression_metrics.MeanSquaredLogarithmicError'>, 'Metric': <class 'keras.src.metrics.metric.Metric'>, 'OneHotIoU': <class 'keras.src.metrics.iou_metrics.OneHotIoU'>, 'OneHotMeanIoU': <class 'keras.src.metrics.iou_metrics.OneHotMeanIoU'>, 'Poisson': <class 'keras.src.metrics.probabilistic_metrics.Poisson'>, 'Precision': <class 'keras.src.metrics.confusion_metrics.Precision'>, 'PrecisionAtRecall': <class 'keras.src.metrics.confusion_metrics.PrecisionAtRecall'>, 'R2Score': <class 'keras.src.metrics.regression_metrics.R2Score'>, 'Recall': <class 'keras.src.metrics.confusion_metrics.Recall'>, 'RecallAtPrecision': <class 'keras.src.metrics.confusion_metrics.RecallAtPrecision'>, 'RootMeanSquaredError': <class 'keras.src.metrics.regression_metrics.RootMeanSquaredError'>, 'SensitivityAtSpecificity': <class 'keras.src.metrics.confusion_metrics.SensitivityAtSpecificity'>, 'SparseCategoricalAccuracy': <class 'keras.src.metrics.accuracy_metrics.SparseCategoricalAccuracy'>, 'SparseCategoricalCrossentropy': <class 'keras.src.metrics.probabilistic_metrics.SparseCategoricalCrossentropy'>, 'SparseTopKCategoricalAccuracy': <class 'keras.src.metrics.accuracy_metrics.SparseTopKCategoricalAccuracy'>, 'SpecificityAtSensitivity': <class 'keras.src.metrics.confusion_metrics.SpecificityAtSensitivity'>, 'SquaredHinge': <class 'keras.src.metrics.hinge_metrics.SquaredHinge'>, 'Sum': <class 'keras.src.metrics.reduction_metrics.Sum'>, 'TopKCategoricalAccuracy': <class 'keras.src.metrics.accuracy_metrics.TopKCategoricalAccuracy'>, 'TrueNegatives': <class 'keras.src.metrics.confusion_metrics.TrueNegatives'>, 'TruePositives': <class 'keras.src.metrics.confusion_metrics.TruePositives'>}
KERAS_OPTIMIZERS = {'Adadelta': <class 'keras.src.optimizers.adadelta.Adadelta'>, 'Adafactor': <class 'keras.src.optimizers.adafactor.Adafactor'>, 'Adagrad': <class 'keras.src.optimizers.adagrad.Adagrad'>, 'Adam': <class 'keras.src.optimizers.adam.Adam'>, 'AdamW': <class 'keras.src.optimizers.adamw.AdamW'>, 'Adamax': <class 'keras.src.optimizers.adamax.Adamax'>, 'Ftrl': <class 'keras.src.optimizers.ftrl.Ftrl'>, 'Lion': <class 'keras.src.optimizers.lion.Lion'>, 'LossScaleOptimizer': <class 'keras.src.optimizers.loss_scale_optimizer.LossScaleOptimizer'>, 'Nadam': <class 'keras.src.optimizers.nadam.Nadam'>, 'Optimizer': <class 'keras.src.optimizers.optimizer.Optimizer'>, 'RMSprop': <class 'keras.src.optimizers.rmsprop.RMSprop'>, 'SGD': <class 'keras.src.optimizers.sgd.SGD'>}
KERAS_LEGACY_OPTIMIZERS = {'Adagrad': <class 'keras.src.optimizers.LegacyOptimizerWarning'>, 'Adam': <class 'keras.src.optimizers.LegacyOptimizerWarning'>, 'Ftrl': <class 'keras.src.optimizers.LegacyOptimizerWarning'>, 'Optimizer': <class 'keras.src.optimizers.LegacyOptimizerWarning'>, 'RMSprop': <class 'keras.src.optimizers.LegacyOptimizerWarning'>, 'SGD': <class 'keras.src.optimizers.LegacyOptimizerWarning'>}
def from_json(dct: dict, ctx: turbo_broccoli.context.Context) -> Any:
388def from_json(dct: dict, ctx: Context) -> Any:
389    decoders = {
390        "keras.model": _json_to_model,  # must be first!
391        "keras.layer": _json_to_layer,
392        "keras.loss": _json_to_loss,
393        "keras.metric": _json_to_metric,
394        "keras.optimizer": _json_to_optimizer,
395    }
396    try:
397        type_name = dct["__type__"]
398        return decoders[type_name](dct, ctx)
399    except KeyError as exc:
400        raise DeserializationError() from exc
def to_json(obj: Any, ctx: turbo_broccoli.context.Context) -> dict:
403def to_json(obj: Any, ctx: Context) -> dict:
404    """
405    Serializes a tensorflow object into JSON by cases. See the README for the
406    precise list of supported types. Most keras object will simply be
407    serialized using `keras.utils.serialize_keras_object`. Here are the
408    exceptions:
409
410    - `keras.Model` (the model must have weights). If `TB_KERAS_FORMAT` is
411      `json`, the document will look like
412
413        ```py
414        {
415
416            "__type__": "keras.model",
417            "__version__": 5,
418            "loss": {...} or null,
419            "metrics": [...],
420            "model": {...},
421            "optimizer": {...} or null,
422            "weights": [...],
423        }
424        ```
425
426      if `TB_KERAS_FORMAT` is `h5` or `tf`, the document will look like
427
428        ```py
429        {
430
431            "__type__": "keras.model",
432            "__version__": 5,
433            "format": <str>,
434            "id": <uuid4>
435        }
436        ```
437
438      where `id` points to an artifact. Note that if the keras saving format is
439      `keras`, the artifact will have the `.keras` extension instead of the
440      usual `.tb`. Tensorflow/keras [forces this
441      behaviour](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model).
442
443    """
444    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
445        (keras.Model, _model_to_json),  # must be first
446        (keras.metrics.Metric, partial(_generic_to_json, type_="metric")),
447        (keras.layers.Layer, partial(_generic_to_json, type_="layer")),
448        (keras.losses.Loss, partial(_generic_to_json, type_="loss")),
449        (keras.optimizers.Optimizer, _optimizer_to_json),
450        (keras.optimizers.legacy.Optimizer, _optimizer_to_json),
451    ]
452    for t, f in encoders:
453        if isinstance(obj, t):
454            return f(obj, ctx)
455    raise TypeNotSupported()

Serializes a tensorflow object into JSON by cases. See the README for the precise list of supported types. Most keras object will simply be serialized using keras.utils.serialize_keras_object. Here are the exceptions:

  • keras.Model (the model must have weights). If TB_KERAS_FORMAT is json, the document will look like

    {
    
        "__type__": "keras.model",
        "__version__": 5,
        "loss": {...} or null,
        "metrics": [...],
        "model": {...},
        "optimizer": {...} or null,
        "weights": [...],
    }
    

    if TB_KERAS_FORMAT is h5 or tf, the document will look like

    {
    
        "__type__": "keras.model",
        "__version__": 5,
        "format": <str>,
        "id": <uuid4>
    }
    

    where id points to an artifact. Note that if the keras saving format is keras, the artifact will have the turbo_broccoli.custom.keras extension instead of the usual .tb. Tensorflow/keras turbo_broccoli.custom.tensorflow.org/api_docs/python/tf/keras/saving/save_model">forces this behaviour.