turbo_broccoli.native
"Native" saving utilities:
1""" 2"Native" saving utilities: 3* `turbo_broccoli.native.save` takes a serializable/dumpable object and 4 a path, and uses the file extension to choose the correct way to save the 5 object; 6* `turbo_broccoli.native.load` does the opposite. 7""" 8 9# pylint: disable=unused-argument 10# pylint: disable=import-outside-toplevel 11 12from functools import partial 13from pathlib import Path 14from typing import Any, Callable 15 16try: 17 import safetensors # pylint: disable=unused-import 18 19 HAS_SAFETENSORS = True 20except ModuleNotFoundError: 21 HAS_SAFETENSORS = False 22 23from turbo_broccoli.custom import ( 24 HAS_KERAS, 25 HAS_NUMPY, 26 HAS_PANDAS, 27 HAS_PYTORCH, 28 HAS_TENSORFLOW, 29) 30from turbo_broccoli.turbo_broccoli import load_json, save_json 31 32 33def _is_dict_of(obj: Any, value_type: type, key_type: type = str) -> bool: 34 """Returns true if `obj` is a `dict[key_type, value_type]`""" 35 return ( 36 isinstance(obj, dict) 37 and all(isinstance(k, key_type) for k in obj.keys()) 38 and all(isinstance(v, value_type) for v in obj.values()) 39 ) 40 41 42def _load_csv(path: str | Path, **kwargs) -> Any: 43 if not HAS_PANDAS: 44 _raise_package_not_installed("pandas", "csv") 45 import pandas as pd 46 47 df = pd.read_csv(path, **kwargs) 48 if "Unnamed: 0" in df.columns: 49 df.drop(["Unnamed: 0"], axis=1, inplace=True) 50 return df 51 52 53def _load_keras(path: str | Path, **kwargs) -> Any: 54 if not HAS_KERAS: 55 _raise_package_not_installed("keras", "keras") 56 57 import keras 58 59 return keras.saving.load_model(path, **kwargs) 60 61 62def _load_np(path: str | Path, **kwargs) -> Any: 63 if not HAS_NUMPY: 64 _raise_package_not_installed("numpy", ".npy/.npz") 65 import numpy as np 66 67 return np.load(path, **kwargs) 68 69 70def _load_pq(path: str | Path, **kwargs) -> Any: 71 if not HAS_PANDAS: 72 _raise_package_not_installed("pandas", ".parquet/.pq") 73 import pandas as pd 74 75 df = pd.read_parquet(path, **kwargs) 76 return df 77 78 79def _load_pt(path: str | Path, **kwargs) -> Any: 80 if not HAS_PYTORCH: 81 _raise_package_not_installed("torch", "pt") 82 import torch 83 84 return torch.load(path, **kwargs) 85 86 87def _load_st(path: str | Path, **kwargs) -> Any: 88 if not HAS_SAFETENSORS: 89 _raise_package_not_installed("safetensors", ".safetensors/.st") 90 from safetensors import numpy as st 91 92 return st.load_file(path, **kwargs) 93 94 95def _raise_package_not_installed(package_name: str, extension: str): 96 """ 97 Raises a `RuntimeError` with a templated error message 98 99 Args: 100 package_name (str): e.g. "numpy" 101 extension (str): e.g. "npy" 102 """ 103 if extension[0] != ".": 104 extension = "." + extension 105 raise RuntimeError( 106 f"Cannot create or load `{extension}` file because {package_name} is " 107 f"not installed. You can install {package_name} by running " 108 f"python3 -m pip install {package_name}" 109 ) 110 111 112def _raise_wrong_type(path: str | Path, obj_needs_to_be_a: str): 113 """ 114 Raises a `TypeError` with a templated error message 115 116 Args: 117 path (str | Path): Path where the file should have been saved 118 extension (str): "pandas DataFrame or Series" 119 """ 120 raise TypeError( 121 f"Could not save object to '{path}': object needs to be a " 122 + obj_needs_to_be_a 123 ) 124 125 126def _save_csv(obj: Any, path: str | Path, **kwargs) -> None: 127 if not HAS_PANDAS: 128 _raise_package_not_installed("pandas", "csv") 129 import pandas as pd 130 131 if not isinstance(obj, (pd.DataFrame, pd.Series)): 132 _raise_wrong_type(path, "pandas DataFrame or Series") 133 obj.to_csv(path, **kwargs) 134 135 136def _save_keras(obj: Any, path: str | Path, **kwargs) -> None: 137 if not HAS_KERAS: 138 _raise_package_not_installed("keras", "keras") 139 140 import keras 141 142 if not isinstance(obj, keras.Model): 143 _raise_wrong_type(path, "keras model") 144 keras.saving.save_model(obj, path, **kwargs) 145 146 147def _save_npy(obj: Any, path: str | Path, **kwargs) -> None: 148 if not HAS_NUMPY: 149 _raise_package_not_installed("numpy", "npy") 150 import numpy as np 151 152 if not isinstance(obj, np.ndarray): 153 _raise_wrong_type(path, "numpy array") 154 np.save(str(path), obj, **kwargs) 155 156 157def _save_npz(obj: Any, path: str | Path, **kwargs) -> None: 158 if not HAS_NUMPY: 159 _raise_package_not_installed("numpy", "npz") 160 import numpy as np 161 162 if not _is_dict_of(obj, np.ndarray): 163 _raise_wrong_type(path, "dict of numpy arrays") 164 np.savez(str(path), **obj, **kwargs) 165 166 167def _save_pq(obj: Any, path: str | Path, **kwargs) -> None: 168 if not HAS_PANDAS: 169 _raise_package_not_installed("pandas", ".parquet/.pq") 170 import pandas as pd 171 172 if not isinstance(obj, pd.DataFrame): 173 _raise_wrong_type(path, "pandas DataFrame") 174 obj.to_parquet(path, **kwargs) 175 176 177def _save_pt(obj: Any, path: str | Path, **kwargs) -> None: 178 if not HAS_PYTORCH: 179 _raise_package_not_installed("torch", "pt") 180 import torch 181 182 if not (isinstance(obj, torch.Tensor) or _is_dict_of(obj, torch.Tensor)): 183 _raise_wrong_type(path, "torch tensor or a dict of torch tensors") 184 torch.save(obj, path, **kwargs) 185 186 187def _save_st(obj: Any, path: str | Path, **kwargs) -> None: 188 if not HAS_SAFETENSORS: 189 _raise_package_not_installed("safetensors", ".safetensors/.st") 190 import safetensors # pylint: disable=redefined-outer-name 191 192 if HAS_NUMPY: 193 import numpy as np 194 195 if _is_dict_of(obj, np.ndarray): 196 safetensors.numpy.save_file(obj, str(path), **kwargs) 197 return 198 199 if HAS_TENSORFLOW: 200 import tensorflow as tf 201 202 if _is_dict_of(obj, tf.Tensor): 203 safetensors.tensorflow.save_file(obj, str(path), **kwargs) 204 return 205 206 if HAS_PYTORCH: 207 import torch 208 209 if _is_dict_of(obj, torch.Tensor): 210 safetensors.torch.save_file(obj, str(path), **kwargs) 211 return 212 213 raise _raise_wrong_type( 214 path, 215 "dict of numpy arrays, a dict of tensorflow tensors, or a dict of " 216 "pytorch tensors", 217 ) 218 219 220def load(path: str | Path, **kwargs) -> Any: 221 """ 222 Loads an object from a file using format-specific (or "native") methods. 223 See `turbo_broccoli.native.save` for the list of supported file extensions. 224 225 Warning: 226 Safetensors files (`.st` or `.safetensors`) will be loaded as dicts of 227 numpy arrays even of the object was originally a dict of e.g. torch 228 tensors. 229 """ 230 extension = Path(path).suffix 231 methods: dict[str, Callable[[str | Path], Any]] = { 232 ".csv": _load_csv, 233 ".h5": _load_keras, 234 ".keras": _load_keras, 235 ".npy": _load_np, 236 ".npz": _load_np, 237 ".parquet": _load_pq, 238 ".pq": _load_pq, 239 ".pt": _load_pt, 240 ".st": _load_st, 241 ".tf": _load_keras, 242 } 243 method: Callable = methods.get(extension, load_json) 244 return method(path, **kwargs) 245 246 247def save(obj: Any, path: str | Path, **kwargs) -> None: 248 """ 249 Saves an object using the file extension of `path` to determine the 250 serialization/dumping method: 251 252 * `.csv`: 253 [`pandas.DataFrame.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_csv.html) 254 or 255 [`pandas.Series.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.Series.to_csv.html) 256 * `.h5`: 257 [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model) 258 with `save_format="h5"` 259 * `.keras`: 260 [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model) 261 with `save_format="keras"` 262 * `.npy`: 263 [`numpy.save`](https://numpy.org/doc/stable/reference/generated/numpy.save.html) 264 * `.npz`: 265 [`numpy.savez`](https://numpy.org/doc/stable/reference/generated/numpy.savez.html) 266 * `.pq`, `.parquet`: 267 [`pandas.DataFrame.to_parquet`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_parquet.html) 268 * `.pt`: 269 [`torch.save`](https://pytorch.org/docs/stable/generated/torch.save.html) 270 * `.safetensors`, `.st`: (for numpy arrays, pytorch tensors and tensorflow tensors) 271 [safetensors](https://huggingface.co/docs/safetensors/index) 272 * `.tf`: 273 [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model) 274 with `save_format="tf"` 275 * `.json` and anything else: just forwarded to `turbo_broccoli.save_json` 276 277 Args: 278 obj (Any): 279 path (str | Path): 280 kwargs: Passed to the serialization method 281 """ 282 extension = Path(path).suffix 283 methods: dict[str, Callable[[Any, str | Path], None]] = { 284 ".csv": _save_csv, 285 ".h5": partial(_save_keras, save_format="h5"), 286 ".keras": partial(_save_keras, save_format="keras"), 287 ".npy": _save_npy, 288 ".npz": _save_npz, 289 ".parquet": _save_pq, 290 ".pq": _save_pq, 291 ".pt": _save_pt, 292 ".st": _save_st, 293 ".tf": partial(_save_keras, save_format="tf"), 294 } 295 method = methods.get(extension, save_json) 296 method(obj, path, **kwargs)
def
load(path: str | pathlib.Path, **kwargs) -> Any:
221def load(path: str | Path, **kwargs) -> Any: 222 """ 223 Loads an object from a file using format-specific (or "native") methods. 224 See `turbo_broccoli.native.save` for the list of supported file extensions. 225 226 Warning: 227 Safetensors files (`.st` or `.safetensors`) will be loaded as dicts of 228 numpy arrays even of the object was originally a dict of e.g. torch 229 tensors. 230 """ 231 extension = Path(path).suffix 232 methods: dict[str, Callable[[str | Path], Any]] = { 233 ".csv": _load_csv, 234 ".h5": _load_keras, 235 ".keras": _load_keras, 236 ".npy": _load_np, 237 ".npz": _load_np, 238 ".parquet": _load_pq, 239 ".pq": _load_pq, 240 ".pt": _load_pt, 241 ".st": _load_st, 242 ".tf": _load_keras, 243 } 244 method: Callable = methods.get(extension, load_json) 245 return method(path, **kwargs)
Loads an object from a file using format-specific (or "native") methods.
See save
for the list of supported file extensions.
Warning:
Safetensors files (.st
or .safetensors
) will be loaded as dicts of
numpy arrays even of the object was originally a dict of e.g. torch
tensors.
def
save(obj: Any, path: str | pathlib.Path, **kwargs) -> None:
248def save(obj: Any, path: str | Path, **kwargs) -> None: 249 """ 250 Saves an object using the file extension of `path` to determine the 251 serialization/dumping method: 252 253 * `.csv`: 254 [`pandas.DataFrame.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_csv.html) 255 or 256 [`pandas.Series.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.Series.to_csv.html) 257 * `.h5`: 258 [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model) 259 with `save_format="h5"` 260 * `.keras`: 261 [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model) 262 with `save_format="keras"` 263 * `.npy`: 264 [`numpy.save`](https://numpy.org/doc/stable/reference/generated/numpy.save.html) 265 * `.npz`: 266 [`numpy.savez`](https://numpy.org/doc/stable/reference/generated/numpy.savez.html) 267 * `.pq`, `.parquet`: 268 [`pandas.DataFrame.to_parquet`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_parquet.html) 269 * `.pt`: 270 [`torch.save`](https://pytorch.org/docs/stable/generated/torch.save.html) 271 * `.safetensors`, `.st`: (for numpy arrays, pytorch tensors and tensorflow tensors) 272 [safetensors](https://huggingface.co/docs/safetensors/index) 273 * `.tf`: 274 [`tf.keras.saving.save_model`](https://www.tensorflow.org/api_docs/python/tf/keras/saving/save_model) 275 with `save_format="tf"` 276 * `.json` and anything else: just forwarded to `turbo_broccoli.save_json` 277 278 Args: 279 obj (Any): 280 path (str | Path): 281 kwargs: Passed to the serialization method 282 """ 283 extension = Path(path).suffix 284 methods: dict[str, Callable[[Any, str | Path], None]] = { 285 ".csv": _save_csv, 286 ".h5": partial(_save_keras, save_format="h5"), 287 ".keras": partial(_save_keras, save_format="keras"), 288 ".npy": _save_npy, 289 ".npz": _save_npz, 290 ".parquet": _save_pq, 291 ".pq": _save_pq, 292 ".pt": _save_pt, 293 ".st": _save_st, 294 ".tf": partial(_save_keras, save_format="tf"), 295 } 296 method = methods.get(extension, save_json) 297 method(obj, path, **kwargs)
Saves an object using the file extension of path
to determine the
serialization/dumping method:
.csv
:pandas.DataFrame.to_csv
orpandas.Series.to_csv
.h5
:tf.keras.saving.save_model
withsave_format="h5"
.keras
:tf.keras.saving.save_model
withsave_format="keras"
.npy
:numpy.save
.npz
:numpy.savez
.pq
,.parquet
:pandas.DataFrame.to_parquet
.pt
:torch.save
.safetensors
,.st
: (for numpy arrays, pytorch tensors and tensorflow tensors) safetensors.tf
:tf.keras.saving.save_model
withsave_format="tf"
.json
and anything else: just forwarded toturbo_broccoli.save_json
Args: obj (Any): path (str | Path): kwargs: Passed to the serialization method