• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

spesmilo / electrum / 5980432291856384

27 Nov 2025 03:43PM UTC coverage: 62.193%. Remained the same
5980432291856384

Pull #10110

CirrusCI

accumulator
onion_message: raise specific exceptions if blinded path could not be generated.
add feature filter for blinded payment path peers
do not check frozen setting for onion message channel peers
improve docs
Pull Request #10110: initial support bolt12 offers

379 of 698 new or added lines in 12 files covered. (54.3%)

517 existing lines in 5 files now uncovered.

23701 of 38109 relevant lines covered (62.19%)

0.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.1
/electrum/lnmsg.py
1
import itertools
1✔
2
import os
1✔
3
import csv
1✔
4
import io
1✔
5
from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional, Mapping
1✔
6
from types import MappingProxyType
1✔
7
from collections import OrderedDict
1✔
8

9
import electrum_ecc as ecc
1✔
10

11
from . import bitcoin
1✔
12
from .lnutil import OnionFailureCodeMetaFlag
1✔
13

14

15
class FailedToParseMsg(Exception):
1✔
16
    msg_type_int: Optional[int] = None
1✔
17
    msg_type_name: Optional[str] = None
1✔
18

19

20
class UnknownMsgType(FailedToParseMsg): pass
1✔
21
class UnknownOptionalMsgType(UnknownMsgType): pass
1✔
22
class UnknownMandatoryMsgType(UnknownMsgType): pass
1✔
23
class MalformedMsg(FailedToParseMsg): pass
1✔
24
class UnknownMsgFieldType(MalformedMsg): pass
1✔
25
class UnexpectedEndOfStream(MalformedMsg): pass
1✔
26
class FieldEncodingNotMinimal(MalformedMsg): pass
1✔
27
class UnknownMandatoryTLVRecordType(MalformedMsg): pass
1✔
28
class MsgTrailingGarbage(MalformedMsg): pass
1✔
29
class MsgInvalidFieldOrder(MalformedMsg): pass
1✔
30
class UnexpectedFieldSizeForEncoder(MalformedMsg): pass
1✔
31
class MsgInvalidSignature(MalformedMsg): pass
1✔
32

33

34
def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int:
1✔
35
    cur_pos = fd.tell()
1✔
36
    end_pos = fd.seek(0, io.SEEK_END)
1✔
37
    fd.seek(cur_pos)
1✔
38
    return end_pos - cur_pos
1✔
39

40

41
def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None:
1✔
42
    # note: it's faster to read n bytes and then check if we read n, than
43
    #       to assert we can read at least n and then read n bytes.
44
    nremaining = _num_remaining_bytes_to_read(fd)
×
45
    if nremaining < n:
×
46
        raise UnexpectedEndOfStream(f"wants to read {n} bytes but only {nremaining} bytes left")
×
47

48

49
def write_bigsize_int(i: int) -> bytes:
1✔
50
    assert i >= 0, i
1✔
51
    if i < 0xfd:
1✔
52
        return int.to_bytes(i, length=1, byteorder="big", signed=False)
1✔
53
    elif i < 0x1_0000:
1✔
54
        return b"\xfd" + int.to_bytes(i, length=2, byteorder="big", signed=False)
1✔
55
    elif i < 0x1_0000_0000:
1✔
56
        return b"\xfe" + int.to_bytes(i, length=4, byteorder="big", signed=False)
1✔
57
    else:
58
        return b"\xff" + int.to_bytes(i, length=8, byteorder="big", signed=False)
1✔
59

60

61
def read_bigsize_int(fd: io.BytesIO) -> Optional[int]:
1✔
62
    try:
1✔
63
        first = fd.read(1)[0]
1✔
64
    except IndexError:
1✔
65
        return None  # end of file
1✔
66
    if first < 0xfd:
1✔
67
        return first
1✔
68
    elif first == 0xfd:
1✔
69
        buf = fd.read(2)
1✔
70
        if len(buf) != 2:
1✔
71
            raise UnexpectedEndOfStream()
1✔
72
        val = int.from_bytes(buf, byteorder="big", signed=False)
1✔
73
        if not (0xfd <= val < 0x1_0000):
1✔
74
            raise FieldEncodingNotMinimal()
1✔
75
        return val
1✔
76
    elif first == 0xfe:
1✔
77
        buf = fd.read(4)
1✔
78
        if len(buf) != 4:
1✔
79
            raise UnexpectedEndOfStream()
1✔
80
        val = int.from_bytes(buf, byteorder="big", signed=False)
1✔
81
        if not (0x1_0000 <= val < 0x1_0000_0000):
1✔
82
            raise FieldEncodingNotMinimal()
1✔
83
        return val
1✔
84
    elif first == 0xff:
1✔
85
        buf = fd.read(8)
1✔
86
        if len(buf) != 8:
1✔
87
            raise UnexpectedEndOfStream()
1✔
88
        val = int.from_bytes(buf, byteorder="big", signed=False)
1✔
89
        if not (0x1_0000_0000 <= val):
1✔
90
            raise FieldEncodingNotMinimal()
1✔
91
        return val
1✔
92
    raise Exception()
×
93

94

95
# TODO: maybe if field_type is not "byte", we could return a list of type_len sized chunks?
96
#       if field_type is a numeric, we could return a list of ints?
97
def _read_primitive_field(
1✔
98
        *,
99
        fd: io.BytesIO,
100
        field_type: str,
101
        count: Union[int, str]
102
) -> Union[bytes, int, str]:
103
    if not fd:
1✔
104
        raise Exception()
×
105
    if isinstance(count, int):
1✔
106
        assert count >= 0, f"{count!r} must be non-neg int"
1✔
107
    elif count == "...":
1✔
108
        pass
1✔
109
    else:
110
        raise Exception(f"unexpected field count: {count!r}")
×
111
    if count == 0:
1✔
112
        return b""
1✔
113
    type_len = None
1✔
114
    if field_type == 'byte':
1✔
115
        type_len = 1
1✔
116
    elif field_type in ('u8', 'u16', 'u32', 'u64'):
1✔
117
        if field_type == 'u8':
1✔
118
            type_len = 1
×
119
        elif field_type == 'u16':
1✔
120
            type_len = 2
1✔
121
        elif field_type == 'u32':
1✔
122
            type_len = 4
1✔
123
        else:
124
            assert field_type == 'u64'
1✔
125
            type_len = 8
1✔
126
        assert count == 1, count
1✔
127
        buf = fd.read(type_len)
1✔
128
        if len(buf) != type_len:
1✔
129
            raise UnexpectedEndOfStream()
1✔
130
        return int.from_bytes(buf, byteorder="big", signed=False)
1✔
131
    elif field_type in ('tu16', 'tu32', 'tu64'):
1✔
132
        if field_type == 'tu16':
1✔
133
            type_len = 2
×
134
        elif field_type == 'tu32':
1✔
135
            type_len = 4
1✔
136
        else:
137
            assert field_type == 'tu64'
1✔
138
            type_len = 8
1✔
139
        assert count == 1, count
1✔
140
        raw = fd.read(type_len)
1✔
141
        if len(raw) > 0 and raw[0] == 0x00:
1✔
142
            raise FieldEncodingNotMinimal()
1✔
143
        return int.from_bytes(raw, byteorder="big", signed=False)
1✔
144
    elif field_type == 'bigsize':
1✔
145
        assert count == 1, count
1✔
146
        val = read_bigsize_int(fd)
1✔
147
        if val is None:
1✔
148
            raise UnexpectedEndOfStream()
1✔
149
        return val
1✔
150
    elif field_type == 'chain_hash':
1✔
151
        type_len = 32
1✔
152
    elif field_type == 'channel_id':
1✔
153
        type_len = 32
1✔
154
    elif field_type == 'sha256':
1✔
155
        type_len = 32
1✔
156
    elif field_type == 'signature':
1✔
157
        type_len = 64
1✔
158
    elif field_type == 'bip340sig':
1✔
159
        type_len = 64
1✔
160
    elif field_type == 'point':
1✔
161
        type_len = 33
1✔
162
    elif field_type == 'short_channel_id':
1✔
163
        type_len = 8
1✔
164
    elif field_type == 'sciddir_or_pubkey':
1✔
165
        buf = fd.read(1)
1✔
166
        if buf[0] in [0, 1]:
1✔
167
            type_len = 9
×
168
        elif buf[0] in [2, 3]:
1✔
169
            type_len = 33
1✔
170
        else:
171
            raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3")
×
172
        buf += fd.read(type_len - 1)
1✔
173
        if len(buf) != type_len:
1✔
174
            raise UnexpectedEndOfStream()
×
175
        return buf
1✔
176
    elif field_type == 'utf8':
1✔
177
        if count != '...':
1✔
NEW
178
            raise Exception(f"utf8 fields can only have unbounded count")
×
179

180
    if count == "...":
1✔
181
        total_len = -1  # read all
1✔
182
    else:
183
        if type_len is None:
1✔
184
            raise UnknownMsgFieldType(f"unknown field type: {field_type!r}")
×
185
        total_len = count * type_len
1✔
186

187
    buf = fd.read(total_len)
1✔
188
    if total_len >= 0 and len(buf) != total_len:
1✔
189
        raise UnexpectedEndOfStream()
1✔
190

191
    if field_type == 'utf8':
1✔
192
        return buf.decode('utf-8')
1✔
193

194
    return buf
1✔
195

196

197
# TODO: maybe for "value" we could accept a list with len "count" of appropriate items
198
def _write_primitive_field(
1✔
199
        *,
200
        fd: io.BytesIO,
201
        field_type: str,
202
        count: Union[int, str],
203
        value: Union[bytes, int, str]
204
) -> None:
205
    if not fd:
1✔
206
        raise Exception()
×
207
    if isinstance(count, int):
1✔
208
        assert count >= 0, f"{count!r} must be non-neg int"
1✔
209
    elif count == "...":
1✔
210
        pass
1✔
211
    else:
212
        raise Exception(f"unexpected field count: {count!r}")
×
213
    if count == 0:
1✔
214
        return
1✔
215
    type_len = None
1✔
216
    if field_type == 'byte':
1✔
217
        type_len = 1
1✔
218
    elif field_type == 'u8':
1✔
219
        type_len = 1
×
220
    elif field_type == 'u16':
1✔
221
        type_len = 2
1✔
222
    elif field_type == 'u32':
1✔
223
        type_len = 4
1✔
224
    elif field_type == 'u64':
1✔
225
        type_len = 8
1✔
226
    elif field_type in ('tu16', 'tu32', 'tu64'):
1✔
227
        if field_type == 'tu16':
1✔
228
            type_len = 2
×
229
        elif field_type == 'tu32':
1✔
230
            type_len = 4
1✔
231
        else:
232
            assert field_type == 'tu64'
1✔
233
            type_len = 8
1✔
234
        assert count == 1, count
1✔
235
        if isinstance(value, int):
1✔
236
            value = int.to_bytes(value, length=type_len, byteorder="big", signed=False)
1✔
237
        if not isinstance(value, (bytes, bytearray)):
1✔
238
            raise Exception(f"can only write bytes into fd. got: {value!r}")
×
239
        while len(value) > 0 and value[0] == 0x00:
1✔
240
            value = value[1:]
1✔
241
        nbytes_written = fd.write(value)
1✔
242
        if nbytes_written != len(value):
1✔
243
            raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
×
244
        return
1✔
245
    elif field_type == 'bigsize':
1✔
246
        assert count == 1, count
1✔
247
        if isinstance(value, int):
1✔
248
            value = write_bigsize_int(value)
1✔
249
        if not isinstance(value, (bytes, bytearray)):
1✔
250
            raise Exception(f"can only write bytes into fd. got: {value!r}")
×
251
        nbytes_written = fd.write(value)
1✔
252
        if nbytes_written != len(value):
1✔
253
            raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
×
254
        return
1✔
255
    elif field_type == 'chain_hash':
1✔
256
        type_len = 32
1✔
257
    elif field_type == 'channel_id':
1✔
258
        type_len = 32
1✔
259
    elif field_type == 'sha256':
1✔
260
        type_len = 32
1✔
261
    elif field_type == 'signature':
1✔
262
        type_len = 64
1✔
263
    elif field_type == 'bip340sig':
1✔
264
        type_len = 64
1✔
265
    elif field_type == 'point':
1✔
266
        type_len = 33
1✔
267
    elif field_type == 'short_channel_id':
1✔
268
        type_len = 8
1✔
269
    elif field_type == 'sciddir_or_pubkey':
1✔
270
        assert isinstance(value, bytes)
1✔
271
        if value[0] in [0, 1]:
1✔
272
            type_len = 9  # short_channel_id
×
273
        elif value[0] in [2, 3]:
1✔
274
            type_len = 33  # point
1✔
275
        else:
276
            raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3")
×
277
    elif field_type == 'utf8':
1✔
278
        if count != '...':
1✔
NEW
279
            raise Exception(f"utf8 fields can only have unbounded count")
×
280
        value = value.encode('utf-8')
1✔
281
    total_len = -1
1✔
282
    if count != "...":
1✔
283
        if type_len is None:
1✔
284
            raise UnknownMsgFieldType(f"unknown field type: {field_type!r}")
×
285
        total_len = count * type_len
1✔
286
        if isinstance(value, int) and (count == 1 or field_type == 'byte'):
1✔
287
            value = int.to_bytes(value, length=total_len, byteorder="big", signed=False)
1✔
288
    if not isinstance(value, (bytes, bytearray)):
1✔
289
        raise Exception(f"can only write bytes into fd. got: {value!r}")
×
290
    if count != "..." and total_len != len(value):
1✔
291
        raise UnexpectedFieldSizeForEncoder(f"expected: {total_len}, got {len(value)}")
1✔
292
    nbytes_written = fd.write(value)
1✔
293
    if nbytes_written != len(value):
1✔
294
        raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")
×
295

296

297
def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes, bytes]:
1✔
298
    if not fd: raise Exception()
1✔
299
    pos_start = fd.seek(0, 1)
1✔
300
    tlv_type = _read_primitive_field(fd=fd, field_type="bigsize", count=1)
1✔
301
    tlv_len = _read_primitive_field(fd=fd, field_type="bigsize", count=1)
1✔
302
    tlv_val = _read_primitive_field(fd=fd, field_type="byte", count=tlv_len)
1✔
303
    pos_end = fd.seek(0, 1)
1✔
304
    rawlen = pos_end - pos_start
1✔
305
    fd.seek(-rawlen, 1)
1✔
306
    rawbytes = fd.read(rawlen)
1✔
307
    return tlv_type, tlv_val, rawbytes
1✔
308

309

310
def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None:
1✔
311
    if not fd: raise Exception()
1✔
312
    tlv_len = len(tlv_val)
1✔
313
    _write_primitive_field(fd=fd, field_type="bigsize", count=1, value=tlv_type)
1✔
314
    _write_primitive_field(fd=fd, field_type="bigsize", count=1, value=tlv_len)
1✔
315
    _write_primitive_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val)
1✔
316

317

318
def _resolve_field_count(field_count_str: str, *, vars_dict: Mapping, allow_any=False) -> Union[int, str]:
1✔
319
    """Returns an evaluated field count, typically an int.
320
    If allow_any is True, the return value can be a str with value=="...".
321
    """
322
    if field_count_str == "":
1✔
323
        field_count = 1
1✔
324
    elif field_count_str == "...":
1✔
325
        if not allow_any:
1✔
326
            raise Exception("field count is '...' but allow_any is False")
×
327
        return field_count_str
1✔
328
    else:
329
        try:
1✔
330
            field_count = int(field_count_str)
1✔
331
        except ValueError:
1✔
332
            field_count = vars_dict[field_count_str]
1✔
333
            if isinstance(field_count, (bytes, bytearray)):
1✔
334
                field_count = int.from_bytes(field_count, byteorder="big")
1✔
335
    assert isinstance(field_count, int)
1✔
336
    return field_count
1✔
337

338

339
def _parse_msgtype_intvalue_for_onion_wire(value: str) -> int:
1✔
340
    msg_type_int = 0
1✔
341
    for component in value.split("|"):
1✔
342
        try:
1✔
343
            msg_type_int |= int(component)
1✔
344
        except ValueError:
1✔
345
            msg_type_int |= OnionFailureCodeMetaFlag[component]
1✔
346
    return msg_type_int
1✔
347

348

349
def batched(iterable, n):  # itertools.batched available from python >=3.12
1✔
350
    # batched('ABCDEFG', 3) --> ABC DEF G
351
    if n < 1:
1✔
NEW
352
        raise ValueError('n must be at least one')
×
353
    it = iter(iterable)
1✔
354
    while batch := tuple(itertools.islice(it, n)):
1✔
355
        yield batch
1✔
356

357

358
def _tlv_merkle_root(tlvs: List[Sequence[bytes]]) -> bytes:
1✔
359
    first_tlv = None
1✔
360
    tlv_merkle_nodes = []
1✔
361

362
    for tlvt, tlv in tlvs:
1✔
363
        if first_tlv is None:
1✔
364
            first_tlv = tlv
1✔
365
        tlv_val = tlv
1✔
366
        tlv_record_type = write_bigsize_int(tlvt)
1✔
367
        merkle_leaf_hash = bitcoin.bip340_tagged_hash(b'LnLeaf', tlv_val)
1✔
368
        merkle_nonce = bitcoin.bip340_tagged_hash(b'LnNonce' + first_tlv, tlv_record_type)
1✔
369

370
        # ascending order
371
        msg = merkle_leaf_hash + merkle_nonce if merkle_leaf_hash < merkle_nonce else merkle_nonce + merkle_leaf_hash
1✔
372
        merkle_node_hash = bitcoin.bip340_tagged_hash(b'LnBranch', msg)
1✔
373

374
        tlv_merkle_nodes.append(merkle_node_hash)
1✔
375

376
    while len(tlv_merkle_nodes) > 1:
1✔
377
        target = []
1✔
378
        for batch in batched(tlv_merkle_nodes, 2):
1✔
379
            if len(batch) == 1:
1✔
380
                target.append(batch[0])
1✔
381
            else:
382
                msg = batch[0] + batch[1] if batch[0] < batch[1] else batch[1] + batch[0]
1✔
383
                merkle_node_hash = bitcoin.bip340_tagged_hash(b'LnBranch', msg)
1✔
384
                target.append(merkle_node_hash)
1✔
385
        tlv_merkle_nodes = target
1✔
386

387
    return tlv_merkle_nodes[0]
1✔
388

389

390
class LNSerializer:
1✔
391

392
    def __init__(self, *, for_onion_wire: bool = False):
1✔
393
        # TODO msg_type could be 'int' everywhere...
394
        self.msg_scheme_from_type = {}  # type: Dict[bytes, List[Sequence[str]]]
1✔
395
        self.msg_type_from_name = {}  # type: Dict[str, bytes]
1✔
396

397
        self.in_tlv_stream_get_tlv_record_scheme_from_type = {}  # type: Dict[str, Dict[int, List[Sequence[str]]]]
1✔
398
        self.in_tlv_stream_get_record_type_from_name = {}  # type: Dict[str, Dict[str, int]]
1✔
399
        self.in_tlv_stream_get_record_name_from_type = {}  # type: Dict[str, Dict[int, str]]
1✔
400

401
        self.subtypes = {}  # type: Dict[str, Dict[str, Sequence[str]]]
1✔
402

403
        if for_onion_wire:
1✔
404
            path = os.path.join(os.path.dirname(__file__), "lnwire", "onion_wire.csv")
1✔
405
        else:
406
            path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv")
1✔
407
        with open(path, newline='') as f:
1✔
408
            csvreader = csv.reader(f)
1✔
409
            for row in csvreader:
1✔
410
                #print(f">>> {row!r}")
411
                if row[0] == "msgtype":
1✔
412
                    # msgtype,<msgname>,<value>[,<option>]
413
                    msg_type_name = row[1]
1✔
414
                    if for_onion_wire:
1✔
415
                        msg_type_int = _parse_msgtype_intvalue_for_onion_wire(str(row[2]))
1✔
416
                    else:
417
                        msg_type_int = int(row[2])
1✔
418
                    msg_type_bytes = msg_type_int.to_bytes(2, 'big')
1✔
419
                    assert msg_type_bytes not in self.msg_scheme_from_type, f"type collision? for {msg_type_name}"
1✔
420
                    assert msg_type_name not in self.msg_type_from_name, f"type collision? for {msg_type_name}"
1✔
421
                    row[2] = msg_type_int
1✔
422
                    self.msg_scheme_from_type[msg_type_bytes] = [tuple(row)]
1✔
423
                    self.msg_type_from_name[msg_type_name] = msg_type_bytes
1✔
424
                elif row[0] == "msgdata":
1✔
425
                    # msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
426
                    assert msg_type_name == row[1]
1✔
427
                    self.msg_scheme_from_type[msg_type_bytes].append(tuple(row))
1✔
428
                elif row[0] == "tlvtype":
1✔
429
                    # tlvtype,<tlvstreamname>,<tlvname>,<value>[,<option>]
430
                    tlv_stream_name = row[1]
1✔
431
                    tlv_record_name = row[2]
1✔
432
                    tlv_record_type = int(row[3])
1✔
433
                    row[3] = tlv_record_type
1✔
434
                    if tlv_stream_name not in self.in_tlv_stream_get_tlv_record_scheme_from_type:
1✔
435
                        self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name] = OrderedDict()
1✔
436
                        self.in_tlv_stream_get_record_type_from_name[tlv_stream_name] = {}
1✔
437
                        self.in_tlv_stream_get_record_name_from_type[tlv_stream_name] = {}
1✔
438
                    assert tlv_record_type not in self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
1✔
439
                    assert tlv_record_name not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
1✔
440
                    assert tlv_record_type not in self.in_tlv_stream_get_record_type_from_name[tlv_stream_name], f"type collision? for {tlv_stream_name}/{tlv_record_name}"
1✔
441
                    self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type] = [tuple(row)]
1✔
442
                    self.in_tlv_stream_get_record_type_from_name[tlv_stream_name][tlv_record_name] = tlv_record_type
1✔
443
                    self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type] = tlv_record_name
1✔
444
                    if max(self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name].keys()) > tlv_record_type:
1✔
445
                        raise Exception(f"tlv record types must be listed in monotonically increasing order for stream. "
×
446
                                        f"stream={tlv_stream_name}")
447
                elif row[0] == "tlvdata":
1✔
448
                    # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
449
                    assert tlv_stream_name == row[1]
1✔
450
                    assert tlv_record_name == row[2]
1✔
451
                    self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name][tlv_record_type].append(tuple(row))
1✔
452
                elif row[0] == "subtype":
1✔
453
                    # subtype,<subtypename>
454
                    subtypename = row[1]
1✔
455
                    assert subtypename not in self.subtypes, f"duplicate declaration of subtype {subtypename}"
1✔
456
                    self.subtypes[subtypename] = {}
1✔
457
                elif row[0] == "subtypedata":
1✔
458
                    # subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
459
                    subtypename = row[1]
1✔
460
                    fieldname = row[2]
1✔
461
                    assert subtypename in self.subtypes, f"subtypedata definition for subtype {subtypename} declared before subtype"
1✔
462
                    assert fieldname not in self.subtypes[subtypename], f"duplicate field definition for {fieldname} for subtype {subtypename}"
1✔
463
                    self.subtypes[subtypename][fieldname] = tuple(row)
1✔
464
                else:
NEW
465
                    pass  # TODO: raise?
×
466

467
    def write_field(
1✔
468
            self,
469
            *,
470
            fd: io.BytesIO,
471
            field_type: str,
472
            count: Union[int, str],
473
            value: Union[Sequence[Mapping[str, Any]], Mapping[str, Any]],
474
    ) -> None:
475
        assert fd
1✔
476

477
        if field_type not in self.subtypes:
1✔
478
            _write_primitive_field(fd=fd, field_type=field_type, count=count, value=value)
1✔
479
            return
1✔
480

481
        if isinstance(count, int):
1✔
482
            assert count >= 0, f"{count!r} must be non-neg int"
1✔
483
        elif count == "...":
1✔
484
            pass
1✔
485
        else:
486
            raise Exception(f"unexpected field count: {count!r}")
×
487
        if count == 0:
1✔
488
            return
1✔
489

490
        if count == 1:
1✔
491
            assert isinstance(value, (MappingProxyType, dict)) or isinstance(value, (list, tuple)), type(value)
1✔
492
            values = [value] if isinstance(value, (MappingProxyType, dict)) else value
1✔
493
        else:
494
            assert isinstance(value, (tuple, list)), f'{field_type=}, expected value of type list/tuple for {count=}'
1✔
495
            values = value
1✔
496

497
        if count == '...':
1✔
498
            count = len(values)
1✔
499
        else:
500
            assert count == len(values), f'{field_type=}, expected {count} but got {len(values)}'
1✔
501
        if count == 0:
1✔
502
            return
×
503

504
        for record in values:
1✔
505
            for subtypename, row in self.subtypes[field_type].items():
1✔
506
                # subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
507
                subtype_field_name = row[2]
1✔
508
                subtype_field_type = row[3]
1✔
509
                subtype_field_count_str = row[4]
1✔
510

511
                subtype_field_count = _resolve_field_count(
1✔
512
                    subtype_field_count_str,
513
                    vars_dict=record,
514
                    allow_any=True)
515

516
                if subtype_field_name not in record:
1✔
517
                    raise Exception(f'complex field type {field_type} missing element {subtype_field_name}')
1✔
518

519
                self.write_field(
1✔
520
                    fd=fd,
521
                    field_type=subtype_field_type,
522
                    count=subtype_field_count,
523
                    value=record[subtype_field_name])
524

525
    def read_field(
1✔
526
            self,
527
            *,
528
            fd: io.BytesIO,
529
            field_type: str,
530
            count: Union[int, str]
531
    ) -> Union[bytes, List[Dict[str, Any]], Dict[str, Any]]:
532
        assert fd
1✔
533

534
        if field_type not in self.subtypes:
1✔
535
            return _read_primitive_field(fd=fd, field_type=field_type, count=count)
1✔
536

537
        if isinstance(count, int):
1✔
538
            assert count >= 0, f"{count!r} must be non-neg int"
1✔
539
        elif count == "...":
1✔
540
            pass
1✔
541
        else:
542
            raise Exception(f"unexpected field count: {count!r}")
×
543
        if count == 0:
1✔
544
            return b""
1✔
545

546
        parsedlist = []
1✔
547

548
        while _num_remaining_bytes_to_read(fd):
1✔
549
            parsed = {}
1✔
550
            for subtypename, row in self.subtypes[field_type].items():
1✔
551
                # subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
552
                subtype_field_name = row[2]
1✔
553
                subtype_field_type = row[3]
1✔
554
                subtype_field_count_str = row[4]
1✔
555

556
                subtype_field_count = _resolve_field_count(
1✔
557
                    subtype_field_count_str,
558
                    vars_dict=parsed,
559
                    allow_any=True)
560

561
                parsed[subtype_field_name] = self.read_field(
1✔
562
                    fd=fd,
563
                    field_type=subtype_field_type,
564
                    count=subtype_field_count)
565
            parsedlist.append(parsed)
1✔
566

567
            # fd might contain more bytes, but we got passed a count. break when we have 'count' items.
568
            # (e.g. nested complex types)
569
            if isinstance(count, int) and len(parsedlist) == count:
1✔
570
                break
1✔
571

572
        return parsedlist if count == '...' or count > 1 else parsedlist[0]
1✔
573

574
    def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, signing_key: bytes = None, **kwargs) -> None:
1✔
575
        sign_over_tlvs = []
1✔
576

577
        scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
1✔
578
        for tlv_record_type, scheme in scheme_map.items():  # note: tlv_record_type is monotonically increasing
1✔
579
            tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
1✔
580
            if tlv_record_name not in kwargs:
1✔
581
                # skip record_name if not in kwargs, unless we need to generate it
582
                if tlv_record_name != 'signature' or signing_key is None:
1✔
583
                    continue
1✔
584
                else:
585
                    # calculate signature over previously serialized tlv records
586
                    # and store in kwargs for inclusion in tlv stream
587
                    merkle_root = _tlv_merkle_root(sign_over_tlvs)
1✔
588
                    priv = ecc.ECPrivkey(signing_key)
1✔
589
                    tag = b'lightning' + tlv_stream_name.encode('ascii') + b'signature'
1✔
590
                    signature = priv.schnorr_sign(bitcoin.bip340_tagged_hash(tag, merkle_root))
1✔
591
                    kwargs[tlv_record_name] = {'sig': signature}
1✔
592
            with io.BytesIO() as tlv_record_fd:
1✔
593
                for row in scheme:
1✔
594
                    if row[0] == "tlvtype":
1✔
595
                        pass
1✔
596
                    elif row[0] == "tlvdata":
1✔
597
                        # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
598
                        assert tlv_stream_name == row[1]
1✔
599
                        assert tlv_record_name == row[2]
1✔
600
                        field_name = row[3]
1✔
601
                        field_type = row[4]
1✔
602
                        field_count_str = row[5]
1✔
603
                        field_count = _resolve_field_count(field_count_str,
1✔
604
                                                           vars_dict=kwargs[tlv_record_name],
605
                                                           allow_any=True)
606
                        field_value = kwargs[tlv_record_name][field_name]
1✔
607
                        self.write_field(
1✔
608
                            fd=tlv_record_fd,
609
                            field_type=field_type,
610
                            count=field_count,
611
                            value=field_value)
612
                    else:
613
                        raise Exception(f"unexpected row in scheme: {row!r}")
×
614

615
                tlv_val = tlv_record_fd.getvalue()
1✔
616
                _write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_val)
1✔
617

618
                # if we need to sign the tlvs, we need the entire TLV, so serialize again
619
                # and collect in `sign_over_tlvs`
620
                # NOTE: assumption: there are no fields after 'signature' (240)
621
                if signing_key and tlv_record_name != 'signature':
1✔
622
                    with io.BytesIO() as tlvfd:
1✔
623
                        _write_tlv_record(fd=tlvfd, tlv_type=tlv_record_type, tlv_val=tlv_val)
1✔
624
                        sign_over_tlvs.append((tlv_record_type, tlvfd.getvalue()))
1✔
625

626
    def read_tlv_stream(self, *,
1✔
627
                        fd: io.BytesIO,
628
                        tlv_stream_name: str,
629
                        signing_key_path: Optional[Sequence[str]] = None) -> Dict[str, Dict[str, Any]]:
630
        sign_over_tlvs = []
1✔
631
        signature_seen = False
1✔
632
        parsed = {}  # type: Dict[str, Dict[str, Any]]
1✔
633
        scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
1✔
634
        last_seen_tlv_record_type = -1  # type: int
1✔
635
        while _num_remaining_bytes_to_read(fd) > 0:
1✔
636
            tlv_record_type, tlv_record_val, rawbytes = _read_tlv_record(fd=fd)
1✔
637
            if not (tlv_record_type > last_seen_tlv_record_type):
1✔
638
                raise MsgInvalidFieldOrder(f"TLV records must be monotonically increasing by type. "
1✔
639
                                           f"cur: {tlv_record_type}. prev: {last_seen_tlv_record_type}")
640
            last_seen_tlv_record_type = tlv_record_type
1✔
641
            try:
1✔
642
                scheme = scheme_map[tlv_record_type]
1✔
643
            except KeyError:
1✔
644
                if tlv_record_type % 2 == 0:
1✔
645
                    # unknown "even" type: hard fail
646
                    raise UnknownMandatoryTLVRecordType(f"{tlv_stream_name}/{tlv_record_type}") from None
1✔
647
                else:
648
                    # unknown "odd" type: skip it
649
                    continue
1✔
650
            tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]
1✔
651
            # collect tlvs for signature check
652
            if signing_key_path:
1✔
653
                if tlv_record_name == 'signature':
1✔
654
                    signature_seen = True
1✔
655
                    # verify
656
                    merkle_root = _tlv_merkle_root(sign_over_tlvs)
1✔
657
                    signature = tlv_record_val
1✔
658
                    tag = b'lightning' + tlv_stream_name.encode('ascii') + b'signature'
1✔
659
                    tagh = bitcoin.bip340_tagged_hash(tag, merkle_root)
1✔
660
                    signing_key = parsed
1✔
661
                    for key in signing_key_path:  # walk signing_key_path
1✔
662
                        signing_key = signing_key[key]
1✔
663
                    assert isinstance(signing_key, bytes)
1✔
664
                    correct = ecc.ECPubkey(signing_key).schnorr_verify(signature, tagh)
1✔
665
                    if not correct:
1✔
NEW
666
                        raise MsgInvalidSignature(f"invalid signature in {'.'.join(signing_key_path)}")
×
667
                else:
668
                    sign_over_tlvs.append((tlv_record_type, rawbytes))
1✔
669
            parsed[tlv_record_name] = {}
1✔
670
            with io.BytesIO(tlv_record_val) as tlv_record_fd:
1✔
671
                for row in scheme:
1✔
672
                    #print(f"row: {row!r}")
673
                    if row[0] == "tlvtype":
1✔
674
                        pass
1✔
675
                    elif row[0] == "tlvdata":
1✔
676
                        # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
677
                        assert tlv_stream_name == row[1]
1✔
678
                        assert tlv_record_name == row[2]
1✔
679
                        field_name = row[3]
1✔
680
                        field_type = row[4]
1✔
681
                        field_count_str = row[5]
1✔
682
                        field_count = _resolve_field_count(
1✔
683
                            field_count_str,
684
                            vars_dict=parsed[tlv_record_name],
685
                            allow_any=True)
686
                        #print(f">> count={field_count}. parsed={parsed}")
687
                        parsed[tlv_record_name][field_name] = self.read_field(
1✔
688
                            fd=tlv_record_fd,
689
                            field_type=field_type,
690
                            count=field_count)
691
                    else:
692
                        raise Exception(f"unexpected row in scheme: {row!r}")
×
693
                if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
1✔
694
                    raise MsgTrailingGarbage(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage")
1✔
695
        if signing_key_path and not signature_seen:
1✔
NEW
696
            raise MalformedMsg(f"signature expected but missing")
×
697
        return parsed
1✔
698

699
    def encode_msg(self, msg_type: str, **kwargs) -> bytes:
1✔
700
        """
701
        Encode kwargs into a Lightning message (bytes)
702
        of the type given in the msg_type string
703
        """
704
        #print(f">>> encode_msg. msg_type={msg_type}, payload={kwargs!r}")
705
        msg_type_bytes = self.msg_type_from_name[msg_type]
1✔
706
        scheme = self.msg_scheme_from_type[msg_type_bytes]
1✔
707
        with io.BytesIO() as fd:
1✔
708
            fd.write(msg_type_bytes)
1✔
709
            for row in scheme:
1✔
710
                if row[0] == "msgtype":
1✔
711
                    pass
1✔
712
                elif row[0] == "msgdata":
1✔
713
                    # msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
714
                    field_name = row[2]
1✔
715
                    field_type = row[3]
1✔
716
                    field_count_str = row[4]
1✔
717
                    #print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}")
718
                    field_count = _resolve_field_count(field_count_str, vars_dict=kwargs)
1✔
719
                    if field_name == "tlvs":
1✔
720
                        tlv_stream_name = field_type
1✔
721
                        if tlv_stream_name in kwargs:
1✔
722
                            self.write_tlv_stream(fd=fd, tlv_stream_name=tlv_stream_name, **(kwargs[tlv_stream_name]))
1✔
723
                        continue
1✔
724
                    try:
1✔
725
                        field_value = kwargs[field_name]
1✔
726
                    except KeyError:
1✔
727
                        field_value = 0  # default mandatory fields to zero
1✔
728
                    #print(f">>> encode_msg. writing field: {field_name}. value={field_value!r}. field_type={field_type!r}. count={field_count!r}")
729
                    _write_primitive_field(fd=fd, field_type=field_type, count=field_count, value=field_value)
1✔
730
                    #print(f">>> encode_msg. so far: {fd.getvalue().hex()}")
731
                else:
732
                    raise Exception(f"unexpected row in scheme: {row!r}")
×
733
            return fd.getvalue()
1✔
734

735
    def decode_msg(self, data: bytes) -> Tuple[str, dict]:
1✔
736
        """
737
        Decode Lightning message by reading the first
738
        two bytes to determine message type.
739

740
        Returns message type string and parsed message contents dict,
741
        or raises FailedToParseMsg.
742
        """
743
        #print(f"decode_msg >>> {data.hex()}")
744
        assert len(data) >= 2
1✔
745
        msg_type_bytes = data[:2]
1✔
746
        msg_type_int = int.from_bytes(msg_type_bytes, byteorder="big", signed=False)
1✔
747
        try:
1✔
748
            scheme = self.msg_scheme_from_type[msg_type_bytes]
1✔
749
        except KeyError:
1✔
750
            if msg_type_int % 2 == 0:  # even types must be understood: "mandatory"
1✔
751
                raise UnknownMandatoryMsgType(f"msg_type={msg_type_int}")
1✔
752
            else:  # odd types are ok not to understand: "optional"
753
                raise UnknownOptionalMsgType(f"msg_type={msg_type_int}")
1✔
754
        assert scheme[0][2] == msg_type_int
1✔
755
        msg_type_name = scheme[0][1]
1✔
756
        parsed = {}
1✔
757
        try:
1✔
758
            with io.BytesIO(data[2:]) as fd:
1✔
759
                for row in scheme:
1✔
760
                    #print(f"row: {row!r}")
761
                    if row[0] == "msgtype":
1✔
762
                        pass
1✔
763
                    elif row[0] == "msgdata":
1✔
764
                        field_name = row[2]
1✔
765
                        field_type = row[3]
1✔
766
                        field_count_str = row[4]
1✔
767
                        field_count = _resolve_field_count(field_count_str, vars_dict=parsed)
1✔
768
                        if field_name == "tlvs":
1✔
769
                            tlv_stream_name = field_type
1✔
770
                            d = self.read_tlv_stream(fd=fd, tlv_stream_name=tlv_stream_name)
1✔
771
                            parsed[tlv_stream_name] = d
1✔
772
                            continue
1✔
773
                        #print(f">> count={field_count}. parsed={parsed}")
774
                        parsed[field_name] = _read_primitive_field(fd=fd, field_type=field_type, count=field_count)
1✔
775
                    else:
776
                        raise Exception(f"unexpected row in scheme: {row!r}")
×
777
        except FailedToParseMsg as e:
1✔
778
            e.msg_type_int = msg_type_int
1✔
779
            e.msg_type_name = msg_type_name
1✔
780
            raise
1✔
781
        return msg_type_name, parsed
1✔
782

783

784
_inst = LNSerializer()
1✔
785
encode_msg = _inst.encode_msg
1✔
786
decode_msg = _inst.decode_msg
1✔
787

788

789
OnionWireSerializer = LNSerializer(for_onion_wire=True)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc