turbo_broccoli.native

"Native" saving utilities:

  • save takes a serializable/dumpable object and a path, and uses the file extension to choose the correct way to save the object;
  • load does the opposite.
  1"""
  2"Native" saving utilities:
  3* `turbo_broccoli.native.save` takes a serializable/dumpable object and
  4  a path, and uses the file extension to choose the correct way to save the
  5  object;
  6* `turbo_broccoli.native.load` does the opposite.
  7"""
  8
  9# pylint: disable=unused-argument
 10# pylint: disable=import-outside-toplevel
 11
 12from functools import partial
 13from pathlib import Path
 14from typing import Any, Callable
 15
 16try:
 17    import safetensors  # pylint: disable=unused-import
 18
 19    HAS_SAFETENSORS = True
 20except ModuleNotFoundError:
 21    HAS_SAFETENSORS = False
 22
 23from turbo_broccoli.custom import (
 24    HAS_KERAS,
 25    HAS_NUMPY,
 26    HAS_PANDAS,
 27    HAS_PYTORCH,
 28    HAS_TENSORFLOW,
 29)
 30from turbo_broccoli.turbo_broccoli import load_json, save_json
 31
 32
 33def _is_dict_of(obj: Any, value_type: type, key_type: type = str) -> bool:
 34    """Returns true if `obj` is a `dict[key_type, value_type]`"""
 35    return (
 36        isinstance(obj, dict)
 37        and all(isinstance(k, key_type) for k in obj.keys())
 38        and all(isinstance(v, value_type) for v in obj.values())
 39    )
 40
 41
 42def _load_csv(path: str | Path, **kwargs) -> Any:
 43    if not HAS_PANDAS:
 44        _raise_package_not_installed("pandas", "csv")
 45    import pandas as pd
 46
 47    df = pd.read_csv(path, **kwargs)
 48    if "Unnamed: 0" in df.columns:
 49        df.drop(["Unnamed: 0"], axis=1, inplace=True)
 50    return df
 51
 52
 53def _load_keras(path: str | Path, **kwargs) -> Any:
 54    if not HAS_KERAS:
 55        _raise_package_not_installed("keras", "keras")
 56
 57    import keras
 58
 59    return keras.saving.load_model(path, **kwargs)
 60
 61
 62def _load_np(path: str | Path, **kwargs) -> Any:
 63    if not HAS_NUMPY:
 64        _raise_package_not_installed("numpy", ".npy/.npz")
 65    import numpy as np
 66
 67    return np.load(path, **kwargs)
 68
 69
 70def _load_pq(path: str | Path, **kwargs) -> Any:
 71    if not HAS_PANDAS:
 72        _raise_package_not_installed("pandas", ".parquet/.pq")
 73    import pandas as pd
 74
 75    df = pd.read_parquet(path, **kwargs)
 76    return df
 77
 78
 79def _load_pt(path: str | Path, **kwargs) -> Any:
 80    if not HAS_PYTORCH:
 81        _raise_package_not_installed("torch", "pt")
 82    import torch
 83
 84    return torch.load(path, **kwargs)
 85
 86
 87def _load_st(path: str | Path, **kwargs) -> Any:
 88    if not HAS_SAFETENSORS:
 89        _raise_package_not_installed("safetensors", ".safetensors/.st")
 90    from safetensors import numpy as st
 91
 92    return st.load_file(path, **kwargs)
 93
 94
 95def _raise_package_not_installed(package_name: str, extension: str):
 96    """
 97    Raises a `RuntimeError` with a templated error message
 98
 99    Args:
100        package_name (str): e.g. "numpy"
101        extension (str): e.g. "npy"
102    """
103    if extension[0] != ".":
104        extension = "." + extension
105    raise RuntimeError(
106        f"Cannot create or load `{extension}` file because {package_name} is "
107        f"not installed. You can install {package_name} by running "
108        f"python3 -m pip install {package_name}"
109    )
110
111
112def _raise_wrong_type(path: str | Path, obj_needs_to_be_a: str):
113    """
114    Raises a `TypeError` with a templated error message
115
116    Args:
117        path (str | Path): Path where the file should have been saved
118        extension (str): "pandas DataFrame or Series"
119    """
120    raise TypeError(
121        f"Could not save object to '{path}': object needs to be a "
122        + obj_needs_to_be_a
123    )
124
125
126def _save_csv(obj: Any, path: str | Path, **kwargs) -> None:
127    if not HAS_PANDAS:
128        _raise_package_not_installed("pandas", "csv")
129    import pandas as pd
130
131    if not isinstance(obj, (pd.DataFrame, pd.Series)):
132        _raise_wrong_type(path, "pandas DataFrame or Series")
133    obj.to_csv(path, **kwargs)
134
135
136def _save_keras(obj: Any, path: str | Path, **kwargs) -> None:
137    if not HAS_KERAS:
138        _raise_package_not_installed("keras", "keras")
139
140    import keras
141
142    if not isinstance(obj, keras.Model):
143        _raise_wrong_type(path, "keras model")
144    keras.saving.save_model(obj, path, **kwargs)
145
146
147def _save_npy(obj: Any, path: str | Path, **kwargs) -> None:
148    if not HAS_NUMPY:
149        _raise_package_not_installed("numpy", "npy")
150    import numpy as np
151
152    if not isinstance(obj, np.ndarray):
153        _raise_wrong_type(path, "numpy array")
154    np.save(str(path), obj, **kwargs)
155
156
157def _save_npz(obj: Any, path: str | Path, **kwargs) -> None:
158    if not HAS_NUMPY:
159        _raise_package_not_installed("numpy", "npz")
160    import numpy as np
161
162    if not _is_dict_of(obj, np.ndarray):
163        _raise_wrong_type(path, "dict of numpy arrays")
164    np.savez(str(path), **obj, **kwargs)
165
166
167def _save_pq(obj: Any, path: str | Path, **kwargs) -> None:
168    if not HAS_PANDAS:
169        _raise_package_not_installed("pandas", ".parquet/.pq")
170    import pandas as pd
171
172    if not isinstance(obj, pd.DataFrame):
173        _raise_wrong_type(path, "pandas DataFrame")
174    obj.to_parquet(path, **kwargs)
175
176
177def _save_pt(obj: Any, path: str | Path, **kwargs) -> None:
178    if not HAS_PYTORCH:
179        _raise_package_not_installed("torch", "pt")
180    import torch
181
182    if not (isinstance(obj, torch.Tensor) or _is_dict_of(obj, torch.Tensor)):
183        _raise_wrong_type(path, "torch tensor or a dict of torch tensors")
184    torch.save(obj, path, **kwargs)
185
186
187def _save_st(obj: Any, path: str | Path, **kwargs) -> None:
188    if not HAS_SAFETENSORS:
189        _raise_package_not_installed("safetensors", ".safetensors/.st")
190    import safetensors  # pylint: disable=redefined-outer-name
191
192    if HAS_NUMPY:
193        import numpy as np
194
195        if _is_dict_of(obj, np.ndarray):
196            safetensors.numpy.save_file(obj, str(path), **kwargs)
197            return
198
199    if HAS_TENSORFLOW:
200        import tensorflow as tf
201
202        if _is_dict_of(obj, tf.Tensor):
203            safetensors.tensorflow.save_file(obj, str(path), **kwargs)
204            return
205
206    if HAS_PYTORCH:
207        import torch
208
209        if _is_dict_of(obj, torch.Tensor):
210            safetensors.torch.save_file(obj, str(path), **kwargs)
211            return
212
213    raise _raise_wrong_type(
214        path,
215        "dict of numpy arrays, a dict of tensorflow tensors, or a dict of "
216        "pytorch tensors",
217    )
218
219
220def load(path: str | Path, **kwargs) -> Any:
221    """
222    Loads an object from a file using format-specific (or "native") methods.
223    See `turbo_broccoli.native.save` for the list of supported file extensions.
224
225    Warning:
226        Safetensors files (`.st` or `.safetensors`) will be loaded as dicts of
227        numpy arrays even of the object was originally a dict of e.g. torch
228        tensors.
229    """
230    extension = Path(path).suffix
231    methods: dict[str, Callable[[str | Path], Any]] = {
232        ".csv": _load_csv,
233        ".h5": _load_keras,
234        ".keras": _load_keras,
235        ".npy": _load_np,
236        ".npz": _load_np,
237        ".parquet": _load_pq,
238        ".pq": _load_pq,
239        ".pt": _load_pt,
240        ".st": _load_st,
241        ".tf": _load_keras,
242    }
243    method: Callable = methods.get(extension, load_json)
244    return method(path, **kwargs)
245
246
247def save(obj: Any, path: str | Path, **kwargs) -> None:
248    """
249    Saves an object using the file extension of `path` to determine the
250    serialization/dumping method:
251
252    * `.csv`:
253      [`pandas.DataFrame.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_csv.html)
254      or
255      [`pandas.Series.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.Series.to_csv.html)
256    * `.h5`:
257      [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model)
258      with `save_format="h5"`
259    * `.keras`:
260      [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model)
261      with `save_format="keras"`
262    * `.npy`:
263        [`numpy.save`](https://numpy.org/doc/stable/reference/generated/numpy.save.html)
264    * `.npz`:
265        [`numpy.savez`](https://numpy.org/doc/stable/reference/generated/numpy.savez.html)
266    * `.pq`, `.parquet`:
267      [`pandas.DataFrame.to_parquet`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_parquet.html)
268    * `.pt`:
269      [`torch.save`](https://pytorch.org/docs/stable/generated/torch.save.html)
270    * `.safetensors`, `.st`: (for numpy arrays, pytorch tensors and tensorflow tensors)
271      [safetensors](https://huggingface.co/docs/safetensors/index)
272    * `.tf`:
273      [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model)
274      with `save_format="tf"`
275    * `.json` and anything else: just forwarded to `turbo_broccoli.save_json`
276
277    Args:
278        obj (Any):
279        path (str | Path):
280        kwargs: Passed to the serialization method
281    """
282    extension = Path(path).suffix
283    methods: dict[str, Callable[[Any, str | Path], None]] = {
284        ".csv": _save_csv,
285        ".h5": partial(_save_keras, save_format="h5"),
286        ".keras": partial(_save_keras, save_format="keras"),
287        ".npy": _save_npy,
288        ".npz": _save_npz,
289        ".parquet": _save_pq,
290        ".pq": _save_pq,
291        ".pt": _save_pt,
292        ".st": _save_st,
293        ".tf": partial(_save_keras, save_format="tf"),
294    }
295    method = methods.get(extension, save_json)
296    method(obj, path, **kwargs)
def load(path: str | pathlib.Path, **kwargs) -> Any:
221def load(path: str | Path, **kwargs) -> Any:
222    """
223    Loads an object from a file using format-specific (or "native") methods.
224    See `turbo_broccoli.native.save` for the list of supported file extensions.
225
226    Warning:
227        Safetensors files (`.st` or `.safetensors`) will be loaded as dicts of
228        numpy arrays even of the object was originally a dict of e.g. torch
229        tensors.
230    """
231    extension = Path(path).suffix
232    methods: dict[str, Callable[[str | Path], Any]] = {
233        ".csv": _load_csv,
234        ".h5": _load_keras,
235        ".keras": _load_keras,
236        ".npy": _load_np,
237        ".npz": _load_np,
238        ".parquet": _load_pq,
239        ".pq": _load_pq,
240        ".pt": _load_pt,
241        ".st": _load_st,
242        ".tf": _load_keras,
243    }
244    method: Callable = methods.get(extension, load_json)
245    return method(path, **kwargs)

Loads an object from a file using format-specific (or "native") methods. See save for the list of supported file extensions.

Warning: Safetensors files (.st or .safetensors) will be loaded as dicts of numpy arrays even of the object was originally a dict of e.g. torch tensors.

def save(obj: Any, path: str | pathlib.Path, **kwargs) -> None:
248def save(obj: Any, path: str | Path, **kwargs) -> None:
249    """
250    Saves an object using the file extension of `path` to determine the
251    serialization/dumping method:
252
253    * `.csv`:
254      [`pandas.DataFrame.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_csv.html)
255      or
256      [`pandas.Series.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.Series.to_csv.html)
257    * `.h5`:
258      [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model)
259      with `save_format="h5"`
260    * `.keras`:
261      [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model)
262      with `save_format="keras"`
263    * `.npy`:
264        [`numpy.save`](https://numpy.org/doc/stable/reference/generated/numpy.save.html)
265    * `.npz`:
266        [`numpy.savez`](https://numpy.org/doc/stable/reference/generated/numpy.savez.html)
267    * `.pq`, `.parquet`:
268      [`pandas.DataFrame.to_parquet`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_parquet.html)
269    * `.pt`:
270      [`torch.save`](https://pytorch.org/docs/stable/generated/torch.save.html)
271    * `.safetensors`, `.st`: (for numpy arrays, pytorch tensors and tensorflow tensors)
272      [safetensors](https://huggingface.co/docs/safetensors/index)
273    * `.tf`:
274      [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model)
275      with `save_format="tf"`
276    * `.json` and anything else: just forwarded to `turbo_broccoli.save_json`
277
278    Args:
279        obj (Any):
280        path (str | Path):
281        kwargs: Passed to the serialization method
282    """
283    extension = Path(path).suffix
284    methods: dict[str, Callable[[Any, str | Path], None]] = {
285        ".csv": _save_csv,
286        ".h5": partial(_save_keras, save_format="h5"),
287        ".keras": partial(_save_keras, save_format="keras"),
288        ".npy": _save_npy,
289        ".npz": _save_npz,
290        ".parquet": _save_pq,
291        ".pq": _save_pq,
292        ".pt": _save_pt,
293        ".st": _save_st,
294        ".tf": partial(_save_keras, save_format="tf"),
295    }
296    method = methods.get(extension, save_json)
297    method(obj, path, **kwargs)

Saves an object using the file extension of path to determine the serialization/dumping method:

Args: obj (Any): path (str | Path): kwargs: Passed to the serialization method