turbo_broccoli.custom.pytorch

Pytorch (de)serialization utilities.

  1"""Pytorch (de)serialization utilities."""
  2
  3from typing import Any, Callable, Tuple
  4
  5import safetensors.torch as st
  6from torch import Tensor
  7from torch.nn import Module
  8from torch.utils.data import ConcatDataset, StackDataset, Subset, TensorDataset
  9
 10from turbo_broccoli.context import Context
 11from turbo_broccoli.exceptions import DeserializationError, TypeNotSupported
 12
 13
 14def _concatdataset_to_json(obj: ConcatDataset, ctx: Context) -> dict:
 15    return {
 16        "__type__": "pytorch.concatdataset",
 17        "__version__": 1,
 18        "datasets": obj.datasets,
 19    }
 20
 21
 22def _json_to_concatdataset(dct: dict, ctx: Context) -> ConcatDataset:
 23    decoders = {1: _json_to_concatdataset_v1}
 24    return decoders[dct["__version__"]](dct, ctx)
 25
 26
 27def _json_to_module(dct: dict, ctx: Context) -> Module:
 28    ctx.raise_if_nodecode("bytes")
 29    decoders = {
 30        3: _json_to_module_v3,
 31    }
 32    return decoders[dct["__version__"]](dct, ctx)
 33
 34
 35def _json_to_module_v3(dct: dict, ctx: Context) -> Module:
 36    parts = dct["__type__"].split(".")
 37    type_name = ".".join(parts[2:])  # remove "pytorch.module." prefix
 38    module: Module = ctx.pytorch_module_types[type_name]()
 39    state = st.load(dct["state"])
 40    module.load_state_dict(state)
 41    return module
 42
 43
 44def _json_to_concatdataset_v1(dct: dict, ctx: Context) -> ConcatDataset:
 45    return ConcatDataset(dct["datasets"])
 46
 47
 48def _json_to_stackdataset(dct: dict, ctx: Context) -> StackDataset:
 49    decoders = {1: _json_to_stackdataset_v1}
 50    return decoders[dct["__version__"]](dct, ctx)
 51
 52
 53def _json_to_stackdataset_v1(dct: dict, ctx: Context) -> StackDataset:
 54    d = dct["datasets"]
 55    if isinstance(d, dict):
 56        return StackDataset(**d)
 57    return StackDataset(*d)
 58
 59
 60def _json_to_subset(dct: dict, ctx: Context) -> Subset:
 61    decoders = {1: _json_to_subset_v1}
 62    return decoders[dct["__version__"]](dct, ctx)
 63
 64
 65def _json_to_subset_v1(dct: dict, ctx: Context) -> Subset:
 66    return Subset(dct["dataset"], dct["indices"])
 67
 68
 69def _json_to_tensor(dct: dict, ctx: Context) -> Tensor:
 70    ctx.raise_if_nodecode("bytes")
 71    decoders = {
 72        3: _json_to_tensor_v3,
 73    }
 74    return decoders[dct["__version__"]](dct, ctx)
 75
 76
 77def _json_to_tensor_v3(dct: dict, ctx: Context) -> Tensor:
 78    data = dct["data"]
 79    return Tensor() if data is None else st.load(data)["data"]
 80
 81
 82def _json_to_tensordataset(dct: dict, ctx: Context) -> TensorDataset:
 83    decoders = {1: _json_to_tensordataset_v1}
 84    return decoders[dct["__version__"]](dct, ctx)
 85
 86
 87def _json_to_tensordataset_v1(dct: dict, ctx: Context) -> TensorDataset:
 88    return TensorDataset(*dct["tensors"])
 89
 90
 91def _module_to_json(module: Module, ctx: Context) -> dict:
 92    return {
 93        "__type__": "pytorch.module." + module.__class__.__name__,
 94        "__version__": 3,
 95        "state": st.save(module.state_dict()),
 96    }
 97
 98
 99def _stackdataset_to_json(obj: StackDataset, ctx: Context) -> dict:
100    return {
101        "__type__": "pytorch.stackdataset",
102        "__version__": 1,
103        "datasets": obj.datasets,
104    }
105
106
107def _subset_to_json(obj: Subset, ctx: Context) -> dict:
108    return {
109        "__type__": "pytorch.subset",
110        "__version__": 1,
111        "dataset": obj.dataset,
112        "indices": obj.indices,
113    }
114
115
116def _tensor_to_json(tens: Tensor, ctx: Context) -> dict:
117    x = tens.detach().cpu().contiguous()
118    return {
119        "__type__": "pytorch.tensor",
120        "__version__": 3,
121        "data": st.save({"data": x}) if x.numel() > 0 else None,
122    }
123
124
125def _tensordataset_to_json(obj: TensorDataset, ctx: Context) -> dict:
126    return {
127        "__type__": "pytorch.tensordataset",
128        "__version__": 1,
129        "tensors": obj.tensors,
130    }
131
132
133# pylint: disable=missing-function-docstring
134def from_json(dct: dict, ctx: Context) -> Any:
135    decoders = {
136        "pytorch.concatdataset": _json_to_concatdataset,
137        "pytorch.stackdataset": _json_to_stackdataset,
138        "pytorch.subset": _json_to_subset,
139        "pytorch.tensor": _json_to_tensor,
140        "pytorch.tensordataset": _json_to_tensordataset,
141    }
142    try:
143        type_name = dct["__type__"]
144        if type_name.startswith("pytorch.module."):
145            return _json_to_module(dct, ctx)
146        return decoders[type_name](dct, ctx)
147    except KeyError as exc:
148        raise DeserializationError() from exc
149
150
151def to_json(obj: Any, ctx: Context) -> dict:
152    """
153    Serializes a tensor into JSON by cases. See the README for the precise list
154    of supported types. The return dict has the following structure:
155
156    - Tensor:
157
158        ```py
159        {
160            "__type__": "pytorch.tensor",
161            "__version__": 3,
162            "data": {
163                "__type__": "bytes",
164                ...
165            },
166        }
167        ```
168
169      see `turbo_broccoli.custom.bytes.to_json`.
170
171    - Module:
172
173        ```py
174        {
175            "__type__": "pytorch.module.<class name>",
176            "__version__": 3,
177            "state": {
178                "__type__": "bytes",
179                ...
180            },
181        }
182        ```
183
184      see `turbo_broccoli.custom.bytes.to_json`.
185
186    """
187    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
188        (Module, _module_to_json),
189        (Tensor, _tensor_to_json),
190        (ConcatDataset, _concatdataset_to_json),
191        (StackDataset, _stackdataset_to_json),
192        (Subset, _subset_to_json),
193        (TensorDataset, _tensordataset_to_json),
194    ]
195    for t, f in encoders:
196        if isinstance(obj, t):
197            return f(obj, ctx)
198    raise TypeNotSupported()
def from_json(dct: dict, ctx: turbo_broccoli.context.Context) -> Any:
135def from_json(dct: dict, ctx: Context) -> Any:
136    decoders = {
137        "pytorch.concatdataset": _json_to_concatdataset,
138        "pytorch.stackdataset": _json_to_stackdataset,
139        "pytorch.subset": _json_to_subset,
140        "pytorch.tensor": _json_to_tensor,
141        "pytorch.tensordataset": _json_to_tensordataset,
142    }
143    try:
144        type_name = dct["__type__"]
145        if type_name.startswith("pytorch.module."):
146            return _json_to_module(dct, ctx)
147        return decoders[type_name](dct, ctx)
148    except KeyError as exc:
149        raise DeserializationError() from exc
def to_json(obj: Any, ctx: turbo_broccoli.context.Context) -> dict:
152def to_json(obj: Any, ctx: Context) -> dict:
153    """
154    Serializes a tensor into JSON by cases. See the README for the precise list
155    of supported types. The return dict has the following structure:
156
157    - Tensor:
158
159        ```py
160        {
161            "__type__": "pytorch.tensor",
162            "__version__": 3,
163            "data": {
164                "__type__": "bytes",
165                ...
166            },
167        }
168        ```
169
170      see `turbo_broccoli.custom.bytes.to_json`.
171
172    - Module:
173
174        ```py
175        {
176            "__type__": "pytorch.module.<class name>",
177            "__version__": 3,
178            "state": {
179                "__type__": "bytes",
180                ...
181            },
182        }
183        ```
184
185      see `turbo_broccoli.custom.bytes.to_json`.
186
187    """
188    encoders: list[Tuple[type, Callable[[Any, Context], dict]]] = [
189        (Module, _module_to_json),
190        (Tensor, _tensor_to_json),
191        (ConcatDataset, _concatdataset_to_json),
192        (StackDataset, _stackdataset_to_json),
193        (Subset, _subset_to_json),
194        (TensorDataset, _tensordataset_to_json),
195    ]
196    for t, f in encoders:
197        if isinstance(obj, t):
198            return f(obj, ctx)
199    raise TypeNotSupported()

Serializes a tensor into JSON by cases. See the README for the precise list of supported types. The return dict has the following structure: