• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

hgrecco / serialize / 7080427447

04 Dec 2023 12:33AM UTC coverage: 83.051% (-5.8%) from 88.857%
7080427447

push

github

hgrecco
Fix README testing

8 of 8 new or added lines in 1 file covered. (100.0%)

34 existing lines in 9 files now uncovered.

343 of 413 relevant lines covered (83.05%)

3.32 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.33
/serialize/all.py
1
# -*- coding: utf-8 -*-
2
"""
4✔
3
    serialize.all
4
    ~~~~~~~~~~~~~
5

6
    Common routines for serialization and deserialization.
7

8
    :copyright: (c) 2016 by Hernan E. Grecco.
9
    :license: BSD, see LICENSE for more details.
10
"""
11

12

13
import pathlib
4✔
14
from collections import namedtuple
4✔
15
from io import BytesIO
4✔
16

17
#: Stores the functions to convert custom classes to and from builtin types.
18
ClassHelper = namedtuple("ClassHelper", "to_builtin from_builtin")
4✔
19

20
#: Stores information and function about each format type.
21
Format = namedtuple("Format", "extension dump dumps load loads register_class")
4✔
22
UnavailableFormat = namedtuple("UnavailableFormat", "extension msg")
4✔
23

24
#: Map unavailable formats to the corresponding error message.
25
# :type: str -> UnavailableFormat
26
UNAVAILABLE_FORMATS = {}
4✔
27

28
#: Map available format names to the corresponding dumper and loader.
29
# :type: str -> Format
30
FORMATS = {}
4✔
31

32
#: Map extension to format name.
33
# :type: str -> str
34
FORMAT_BY_EXTENSION = {}
4✔
35

36
#: Map registered classes to the corresponding to_builtin and from_builtin.
37
# :type: type -> ClassHelper
38
CLASSES = {}
4✔
39

40
#: Map class name obtained from str(class) to class.
41
#: :type: str -> ClassHelper
42
CLASSES_BY_NAME = {}
4✔
43

44

45
def _get_format(fmt):
4✔
46
    """Convenience function to get the format information.
47

48
    Raises a nice error if the format is unavailable or unknown.
49
    """
50

51
    if fmt in FORMATS:
4✔
52
        return FORMATS[fmt]
4✔
53

54
    if fmt in UNAVAILABLE_FORMATS:
4✔
55
        raise ValueError(
×
56
            ("'%s' is an unavailable format. " % fmt) + UNAVAILABLE_FORMATS[fmt].msg
57
        )
58

59
    raise ValueError(
4✔
60
        "'%s' is an unknown format. Valid options are %s"
61
        % (fmt, ", ".join(FORMATS.keys()))
62
    )
63

64

65
def _get_format_from_ext(ext):
4✔
66
    """Convenience function to get the format information from a file extension.
67

68
    Raises a nice error if the extension is unknown.
69
    """
70

71
    ext = ext.lower()
4✔
72
    if ext in FORMAT_BY_EXTENSION:
4✔
73
        return FORMAT_BY_EXTENSION[ext]
4✔
74

75
    valid = ", ".join(FORMAT_BY_EXTENSION.keys())
4✔
76

77
    raise ValueError(
4✔
78
        "'%s' is an unknown extension. " "Valid options are %s" % (ext, valid)
79
    )
80

81

82
def encode_helper(obj, to_builtin):
4✔
83
    """Encode an object into a two element dict using a function
84
    that can convert it to a builtin data type.
85
    """
86

87
    return dict(__class_name__=str(obj.__class__), __dumped_obj__=to_builtin(obj))
4✔
88

89

90
def encode(obj, defaultfunc=None):
4✔
91
    """Encode registered types using the corresponding functions.
92
    For other types, the defaultfunc will be used
93
    """
94

95
    for klass, (to_builtin, _) in CLASSES.items():
4✔
96
        if isinstance(obj, klass):
4✔
97
            return encode_helper(obj, to_builtin)
4✔
98

99
    if defaultfunc is None:
4✔
100
        return obj
4✔
101

102
    return defaultfunc(obj)
×
103

104

105
def _traverse_dict_ec(obj, ef, td):
4✔
106
    return {
4✔
107
        traverse_and_encode(k, ef, td): traverse_and_encode(v, ef, td)
108
        for k, v in obj.items()
109
    }
110

111

112
def _traverse_list_ec(obj, ef, td):
4✔
113
    return [traverse_and_encode(el, ef, td) for el in obj]
4✔
114

115

116
def _traverse_tuple_ec(obj, ef, td):
4✔
117
    return tuple(traverse_and_encode(el, ef, td) for el in obj)
×
118

119

120
DEFAULT_TRAVERSE_EC = {
4✔
121
    dict: _traverse_dict_ec,
122
    list: _traverse_list_ec,
123
    tuple: _traverse_tuple_ec,
124
}
125

126

127
def traverse_and_encode(obj, encode_func=None, trav_dict=None):
4✔
128
    """Transverse a Python data structure encoding each element with encode_func.
129

130
    It is used with serialization packages that do not support
131
    custom types.
132

133
    `trav_dict` can be used to provide custom ways of traversing structures.
134
    """
135
    encode_func = encode_func or encode
4✔
136
    trav_dict = trav_dict or DEFAULT_TRAVERSE_EC
4✔
137
    for t, func in trav_dict.items():
4✔
138
        if isinstance(obj, t):
4✔
139
            value = func(obj, encode_func, trav_dict)
4✔
140
            break
4✔
141
    else:
142
        value = encode_func(obj)
4✔
143

144
    return value
4✔
145

146

147
def decode(dct, classes_by_name=None):
4✔
148
    """If the dict contains a __class__ and __serialized__ field tries to
149
    decode it using the registered classes within the encoder/decoder
150
    instance.
151
    """
152
    if not isinstance(dct, dict):
4✔
153
        return dct
×
154

155
    s = dct.get("__class_name__", None)
4✔
156
    if s is None:
4✔
157
        return dct
4✔
158

159
    classes_by_name = classes_by_name or CLASSES_BY_NAME
4✔
160
    try:
4✔
161
        _, from_builtin = classes_by_name[s]
4✔
162
        c = dct["__dumped_obj__"]
4✔
163
    except KeyError:
×
164
        return dct
×
165

166
    return from_builtin(c)
4✔
167

168

169
def _traverse_dict_dc(obj, df, td):
4✔
170
    if "__class_name__" in obj:
4✔
171
        return df(obj)
4✔
172

173
    return {
4✔
174
        traverse_and_decode(k, df, td): traverse_and_decode(v, df, td)
175
        for k, v in obj.items()
176
    }
177

178

179
def _traverse_list_dc(obj, df, td):
4✔
180
    return [traverse_and_decode(el, df, td) for el in obj]
4✔
181

182

183
def _traverse_tuple_dc(obj, df, td):
4✔
184
    return tuple(traverse_and_decode(el, df, td) for el in obj)
×
185

186

187
DEFAULT_TRAVERSE_DC = {
4✔
188
    dict: _traverse_dict_dc,
189
    list: _traverse_list_dc,
190
    tuple: _traverse_tuple_dc,
191
}
192

193

194
def traverse_and_decode(obj, decode_func=None, trav_dict=None):
4✔
195
    """Traverse an arbitrary Python object structure
196
    calling a callback function for every element in the structure,
197
    and inserting the return value of the callback as the new value.
198

199
    This is used for serialization with libraries that do not have object hooks.
200
    """
201
    decode_func = decode_func or decode
4✔
202
    trav_dict = trav_dict or DEFAULT_TRAVERSE_DC
4✔
203
    for t, func in trav_dict.items():
4✔
204
        if isinstance(obj, t):
4✔
205
            value = func(obj, decode_func, trav_dict)
4✔
206
            break
4✔
207
    else:
208
        value = obj
4✔
209

210
    return value
4✔
211

212

213
# A Sentinel for a missing argument.
214
MISSING = object()
4✔
215

216

217
def unregister_format(fmt):
4✔
218
    """Register an available serialization format."""
219
    del FORMATS[fmt]
4✔
220

221

222
def register_format(
4✔
223
    fmt,
224
    dumpser=None,
225
    loadser=None,
226
    dumper=None,
227
    loader=None,
228
    extension=MISSING,
229
    register_class=None,
230
):
231
    """Register an available serialization format.
232

233
    `fmt` is a unique string identifying the format, such as `json`. Use a colon (`:`)
234
    to separate between subformats.
235

236
    `dumpser` and `dumper` should be callables with the same purpose and arguments
237
    that `json.dumps` and `json.dump`. If one of those is missing, it will be
238
    generated automatically from the other.
239

240
    `loadser` and `loader` should be callables with the same purpose and arguments
241
    that `json.loads` and `json.load`. If one of those is missing, it will be
242
    generated automatically from the other.
243

244
    `extension` is the file extension used to guess the desired serialization format
245
    when loading from or dumping to a file. If not given, the part before the colon of
246
    `fmt` will be used. If `None`, the format will not be associated with any extension.
247

248
    `register_class` is a callback made when a class is registered with
249
    `serialize.register_class`. When a new format is registered,
250
    previously registered classes are called. It takes on argument, the
251
    class to register. See `serialize.yaml.py` for an example.
252
    """
253

254
    # For simplicity. We do not allow to overwrite format.
255
    if fmt in FORMATS:
4✔
256
        raise ValueError("%s is already defined." % fmt)
4✔
257

258
    # Here we generate register_class if it is not present
259
    if not register_class:
4✔
260

261
        def register_class(klass):
4✔
262
            pass
4✔
263

264
    # Here we generate dumper/dumpser if they are not present.
265
    if dumper and not dumpser:
4✔
266

267
        def dumpser(obj):
4✔
268
            buf = BytesIO()
4✔
269
            dumper(obj, buf)
4✔
270
            return buf.getvalue()
4✔
271

272
    elif not dumper and dumpser:
4✔
273

274
        def dumper(obj, fp):
4✔
275
            fp.write(dumpser(obj))
4✔
276

277
    elif not dumper and not dumpser:
4✔
278

279
        def raiser(*args, **kwargs):
4✔
UNCOV
280
            raise ValueError("dump/dumps is not defined for %s" % fmt)
×
281

282
        dumper = dumpser = raiser
4✔
283

284
    # Here we generate loader/loadser if they are not present.
285
    if loader and not loadser:
4✔
286

287
        def loadser(serialized):
4✔
288
            return loader(BytesIO(serialized))
4✔
289

290
    elif not loader and loadser:
4✔
291

292
        def loader(fp):
4✔
293
            return loadser(fp.read())
4✔
294

295
    elif not loader and not loadser:
4✔
296

297
        def raiser(*args, **kwargs):
4✔
UNCOV
298
            raise ValueError("load/loads is not defined for %s" % fmt)
×
299

300
        loader = loadser = raiser
4✔
301

302
    if extension is MISSING:
4✔
303
        extension = fmt.split(":", 1)[0]
4✔
304

305
    FORMATS[fmt] = Format(extension, dumper, dumpser, loader, loadser, register_class)
4✔
306

307
    if extension and extension not in FORMAT_BY_EXTENSION:
4✔
308
        FORMAT_BY_EXTENSION[extension.lower()] = fmt
4✔
309

310
    # register previously registered classes with the new format
311
    for klass in CLASSES:
4✔
312
        FORMATS[fmt].register_class(klass)
4✔
313

314

315
def register_unavailable(fmt, msg="", pkg="", extension=MISSING):
4✔
316
    """Register an unavailable serialization format.
317

318
    Unavailable formats are those known by Serialize but that cannot be used
319
    due to a missing requirement (e.g. the package that does the work).
320

321
    """
UNCOV
322
    if pkg:
×
UNCOV
323
        msg = "This serialization format requires the %s package." % pkg
×
324

UNCOV
325
    if extension is MISSING:
×
UNCOV
326
        extension = fmt.split(":", 1)[0]
×
327

UNCOV
328
    UNAVAILABLE_FORMATS[fmt] = UnavailableFormat(extension, msg)
×
329

UNCOV
330
    if extension and extension not in FORMAT_BY_EXTENSION:
×
UNCOV
331
        FORMAT_BY_EXTENSION[extension.lower()] = fmt
×
332

333

334
def dumps(obj, fmt):
4✔
335
    """Serialize `obj` to bytes using the format specified by `fmt`"""
336

337
    return _get_format(fmt).dumps(obj)
4✔
338

339

340
def dump(obj, file, fmt=None):
4✔
341
    """Serialize `obj` to a file using the format specified by `fmt`
342

343
    The file can be specified by a file-like object or filename.
344
    In the latter case the fmt is not need if it can be guessed from the extension.
345
    """
346
    if isinstance(file, str):
4✔
347
        file = pathlib.Path(file)
4✔
348

349
    if isinstance(file, pathlib.Path):
4✔
350
        if fmt is None:
4✔
351
            fmt = _get_format_from_ext(file.suffix.lstrip("."))
4✔
352
        with file.open(mode="wb") as fp:
4✔
353
            dump(obj, fp, fmt)
4✔
354
    else:
355
        _get_format(fmt).dump(obj, file)
4✔
356

357

358
def loads(serialized, fmt):
4✔
359
    """Deserialize bytes using the format specified by `fmt`"""
360

361
    return _get_format(fmt).loads(serialized)
4✔
362

363

364
def load(file, fmt=None):
4✔
365
    """Deserialize from a file using the format specified by `fmt`
366

367
    The file can be specified by a file-like object or filename.
368
    In the latter case the fmt is not need if it can be guessed from the extension.
369
    """
370
    if isinstance(file, str):
4✔
371
        file = pathlib.Path(file)
4✔
372

373
    if isinstance(file, pathlib.Path):
4✔
374
        if fmt is None:
4✔
375
            fmt = _get_format_from_ext(file.suffix.lstrip("."))
4✔
376
        with file.open(mode="rb") as fp:
4✔
377
            return load(fp, fmt)
4✔
378

379
    return _get_format(fmt).load(file)
4✔
380

381

382
def register_class(klass, to_builtin, from_builtin):
4✔
383
    """Register a custom class for serialization and deserialization.
384

385
    `to_builtin` must be a function that takes an object from the custom class
386
    and returns an object consisting only of Python builtin types.
387

388
    `from_builtin` must be a function that takes the output of `to_builtin` and
389
    returns an object from the custom class.
390

391
    In other words:
392
        >>> obj == from_builtin(to_builtin(obj))    # doctest: +SKIP
393
    """
394
    CLASSES[klass] = CLASSES_BY_NAME[str(klass)] = ClassHelper(to_builtin, from_builtin)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc