• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

spesmilo / electrum / 5013702581157888

28 Jan 2026 02:39PM UTC coverage: 63.596% (+0.9%) from 62.722%
5013702581157888

Pull #10451

CirrusCI

f321x
tests: add unittest for util.create_wallet_history_export

Adds unittest for create_wallet_history_export that
compares the output against reference files. This
should help to prevent regressions and ensure the layout
of the export stays static over time.
Pull Request #10451: history export: make fees bitcoin, add unittest, rm unconf tx

31 of 32 new or added lines in 1 file covered. (96.88%)

152 existing lines in 3 files now uncovered.

24349 of 38287 relevant lines covered (63.6%)

0.64 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.9
/electrum/lnmsg.py
1
import os
1✔
2
import csv
1✔
3
import io
1✔
4
from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional, Mapping
1✔
5
from types import MappingProxyType
1✔
6
from collections import OrderedDict
1✔
7

8
from .lnutil import OnionFailureCodeMetaFlag
1✔
9

10

11
class FailedToParseMsg(Exception):
1✔
12
    msg_type_int: Optional[int] = None
1✔
13
    msg_type_name: Optional[str] = None
1✔
14

15
class UnknownMsgType(FailedToParseMsg): pass
1✔
16
class UnknownOptionalMsgType(UnknownMsgType): pass
1✔
17
class UnknownMandatoryMsgType(UnknownMsgType): pass
1✔
18

19
class MalformedMsg(FailedToParseMsg): pass
1✔
20
class UnknownMsgFieldType(MalformedMsg): pass
1✔
21
class UnexpectedEndOfStream(MalformedMsg): pass
1✔
22
class FieldEncodingNotMinimal(MalformedMsg): pass
1✔
23
class UnknownMandatoryTLVRecordType(MalformedMsg): pass
1✔
24
class MsgTrailingGarbage(MalformedMsg): pass
1✔
25
class MsgInvalidFieldOrder(MalformedMsg): pass
1✔
26
class UnexpectedFieldSizeForEncoder(MalformedMsg): pass
1✔
27

28

29
def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int:
1✔
30
    cur_pos = fd.tell()
1✔
31
    end_pos = fd.seek(0, io.SEEK_END)
1✔
32
    fd.seek(cur_pos)
1✔
33
    return end_pos - cur_pos
1✔
34

35

36
def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
1✔
37
    # note: it's faster to read n bytes and then check if we read n, than
38
    #       to assert we can read at least n and then read n bytes.
39
    nremaining = _num_remaining_bytes_to_read(fd)
×
40
    if nremaining < n:
×
41
        raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left")
×
42

43

44
def write_bigsize_int(i: int) -> bytes:
1✔
45
    assert i >= 0, i
1✔
46
    if i < 0xfd:
1✔
47
        return int.to_bytes(i, length=1, byteorder="big", signed=False)
1✔
48
    elif i < 0x1_0000:
1✔
49
        return b"\xfd" + int.to_bytes(i, length=2, byteorder="big", signed=False)
1✔
50
    elif i < 0x1_0000_0000:
1✔
51
        return b"\xfe" + int.to_bytes(i, length=4, byteorder="big", signed=False)
1✔
52
    else:
53
        return b"\xff" + int.to_bytes(i, length=8, byteorder="big", signed=False)
1✔
54

55

56
def read_bigsize_int(fd: io.BytesIO) -> Optional[int]:
1✔
57
    try:
1✔
58
        first = fd.read(1)[0]
1✔
59
    except IndexError:
1✔
60
        return None  # end of file
1✔
61
    if first < 0xfd:
1✔
62
        return first
1✔
63
    elif first == 0xfd:
1✔
64
        buf = fd.read(2)
1✔
65
        if len(buf) != 2:
1✔
66
            raise UnexpectedEndOfStream()
1✔
67
        val = int.from_bytes(buf, byteorder="big", signed=False)
1✔
68
        if not (0xfd <= val < 0x1_0000):
1✔
69
            raise FieldEncodingNotMinimal()
1✔
70
        return val
1✔
71
    elif first == 0xfe:
1✔
72
        buf = fd.read(4)
1✔
73
        if len(buf) != 4:
1✔
74
            raise UnexpectedEndOfStream()
1✔
75
        val = int.from_bytes(buf, byteorder="big", signed=False)
1✔
76
        if not (0x1_0000 <= val < 0x1_0000_0000):
1✔
77
            raise FieldEncodingNotMinimal()
1✔
78
        return val
1✔
79
    elif first == 0xff:
1✔
80
        buf = fd.read(8)
1✔
81
        if len(buf) != 8:
1✔
82
            raise UnexpectedEndOfStream()
1✔
83
        val = int.from_bytes(buf, byteorder="big", signed=False)
1✔
84
        if not (0x1_0000_0000 <= val):
1✔
85
            raise FieldEncodingNotMinimal()
1✔
86
        return val
1✔
87
    raise Exception()
×
88

89

90
# TODO: maybe if field_type is not "byte", we could return a list of type_len sized chunks?
91
#       if field_type is a numeric, we could return a list of ints?
92
def _read_primitive_field(
1✔
93
        *,
94
        fd: io.BytesIO,
95
        field_type: str,
96
        count: Union[int, str]
97
) -> Union[bytes, int]:
98
    if not fd:
1✔
99
        raise Exception()
×
100
    if isinstance(count, int):
1✔
101
        assert count >= 0, f"{count!r} must be non-neg int"
1✔
102
    elif count == "...":
1✔
103
        pass
1✔
104
    else:
105
        raise Exception(f"unexpected field count: {count!r}")
×
106
    if count == 0:
1✔
107
        return b""
1✔
108
    type_len = None
1✔
109
    if field_type == 'byte':
1✔
110
        type_len = 1
1✔
111
    elif field_type in ('u8', 'u16', 'u32', 'u64'):
1✔
112
        if field_type == 'u8':
1✔
113
            type_len = 1
×
114
        elif field_type == 'u16':
1✔
115
            type_len = 2
1✔
116
        elif field_type == 'u32':
1✔
117
            type_len = 4
1✔
118
        else:
119
            assert field_type == 'u64'
1✔
120
            type_len = 8
1✔
121
        assert count == 1, count
1✔
122
        buf = fd.read(type_len)
1✔
123
        if len(buf) != type_len:
1✔
124
            raise UnexpectedEndOfStream()
1✔
125
        return int.from_bytes(buf, byteorder="big", signed=False)
1✔
126
    elif field_type in ('tu16', 'tu32', 'tu64'):
1✔
127
        if field_type == 'tu16':
1✔
128
            type_len = 2
×
129
        elif field_type == 'tu32':
1✔
130
            type_len = 4
1✔
131
        else:
132
            assert field_type == 'tu64'
1✔
133
            type_len = 8
1✔
134
        assert count == 1, count
1✔
135
        raw = fd.read(type_len)
1✔
136
        if len(raw) > 0 and raw[0] == 0x00:
1✔
137
            raise FieldEncodingNotMinimal()
1✔
138
        return int.from_bytes(raw, byteorder="big", signed=False)
1✔
139
    elif field_type == 'bigsize':
1✔
140
        assert count == 1, count
1✔
141
        val = read_bigsize_int(fd)
1✔
142
        if val is None:
1✔
143
            raise UnexpectedEndOfStream()
1✔
144
        return val
1✔
145
    elif field_type == 'chain_hash':
1✔
146
        type_len = 32
1✔
147
    elif field_type == 'channel_id':
1✔
148
        type_len = 32
1✔
149
    elif field_type == 'sha256':
1✔
150
        type_len = 32
1✔
151
    elif field_type == 'signature':
1✔
152
        type_len = 64
1✔
153
    elif field_type == 'point':
1✔
154
        type_len = 33
1✔
155
    elif field_type == 'short_channel_id':
1✔
156
        type_len = 8
1✔
157
    elif field_type == 'sciddir_or_pubkey':
×
158
        buf = fd.read(1)
×
159
        if buf[0] in [0, 1]:
×
160
            type_len = 9
×
161
        elif buf[0] in [2, 3]:
×
162
            type_len = 33
×
163
        else:
164
            raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3")
×
165
        buf += fd.read(type_len - 1)
×
166
        if len(buf) != type_len:
×
167
            raise UnexpectedEndOfStream()
×
168
        return buf
×
169

170
    if count == "...":
1✔
171
        total_len = -1  # read all
1✔
172
    else:
173
        if type_len is None:
1✔
174
            raise UnknownMsgFieldType(f"unknown field type: {field_type!r}")
×
175
        total_len = count * type_len
1✔
176

177
    buf = fd.read(total_len)
1✔
178
    if total_len >= 0 and len(buf) != total_len:
1✔
179
        raise UnexpectedEndOfStream()
1✔
180
    return buf
1✔
181

182

183
# TODO: maybe for "value" we could accept a list with len "count" of appropriate items
184
def _write_primitive_field(
1✔
185
        *,
186
        fd: io.BytesIO,
187
        field_type: str,
188
        count: Union[int, str],
189
        value: Union[bytes, int]
190
) -> None:
191
    if not fd:
1✔
192
        raise Exception()
×
193
    if isinstance(count, int):
1✔
194
        assert count >= 0, f"{count!r} must be non-neg int"
1✔
195
    elif count == "...":
1✔
196
        pass
1✔
197
    else:
198
        raise Exception(f"unexpected field count: {count!r}")
×
199
    if count == 0:
1✔
200
        return
1✔
201
    type_len = None
1✔
202
    if field_type == 'byte':
1✔
203
        type_len = 1
1✔
204
    elif field_type == 'u8':
1✔
205
        type_len = 1
×
206
    elif field_type == 'u16':
1✔
207
        type_len = 2
1✔
208
    elif field_type == 'u32':
1✔
209
        type_len = 4
1✔
210
    elif field_type == 'u64':
1✔
211
        type_len = 8
1✔
212
    elif field_type in ('tu16', 'tu32', 'tu64'):
1✔
213
        if field_type == 'tu16':
1✔
214
            type_len = 2
×
215
        elif field_type == 'tu32':
1✔
216
            type_len = 4
1✔
217
        else:
218
            assert field_type == 'tu64'
1✔
219
            type_len = 8
1✔
220
        assert count == 1, count
1✔
221
        if isinstance(value, int):
1✔
222
            value = int.to_bytes(value, length=type_len, byteorder="big", signed=False)
1✔
223
        if not isinstance(value, (bytes, bytearray)):
1✔
224
            raise Exception(f"can only write bytes into fd. got: {value!r}")
×
225
        while len(value) > 0 and value[0] == 0x00:
1✔
226
            value = value[1:]
1✔
227
        nbytes_written = fd.write(value)
1✔
228
        if nbytes_written != len(value):
1✔
229
            raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
×
230
        return
1✔
231
    elif field_type == 'bigsize':
1✔
232
        assert count == 1, count
1✔
233
        if isinstance(value, int):
1✔
234
            value = write_bigsize_int(value)
1✔
235
        if not isinstance(value, (bytes, bytearray)):
1✔
236
            raise Exception(f"can only write bytes into fd. got: {value!r}")
×
237
        nbytes_written = fd.write(value)
1✔
238
        if nbytes_written != len(value):
1✔
239
            raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
×
240
        return
1✔
241
    elif field_type == 'chain_hash':
1✔
242
        type_len = 32
1✔
243
    elif field_type == 'channel_id':
1✔
244
        type_len = 32
1✔
245
    elif field_type == 'sha256':
1✔
246
        type_len = 32
1✔
247
    elif field_type == 'signature':
1✔
248
        type_len = 64
1✔
249
    elif field_type == 'point':
1✔
250
        type_len = 33
1✔
251
    elif field_type == 'short_channel_id':
1✔
252
        type_len = 8
1✔
253
    elif field_type == 'sciddir_or_pubkey':
1✔
254
        assert isinstance(value, bytes)
1✔
255
        if value[0] in [0, 1]:
1✔
256
            type_len = 9  # short_channel_id
×
257
        elif value[0] in [2, 3]:
1✔
258
            type_len = 33  # point
1✔
259
        else:
260
            raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3")
×
261
    total_len = -1
1✔
262
    if count != "...":
1✔
263
        if type_len is None:
1✔
264
            raise UnknownMsgFieldType(f"unknown field type: {field_type!r}")
×
265
        total_len = count * type_len
1✔
266
        if isinstance(value, int) and (count == 1 or field_type == 'byte'):
1✔
267
            value = int.to_bytes(value, length=total_len, byteorder="big", signed=False)
1✔
268
    if not isinstance(value, (bytes, bytearray)):
1✔
269
        raise Exception(f"can only write bytes into fd. got: {value!r}")
×
270
    if count != "..." and total_len != len(value):
1✔
271
        raise UnexpectedFieldSizeForEncoder(f"expected: {total_len}, got {len(value)}")
1✔
272
    nbytes_written = fd.write(value)
1✔
273
    if nbytes_written != len(value):
1✔
274
        raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
×
275

276

277
def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]:
1✔
278
    if not fd: raise Exception()
1✔
279
    tlv_type = _read_primitive_field(fd=fd, field_type="bigsize", count=1)
1✔
280
    tlv_len = _read_primitive_field(fd=fd, field_type="bigsize", count=1)
1✔
281
    tlv_val = _read_primitive_field(fd=fd, field_type="byte", count=tlv_len)
1✔
282
    return tlv_type, tlv_val
1✔
283

284

285
def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None:
1✔
286
    if not fd: raise Exception()
1✔
287
    tlv_len = len(tlv_val)
1✔
288
    _write_primitive_field(fd=fd, field_type="bigsize", count=1, value=tlv_type)
1✔
289
    _write_primitive_field(fd=fd, field_type="bigsize", count=1, value=tlv_len)
1✔
290
    _write_primitive_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val)
1✔
291

292

293
def _resolve_field_count(field_count_str: str, *, vars_dict: Mapping, allow_any=False) -> Union[int, str]:
1✔
294
    """Returns an evaluated field count, typically an int.
295
    If allow_any is True, the return value can be a str with value=="...".
296
    """
297
    if field_count_str == "":
1✔
298
        field_count = 1
1✔
299
    elif field_count_str == "...":
1✔
300
        if not allow_any:
1✔
301
            raise Exception("field count is '...' but allow_any is False")
×
302
        return field_count_str
1✔
303
    else:
304
        try:
1✔
305
            field_count = int(field_count_str)
1✔
306
        except ValueError:
1✔
307
            field_count = vars_dict[field_count_str]
1✔
308
            if isinstance(field_count, (bytes, bytearray)):
1✔
309
                field_count = int.from_bytes(field_count, byteorder="big")
1✔
310
    assert isinstance(field_count, int)
1✔
311
    return field_count
1✔
312

313

314
def _parse_msgtype_intvalue_for_onion_wire(value: str) -> int:
1✔
315
    msg_type_int = 0
1✔
316
    for component in value.split("|"):
1✔
317
        try:
1✔
318
            msg_type_int |= int(component)
1✔
319
        except ValueError:
1✔
320
            msg_type_int |= OnionFailureCodeMetaFlag[component]
1✔
321
    return msg_type_int
1✔
322

323

324
class LNSerializer:
1✔
325

326
    def __init__(self, *, for_onion_wire: bool = False):
1✔
327
        # TODO msg_type could be 'int' everywhere...
328
        self.msg_scheme_from_type = {}  # type: Dict[bytes, List[Sequence[str]]]
1✔
329
        self.msg_type_from_name = {}  # type: Dict[str, bytes]
1✔
330

331
        self.in_tlv_stream_get_tlv_record_scheme_from_type = {}  # type: Dict[str, Dict[int, List[Sequence[str]]]]
1✔
332
        self.in_tlv_stream_get_record_type_from_name = {}  # type: Dict[str, Dict[str, int]]
1✔
333
        self.in_tlv_stream_get_record_name_from_type = {}  # type: Dict[str, Dict[int, str]]
1✔
334

335
        self.subtypes = {}  # type: Dict[str, Dict[str, Sequence[str]]]
1✔
336

337
        if for_onion_wire:
1✔
338
            path = os.path.join(os.path.dirname(__file__), "lnwire", "onion_wire.csv")
1✔
339
        else:
340
            path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv")
1✔
341
        with open(path, newline='') as f:
1✔
342
            csvreader = csv.reader(f)
1✔
343
            for row in csvreader:
1✔
344
                #print(f">>> {row!r}")
345
                if row[0] == "msgtype":
1✔
346
                    # msgtype,<msgname>,<value>[,<option>]
347
                    msg_type_name = row[1]
1✔
348
                    if for_onion_wire:
1✔
349
                        msg_type_int = _parse_msgtype_intvalue_for_onion_wire(str(row[2]))
1✔
350
                    else:
351
                        msg_type_int = int(row[2])
1✔
352
                    msg_type_bytes = msg_type_int.to_bytes(2, 'big')
1✔
353
                    assert msg_type_bytes not in self.msg_scheme_from_type, f"type collision? for {msg_type_name}"
1✔
354
                    assert msg_type_name not in self.msg_type_from_name, f"type collision? for {msg_type_name}"
1✔
355
                    row[2] = msg_type_int
1✔
356
                    self.msg_scheme_from_type[msg_type_bytes] = [tuple(row)]
1✔
357
                    self.msg_type_from_name[msg_type_name] = msg_type_bytes
1✔
358
                elif row[0] == "msgdata":
1✔
359
                    # msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
360
                    assert msg_type_name == row[1]
1✔
361
                    self.msg_scheme_from_type[msg_type_bytes].append(tuple(row))
1✔
362
                elif row[0] == "tlvtype":
1✔
363
                    # tlvtype,<tlvstreamname>,<tlvname>,<value>[,<option>]
364
                    tlv_stream_name = row[1]
1✔
365
                    tlv_record_name = row[2]
1✔
366
                    tlv_record_type = int(row[3])
1✔
367
                    row[3] = tlv_record_type
1✔
368
                    if tlv_stream_name not in self.in_tlv_stream_get_tlv_record_scheme_from_type:
1✔
369
                        self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name] = OrderedDict()
1✔
370
                        self.in_tlv_stream_get_record_type_from_name[tlv_stream_name] = {}
1✔
371
                        self.in_tlv_stream_get_record_name_from_type[tlv_stream_name] = {}
1✔
372
                    assert tlv_record_type not in self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
1✔
373
                    assert tlv_record_name not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
1✔
374
                    assert tlv_record_type not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
1✔
375
                    self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type] = [tuple(row)]
1✔
376
                    self.in_tlv_stream_get_record_type_from_name[tlv_stream_name][tlv_record_name] = tlv_record_type
1✔
377
                    self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type] = tlv_record_name
1✔
378
                    if max(self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name].keys()) > tlv_record_type:
1✔
UNCOV
379
                        raise Exception(f"tlv record types must be listed in monotonically increasing order for stream. "
×
380
                                        f"stream={tlv_stream_name}")
381
                elif row[0] == "tlvdata":
1✔
382
                    # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
383
                    assert tlv_stream_name == row[1]
1✔
384
                    assert tlv_record_name == row[2]
1✔
385
                    self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row))
1✔
386
                elif row[0] == "subtype":
1✔
387
                    # subtype,<subtypename>
388
                    subtypename = row[1]
1✔
389
                    assert subtypename not in self.subtypes, f"duplicate declaration of subtype {subtypename}"
1✔
390
                    self.subtypes[subtypename] = {}
1✔
391
                elif row[0] == "subtypedata":
1✔
392
                    # subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
393
                    subtypename = row[1]
1✔
394
                    fieldname = row[2]
1✔
395
                    assert subtypename in self.subtypes, f"subtypedata definition for subtype {subtypename} declared before subtype"
1✔
396
                    assert fieldname not in self.subtypes[subtypename], f"duplicate field definition for {fieldname} for subtype {subtypename}"
1✔
397
                    self.subtypes[subtypename][fieldname] = tuple(row)
1✔
398
                else:
UNCOV
399
                    pass  # TODO
×
400

401
    def write_field(
1✔
402
            self,
403
            *,
404
            fd: io.BytesIO,
405
            field_type: str,
406
            count: Union[int, str],
407
            value: Union[Sequence[Mapping[str, Any]], Mapping[str, Any]],
408
    ) -> None:
409
        assert fd
1✔
410

411
        if field_type not in self.subtypes:
1✔
412
            _write_primitive_field(fd=fd, field_type=field_type, count=count, value=value)
1✔
413
            return
1✔
414

415
        if isinstance(count, int):
1✔
416
            assert count >= 0, f"{count!r} must be non-neg int"
1✔
417
        elif count == "...":
×
UNCOV
418
            pass
×
419
        else:
UNCOV
420
            raise Exception(f"unexpected field count: {count!r}")
×
421
        if count == 0:
1✔
UNCOV
422
            return
×
423

424
        if count == 1:
1✔
425
            assert isinstance(value, (MappingProxyType, dict)) or isinstance(value, (list, tuple)), type(value)
1✔
426
            values = [value] if isinstance(value, (MappingProxyType, dict)) else value
1✔
427
        else:
428
            assert isinstance(value, (tuple, list)), f'{field_type=}, expected value of type list/tuple for {count=}'
1✔
429
            values = value
1✔
430

431
        if count == '...':
1✔
UNCOV
432
            count = len(values)
×
433
        else:
434
            assert count == len(values), f'{field_type=}, expected {count} but got {len(values)}'
1✔
435
        if count == 0:
1✔
UNCOV
436
            return
×
437

438
        for record in values:
1✔
439
            for subtypename, row in self.subtypes[field_type].items():
1✔
440
                # subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
441
                subtype_field_name = row[2]
1✔
442
                subtype_field_type = row[3]
1✔
443
                subtype_field_count_str = row[4]
1✔
444

445
                subtype_field_count = _resolve_field_count(
1✔
446
                    subtype_field_count_str,
447
                    vars_dict=record,
448
                    allow_any=True)
449

450
                if subtype_field_name not in record:
1✔
UNCOV
451
                    raise Exception(f'complex field type {field_type} missing element {subtype_field_name}')
×
452

453
                self.write_field(
1✔
454
                    fd=fd,
455
                    field_type=subtype_field_type,
456
                    count=subtype_field_count,
457
                    value=record[subtype_field_name])
458

459
    def read_field(
1✔
460
            self,
461
            *,
462
            fd: io.BytesIO,
463
            field_type: str,
464
            count: Union[int, str]
465
    ) -> Union[bytes, List[Dict[str, Any]], Dict[str, Any]]:
466
        assert fd
1✔
467

468
        if field_type not in self.subtypes:
1✔
469
            return _read_primitive_field(fd=fd, field_type=field_type, count=count)
1✔
470

471
        if isinstance(count, int):
×
UNCOV
472
            assert count >= 0, f"{count!r} must be non-neg int"
×
473
        elif count == "...":
×
474
            pass
×
475
        else:
UNCOV
476
            raise Exception(f"unexpected field count: {count!r}")
×
477
        if count == 0:
×
UNCOV
478
            return b""
×
479

480
        parsedlist = []
×
481

UNCOV
482
        while _num_remaining_bytes_to_read(fd):
×
483
            parsed = {}
×
484
            for subtypename, row in self.subtypes[field_type].items():
×
485
                # subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
UNCOV
486
                subtype_field_name = row[2]
×
487
                subtype_field_type = row[3]
×
UNCOV
488
                subtype_field_count_str = row[4]
×
489

UNCOV
490
                subtype_field_count = _resolve_field_count(
×
491
                    subtype_field_count_str,
492
                    vars_dict=parsed,
493
                    allow_any=True)
494

UNCOV
495
                parsed[subtype_field_name] = self.read_field(
×
496
                    fd=fd,
497
                    field_type=subtype_field_type,
498
                    count=subtype_field_count)
UNCOV
499
            parsedlist.append(parsed)
×
500

UNCOV
501
        return parsedlist if count == '...' or count > 1 else parsedlist[0]
×
502

503
    def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None:
1✔
504
        scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
1✔
505
        for tlv_record_type, scheme in scheme_map.items():  # note: tlv_record_type is monotonically increasing
1✔
506
            tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
1✔
507
            if tlv_record_name not in kwargs:
1✔
508
                continue
1✔
509
            with io.BytesIO() as tlv_record_fd:
1✔
510
                for row in scheme:
1✔
511
                    if row[0] == "tlvtype":
1✔
512
                        pass
1✔
513
                    elif row[0] == "tlvdata":
1✔
514
                        # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
515
                        assert tlv_stream_name == row[1]
1✔
516
                        assert tlv_record_name == row[2]
1✔
517
                        field_name = row[3]
1✔
518
                        field_type = row[4]
1✔
519
                        field_count_str = row[5]
1✔
520
                        field_count = _resolve_field_count(field_count_str,
1✔
521
                                                           vars_dict=kwargs[tlv_record_name],
522
                                                           allow_any=True)
523
                        field_value = kwargs[tlv_record_name][field_name]
1✔
524
                        self.write_field(
1✔
525
                            fd=tlv_record_fd,
526
                            field_type=field_type,
527
                            count=field_count,
528
                            value=field_value)
529
                    else:
UNCOV
530
                        raise Exception(f"unexpected row in scheme: {row!r}")
×
531
                _write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())
1✔
532

533
    def read_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str) -> Dict[str, Dict[str, Any]]:
1✔
534
        parsed = {}  # type: Dict[str, Dict[str, Any]]
1✔
535
        scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
1✔
536
        last_seen_tlv_record_type = -1  # type: int
1✔
537
        while _num_remaining_bytes_to_read(fd) > 0:
1✔
538
            tlv_record_type, tlv_record_val = _read_tlv_record(fd=fd)
1✔
539
            if not (tlv_record_type > last_seen_tlv_record_type):
1✔
540
                raise MsgInvalidFieldOrder(f"TLV records must be monotonically increasing by type. "
1✔
541
                                           f"cur: {tlv_record_type}. prev: {last_seen_tlv_record_type}")
542
            last_seen_tlv_record_type = tlv_record_type
1✔
543
            try:
1✔
544
                scheme = scheme_map[tlv_record_type]
1✔
545
            except KeyError:
1✔
546
                if tlv_record_type % 2 == 0:
1✔
547
                    # unknown "even" type: hard fail
548
                    raise UnknownMandatoryTLVRecordType(f"{tlv_stream_name}/{tlv_record_type}") from None
1✔
549
                else:
550
                    # unknown "odd" type: skip it
551
                    continue
1✔
552
            tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
1✔
553
            parsed[tlv_record_name] = {}
1✔
554
            with io.BytesIO(tlv_record_val) as tlv_record_fd:
1✔
555
                for row in scheme:
1✔
556
                    #print(f"row: {row!r}")
557
                    if row[0] == "tlvtype":
1✔
558
                        pass
1✔
559
                    elif row[0] == "tlvdata":
1✔
560
                        # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
561
                        assert tlv_stream_name == row[1]
1✔
562
                        assert tlv_record_name == row[2]
1✔
563
                        field_name = row[3]
1✔
564
                        field_type = row[4]
1✔
565
                        field_count_str = row[5]
1✔
566
                        field_count = _resolve_field_count(
1✔
567
                            field_count_str,
568
                            vars_dict=parsed[tlv_record_name],
569
                            allow_any=True)
570
                        #print(f">> count={field_count}. parsed={parsed}")
571
                        parsed[tlv_record_name][field_name] = self.read_field(
1✔
572
                            fd=tlv_record_fd,
573
                            field_type=field_type,
574
                            count=field_count)
575
                    else:
UNCOV
576
                        raise Exception(f"unexpected row in scheme: {row!r}")
×
577
                if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
1✔
578
                    raise MsgTrailingGarbage(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage")
1✔
579
        return parsed
1✔
580

581
    def encode_msg(self, msg_type: str, **kwargs) -> bytes:
1✔
582
        """
583
        Encode kwargs into a Lightning message (bytes)
584
        of the type given in the msg_type string
585
        """
586
        #print(f">>> encode_msg. msg_type={msg_type}, payload={kwargs!r}")
587
        msg_type_bytes = self.msg_type_from_name[msg_type]
1✔
588
        scheme = self.msg_scheme_from_type[msg_type_bytes]
1✔
589
        with io.BytesIO() as fd:
1✔
590
            fd.write(msg_type_bytes)
1✔
591
            for row in scheme:
1✔
592
                if row[0] == "msgtype":
1✔
593
                    pass
1✔
594
                elif row[0] == "msgdata":
1✔
595
                    # msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
596
                    field_name = row[2]
1✔
597
                    field_type = row[3]
1✔
598
                    field_count_str = row[4]
1✔
599
                    #print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}")
600
                    field_count = _resolve_field_count(field_count_str, vars_dict=kwargs)
1✔
601
                    if field_name == "tlvs":
1✔
602
                        tlv_stream_name = field_type
1✔
603
                        if tlv_stream_name in kwargs:
1✔
604
                            self.write_tlv_stream(fd=fd, tlv_stream_name=tlv_stream_name, **(kwargs[tlv_stream_name]))
1✔
605
                        continue
1✔
606
                    try:
1✔
607
                        field_value = kwargs[field_name]
1✔
608
                    except KeyError:
1✔
609
                        field_value = 0  # default mandatory fields to zero
1✔
610
                    #print(f">>> encode_msg. writing field: {field_name}. value={field_value!r}. field_type={field_type!r}. count={field_count!r}")
611
                    _write_primitive_field(fd=fd, field_type=field_type, count=field_count, value=field_value)
1✔
612
                    #print(f">>> encode_msg. so far: {fd.getvalue().hex()}")
613
                else:
UNCOV
614
                    raise Exception(f"unexpected row in scheme: {row!r}")
×
615
            return fd.getvalue()
1✔
616

617
    def decode_msg(self, data: bytes) -> Tuple[str, dict]:
1✔
618
        """
619
        Decode Lightning message by reading the first
620
        two bytes to determine message type.
621

622
        Returns message type string and parsed message contents dict,
623
        or raises FailedToParseMsg.
624
        """
625
        #print(f"decode_msg >>> {data.hex()}")
626
        assert len(data) >= 2
1✔
627
        msg_type_bytes = data[:2]
1✔
628
        msg_type_int = int.from_bytes(msg_type_bytes, byteorder="big", signed=False)
1✔
629
        try:
1✔
630
            scheme = self.msg_scheme_from_type[msg_type_bytes]
1✔
631
        except KeyError:
1✔
632
            if msg_type_int % 2 == 0:  # even types must be understood: "mandatory"
1✔
633
                raise UnknownMandatoryMsgType(f"msg_type={msg_type_int}")
1✔
634
            else:  # odd types are ok not to understand: "optional"
635
                raise UnknownOptionalMsgType(f"msg_type={msg_type_int}")
1✔
636
        assert scheme[0][2] == msg_type_int
1✔
637
        msg_type_name = scheme[0][1]
1✔
638
        parsed = {}
1✔
639
        try:
1✔
640
            with io.BytesIO(data[2:]) as fd:
1✔
641
                for row in scheme:
1✔
642
                    #print(f"row: {row!r}")
643
                    if row[0] == "msgtype":
1✔
644
                        pass
1✔
645
                    elif row[0] == "msgdata":
1✔
646
                        field_name = row[2]
1✔
647
                        field_type = row[3]
1✔
648
                        field_count_str = row[4]
1✔
649
                        field_count = _resolve_field_count(field_count_str, vars_dict=parsed)
1✔
650
                        if field_name == "tlvs":
1✔
651
                            tlv_stream_name = field_type
1✔
652
                            d = self.read_tlv_stream(fd=fd, tlv_stream_name=tlv_stream_name)
1✔
653
                            parsed[tlv_stream_name] = d
1✔
654
                            continue
1✔
655
                        #print(f">> count={field_count}. parsed={parsed}")
656
                        parsed[field_name] = _read_primitive_field(fd=fd, field_type=field_type, count=field_count)
1✔
657
                    else:
UNCOV
658
                        raise Exception(f"unexpected row in scheme: {row!r}")
×
659
        except FailedToParseMsg as e:
1✔
660
            e.msg_type_int = msg_type_int
1✔
661
            e.msg_type_name = msg_type_name
1✔
662
            raise
1✔
663
        return msg_type_name, parsed
1✔
664

665

666
_inst = LNSerializer()
1✔
667
encode_msg = _inst.encode_msg
1✔
668
decode_msg = _inst.decode_msg
1✔
669

670

671
OnionWireSerializer = LNSerializer(for_onion_wire=True)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc