• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mendersoftware / mender-server / 1491991335

11 Oct 2024 01:37PM UTC coverage: 74.12% (+0.005%) from 74.115%
1491991335

push

gitlab-ci

web-flow
Merge pull request #98 from mzedel/fix/readme

fix: fixed an issue that prevented enterprise demo tenant creation

4399 of 6347 branches covered (69.31%)

Branch coverage included in aggregate %.

41997 of 56249 relevant lines covered (74.66%)

29.24 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.4
/backend/services/deviceconnect/api/http/management.go
1
// Copyright 2023 Northern.tech AS
2
//
3
//    Licensed under the Apache License, Version 2.0 (the "License");
4
//    you may not use this file except in compliance with the License.
5
//    You may obtain a copy of the License at
6
//
7
//        http://www.apache.org/licenses/LICENSE-2.0
8
//
9
//    Unless required by applicable law or agreed to in writing, software
10
//    distributed under the License is distributed on an "AS IS" BASIS,
11
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
//    See the License for the specific language governing permissions and
13
//    limitations under the License.
14

15
package http
16

17
import (
18
        "bufio"
19
        "context"
20
        "encoding/binary"
21
        "encoding/json"
22
        "io"
23
        "net/http"
24
        "strconv"
25
        "sync"
26
        "time"
27

28
        "github.com/gin-gonic/gin"
29
        validation "github.com/go-ozzo/ozzo-validation/v4"
30
        "github.com/gorilla/websocket"
31
        natsio "github.com/nats-io/nats.go"
32
        "github.com/pkg/errors"
33
        "github.com/vmihailenco/msgpack/v5"
34

35
        "github.com/mendersoftware/mender-server/pkg/identity"
36
        "github.com/mendersoftware/mender-server/pkg/log"
37
        "github.com/mendersoftware/mender-server/pkg/requestid"
38
        "github.com/mendersoftware/mender-server/pkg/rest.utils"
39
        "github.com/mendersoftware/mender-server/pkg/ws"
40
        "github.com/mendersoftware/mender-server/pkg/ws/menderclient"
41
        "github.com/mendersoftware/mender-server/pkg/ws/shell"
42

43
        "github.com/mendersoftware/mender-server/services/deviceconnect/app"
44
        "github.com/mendersoftware/mender-server/services/deviceconnect/client/nats"
45
        "github.com/mendersoftware/mender-server/services/deviceconnect/model"
46
)
47

48
// HTTP errors
49
var (
50
        ErrMissingUserAuthentication = errors.New(
51
                "missing or non-user identity in the authorization headers",
52
        )
53
        ErrMsgSessionLimit = "session byte limit exceeded"
54

55
        //The name of the field holding a number of milliseconds to sleep between
56
        //the consecutive writes of session recording data. Note that it does not have
57
        //anything to do with the sleep between the keystrokes send, lines printed,
58
        //or screen blinks, we are only aware of the stream of bytes.
59
        PlaybackSleepIntervalMsField = "sleep_ms"
60

61
        //The name of the field in the query parameter to GET that holds the id of a session
62
        PlaybackSessionIDField = "sessionId"
63

64
        //The threshold between the shell commands received (keystrokes) above which the
65
        //delay control message is saved (1.5 seconds)
66
        keyStrokeDelayRecordingThresholdNs = int64(1500 * 1000000)
67

68
        //The key stroke delay is recorded in two bytes, so this is the maximal
69
        //possible delay. We round down to this if the real delay is larger
70
        keyStrokeMaxDelayRecording = int64(65535 * 1000000)
71
)
72

73
const channelSize = 25 // TODO make configurable
74

75
const (
76
        PropertyUserID = "user_id"
77
)
78

79
var wsUpgrader = websocket.Upgrader{
80
        Subprotocols: []string{"protomsg/msgpack"},
81
        CheckOrigin:  allowAllOrigins,
82
        Error: func(
83
                w http.ResponseWriter, r *http.Request, s int, e error,
84
        ) {
3✔
85
                w.WriteHeader(s)
3✔
86
                enc := json.NewEncoder(w)
3✔
87
                _ = enc.Encode(rest.Error{
3✔
88
                        Err:       e.Error(),
3✔
89
                        RequestID: requestid.FromContext(r.Context())},
3✔
90
                )
3✔
91
        },
3✔
92
}
93

94
// ManagementController container for end-points
95
type ManagementController struct {
96
        app  app.App
97
        nats nats.Client
98
}
99

100
// NewManagementController returns a new ManagementController
101
func NewManagementController(
102
        app app.App,
103
        nc nats.Client,
104
) *ManagementController {
3✔
105
        return &ManagementController{
3✔
106
                app:  app,
3✔
107
                nats: nc,
3✔
108
        }
3✔
109
}
3✔
110

111
// GetDevice returns a device
112
func (h ManagementController) GetDevice(c *gin.Context) {
3✔
113
        ctx := c.Request.Context()
3✔
114

3✔
115
        idata := identity.FromContext(ctx)
3✔
116
        if idata == nil || !idata.IsUser {
3✔
117
                c.JSON(http.StatusBadRequest, gin.H{
×
118
                        "error": ErrMissingUserAuthentication.Error(),
×
119
                })
×
120
                return
×
121
        }
×
122
        tenantID := idata.Tenant
3✔
123
        deviceID := c.Param("deviceId")
3✔
124

3✔
125
        device, err := h.app.GetDevice(ctx, tenantID, deviceID)
3✔
126
        if err == app.ErrDeviceNotFound {
4✔
127
                c.JSON(http.StatusNotFound, gin.H{
1✔
128
                        "error": err.Error(),
1✔
129
                })
1✔
130
                return
1✔
131
        } else if err != nil {
5✔
132
                c.JSON(http.StatusBadRequest, gin.H{
1✔
133
                        "error": err.Error(),
1✔
134
                })
1✔
135
                return
1✔
136
        }
1✔
137

138
        c.JSON(http.StatusOK, device)
3✔
139
}
140

141
// Connect extracts identity from request, checks user permissions
142
// and calls ConnectDevice
143
func (h ManagementController) Connect(c *gin.Context) {
2✔
144
        ctx := c.Request.Context()
2✔
145
        l := log.FromContext(ctx)
2✔
146

2✔
147
        idata := identity.FromContext(ctx)
2✔
148
        if !idata.IsUser {
2✔
149
                c.JSON(http.StatusBadRequest, gin.H{
×
150
                        "error": ErrMissingUserAuthentication.Error(),
×
151
                })
×
152
                return
×
153
        }
×
154

155
        tenantID := idata.Tenant
2✔
156
        userID := idata.Subject
2✔
157
        deviceID := c.Param("deviceId")
2✔
158

2✔
159
        session := &model.Session{
2✔
160
                TenantID:           tenantID,
2✔
161
                UserID:             userID,
2✔
162
                DeviceID:           deviceID,
2✔
163
                StartTS:            time.Now(),
2✔
164
                BytesRecordedMutex: &sync.Mutex{},
2✔
165
                Types:              []string{},
2✔
166
        }
2✔
167

2✔
168
        // Prepare the user session
2✔
169
        err := h.app.PrepareUserSession(ctx, session)
2✔
170
        if err == app.ErrDeviceNotFound || err == app.ErrDeviceNotConnected {
4✔
171
                c.JSON(http.StatusNotFound, gin.H{
2✔
172
                        "error": err.Error(),
2✔
173
                })
2✔
174
                return
2✔
175
        } else if _, ok := errors.Cause(err).(validation.Errors); ok {
4✔
176
                c.JSON(http.StatusBadRequest, gin.H{
×
177
                        "error": err.Error(),
×
178
                })
×
179
                return
×
180
        } else if err != nil {
3✔
181
                l.Error(err)
1✔
182
                c.JSON(http.StatusInternalServerError, gin.H{
1✔
183
                        "error": err.Error(),
1✔
184
                })
1✔
185
                return
1✔
186
        }
1✔
187
        defer func() {
4✔
188
                err := h.app.FreeUserSession(ctx, session.ID, session.Types)
2✔
189
                if err != nil {
2✔
190
                        l.Warnf("failed to free session: %s", err.Error())
×
191
                }
×
192
        }()
193

194
        deviceChan := make(chan *natsio.Msg, channelSize)
2✔
195
        sub, err := h.nats.ChanSubscribe(session.Subject(tenantID), deviceChan)
2✔
196
        if err != nil {
2✔
197
                l.Error(err)
×
198
                c.JSON(http.StatusInternalServerError, gin.H{
×
199
                        "error": "failed to establish internal device session",
×
200
                })
×
201
                return
×
202
        }
×
203
        //nolint:errcheck
204
        defer sub.Unsubscribe()
2✔
205

2✔
206
        // upgrade get request to websocket protocol
2✔
207
        conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
2✔
208
        if err != nil {
4✔
209
                err = errors.Wrap(err, "unable to upgrade the request to websocket protocol")
2✔
210
                l.Error(err)
2✔
211
                // upgrader.Upgrade has already responded
2✔
212
                return
2✔
213
        }
2✔
214
        conn.SetReadLimit(int64(app.MessageSizeLimit))
2✔
215

2✔
216
        //nolint:errcheck
2✔
217
        h.ConnectServeWS(ctx, conn, session, deviceChan)
2✔
218
}
219

220
func (h ManagementController) Playback(c *gin.Context) {
2✔
221
        ctx := c.Request.Context()
2✔
222
        l := log.FromContext(ctx)
2✔
223

2✔
224
        idata := identity.FromContext(ctx)
2✔
225
        if !idata.IsUser {
2✔
226
                c.JSON(http.StatusBadRequest, gin.H{
×
227
                        "error": ErrMissingUserAuthentication.Error(),
×
228
                })
×
229
                return
×
230
        }
×
231

232
        tenantID := idata.Tenant
2✔
233
        userID := idata.Subject
2✔
234
        sessionID := c.Param(PlaybackSessionIDField)
2✔
235
        session := &model.Session{
2✔
236
                TenantID:           tenantID,
2✔
237
                UserID:             userID,
2✔
238
                StartTS:            time.Now(),
2✔
239
                BytesRecordedMutex: &sync.Mutex{},
2✔
240
        }
2✔
241
        sleepInterval := c.Param(PlaybackSleepIntervalMsField)
2✔
242
        sleepMilliseconds := uint(app.DefaultPlaybackSleepIntervalMs)
2✔
243
        if len(sleepInterval) > 1 {
2✔
244
                n, err := strconv.ParseUint(sleepInterval, 10, 32)
×
245
                if err != nil {
×
246
                        sleepMilliseconds = uint(n)
×
247
                }
×
248
        }
249

250
        l.Infof("Playing back the session session_id=%s", sessionID)
2✔
251

2✔
252
        // upgrade get request to websocket protocol
2✔
253
        conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
2✔
254
        if err != nil {
4✔
255
                err = errors.Wrap(err, "unable to upgrade the request to websocket protocol")
2✔
256
                l.Error(err)
2✔
257
                return
2✔
258
        }
2✔
259
        conn.SetReadLimit(int64(app.MessageSizeLimit))
1✔
260

1✔
261
        deviceChan := make(chan *natsio.Msg, channelSize)
1✔
262
        errChan := make(chan error, 1)
1✔
263

1✔
264
        //nolint:errcheck
1✔
265
        go h.websocketWriter(ctx,
1✔
266
                conn,
1✔
267
                session,
1✔
268
                deviceChan,
1✔
269
                errChan,
1✔
270
                bufio.NewWriterSize(io.Discard, app.RecorderBufferSize),
1✔
271
                bufio.NewWriterSize(io.Discard, app.RecorderBufferSize))
1✔
272

1✔
273
        go func() {
2✔
274
                err = h.app.GetSessionRecording(ctx,
1✔
275
                        sessionID,
1✔
276
                        app.NewPlayback(deviceChan, sleepMilliseconds))
1✔
277
                if err != nil {
1✔
278
                        err = errors.Wrap(err, "unable to get the session.")
×
279
                        errChan <- err
×
280
                        return
×
281
                }
×
282
        }()
283
        // We need to keep reading in order to keep ping/pong handlers functioning.
284
        for ; err == nil; _, _, err = conn.NextReader() {
2✔
285
        }
1✔
286
}
287

288
func websocketPing(conn *websocket.Conn) bool {
1✔
289
        pongWaitString := strconv.Itoa(int(pongWait.Seconds()))
1✔
290
        if err := conn.WriteControl(
1✔
291
                websocket.PingMessage,
1✔
292
                []byte(pongWaitString),
1✔
293
                time.Now().Add(writeWait),
1✔
294
        ); err != nil {
1✔
295
                return false
×
296
        }
×
297
        return true
1✔
298
}
299

300
func writerFinalizer(conn *websocket.Conn, e *error, l *log.Logger) {
2✔
301
        err := *e
2✔
302
        if err != nil {
4✔
303
                if !websocket.IsUnexpectedCloseError(errors.Cause(err)) {
4✔
304
                        errMsg := err.Error()
2✔
305
                        errBody := make([]byte, len(errMsg)+2)
2✔
306
                        binary.BigEndian.PutUint16(errBody,
2✔
307
                                websocket.CloseInternalServerErr)
2✔
308
                        copy(errBody[2:], errMsg)
2✔
309
                        errClose := conn.WriteControl(
2✔
310
                                websocket.CloseMessage,
2✔
311
                                errBody,
2✔
312
                                time.Now().Add(writeWait),
2✔
313
                        )
2✔
314
                        if errClose != nil {
3✔
315
                                err = errors.Wrapf(err,
1✔
316
                                        "error sending websocket close frame: %s",
1✔
317
                                        errClose.Error(),
1✔
318
                                )
1✔
319
                        }
1✔
320
                }
321
                l.Errorf("websocket closed with error: %s", err.Error())
2✔
322
        }
323
        conn.Close()
2✔
324
}
325

326
// websocketWriter is the go-routine responsible for the writing end of the
327
// websocket. The routine forwards messages posted on the NATS session subject
328
// and periodically pings the connection. If the connection times out or a
329
// protocol violation occurs, the routine closes the connection.
330
func (h ManagementController) websocketWriter(
331
        ctx context.Context,
332
        conn *websocket.Conn,
333
        session *model.Session,
334
        deviceChan <-chan *natsio.Msg,
335
        errChan <-chan error,
336
        recorderBuffered *bufio.Writer,
337
        controlRecorderBuffered *bufio.Writer,
338
) (err error) {
2✔
339
        l := log.FromContext(ctx)
2✔
340
        defer writerFinalizer(conn, &err, l)
2✔
341

2✔
342
        // handle the ping-pong connection health check
2✔
343
        err = conn.SetReadDeadline(time.Now().Add(pongWait))
2✔
344
        if err != nil {
2✔
345
                l.Error(err)
×
346
                return err
×
347
        }
×
348

349
        pingPeriod := (pongWait * 9) / 10
2✔
350
        ticker := time.NewTicker(pingPeriod)
2✔
351
        defer ticker.Stop()
2✔
352
        conn.SetPongHandler(func(string) error {
3✔
353
                ticker.Reset(pingPeriod)
1✔
354
                return conn.SetReadDeadline(time.Now().Add(pongWait))
1✔
355
        })
1✔
356
        conn.SetPingHandler(func(msg string) error {
3✔
357
                ticker.Reset(pingPeriod)
1✔
358
                err := conn.SetReadDeadline(time.Now().Add(pongWait))
1✔
359
                if err != nil {
1✔
360
                        return err
×
361
                }
×
362
                return conn.WriteControl(
1✔
363
                        websocket.PongMessage,
1✔
364
                        []byte(msg),
1✔
365
                        time.Now().Add(writeWait),
1✔
366
                )
1✔
367
        })
368

369
        defer recorderBuffered.Flush()
2✔
370
        defer controlRecorderBuffered.Flush()
2✔
371
        recordedBytes := 0
2✔
372
        controlBytes := 0
2✔
373

2✔
374
        sessOverLimit := false
2✔
375
        sessOverLimitHandled := false
2✔
376

2✔
377
        lastKeystrokeAt := time.Now().UTC().UnixNano()
2✔
378
Loop:
2✔
379
        for {
4✔
380
                var forwardedMsg []byte
2✔
381

2✔
382
                select {
2✔
383
                case msg := <-deviceChan:
2✔
384
                        mr := &ws.ProtoMsg{}
2✔
385
                        err = msgpack.Unmarshal(msg.Data, mr)
2✔
386
                        if err != nil {
2✔
387
                                return err
×
388
                        }
×
389

390
                        forwardedMsg = msg.Data
2✔
391

2✔
392
                        if mr.Header.Proto == ws.ProtoTypeShell {
4✔
393
                                switch mr.Header.MsgType {
2✔
394
                                case shell.MessageTypeShellCommand:
2✔
395

2✔
396
                                        if recordedBytes >= app.MessageSizeLimit ||
2✔
397
                                                controlBytes >= app.MessageSizeLimit {
3✔
398
                                                sessOverLimit = true
1✔
399

1✔
400
                                                errMsg := h.handleSessLimit(ctx,
1✔
401
                                                        session,
1✔
402
                                                        &sessOverLimitHandled,
1✔
403
                                                )
1✔
404

1✔
405
                                                //override original message with shell error
1✔
406
                                                if errMsg != nil {
2✔
407
                                                        forwardedMsg = errMsg
1✔
408
                                                }
1✔
409
                                        } else {
2✔
410
                                                if err = recordSession(ctx,
2✔
411
                                                        mr,
2✔
412
                                                        recorderBuffered,
2✔
413
                                                        controlRecorderBuffered,
2✔
414
                                                        &recordedBytes,
2✔
415
                                                        &controlBytes,
2✔
416
                                                        &lastKeystrokeAt,
2✔
417
                                                        session,
2✔
418
                                                ); err != nil {
2✔
419
                                                        return err
×
420
                                                }
×
421
                                        }
422

423
                                case shell.MessageTypeStopShell:
×
424
                                        l.Debugf("session logging: recorderBuffered.Flush()"+
×
425
                                                " at %d on stop shell", recordedBytes)
×
426
                                        recorderBuffered.Flush()
×
427
                                }
428
                        }
429

430
                        if !sessOverLimit {
4✔
431
                                err = conn.WriteMessage(websocket.BinaryMessage, forwardedMsg)
2✔
432
                                if err != nil {
2✔
433
                                        l.Error(err)
×
434
                                        break Loop
×
435
                                }
436
                        }
437
                case <-ctx.Done():
1✔
438
                        break Loop
1✔
439
                case <-ticker.C:
1✔
440
                        if !websocketPing(conn) {
1✔
441
                                err = errors.New("connection timeout")
×
442
                                break Loop
×
443
                        }
444
                case err := <-errChan:
2✔
445
                        return err
2✔
446
                }
447
        }
448
        return err
1✔
449
}
450

451
func (h ManagementController) handleSessLimit(ctx context.Context,
452
        session *model.Session,
453
        handled *bool,
454
) []byte {
1✔
455
        l := log.FromContext(ctx)
1✔
456

1✔
457
        // possible error return message (ws->user)
1✔
458
        var retMsg []byte
1✔
459

1✔
460
        // attempt to clean up once
1✔
461
        if !(*handled) {
2✔
462
                sendLimitErrDevice(ctx, session, h.nats)
1✔
463
                userErrMsg, err := prepLimitErrUser(ctx, session)
1✔
464
                if err != nil {
1✔
465
                        l.Errorf("session limit: " +
×
466
                                "failed to notify user")
×
467
                }
×
468

469
                retMsg = userErrMsg
1✔
470

1✔
471
                err = h.app.FreeUserSession(ctx, session.ID, session.Types)
1✔
472
                if err != nil {
1✔
473
                        l.Warnf("failed to free session"+
×
474
                                "that went over limit: %s", err.Error())
×
475
                }
×
476

477
                *handled = true
1✔
478
        }
479

480
        return retMsg
1✔
481
}
482

483
func recordSession(ctx context.Context,
484
        msg *ws.ProtoMsg,
485
        recorder io.Writer,
486
        recorderCtrl io.Writer,
487
        recBytes *int,
488
        ctrlBytes *int,
489
        lastKeystrokeAt *int64,
490
        session *model.Session) error {
2✔
491
        l := log.FromContext(ctx)
2✔
492

2✔
493
        b, e := recorder.Write(msg.Body)
2✔
494
        if e != nil {
2✔
495
                l.Errorf("session logging: "+
×
496
                        "recorderBuffered.Write"+
×
497
                        "(len=%d)=%d,%+v",
×
498
                        len(msg.Body), b, e)
×
499
        }
×
500
        timeNowUTC := time.Now().UTC().UnixNano()
2✔
501
        keystrokeDelay := timeNowUTC - (*lastKeystrokeAt)
2✔
502
        if keystrokeDelay >= keyStrokeDelayRecordingThresholdNs {
2✔
503
                if keystrokeDelay > keyStrokeMaxDelayRecording {
×
504
                        keystrokeDelay = keyStrokeMaxDelayRecording
×
505
                }
×
506

507
                controlMsg := app.Control{
×
508
                        Type:   app.DelayMessage,
×
509
                        Offset: *recBytes,
×
510
                        DelayMs: uint16(float64(keystrokeDelay) *
×
511
                                0.000001),
×
512
                        TerminalHeight: 0,
×
513
                        TerminalWidth:  0,
×
514
                }
×
515
                n, _ := recorderCtrl.Write(
×
516
                        controlMsg.MarshalBinary())
×
517
                l.Debugf("saving control delay message: %+v/%d",
×
518
                        controlMsg, n)
×
519
                (*ctrlBytes) += n
×
520
        }
521

522
        (*lastKeystrokeAt) = timeNowUTC
2✔
523

2✔
524
        (*recBytes) += len(msg.Body)
2✔
525
        session.BytesRecordedMutex.Lock()
2✔
526
        session.BytesRecorded = *recBytes
2✔
527
        session.BytesRecordedMutex.Unlock()
2✔
528

2✔
529
        return nil
2✔
530
}
531

532
// prepLimitErrUser preps a session limit exceeded error for the user (shell cmd + err status)
533
func prepLimitErrUser(ctx context.Context, session *model.Session) ([]byte, error) {
1✔
534
        userErrMsg := ws.ProtoMsg{
1✔
535
                Header: ws.ProtoHdr{
1✔
536
                        Proto:     ws.ProtoTypeShell,
1✔
537
                        MsgType:   shell.MessageTypeShellCommand,
1✔
538
                        SessionID: session.ID,
1✔
539
                        Properties: map[string]interface{}{
1✔
540
                                "status": shell.ErrorMessage,
1✔
541
                        },
1✔
542
                },
1✔
543
                Body: []byte(ErrMsgSessionLimit),
1✔
544
        }
1✔
545

1✔
546
        return msgpack.Marshal(userErrMsg)
1✔
547
}
1✔
548

549
// sendLimitErrDevice preps and sends
550
// session limit exceeded error to device (stop shell + err status)
551
// this is best effort, log and swallow errors
552
func sendLimitErrDevice(ctx context.Context, session *model.Session, nats nats.Client) {
1✔
553
        l := log.FromContext(ctx)
1✔
554

1✔
555
        msg := ws.ProtoMsg{
1✔
556
                Header: ws.ProtoHdr{
1✔
557
                        Proto:     ws.ProtoTypeShell,
1✔
558
                        MsgType:   shell.MessageTypeStopShell,
1✔
559
                        SessionID: session.ID,
1✔
560
                        Properties: map[string]interface{}{
1✔
561
                                "status":       shell.ErrorMessage,
1✔
562
                                PropertyUserID: session.UserID,
1✔
563
                        },
1✔
564
                },
1✔
565
                Body: []byte(ErrMsgSessionLimit),
1✔
566
        }
1✔
567
        data, err := msgpack.Marshal(msg)
1✔
568
        if err != nil {
1✔
569
                l.Errorf(
×
570
                        "session limit: "+
×
571
                                "failed to prep stop session"+
×
572
                                "%s message to device: %s, error %v",
×
573
                        session.ID,
×
574
                        session.DeviceID,
×
575
                        err,
×
576
                )
×
577
        }
×
578
        err = nats.Publish(model.GetDeviceSubject(
1✔
579
                session.TenantID, session.DeviceID),
1✔
580
                data,
1✔
581
        )
1✔
582
        if err != nil {
1✔
583
                l.Errorf(
×
584
                        "session limit: failed to send stop session"+
×
585
                                "%s message to device: %s, error %v",
×
586
                        session.ID,
×
587
                        session.DeviceID,
×
588
                        err,
×
589
                )
×
590
        }
×
591
}
592

593
// ConnectServeWS starts a websocket connection with the device
594
// Currently this handler only properly handles a single terminal session.
595
func (h ManagementController) ConnectServeWS(
596
        ctx context.Context,
597
        conn *websocket.Conn,
598
        sess *model.Session,
599
        deviceChan chan *natsio.Msg,
600
) (err error) {
2✔
601
        l := log.FromContext(ctx)
2✔
602
        id := identity.FromContext(ctx)
2✔
603
        errChan := make(chan error, 1)
2✔
604
        remoteTerminalRunning := false
2✔
605

2✔
606
        defer func() {
4✔
607
                if err != nil {
4✔
608
                        select {
2✔
609
                        case errChan <- err:
2✔
610

611
                        case <-time.After(time.Second):
×
612
                                l.Warn("Failed to propagate error to client")
×
613
                        }
614
                }
615
                if remoteTerminalRunning {
3✔
616
                        msg := ws.ProtoMsg{
1✔
617
                                Header: ws.ProtoHdr{
1✔
618
                                        Proto:     ws.ProtoTypeShell,
1✔
619
                                        MsgType:   shell.MessageTypeStopShell,
1✔
620
                                        SessionID: sess.ID,
1✔
621
                                        Properties: map[string]interface{}{
1✔
622
                                                "status":       shell.ErrorMessage,
1✔
623
                                                PropertyUserID: sess.UserID,
1✔
624
                                        },
1✔
625
                                },
1✔
626
                                Body: []byte("user disconnected"),
1✔
627
                        }
1✔
628
                        data, _ := msgpack.Marshal(msg)
1✔
629
                        errPublish := h.nats.Publish(model.GetDeviceSubject(
1✔
630
                                id.Tenant, sess.DeviceID),
1✔
631
                                data,
1✔
632
                        )
1✔
633
                        if errPublish != nil {
1✔
634
                                l.Warnf(
×
635
                                        "failed to propagate stop session "+
×
636
                                                "message to device: %s",
×
637
                                        errPublish.Error(),
×
638
                                )
×
639
                        }
×
640
                }
641
                close(errChan)
2✔
642
        }()
643

644
        controlRecorder := h.app.GetControlRecorder(ctx, sess.ID)
2✔
645
        controlRecorderBuffered := bufio.NewWriterSize(controlRecorder, app.RecorderBufferSize)
2✔
646
        defer controlRecorderBuffered.Flush()
2✔
647

2✔
648
        sessionRecorder := h.app.GetRecorder(ctx, sess.ID)
2✔
649
        sessionRecorderBuffered := bufio.NewWriterSize(sessionRecorder, app.RecorderBufferSize)
2✔
650
        defer sessionRecorderBuffered.Flush()
2✔
651

2✔
652
        // websocketWriter is responsible for closing the websocket
2✔
653
        //nolint:errcheck
2✔
654
        go h.websocketWriter(ctx,
2✔
655
                conn,
2✔
656
                sess,
2✔
657
                deviceChan,
2✔
658
                errChan,
2✔
659
                sessionRecorderBuffered,
2✔
660
                controlRecorderBuffered)
2✔
661

2✔
662
        return h.connectServeWSProcessMessages(ctx, conn, sess, deviceChan,
2✔
663
                &remoteTerminalRunning, controlRecorderBuffered)
2✔
664
}
665

666
func (h ManagementController) connectServeWSProcessMessages(
667
        ctx context.Context,
668
        conn *websocket.Conn,
669
        sess *model.Session,
670
        deviceChan chan *natsio.Msg,
671
        remoteTerminalRunning *bool,
672
        controlRecorderBuffered *bufio.Writer,
673
) (err error) {
2✔
674
        l := log.FromContext(ctx)
2✔
675
        id := identity.FromContext(ctx)
2✔
676
        logTerminal := false
2✔
677
        logPortForward := false
2✔
678

2✔
679
        var data []byte
2✔
680
        controlBytes := 0
2✔
681
        ignoreControlMessages := false
2✔
682
        for {
4✔
683
                _, data, err = conn.ReadMessage()
2✔
684
                if err != nil {
4✔
685
                        if _, ok := err.(*websocket.CloseError); ok {
4✔
686
                                return nil
2✔
687
                        }
2✔
688
                        return err
1✔
689
                }
690
                m := &ws.ProtoMsg{}
2✔
691
                err = msgpack.Unmarshal(data, m)
2✔
692
                if err != nil {
3✔
693
                        return err
1✔
694
                }
1✔
695

696
                m.Header.SessionID = sess.ID
2✔
697
                if m.Header.Properties == nil {
3✔
698
                        m.Header.Properties = make(map[string]interface{})
1✔
699
                }
1✔
700
                m.Header.Properties[PropertyUserID] = sess.UserID
2✔
701
                data, _ = msgpack.Marshal(m)
2✔
702
                switch m.Header.Proto {
2✔
703
                case ws.ProtoTypeShell:
2✔
704
                        // send the audit log for remote terminal
2✔
705
                        if !logTerminal {
4✔
706
                                if err := h.app.LogUserSession(ctx, sess,
2✔
707
                                        model.SessionTypeTerminal); err != nil {
2✔
708
                                        return err
×
709
                                }
×
710
                                sess.Types = append(sess.Types, model.SessionTypeTerminal)
2✔
711
                                logTerminal = true
2✔
712
                        }
713
                        // handle remote terminal-specific messages
714
                        switch m.Header.MsgType {
2✔
715
                        case shell.MessageTypeSpawnShell:
1✔
716
                                *remoteTerminalRunning = true
1✔
717
                        case shell.MessageTypeStopShell:
1✔
718
                                *remoteTerminalRunning = false
1✔
719
                        case shell.MessageTypeResizeShell:
×
720
                                if ignoreControlMessages {
×
721
                                        continue
×
722
                                }
723
                                if controlBytes >= app.MessageSizeLimit {
×
724
                                        l.Infof("session_id=%s control data limit reached.",
×
725
                                                sess.ID)
×
726
                                        //see https://northerntech.atlassian.net/browse/MEN-4448
×
727
                                        ignoreControlMessages = true
×
728
                                        continue
×
729
                                }
730

731
                                controlBytes += sendResizeMessage(m, sess, controlRecorderBuffered)
×
732
                        }
733
                case ws.ProtoTypePortForward:
×
734
                        if !logPortForward {
×
735
                                if err := h.app.LogUserSession(ctx, sess,
×
736
                                        model.SessionTypePortForward); err != nil {
×
737
                                        return err
×
738
                                }
×
739
                                sess.Types = append(sess.Types, model.SessionTypePortForward)
×
740
                                logPortForward = true
×
741
                        }
742
                }
743

744
                err = h.nats.Publish(model.GetDeviceSubject(id.Tenant, sess.DeviceID), data)
2✔
745
                if err != nil {
2✔
746
                        return err
×
747
                }
×
748
        }
749
}
750

751
func sendResizeMessage(m *ws.ProtoMsg,
752
        sess *model.Session,
753
        controlRecorderBuffered *bufio.Writer) (n int) {
×
754
        if _, ok := m.Header.Properties[model.ResizeMessageTermHeightField]; ok {
×
755
                return 0
×
756
        }
×
757
        if _, ok := m.Header.Properties[model.ResizeMessageTermWidthField]; ok {
×
758
                return 0
×
759
        }
×
760

761
        var height uint16 = 0
×
762
        switch m.Header.Properties[model.ResizeMessageTermHeightField].(type) {
×
763
        case uint8:
×
764
                height = uint16(m.Header.Properties[model.ResizeMessageTermHeightField].(uint8))
×
765
        case int8:
×
766
                height = uint16(m.Header.Properties[model.ResizeMessageTermHeightField].(int8))
×
767
        }
768

769
        var width uint16 = 0
×
770
        switch m.Header.Properties[model.ResizeMessageTermWidthField].(type) {
×
771
        case uint8:
×
772
                width = uint16(m.Header.Properties[model.ResizeMessageTermWidthField].(uint8))
×
773
        case int8:
×
774
                width = uint16(m.Header.Properties[model.ResizeMessageTermWidthField].(int8))
×
775
        }
776

777
        sess.BytesRecordedMutex.Lock()
×
778
        controlMsg := app.Control{
×
779
                Type:           app.ResizeMessage,
×
780
                Offset:         sess.BytesRecorded,
×
781
                DelayMs:        0,
×
782
                TerminalHeight: height,
×
783
                TerminalWidth:  width,
×
784
        }
×
785
        sess.BytesRecordedMutex.Unlock()
×
786

×
787
        n, _ = controlRecorderBuffered.Write(
×
788
                controlMsg.MarshalBinary(),
×
789
        )
×
790
        return n
×
791
}
792

793
func (h ManagementController) CheckUpdate(c *gin.Context) {
1✔
794
        h.sendMenderCommand(c, menderclient.MessageTypeMenderClientCheckUpdate)
1✔
795
}
1✔
796

797
func (h ManagementController) SendInventory(c *gin.Context) {
1✔
798
        h.sendMenderCommand(c, menderclient.MessageTypeMenderClientSendInventory)
1✔
799
}
1✔
800

801
func (h ManagementController) sendMenderCommand(c *gin.Context, msgType string) {
1✔
802
        ctx := c.Request.Context()
1✔
803

1✔
804
        idata := identity.FromContext(ctx)
1✔
805
        if idata == nil || !idata.IsUser {
1✔
806
                c.JSON(http.StatusBadRequest, gin.H{
×
807
                        "error": ErrMissingUserAuthentication.Error(),
×
808
                })
×
809
                return
×
810
        }
×
811
        tenantID := idata.Tenant
1✔
812
        deviceID := c.Param("deviceId")
1✔
813

1✔
814
        device, err := h.app.GetDevice(ctx, tenantID, deviceID)
1✔
815
        if err == app.ErrDeviceNotFound {
2✔
816
                c.JSON(http.StatusNotFound, gin.H{
1✔
817
                        "error": err.Error(),
1✔
818
                })
1✔
819
                return
1✔
820
        } else if err != nil {
3✔
821
                c.JSON(http.StatusBadRequest, gin.H{
1✔
822
                        "error": err.Error(),
1✔
823
                })
1✔
824
                return
1✔
825
        } else if device.Status != model.DeviceStatusConnected {
3✔
826
                c.JSON(http.StatusConflict, gin.H{
1✔
827
                        "error": app.ErrDeviceNotConnected,
1✔
828
                })
1✔
829
                return
1✔
830
        }
1✔
831

832
        msg := &ws.ProtoMsg{
1✔
833
                Header: ws.ProtoHdr{
1✔
834
                        Proto:   ws.ProtoTypeMenderClient,
1✔
835
                        MsgType: msgType,
1✔
836
                        Properties: map[string]interface{}{
1✔
837
                                PropertyUserID: idata.Subject,
1✔
838
                        },
1✔
839
                },
1✔
840
        }
1✔
841
        data, _ := msgpack.Marshal(msg)
1✔
842

1✔
843
        err = h.nats.Publish(model.GetDeviceSubject(idata.Tenant, device.ID), data)
1✔
844
        if err != nil {
2✔
845
                c.JSON(http.StatusInternalServerError, gin.H{
1✔
846
                        "error": err.Error(),
1✔
847
                })
1✔
848
        }
1✔
849

850
        c.JSON(http.StatusAccepted, nil)
1✔
851
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc