• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

llamerada-jp / colonio / 14953395849

11 May 2025 07:06AM UTC coverage: 65.743%. First build
14953395849

Pull #101

github

llamerada-jp
refine node id module

Signed-off-by: Yuji Ito <llamerada.jp@gmail.com>
Pull Request #101: refine node id module

126 of 189 new or added lines in 12 files covered. (66.67%)

2372 of 3608 relevant lines covered (65.74%)

44.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.57
/internal/network/node_accessor/node_link.go
1
/*
2
 * Copyright 2017- Yuji Ito <llamerada.jp@gmail.com>
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16
package node_accessor
17

18
import (
19
        "bytes"
20
        "context"
21
        "fmt"
22
        "log/slog"
23
        "math/bits"
24
        "sync"
25
        "time"
26

27
        proto "github.com/llamerada-jp/colonio/api/colonio/v1alpha"
28
        "github.com/llamerada-jp/colonio/config"
29
        "github.com/llamerada-jp/colonio/internal/shared"
30
        proto3 "google.golang.org/protobuf/proto"
31
)
32

33
type nodeLinkState int
34

35
const (
36
        nodeLinkStateConnecting nodeLinkState = iota
37
        nodeLinkStateOnline
38
        nodeLinkStateDisabled
39
)
40

41
type NodeLinkConfig struct {
42
        ctx    context.Context
43
        logger *slog.Logger
44

45
        // label is used for WebRTC data channel's label when isOffer is true
46
        label string
47

48
        ICEServers []*config.ICEServer
49
        // SessionTimeout is used to determine the timeout of the WebRTC session between nodes.
50
        SessionTimeout time.Duration
51
        // KeepaliveInterval is the interval to send a ping packet to tell living the node for each nodes.
52
        KeepaliveInterval time.Duration
53
        //  BufferInterval is maximum interval for buffering packets between nodes.
54
        BufferInterval time.Duration
55
        // PacketBaseBytes is a reference value for the packet size to be sent in WebRTC communication,
56
        // since WebRTC data channel may fail to send too large packets.
57
        PacketBaseBytes int
58
}
59

60
type nodeLinkHandler interface {
61
        nodeLinkChangeState(*nodeLink, nodeLinkState)
62
        nodeLinkUpdateICE(*nodeLink, string)
63
        nodeLinkRecvPacket(*nodeLink, *shared.Packet)
64
}
65

66
type waiting struct {
67
        packet  *shared.Packet
68
        content []byte
69
}
70

71
type nodeLink struct {
72
        config             *NodeLinkConfig
73
        handler            nodeLinkHandler
74
        ctx                context.Context
75
        cancel             context.CancelFunc
76
        webrtc             webRTCLink
77
        stateMtx           sync.RWMutex
78
        state              nodeLinkState
79
        keepaliveTimestamp time.Time
80
        queueMtx           sync.Mutex
81
        queue              []*waiting
82
        keepaliveTicker    *time.Ticker
83
        bufferTicker       *time.Ticker
84
        stackMtx           sync.Mutex
85
        stackID            uint32
86
        stack              []*proto.NodePacket
87
}
88

89
func newNodeLink(config *NodeLinkConfig, handler nodeLinkHandler, isOffer bool) (*nodeLink, error) {
51✔
90
        ctx, cancel := context.WithCancel(config.ctx)
51✔
91

51✔
92
        // The bufferTicker fires for configured intervals when some packets are in the buffer,
51✔
93
        // otherwise it disabled immediately.
51✔
94
        bufferTicker := time.NewTicker(1 * time.Second)
51✔
95
        if config.BufferInterval == 0 {
52✔
96
                bufferTicker.Stop()
1✔
97
        }
1✔
98

99
        var keepaliveTicker *time.Ticker
51✔
100
        if config.KeepaliveInterval > 0 {
102✔
101
                keepaliveTicker = time.NewTicker(config.KeepaliveInterval)
51✔
102
        }
51✔
103

104
        link := &nodeLink{
51✔
105
                config:             config,
51✔
106
                ctx:                ctx,
51✔
107
                cancel:             cancel,
51✔
108
                handler:            handler,
51✔
109
                stateMtx:           sync.RWMutex{},
51✔
110
                state:              nodeLinkStateConnecting,
51✔
111
                keepaliveTimestamp: time.Now(),
51✔
112
                queueMtx:           sync.Mutex{},
51✔
113
                queue:              make([]*waiting, 0),
51✔
114
                keepaliveTicker:    keepaliveTicker,
51✔
115
                bufferTicker:       bufferTicker,
51✔
116
                stackMtx:           sync.Mutex{},
51✔
117
                stackID:            0,
51✔
118
                stack:              nil,
51✔
119
        }
51✔
120

51✔
121
        label := ""
51✔
122
        if isOffer {
77✔
123
                label = config.label
26✔
124
        }
26✔
125
        var err error
51✔
126
        link.webrtc, err = defaultWebRTCLinkFactory(&webRTCLinkConfig{
51✔
127
                iceServers: config.ICEServers,
51✔
128
                isOffer:    isOffer,
51✔
129
                label:      label,
51✔
130
        }, &webRTCLinkEventHandler{
51✔
131
                raiseError:      link.webrtcRaiseError,
51✔
132
                changeLinkState: link.webrtcChangeLinkState,
51✔
133
                updateICE:       link.webrtcUpdateICE,
51✔
134
                recvData:        link.webrtcRecvData,
51✔
135
        })
51✔
136
        if err != nil {
51✔
137
                return nil, fmt.Errorf("failed to create WebRTCLink %w", err)
×
138
        }
×
139

140
        go link.routine()
51✔
141

51✔
142
        return link, nil
51✔
143
}
144

145
func (n *nodeLink) getLabel() string {
10✔
146
        return n.webrtc.getLabel()
10✔
147
}
10✔
148

149
func (n *nodeLink) getLocalSDP() (string, error) {
50✔
150
        sdp, err := n.webrtc.getLocalSDP()
50✔
151
        if err != nil {
50✔
152
                n.webrtc.disconnect()
×
153
                return "", fmt.Errorf("failed to get local SDP %w", err)
×
154
        }
×
155
        return sdp, nil
50✔
156
}
157

158
func (n *nodeLink) setRemoteSDP(sdp string) error {
49✔
159
        err := n.webrtc.setRemoteSDP(sdp)
49✔
160
        if err != nil {
49✔
161
                n.webrtc.disconnect()
×
162
                return fmt.Errorf("failed to set remote SDP %w", err)
×
163
        }
×
164
        return nil
49✔
165
}
166

167
func (n *nodeLink) updateICE(ice string) error {
98✔
168
        err := n.webrtc.updateICE(ice)
98✔
169
        if err != nil {
98✔
170
                n.webrtc.disconnect()
×
171
                return fmt.Errorf("failed to update ICE %w", err)
×
172
        }
×
173
        return nil
98✔
174
}
175

176
func (n *nodeLink) disconnect() error {
150✔
177
        n.stateMtx.Lock()
150✔
178
        if n.state != nodeLinkStateDisabled {
196✔
179
                n.state = nodeLinkStateDisabled
46✔
180
                defer n.handler.nodeLinkChangeState(n, nodeLinkStateDisabled)
46✔
181
        }
46✔
182
        n.stateMtx.Unlock()
150✔
183

150✔
184
        if n.bufferTicker != nil {
300✔
185
                n.bufferTicker.Stop()
150✔
186
        }
150✔
187
        if n.keepaliveTicker != nil {
300✔
188
                n.keepaliveTicker.Stop()
150✔
189
        }
150✔
190

191
        n.cancel()
150✔
192
        return n.webrtc.disconnect()
150✔
193
}
194

195
func (n *nodeLink) getLinkState() nodeLinkState {
341✔
196
        n.stateMtx.RLock()
341✔
197
        defer n.stateMtx.RUnlock()
341✔
198
        return n.state
341✔
199
}
341✔
200

201
func (n *nodeLink) sendPacket(packet *shared.Packet) error {
534✔
202
        content, err := proto3.Marshal(packet.Content)
534✔
203
        if err != nil {
534✔
204
                return fmt.Errorf("failed to marshal packet content %w", err)
×
205
        }
×
206

207
        waiting := &waiting{
534✔
208
                packet:  packet,
534✔
209
                content: content,
534✔
210
        }
534✔
211

534✔
212
        // send packet immediately if bufferInterval is 0
534✔
213
        if n.config.BufferInterval == 0 {
544✔
214
                n.queueMtx.Lock()
10✔
215
                n.queue = append(n.queue, waiting)
10✔
216
                n.queueMtx.Unlock()
10✔
217
                return n.flush()
10✔
218
        }
10✔
219

220
        // push packet to queue
221
        n.queueMtx.Lock()
524✔
222
        n.queue = append(n.queue, waiting)
524✔
223
        sum := 0
524✔
224
        count := len(n.queue)
524✔
225
        for _, w := range n.queue {
20,249✔
226
                sum += len(w.content)
19,725✔
227
        }
19,725✔
228
        n.queueMtx.Unlock()
524✔
229

524✔
230
        // send packet if passed bufferInterval or buffer size is over packetBaseBytes
524✔
231
        if sum > n.config.PacketBaseBytes {
533✔
232
                return n.flush()
9✔
233
        }
9✔
234

235
        if count == 1 {
532✔
236
                n.bufferTicker.Reset(n.config.BufferInterval)
17✔
237
        }
17✔
238

239
        return nil
515✔
240
}
241

242
func (n *nodeLink) flush() error {
123✔
243
        switch n.getLinkState() {
123✔
244
        case nodeLinkStateConnecting:
6✔
245
                return nil
6✔
246

247
        case nodeLinkStateDisabled:
×
248
                n.queueMtx.Lock()
×
249
                defer n.queueMtx.Unlock()
×
250
                n.queue = nil
×
251
                n.config.logger.Warn("link is disabled when flushing packet")
×
252
                return nil
×
253
        }
254

255
        n.queueMtx.Lock()
117✔
256
        queue := n.queue
117✔
257
        n.queue = make([]*waiting, 0)
117✔
258
        n.queueMtx.Unlock()
117✔
259

117✔
260
        p := &proto.NodePackets{}
117✔
261
        contentSize := 0
117✔
262
        for _, w := range queue {
651✔
263
                if contentSize+len(w.content) > n.config.PacketBaseBytes {
543✔
264
                        count, r := bits.Div(0, uint(len(w.content)+contentSize), uint(n.config.PacketBaseBytes))
9✔
265
                        if r != 0 {
18✔
266
                                count++
9✔
267
                        }
9✔
268
                        send := 0
9✔
269
                        for i := int(count - 1); i >= 0; i-- {
36✔
270
                                size := len(w.content) - send
27✔
271
                                if size+contentSize > n.config.PacketBaseBytes {
45✔
272
                                        size = n.config.PacketBaseBytes - contentSize
18✔
273
                                }
18✔
274
                                if i == 0 {
36✔
275
                                        p.Packets = append(p.Packets, &proto.NodePacket{
9✔
276
                                                Head: &proto.NodePacketHead{
9✔
277
                                                        DstNodeId: w.packet.DstNodeID.Proto(),
9✔
278
                                                        SrcNodeId: w.packet.SrcNodeID.Proto(),
9✔
279
                                                        HopCount:  w.packet.HopCount,
9✔
280
                                                        Mode:      uint32(w.packet.Mode),
9✔
281
                                                },
9✔
282
                                                Id:      w.packet.ID,
9✔
283
                                                Index:   uint32(i),
9✔
284
                                                Content: w.content[send : send+size],
9✔
285
                                        })
9✔
286
                                } else {
27✔
287
                                        p.Packets = append(p.Packets, &proto.NodePacket{
18✔
288
                                                Id:      w.packet.ID,
18✔
289
                                                Index:   uint32(i),
18✔
290
                                                Content: w.content[send : send+size],
18✔
291
                                        })
18✔
292
                                        send += size
18✔
293
                                        if !n.send(p) {
18✔
294
                                                return nil
×
295
                                        }
×
296
                                        p = &proto.NodePackets{}
18✔
297
                                        contentSize = 0
18✔
298
                                }
299
                        }
300

301
                } else {
525✔
302
                        p.Packets = append(p.Packets, &proto.NodePacket{
525✔
303
                                Head: &proto.NodePacketHead{
525✔
304
                                        DstNodeId: w.packet.DstNodeID.Proto(),
525✔
305
                                        SrcNodeId: w.packet.SrcNodeID.Proto(),
525✔
306
                                        HopCount:  w.packet.HopCount,
525✔
307
                                        Mode:      uint32(w.packet.Mode),
525✔
308
                                },
525✔
309
                                Id:      w.packet.ID,
525✔
310
                                Index:   0,
525✔
311
                                Content: w.content,
525✔
312
                        })
525✔
313
                        contentSize += len(w.content)
525✔
314
                }
525✔
315
        }
316
        if len(p.Packets) != 0 {
146✔
317
                if !n.send(p) {
29✔
318
                        return nil
×
319
                }
×
320
        }
321

322
        if n.config.KeepaliveInterval != 0 {
234✔
323
                n.keepaliveTicker.Reset(n.config.KeepaliveInterval)
117✔
324
        }
117✔
325
        // wait next packet idly
326
        if n.config.BufferInterval > 0 {
223✔
327
                n.bufferTicker.Reset(1 * time.Second)
106✔
328
        }
106✔
329

330
        return nil
117✔
331
}
332

333
func (n *nodeLink) routine() {
51✔
334
        // ticker to check keepalive timeout
51✔
335
        interval := n.config.KeepaliveInterval / 2
51✔
336
        if interval > 1*time.Second {
57✔
337
                interval = 1 * time.Second
6✔
338
        }
6✔
339
        ticker := time.NewTicker(interval)
51✔
340

51✔
341
        for {
316✔
342
                select {
265✔
343
                case <-n.ctx.Done():
51✔
344
                        n.disconnect()
51✔
345
                        return
51✔
346

347
                case <-n.keepaliveTicker.C:
40✔
348
                        n.queueMtx.Lock()
40✔
349
                        count := len(n.queue)
40✔
350
                        n.queueMtx.Unlock()
40✔
351
                        if count != 0 {
40✔
352
                                err := n.flush()
×
353
                                if err != nil {
×
354
                                        n.config.logger.Warn("failed to flush packet", slog.String("err", err.Error()))
×
355
                                        n.disconnect()
×
356
                                }
×
357
                        } else {
40✔
358
                                n.sendKeepalive()
40✔
359
                        }
40✔
360

361
                case <-n.bufferTicker.C:
55✔
362
                        err := n.flush()
55✔
363
                        if err != nil {
55✔
364
                                n.config.logger.Warn("failed to flush packet", slog.String("err", err.Error()))
×
365
                                n.disconnect()
×
366
                        }
×
367

368
                case <-ticker.C:
119✔
369
                        // wake up to check keepalive timeout
370
                }
371

372
                n.stateMtx.RLock()
214✔
373
                timedOut := time.Now().After(n.keepaliveTimestamp.Add(n.config.SessionTimeout))
214✔
374
                n.stateMtx.RUnlock()
214✔
375
                if timedOut {
216✔
376
                        n.disconnect()
2✔
377
                }
2✔
378
        }
379
}
380

381
func (n *nodeLink) sendKeepalive() {
40✔
382
        if n.getLinkState() != nodeLinkStateOnline {
45✔
383
                return
5✔
384
        }
5✔
385

386
        if n.send(&proto.NodePackets{}) {
70✔
387
                n.keepaliveTicker.Reset(n.config.KeepaliveInterval)
35✔
388
        }
35✔
389
}
390

391
func (n *nodeLink) send(packet *proto.NodePackets) bool {
82✔
392
        data, err := proto3.Marshal(packet)
82✔
393
        if err != nil {
82✔
394
                panic(err)
×
395
        }
396

397
        err = n.webrtc.send(data)
82✔
398
        if err != nil {
82✔
399
                n.config.logger.Warn("failed to send packet", slog.String("err", err.Error()))
×
400
                err = n.disconnect()
×
401
                if err != nil {
×
402
                        n.config.logger.Warn("failed to disconnect", slog.String("err", err.Error()))
×
403
                }
×
404
                return false
×
405
        }
406
        return true
82✔
407
}
408

409
func (n *nodeLink) webrtcRaiseError(err string) {
×
410
        n.config.logger.Warn("webrtc error", slog.String("err", err))
×
411
        n.disconnect()
×
412
}
×
413

414
func (n *nodeLink) webrtcChangeLinkState(active, online bool) {
144✔
415
        n.stateMtx.Lock()
144✔
416
        prevState := n.state
144✔
417
        if active {
193✔
418
                if online {
98✔
419
                        n.state = nodeLinkStateOnline
49✔
420
                } else {
49✔
421
                        n.state = nodeLinkStateConnecting
×
422
                }
×
423
        } else {
95✔
424
                n.state = nodeLinkStateDisabled
95✔
425
        }
95✔
426

427
        if prevState != n.state {
198✔
428
                defer func(s nodeLinkState) {
108✔
429
                        n.handler.nodeLinkChangeState(n, s)
54✔
430
                        switch s {
54✔
431
                        case nodeLinkStateOnline:
49✔
432
                                err := n.flush()
49✔
433
                                if err != nil {
49✔
434
                                        n.config.logger.Warn("failed to flush packet", slog.String("err", err.Error()))
×
435
                                        n.disconnect()
×
436
                                }
×
437

438
                        case nodeLinkStateDisabled:
5✔
439
                                n.disconnect()
5✔
440
                        }
441
                }(n.state)
442
        }
443
        n.stateMtx.Unlock()
144✔
444
}
445

446
func (n *nodeLink) webrtcUpdateICE(ice string) {
100✔
447
        n.handler.nodeLinkUpdateICE(n, ice)
100✔
448
}
100✔
449

450
func (n *nodeLink) webrtcRecvData(data []byte) {
78✔
451
        p := &proto.NodePackets{}
78✔
452
        err := proto3.Unmarshal(data, p)
78✔
453
        if err != nil {
78✔
454
                n.config.logger.Warn("failed to unmarshal packet", slog.String("err", err.Error()))
×
455
                n.disconnect()
×
456
                return
×
457
        }
×
458

459
        n.stackMtx.Lock()
78✔
460
        defer n.stackMtx.Unlock()
78✔
461

78✔
462
        for _, packet := range p.Packets {
630✔
463
                if n.stack != nil && n.stackID != packet.Id {
552✔
464
                        n.config.logger.Warn("received packet id is not continuous")
×
465
                        n.disconnect()
×
466
                        return
×
467
                }
×
468

469
                var packets []*proto.NodePacket
552✔
470
                if packet.Index != 0 {
570✔
471
                        if n.stack == nil {
27✔
472
                                n.stack = []*proto.NodePacket{packet}
9✔
473
                                n.stackID = packet.Id
9✔
474
                        } else {
18✔
475
                                n.stack = append(n.stack, packet)
9✔
476
                        }
9✔
477
                        continue
18✔
478
                } else if n.stack != nil {
543✔
479
                        packets = append(n.stack, packet)
9✔
480
                        n.stack = nil
9✔
481
                } else {
534✔
482
                        packets = []*proto.NodePacket{packet}
525✔
483
                }
525✔
484

485
                contentsList := make([][]byte, 0)
534✔
486
                var head *proto.NodePacketHead
534✔
487
                for i, p := range packets {
1,086✔
488
                        // check packet format
552✔
489
                        if p.Index != uint32(len(packets)-i-1) {
552✔
490
                                n.config.logger.Warn("stacked packet index is not continuous")
×
491
                                n.disconnect()
×
492
                                return
×
493
                        }
×
494
                        if (p.Index == 0 && p.Head == nil) || (p.Index != 0 && p.Head != nil) {
552✔
495
                                n.config.logger.Warn("packet head is not set correctly")
×
496
                                n.disconnect()
×
497
                                return
×
498
                        }
×
499

500
                        contentsList = append(contentsList, p.Content)
552✔
501
                        if p.Index == 0 {
1,086✔
502
                                head = p.Head
534✔
503
                        }
534✔
504
                }
505

506
                content := &proto.PacketContent{}
534✔
507
                var contentBin []byte
534✔
508
                if len(contentsList) == 1 {
1,059✔
509
                        contentBin = contentsList[0]
525✔
510
                } else {
534✔
511
                        contentBin = bytes.Join(contentsList, []byte{})
9✔
512
                }
9✔
513
                err := proto3.Unmarshal(contentBin, content)
534✔
514
                if err != nil {
534✔
515
                        n.config.logger.Warn("failed to unmarshal packet content", slog.String("err", err.Error()))
×
516
                        n.disconnect()
×
517
                        return
×
518
                }
×
519
                dstNodeID, err := shared.NewNodeIDFromProto(head.DstNodeId)
534✔
520
                if err != nil {
534✔
NEW
521
                        n.config.logger.Warn("failed to unmarshal destination node id", slog.String("err", err.Error()))
×
NEW
522
                        n.disconnect()
×
NEW
523
                        return
×
NEW
524
                }
×
525
                srcNodeID, err := shared.NewNodeIDFromProto(head.SrcNodeId)
534✔
526
                if err != nil {
534✔
NEW
527
                        n.config.logger.Warn("failed to unmarshal source node id", slog.String("err", err.Error()))
×
NEW
528
                        n.disconnect()
×
NEW
529
                        return
×
NEW
530
                }
×
531
                packet := &shared.Packet{
534✔
532
                        DstNodeID: dstNodeID,
534✔
533
                        SrcNodeID: srcNodeID,
534✔
534
                        ID:        packet.Id,
534✔
535
                        HopCount:  head.HopCount,
534✔
536
                        Mode:      shared.PacketMode(head.Mode),
534✔
537
                        Content:   content,
534✔
538
                }
534✔
539
                n.handler.nodeLinkRecvPacket(n, packet)
534✔
540
        }
541

542
        n.stateMtx.Lock()
78✔
543
        n.keepaliveTimestamp = time.Now()
78✔
544
        n.stateMtx.Unlock()
78✔
545
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc