• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nats-io / nats-server / 24949216239

24 Apr 2026 08:34AM UTC coverage: 80.645% (-2.4%) from 83.05%
24949216239

push

github

web-flow
(2.14) [ADDED] `RemoteLeafOpts.IgnoreDiscoveredServers` option (#8067)

For a given leafnode remote, if this is set to true, this remote will
ignore any server leafnode URLs returned by the hub, allowing the user
to fully manage the servers this remote can connect to.

Resolves #8002

Signed-off-by: Ivan Kozlovic <ivan@synadia.com>

74685 of 92610 relevant lines covered (80.64%)

632737.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

66.72
/server/websocket.go
1
// Copyright 2020-2025 The NATS Authors
2
// Licensed under the Apache License, Version 2.0 (the "License");
3
// you may not use this file except in compliance with the License.
4
// You may obtain a copy of the License at
5
//
6
// http://www.apache.org/licenses/LICENSE-2.0
7
//
8
// Unless required by applicable law or agreed to in writing, software
9
// distributed under the License is distributed on an "AS IS" BASIS,
10
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
// See the License for the specific language governing permissions and
12
// limitations under the License.
13

14
package server
15

16
import (
17
        "bytes"
18
        crand "crypto/rand"
19
        "crypto/sha1"
20
        "crypto/tls"
21
        "encoding/base64"
22
        "encoding/binary"
23
        "errors"
24
        "fmt"
25
        "io"
26
        "log"
27
        mrand "math/rand"
28
        "net"
29
        "net/http"
30
        "net/url"
31
        "strconv"
32
        "strings"
33
        "sync"
34
        "sync/atomic"
35
        "time"
36
        "unicode/utf8"
37

38
        "github.com/klauspost/compress/flate"
39
)
40

41
type wsOpCode int
42

43
const (
44
        // From https://tools.ietf.org/html/rfc6455#section-5.2
45
        wsTextMessage   = wsOpCode(1)
46
        wsBinaryMessage = wsOpCode(2)
47
        wsCloseMessage  = wsOpCode(8)
48
        wsPingMessage   = wsOpCode(9)
49
        wsPongMessage   = wsOpCode(10)
50

51
        wsFinalBit = 1 << 7
52
        wsRsv1Bit  = 1 << 6 // Used for compression, from https://tools.ietf.org/html/rfc7692#section-6
53
        wsRsv2Bit  = 1 << 5
54
        wsRsv3Bit  = 1 << 4
55

56
        wsMaskBit = 1 << 7
57

58
        wsContinuationFrame     = 0
59
        wsMaxFrameHeaderSize    = 14 // Since LeafNode may need to behave as a client
60
        wsMaxControlPayloadSize = 125
61
        wsFrameSizeForBrowsers  = 4096 // From experiment, webrowsers behave better with limited frame size
62
        wsCompressThreshold     = 64   // Don't compress for small buffer(s)
63
        wsMaxMsgPayloadMultiple = 8
64
        wsMaxMsgPayloadLimit    = 64 * 1024 * 1024
65
        wsCloseSatusSize        = 2
66

67
        // From https://tools.ietf.org/html/rfc6455#section-11.7
68
        wsCloseStatusNormalClosure      = 1000
69
        wsCloseStatusGoingAway          = 1001
70
        wsCloseStatusProtocolError      = 1002
71
        wsCloseStatusUnsupportedData    = 1003
72
        wsCloseStatusNoStatusReceived   = 1005
73
        wsCloseStatusInvalidPayloadData = 1007
74
        wsCloseStatusPolicyViolation    = 1008
75
        wsCloseStatusMessageTooBig      = 1009
76
        wsCloseStatusInternalSrvError   = 1011
77
        wsCloseStatusTLSHandshake       = 1015
78

79
        wsFirstFrame        = true
80
        wsContFrame         = false
81
        wsFinalFrame        = true
82
        wsUncompressedFrame = false
83

84
        wsSchemePrefix    = "ws"
85
        wsSchemePrefixTLS = "wss"
86

87
        wsNoMaskingHeader       = "Nats-No-Masking"
88
        wsNoMaskingValue        = "true"
89
        wsXForwardedForHeader   = "X-Forwarded-For"
90
        wsNoMaskingFullResponse = wsNoMaskingHeader + ": " + wsNoMaskingValue + CR_LF
91
        wsPMCExtension          = "permessage-deflate" // per-message compression
92
        wsPMCSrvNoCtx           = "server_no_context_takeover"
93
        wsPMCCliNoCtx           = "client_no_context_takeover"
94
        wsPMCReqHeaderValue     = wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx
95
        wsPMCFullResponse       = "Sec-WebSocket-Extensions: " + wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx + _CRLF_
96
        wsSecProto              = "Sec-Websocket-Protocol"
97
        wsMQTTSecProtoVal       = "mqtt"
98
        wsMQTTSecProto          = wsSecProto + ": " + wsMQTTSecProtoVal + CR_LF
99
)
100

101
var decompressorPool sync.Pool
102
var compressLastBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff}
103

104
// From https://tools.ietf.org/html/rfc6455#section-1.3
105
var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
106

107
// Test can enable this so that server does not support "no-masking" requests.
108
var wsTestRejectNoMasking = false
109

110
type websocket struct {
111
        frames         net.Buffers
112
        fs             int64
113
        closeMsg       []byte
114
        compress       bool
115
        closeSent      bool
116
        browser        bool
117
        nocompfrag     bool // No fragment for compressed frames
118
        maskread       bool
119
        maskwrite      bool
120
        compressor     *flate.Writer
121
        cookieJwt      string
122
        cookieUsername string
123
        cookiePassword string
124
        cookieToken    string
125
        clientIP       string
126
}
127

128
type srvWebsocket struct {
129
        mu             sync.RWMutex
130
        server         *http.Server
131
        listener       net.Listener
132
        listenerErr    error
133
        allowedOrigins map[string][]*allowedOrigin // host will be the key
134
        sameOrigin     bool
135
        connectURLs    []string
136
        connectURLsMap refCountedUrlSet
137
        authOverride   bool   // indicate if there is auth override in websocket config
138
        rawHeaders     string // raw headers to be used in the upgrade response.
139

140
        // These are immutable and can be accessed without lock.
141
        // This is the case when generating the client INFO.
142
        tls  bool   // True if TLS is required (TLSConfig is specified).
143
        host string // Host/IP the webserver is listening on (shortcut to opts.Websocket.Host).
144
        port int    // Port the webserver is listening on. This is after an ephemeral port may have been selected (shortcut to opts.Websocket.Port).
145
}
146

147
type allowedOrigin struct {
148
        scheme string
149
        port   string
150
}
151

152
type wsUpgradeResult struct {
153
        conn net.Conn
154
        ws   *websocket
155
        kind int
156
}
157

158
type wsReadInfo struct {
159
        rem   uint64
160
        fs    bool
161
        ff    bool
162
        fc    bool
163
        mask  bool // Incoming leafnode connections may not have masking.
164
        mkpos byte
165
        mkey  [4]byte
166
        cbufs [][]byte
167
        coff  int
168
        csz   uint64
169
}
170

171
func (r *wsReadInfo) init() {
61✔
172
        r.fs, r.ff = true, true
61✔
173
}
61✔
174

175
func (r *wsReadInfo) resetCompressedState() {
×
176
        r.fs = true
×
177
        r.ff = true
×
178
        r.fc = false
×
179
        r.rem = 0
×
180
        r.cbufs = nil
×
181
        r.coff = 0
×
182
        r.csz = 0
×
183
}
×
184

185
// Compressed WebSocket messages have to be accumulated before they can be
186
// decompressed and handed to the parser, so this transport limit needs to
187
// allow batching several max_payload-sized NATS operations while still
188
// capping resource usage on the buffered compressed path.
189
func wsMaxMessageSize(mpay int) uint64 {
39✔
190
        if mpay <= 0 {
39✔
191
                mpay = MAX_PAYLOAD_SIZE
×
192
        }
×
193
        limit := uint64(mpay) * wsMaxMsgPayloadMultiple
39✔
194
        if limit > wsMaxMsgPayloadLimit {
39✔
195
                limit = wsMaxMsgPayloadLimit
×
196
        }
×
197
        return limit
39✔
198
}
199

200
// Returns a slice containing `needed` bytes from the given buffer `buf`
201
// starting at position `pos`, and possibly read from the given reader `r`.
202
// When bytes are present in `buf`, the `pos` is incremented by the number
203
// of bytes found up to `needed` and the new position is returned. If not
204
// enough bytes are found, the bytes found in `buf` are copied to the returned
205
// slice and the remaning bytes are read from `r`.
206
func wsGet(r io.Reader, buf []byte, pos, needed uint64) ([]byte, uint64, error) {
650✔
207
        avail := uint64(len(buf)) - pos
650✔
208
        if avail >= needed {
1,299✔
209
                return buf[pos : pos+needed], pos + needed, nil
649✔
210
        }
649✔
211
        b := make([]byte, needed)
1✔
212
        start := uint64(copy(b, buf[pos:]))
1✔
213
        for start != needed {
2✔
214
                n, err := r.Read(b[start:cap(b)])
1✔
215
                if err != nil {
1✔
216
                        return nil, 0, err
×
217
                }
×
218
                start += uint64(n)
1✔
219
        }
220
        return b, pos + avail, nil
1✔
221
}
222

223
// Returns true if this connection is from a Websocket client.
224
// Lock held on entry.
225
func (c *client) isWebsocket() bool {
29,829,971✔
226
        return c.ws != nil
29,829,971✔
227
}
29,829,971✔
228

229
// Returns a slice of byte slices corresponding to payload of websocket frames.
230
// The byte slice `buf` is filled with bytes from the connection's read loop.
231
// This function will decode the frame headers and unmask the payload(s).
232
// It is possible that the returned slices point to the given `buf` slice, so
233
// `buf` should not be overwritten until the returned slices have been parsed.
234
//
235
// Client lock MUST NOT be held on entry.
236
func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, error) {
×
237
        var bufs [][]byte
×
238
        err := c.wsReadLoop(r, ior, buf, func(b []byte, compressed, final bool) error {
×
239
                if compressed {
×
240
                        return errors.New("compressed websocket frames require wsReadAndParse")
×
241
                }
×
242
                bufs = append(bufs, b)
×
243
                return nil
×
244
        })
245
        return bufs, err
×
246
}
247

248
func (c *client) wsReadAndParse(r *wsReadInfo, ior io.Reader, buf []byte) error {
408✔
249
        mpay := int(atomic.LoadInt32(&c.mpay))
408✔
250
        if mpay <= 0 {
408✔
251
                mpay = MAX_PAYLOAD_SIZE
×
252
        }
×
253
        return c.wsReadLoop(r, ior, buf, func(b []byte, compressed, final bool) error {
767✔
254
                if compressed {
398✔
255
                        if err := c.wsDecompressAndParse(r, b, final, mpay); err != nil {
39✔
256
                                r.resetCompressedState()
×
257
                                return err
×
258
                        }
×
259
                        if final {
78✔
260
                                r.fc = false
39✔
261
                        }
39✔
262
                        return nil
39✔
263
                }
264
                return c.parse(b)
320✔
265
        })
266
}
267

268
func (c *client) wsReadLoop(r *wsReadInfo, ior io.Reader, buf []byte, handle func([]byte, bool, bool) error) error {
408✔
269
        var (
408✔
270
                tmpBuf []byte
408✔
271
                err    error
408✔
272
                pos    uint64
408✔
273
                max    = uint64(len(buf))
408✔
274
        )
408✔
275
        for pos != max {
819✔
276
                if r.fs {
724✔
277
                        b0 := buf[pos]
313✔
278
                        frameType := wsOpCode(b0 & 0xF)
313✔
279
                        final := b0&wsFinalBit != 0
313✔
280
                        compressed := b0&wsRsv1Bit != 0
313✔
281
                        if b0&(wsRsv2Bit|wsRsv3Bit) != 0 {
313✔
282
                                return c.wsHandleProtocolError("RSV2 and RSV3 must be clear")
×
283
                        }
×
284
                        if compressed && !c.ws.compress {
313✔
285
                                return c.wsHandleProtocolError("compressed frame received without negotiated permessage-deflate")
×
286
                        }
×
287
                        pos++
313✔
288

313✔
289
                        tmpBuf, pos, err = wsGet(ior, buf, pos, 1)
313✔
290
                        if err != nil {
313✔
291
                                return err
×
292
                        }
×
293
                        b1 := tmpBuf[0]
313✔
294

313✔
295
                        // Clients MUST set the mask bit. If not set, reject.
313✔
296
                        // However, LEAF by default will not have masking, unless they are forced to, by configuration.
313✔
297
                        if r.mask && b1&wsMaskBit == 0 {
313✔
298
                                return c.wsHandleProtocolError("mask bit missing")
×
299
                        }
×
300

301
                        // Store size in case it is < 125
302
                        r.rem = uint64(b1 & 0x7F)
313✔
303

313✔
304
                        switch frameType {
313✔
305
                        case wsPingMessage, wsPongMessage, wsCloseMessage:
30✔
306
                                if r.rem > wsMaxControlPayloadSize {
30✔
307
                                        return c.wsHandleProtocolError(
×
308
                                                fmt.Sprintf("control frame length bigger than maximum allowed of %v bytes",
×
309
                                                        wsMaxControlPayloadSize))
×
310
                                }
×
311
                                if !final {
30✔
312
                                        return c.wsHandleProtocolError("control frame does not have final bit set")
×
313
                                }
×
314
                                if compressed {
30✔
315
                                        return c.wsHandleProtocolError("control frame must not be compressed")
×
316
                                }
×
317
                        case wsTextMessage, wsBinaryMessage:
283✔
318
                                if !r.ff {
283✔
319
                                        return c.wsHandleProtocolError("new message started before final frame for previous message was received")
×
320
                                }
×
321
                                r.ff = final
283✔
322
                                r.fc = compressed
283✔
323
                        case wsContinuationFrame:
×
324
                                // Compressed bit must be only set in the first frame
×
325
                                if r.ff || compressed {
×
326
                                        return c.wsHandleProtocolError("invalid continuation frame")
×
327
                                }
×
328
                                r.ff = final
×
329
                        default:
×
330
                                return c.wsHandleProtocolError(fmt.Sprintf("unknown opcode %v", frameType))
×
331
                        }
332

333
                        switch r.rem {
313✔
334
                        case 126:
195✔
335
                                tmpBuf, pos, err = wsGet(ior, buf, pos, 2)
195✔
336
                                if err != nil {
195✔
337
                                        return err
×
338
                                }
×
339
                                r.rem = uint64(binary.BigEndian.Uint16(tmpBuf))
195✔
340
                        case 127:
1✔
341
                                tmpBuf, pos, err = wsGet(ior, buf, pos, 8)
1✔
342
                                if err != nil {
1✔
343
                                        return err
×
344
                                }
×
345
                                if r.rem = binary.BigEndian.Uint64(tmpBuf); r.rem&(uint64(1)<<63) != 0 {
1✔
346
                                        return c.wsHandleProtocolError("invalid 64-bit payload length")
×
347
                                }
×
348
                        }
349

350
                        if r.mask {
424✔
351
                                // Read masking key
111✔
352
                                tmpBuf, pos, err = wsGet(ior, buf, pos, 4)
111✔
353
                                if err != nil {
111✔
354
                                        return err
×
355
                                }
×
356
                                copy(r.mkey[:], tmpBuf)
111✔
357
                                r.mkpos = 0
111✔
358
                        }
359

360
                        // Handle control messages in place...
361
                        if wsIsControlFrame(frameType) {
343✔
362
                                pos, err = c.wsHandleControlFrame(r, frameType, ior, buf, pos)
30✔
363
                                if err != nil {
60✔
364
                                        return err
30✔
365
                                }
30✔
366
                                continue
×
367
                        }
368

369
                        // Done with the frame header
370
                        r.fs = false
283✔
371
                }
372
                if pos < max {
740✔
373
                        n := r.rem
359✔
374
                        if pos+n > max {
435✔
375
                                n = max - pos
76✔
376
                        }
76✔
377
                        b := buf[pos : pos+n]
359✔
378
                        pos += n
359✔
379
                        r.rem -= n
359✔
380
                        // If needed, unmask the buffer
359✔
381
                        if r.mask {
477✔
382
                                r.unmask(b)
118✔
383
                        }
118✔
384
                        if err := handle(b, r.fc, r.ff && r.rem == 0); err != nil {
361✔
385
                                return err
2✔
386
                        }
2✔
387
                        if r.rem == 0 {
638✔
388
                                r.fs = true
281✔
389
                        }
281✔
390
                }
391
        }
392
        return nil
376✔
393
}
394

395
func (r *wsReadInfo) Read(dst []byte) (int, error) {
80✔
396
        if len(dst) == 0 {
80✔
397
                return 0, nil
×
398
        }
×
399
        if len(r.cbufs) == 0 {
80✔
400
                return 0, io.EOF
×
401
        }
×
402
        copied := 0
80✔
403
        rem := len(dst)
80✔
404
        for buf := r.cbufs[0]; buf != nil && rem > 0; {
160✔
405
                n := len(buf[r.coff:])
80✔
406
                if n > rem {
121✔
407
                        n = rem
41✔
408
                }
41✔
409
                copy(dst[copied:], buf[r.coff:r.coff+n])
80✔
410
                copied += n
80✔
411
                rem -= n
80✔
412
                r.coff += n
80✔
413
                buf = r.nextCBuf()
80✔
414
        }
415
        return copied, nil
80✔
416
}
417

418
func (r *wsReadInfo) nextCBuf() []byte {
7,340✔
419
        // We still have remaining data in the first buffer
7,340✔
420
        if r.coff != len(r.cbufs[0]) {
14,602✔
421
                return r.cbufs[0]
7,262✔
422
        }
7,262✔
423
        // We read the full first buffer. Reset offset.
424
        r.coff = 0
78✔
425
        // We were at the last buffer, so we are done.
78✔
426
        if len(r.cbufs) == 1 {
117✔
427
                r.cbufs = nil
39✔
428
                return nil
39✔
429
        }
39✔
430
        // Here we move to the next buffer.
431
        r.cbufs = r.cbufs[1:]
39✔
432
        return r.cbufs[0]
39✔
433
}
434

435
func (r *wsReadInfo) ReadByte() (byte, error) {
7,260✔
436
        for len(r.cbufs) > 0 && len(r.cbufs[0]) == 0 {
7,260✔
437
                r.nextCBuf()
×
438
        }
×
439
        if len(r.cbufs) == 0 {
7,260✔
440
                return 0, io.EOF
×
441
        }
×
442
        b := r.cbufs[0][r.coff]
7,260✔
443
        r.coff++
7,260✔
444
        r.nextCBuf()
7,260✔
445
        return b, nil
7,260✔
446
}
447

448
func (c *client) wsDecompressAndParse(r *wsReadInfo, b []byte, final bool, mpay int) error {
39✔
449
        limit := wsMaxMessageSize(mpay)
39✔
450
        if len(b) > 0 {
78✔
451
                if r.csz+uint64(len(b)) > limit {
39✔
452
                        return ErrMaxPayload
×
453
                }
×
454
                r.cbufs = append(r.cbufs, append([]byte(nil), b...))
39✔
455
                r.csz += uint64(len(b))
39✔
456
        }
457
        if !final {
39✔
458
                return nil
×
459
        }
×
460
        if r.csz+uint64(len(compressLastBlock)) > limit {
39✔
461
                return ErrMaxPayload
×
462
        }
×
463
        r.cbufs = append(r.cbufs, compressLastBlock)
39✔
464
        r.csz += uint64(len(compressLastBlock))
39✔
465
        r.coff = 0
39✔
466
        d, _ := decompressorPool.Get().(io.ReadCloser)
39✔
467
        if d == nil {
52✔
468
                d = flate.NewReader(r)
13✔
469
        } else {
39✔
470
                d.(flate.Resetter).Reset(r, nil)
26✔
471
        }
26✔
472
        defer func() {
78✔
473
                d.Close()
39✔
474
                decompressorPool.Put(d)
39✔
475
                r.cbufs = nil
39✔
476
                r.coff = 0
39✔
477
                r.csz = 0
39✔
478
        }()
39✔
479
        lr := io.LimitedReader{R: d, N: int64(mpay + 1)}
39✔
480
        buf := make([]byte, 32*1024)
39✔
481
        total := 0
39✔
482
        for {
117✔
483
                n, err := lr.Read(buf)
78✔
484
                if n > 0 {
117✔
485
                        pn := n
39✔
486
                        if total+n > mpay {
39✔
487
                                pn = mpay - total
×
488
                        }
×
489
                        if pn > 0 {
78✔
490
                                if err := c.parse(buf[:pn]); err != nil {
39✔
491
                                        return err
×
492
                                }
×
493
                        }
494
                        total += n
39✔
495
                        if total > mpay {
39✔
496
                                return ErrMaxPayload
×
497
                        }
×
498
                }
499
                if err == nil {
117✔
500
                        continue
39✔
501
                }
502
                if err == io.EOF {
78✔
503
                        return nil
39✔
504
                }
39✔
505
                return err
×
506
        }
507
}
508

509
// Handles the PING, PONG and CLOSE websocket control frames.
510
//
511
// Client lock MUST NOT be held on entry.
512
func (c *client) wsHandleControlFrame(r *wsReadInfo, frameType wsOpCode, nc io.Reader, buf []byte, pos uint64) (uint64, error) {
30✔
513
        var payload []byte
30✔
514
        var err error
30✔
515

30✔
516
        if r.rem > 0 {
60✔
517
                payload, pos, err = wsGet(nc, buf, pos, r.rem)
30✔
518
                if err != nil {
30✔
519
                        return pos, err
×
520
                }
×
521
                if r.mask {
48✔
522
                        r.unmask(payload)
18✔
523
                }
18✔
524
                r.rem = 0
30✔
525
        }
526
        switch frameType {
30✔
527
        case wsCloseMessage:
30✔
528
                status := wsCloseStatusNoStatusReceived
30✔
529
                var body string
30✔
530
                lp := len(payload)
30✔
531
                if lp == 1 {
30✔
532
                        return pos, c.wsHandleProtocolError("close frame payload cannot be 1 byte")
×
533
                }
×
534
                // If there is a payload, the status is represented as a 2-byte
535
                // unsigned integer (in network byte order). Then, there may be an
536
                // optional body.
537
                hasStatus, hasBody := lp >= wsCloseSatusSize, lp > wsCloseSatusSize
30✔
538
                if hasStatus {
60✔
539
                        // Decode the status
30✔
540
                        status = int(binary.BigEndian.Uint16(payload[:wsCloseSatusSize]))
30✔
541
                        if !wsIsValidCloseStatus(status) {
30✔
542
                                return pos, c.wsHandleProtocolError(fmt.Sprintf("invalid close status code %v", status))
×
543
                        }
×
544
                        // Now if there is a body, capture it and make sure this is a valid UTF-8.
545
                        if hasBody {
58✔
546
                                body = string(payload[wsCloseSatusSize:])
28✔
547
                                if !utf8.ValidString(body) {
28✔
548
                                        // https://tools.ietf.org/html/rfc6455#section-5.5.1
×
549
                                        // If body is present, it must be a valid utf8
×
550
                                        status = wsCloseStatusInvalidPayloadData
×
551
                                        body = "invalid utf8 body in close frame"
×
552
                                }
×
553
                        }
554
                }
555
                // If the status indicates that nothing was received, then we don't
556
                // send anything back.
557
                // From https://datatracker.ietf.org/doc/html/rfc6455#section-7.4
558
                // it says that code 1005 is a reserved value and MUST NOT be set as a
559
                // status code in a Close control frame by an endpoint.  It is
560
                // designated for use in applications expecting a status code to indicate
561
                // that no status code was actually present.
562
                var clm []byte
30✔
563
                if status != wsCloseStatusNoStatusReceived {
60✔
564
                        clm = wsCreateCloseMessage(status, body)
30✔
565
                }
30✔
566
                c.wsEnqueueControlMessage(wsCloseMessage, clm)
30✔
567
                if len(clm) > 0 {
60✔
568
                        nbPoolPut(clm) // wsEnqueueControlMessage has taken a copy.
30✔
569
                }
30✔
570
                // Return io.EOF so that readLoop will close the connection as ClientClosed
571
                // after processing pending buffers.
572
                return pos, io.EOF
30✔
573
        case wsPingMessage:
×
574
                c.wsEnqueueControlMessage(wsPongMessage, payload)
×
575
        case wsPongMessage:
×
576
                // Nothing to do..
577
        }
578
        return pos, nil
×
579
}
580

581
// Unmask the given slice.
582
func (r *wsReadInfo) unmask(buf []byte) {
136✔
583
        p := int(r.mkpos)
136✔
584
        if len(buf) < 16 {
158✔
585
                for i := 0; i < len(buf); i++ {
197✔
586
                        buf[i] ^= r.mkey[p&3]
175✔
587
                        p++
175✔
588
                }
175✔
589
                r.mkpos = byte(p & 3)
22✔
590
                return
22✔
591
        }
592
        var k [8]byte
114✔
593
        for i := 0; i < 8; i++ {
1,026✔
594
                k[i] = r.mkey[(p+i)&3]
912✔
595
        }
912✔
596
        km := binary.BigEndian.Uint64(k[:])
114✔
597
        n := (len(buf) / 8) * 8
114✔
598
        for i := 0; i < n; i += 8 {
16,132✔
599
                tmp := binary.BigEndian.Uint64(buf[i : i+8])
16,018✔
600
                tmp ^= km
16,018✔
601
                binary.BigEndian.PutUint64(buf[i:], tmp)
16,018✔
602
        }
16,018✔
603
        buf = buf[n:]
114✔
604
        for i := 0; i < len(buf); i++ {
285✔
605
                buf[i] ^= r.mkey[p&3]
171✔
606
                p++
171✔
607
        }
171✔
608
        r.mkpos = byte(p & 3)
114✔
609
}
610

611
// Returns true if the op code corresponds to a control frame.
612
func wsIsControlFrame(frameType wsOpCode) bool {
313✔
613
        return frameType >= wsCloseMessage
313✔
614
}
313✔
615

616
// Create the frame header.
617
// Encodes the frame type and optional compression flag, and the size of the payload.
618
func wsCreateFrameHeader(useMasking, compressed bool, frameType wsOpCode, l int) ([]byte, []byte) {
291✔
619
        fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize]
291✔
620
        n, key := wsFillFrameHeader(fh, useMasking, wsFirstFrame, wsFinalFrame, compressed, frameType, l)
291✔
621
        return fh[:n], key
291✔
622
}
291✔
623

624
func wsFillFrameHeader(fh []byte, useMasking, first, final, compressed bool, frameType wsOpCode, l int) (int, []byte) {
355✔
625
        var n int
355✔
626
        var b byte
355✔
627
        if first {
710✔
628
                b = byte(frameType)
355✔
629
        }
355✔
630
        if final {
710✔
631
                b |= wsFinalBit
355✔
632
        }
355✔
633
        if compressed {
394✔
634
                b |= wsRsv1Bit
39✔
635
        }
39✔
636
        b1 := byte(0)
355✔
637
        if useMasking {
462✔
638
                b1 |= wsMaskBit
107✔
639
        }
107✔
640
        switch {
355✔
641
        case l <= 125:
155✔
642
                n = 2
155✔
643
                fh[0] = b
155✔
644
                fh[1] = b1 | byte(l)
155✔
645
        case l < 65536:
199✔
646
                n = 4
199✔
647
                fh[0] = b
199✔
648
                fh[1] = b1 | 126
199✔
649
                binary.BigEndian.PutUint16(fh[2:], uint16(l))
199✔
650
        default:
1✔
651
                n = 10
1✔
652
                fh[0] = b
1✔
653
                fh[1] = b1 | 127
1✔
654
                binary.BigEndian.PutUint64(fh[2:], uint64(l))
1✔
655
        }
656
        var key []byte
355✔
657
        if useMasking {
462✔
658
                var keyBuf [4]byte
107✔
659
                if _, err := io.ReadFull(crand.Reader, keyBuf[:4]); err != nil {
107✔
660
                        kv := mrand.Int31()
×
661
                        binary.LittleEndian.PutUint32(keyBuf[:4], uint32(kv))
×
662
                }
×
663
                copy(fh[n:], keyBuf[:4])
107✔
664
                key = fh[n : n+4]
107✔
665
                n += 4
107✔
666
        }
667
        return n, key
355✔
668
}
669

670
// Invokes wsEnqueueControlMessageLocked under client lock.
671
//
672
// Client lock MUST NOT be held on entry
673
func (c *client) wsEnqueueControlMessage(controlMsg wsOpCode, payload []byte) {
30✔
674
        c.mu.Lock()
30✔
675
        c.wsEnqueueControlMessageLocked(controlMsg, payload)
30✔
676
        c.mu.Unlock()
30✔
677
}
30✔
678

679
// Mask the buffer with the given key
680
func wsMaskBuf(key, buf []byte) {
32✔
681
        for i := 0; i < len(buf); i++ {
1,749✔
682
                buf[i] ^= key[i&3]
1,717✔
683
        }
1,717✔
684
}
685

686
// Mask the buffers, as if they were contiguous, with the given key
687
func wsMaskBufs(key []byte, bufs [][]byte) {
75✔
688
        pos := 0
75✔
689
        for i := 0; i < len(bufs); i++ {
158✔
690
                buf := bufs[i]
83✔
691
                for j := 0; j < len(buf); j++ {
125,079✔
692
                        buf[j] ^= key[pos&3]
124,996✔
693
                        pos++
124,996✔
694
                }
124,996✔
695
        }
696
}
697

698
// Enqueues a websocket control message.
699
// If the control message is a wsCloseMessage, then marks this client
700
// has having sent the close message (since only one should be sent).
701
// This will prevent the generic closeConnection() to enqueue one.
702
//
703
// Client lock held on entry.
704
func (c *client) wsEnqueueControlMessageLocked(controlMsg wsOpCode, payload []byte) {
64✔
705
        // Control messages are never compressed and their size will be
64✔
706
        // less than wsMaxControlPayloadSize, which means the frame header
64✔
707
        // will be only 2 or 6 bytes.
64✔
708
        useMasking := c.ws.maskwrite
64✔
709
        sz := 2
64✔
710
        if useMasking {
84✔
711
                sz += 4
20✔
712
        }
20✔
713
        cm := nbPoolGet(sz + len(payload))
64✔
714
        cm = cm[:cap(cm)]
64✔
715
        n, key := wsFillFrameHeader(cm, useMasking, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, controlMsg, len(payload))
64✔
716
        cm = cm[:n]
64✔
717
        // Note that payload is optional.
64✔
718
        if len(payload) > 0 {
128✔
719
                cm = append(cm, payload...)
64✔
720
                if useMasking {
84✔
721
                        wsMaskBuf(key, cm[n:])
20✔
722
                }
20✔
723
        }
724
        c.out.pb += int64(len(cm))
64✔
725
        if controlMsg == wsCloseMessage {
128✔
726
                // We can't add the close message to the frames buffers
64✔
727
                // now. It will be done on a flushOutbound() when there
64✔
728
                // are no more pending buffers to send.
64✔
729
                c.ws.closeSent = true
64✔
730
                c.ws.closeMsg = cm
64✔
731
        } else {
64✔
732
                c.ws.frames = append(c.ws.frames, cm)
×
733
                c.ws.fs += int64(len(cm))
×
734
        }
×
735
        c.flushSignal()
64✔
736
}
737

738
// Enqueues a websocket close message with a status mapped from the given `reason`.
739
//
740
// Client lock held on entry
741
func (c *client) wsEnqueueCloseMessage(reason ClosedState) {
34✔
742
        var status int
34✔
743
        switch reason {
34✔
744
        case ClientClosed:
×
745
                status = wsCloseStatusNormalClosure
×
746
        case AuthenticationTimeout, AuthenticationViolation, SlowConsumerPendingBytes, SlowConsumerWriteDeadline,
747
                MaxAccountConnectionsExceeded, MaxConnectionsExceeded, MaxControlLineExceeded, MaxSubscriptionsExceeded,
748
                MissingAccount, AuthenticationExpired, Revocation:
2✔
749
                status = wsCloseStatusPolicyViolation
2✔
750
        case TLSHandshakeError:
×
751
                status = wsCloseStatusTLSHandshake
×
752
        case ParseError, ProtocolViolation, BadClientProtocolVersion:
×
753
                status = wsCloseStatusProtocolError
×
754
        case MaxPayloadExceeded:
×
755
                status = wsCloseStatusMessageTooBig
×
756
        case WriteError, ReadError, StaleConnection, ServerShutdown:
32✔
757
                // We used to have WriteError, ReadError and StaleConnection result in
32✔
758
                // code 1006, which the spec says that it must not be used to set the
32✔
759
                // status in the close message. So using this one instead.
32✔
760
                status = wsCloseStatusGoingAway
32✔
761
        default:
×
762
                status = wsCloseStatusInternalSrvError
×
763
        }
764
        body := wsCreateCloseMessage(status, reason.String())
34✔
765
        c.wsEnqueueControlMessageLocked(wsCloseMessage, body)
34✔
766
        nbPoolPut(body) // wsEnqueueControlMessageLocked has taken a copy.
34✔
767
}
768

769
// Create and then enqueue a close message with a protocol error and the
770
// given message. This is invoked when parsing websocket frames.
771
//
772
// Lock MUST NOT be held on entry.
773
func (c *client) wsHandleProtocolError(message string) error {
×
774
        buf := wsCreateCloseMessage(wsCloseStatusProtocolError, message)
×
775
        c.wsEnqueueControlMessage(wsCloseMessage, buf)
×
776
        nbPoolPut(buf) // wsEnqueueControlMessage has taken a copy.
×
777
        return errors.New(message)
×
778
}
×
779

780
func wsIsValidCloseStatus(code int) bool {
30✔
781
        switch code {
30✔
782
        case wsCloseStatusNoStatusReceived, 1004, 1006, wsCloseStatusTLSHandshake:
×
783
                return false
×
784
        }
785
        if code < 1000 || code >= 5000 {
30✔
786
                return false
×
787
        }
×
788
        // 1016-2999 are currently reserved.
789
        if code >= 1016 && code <= 2999 {
30✔
790
                return false
×
791
        }
×
792
        return true
30✔
793
}
794

795
// Create a close message with the given `status` and `body`.
796
// If the `body` is more than the maximum allows control frame payload size,
797
// it is truncated and "..." is added at the end (as a hint that message
798
// is not complete).
799
func wsCreateCloseMessage(status int, body string) []byte {
64✔
800
        // Since a control message payload is limited in size, we
64✔
801
        // will limit the text and add trailing "..." if truncated.
64✔
802
        // The body of a Close Message must be preceded with 2 bytes,
64✔
803
        // so take that into account for limiting the body length.
64✔
804
        if len(body) > wsMaxControlPayloadSize-2 {
64✔
805
                body = body[:wsMaxControlPayloadSize-5]
×
806
                body += "..."
×
807
        }
×
808
        buf := nbPoolGet(2 + len(body))[:2+len(body)]
64✔
809
        // We need to have a 2 byte unsigned int that represents the error status code
64✔
810
        // https://tools.ietf.org/html/rfc6455#section-5.5.1
64✔
811
        binary.BigEndian.PutUint16(buf[:2], uint16(status))
64✔
812
        copy(buf[2:], []byte(body))
64✔
813
        return buf
64✔
814
}
815

816
// Process websocket client handshake. On success, returns the raw net.Conn that
817
// will be used to create a *client object.
818
// Invoked from the HTTP server listening on websocket port.
819
func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeResult, error) {
32✔
820
        kind := CLIENT
32✔
821
        if r.URL != nil {
64✔
822
                ep := r.URL.EscapedPath()
32✔
823
                if strings.HasSuffix(ep, leafNodeWSPath) {
61✔
824
                        kind = LEAF
29✔
825
                } else if strings.HasSuffix(ep, mqttWSPath) {
32✔
826
                        kind = MQTT
×
827
                }
×
828
        }
829

830
        opts := s.getOpts()
32✔
831

32✔
832
        // From https://tools.ietf.org/html/rfc6455#section-4.2.1
32✔
833
        // Point 1.
32✔
834
        if r.Method != "GET" {
32✔
835
                return nil, wsReturnHTTPError(w, r, http.StatusMethodNotAllowed, "request method must be GET")
×
836
        }
×
837
        // Point 2.
838
        if r.Host == _EMPTY_ {
32✔
839
                return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "'Host' missing in request")
×
840
        }
×
841
        // Point 3.
842
        if !wsHeaderContains(r.Header, "Upgrade", "websocket") {
32✔
843
                return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Upgrade'")
×
844
        }
×
845
        // Point 4.
846
        if !wsHeaderContains(r.Header, "Connection", "Upgrade") {
32✔
847
                return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Connection'")
×
848
        }
×
849
        // Point 5.
850
        key := r.Header.Get("Sec-Websocket-Key")
32✔
851
        if key == _EMPTY_ {
32✔
852
                return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "key missing")
×
853
        }
×
854
        decoded, err := base64.StdEncoding.DecodeString(key)
32✔
855
        if err != nil || len(decoded) != 16 {
32✔
856
                return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid websocket key")
×
857
        }
×
858
        // Point 6.
859
        if !wsHeaderContains(r.Header, "Sec-Websocket-Version", "13") {
32✔
860
                return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid version")
×
861
        }
×
862
        // Others are optional
863
        // Point 7.
864
        if err := s.websocket.checkOrigin(r); err != nil {
32✔
865
                return nil, wsReturnHTTPError(w, r, http.StatusForbidden, fmt.Sprintf("origin not allowed: %v", err))
×
866
        }
×
867
        // Point 8.
868
        // We don't have protocols, so ignore.
869
        // Point 9.
870
        // Extensions, only support for compression at the moment
871
        compress := opts.Websocket.Compression
32✔
872
        if compress {
41✔
873
                // Simply check if permessage-deflate extension is present.
9✔
874
                compress, _ = wsPMCExtensionSupport(r.Header, true)
9✔
875
        }
9✔
876
        // We will do masking if asked (unless we reject for tests)
877
        noMasking := r.Header.Get(wsNoMaskingHeader) == wsNoMaskingValue && !wsTestRejectNoMasking
32✔
878

32✔
879
        h, ok := w.(http.Hijacker)
32✔
880
        if !ok {
32✔
881
                return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "websocket upgrade not supported")
×
882
        }
×
883
        conn, brw, err := h.Hijack()
32✔
884
        if err != nil {
32✔
885
                if conn != nil {
×
886
                        conn.Close()
×
887
                }
×
888
                return nil, wsReturnHTTPError(w, r, http.StatusInternalServerError, err.Error())
×
889
        }
890
        if brw.Reader.Buffered() > 0 {
32✔
891
                conn.Close()
×
892
                return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "client sent data before handshake is complete")
×
893
        }
×
894

895
        var buf [1024]byte
32✔
896
        p := buf[:0]
32✔
897

32✔
898
        // From https://tools.ietf.org/html/rfc6455#section-4.2.2
32✔
899
        p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
32✔
900
        p = append(p, wsAcceptKey(key)...)
32✔
901
        p = append(p, _CRLF_...)
32✔
902
        if compress {
37✔
903
                p = append(p, wsPMCFullResponse...)
5✔
904
        }
5✔
905
        if noMasking {
41✔
906
                p = append(p, wsNoMaskingFullResponse...)
9✔
907
        }
9✔
908
        if kind == MQTT {
32✔
909
                p = append(p, wsMQTTSecProto...)
×
910
        }
×
911
        if s.websocket.rawHeaders != _EMPTY_ {
32✔
912
                p = append(p, s.websocket.rawHeaders...)
×
913
        }
×
914
        p = append(p, _CRLF_...)
32✔
915

32✔
916
        if _, err = conn.Write(p); err != nil {
32✔
917
                conn.Close()
×
918
                return nil, err
×
919
        }
×
920
        // If there was a deadline set for the handshake, clear it now.
921
        if opts.Websocket.HandshakeTimeout > 0 {
33✔
922
                conn.SetDeadline(time.Time{})
1✔
923
        }
1✔
924
        // Server always expect "clients" to send masked payload, unless the option
925
        // "no-masking" has been enabled.
926
        ws := &websocket{compress: compress, maskread: !noMasking}
32✔
927

32✔
928
        // Check for X-Forwarded-For header
32✔
929
        if cips, ok := r.Header[wsXForwardedForHeader]; ok {
32✔
930
                if len(cips) > 0 {
×
931
                        cip := cips[0]
×
932
                        if net.ParseIP(cip) != nil {
×
933
                                ws.clientIP = cip
×
934
                        }
×
935
                }
936
        }
937

938
        if kind == CLIENT || kind == MQTT {
35✔
939
                // Indicate if this is likely coming from a browser.
3✔
940
                if ua := r.Header.Get("User-Agent"); ua != _EMPTY_ && strings.HasPrefix(ua, "Mozilla/") {
3✔
941
                        ws.browser = true
×
942
                        // Disable fragmentation of compressed frames for Safari browsers.
×
943
                        // Unfortunately, you could be running Chrome on macOS and this
×
944
                        // string will contain "Safari/" (along "Chrome/"). However, what
×
945
                        // I have found is that actual Safari browser also have "Version/".
×
946
                        // So make the combination of the two.
×
947
                        ws.nocompfrag = ws.compress && strings.Contains(ua, "Version/") && strings.Contains(ua, "Safari/")
×
948
                }
×
949

950
                if cookies := r.Cookies(); len(cookies) > 0 {
3✔
951
                        ows := &opts.Websocket
×
952
                        for _, c := range cookies {
×
953
                                if ows.JWTCookie == c.Name {
×
954
                                        ws.cookieJwt = c.Value
×
955
                                } else if ows.UsernameCookie == c.Name {
×
956
                                        ws.cookieUsername = c.Value
×
957
                                } else if ows.PasswordCookie == c.Name {
×
958
                                        ws.cookiePassword = c.Value
×
959
                                } else if ows.TokenCookie == c.Name {
×
960
                                        ws.cookieToken = c.Value
×
961
                                }
×
962
                        }
963
                }
964
        }
965
        return &wsUpgradeResult{conn: conn, ws: ws, kind: kind}, nil
32✔
966
}
967

968
// Returns true if the header named `name` contains a token with value `value`.
969
func wsHeaderContains(header http.Header, name string, value string) bool {
96✔
970
        for _, s := range header[name] {
192✔
971
                tokens := strings.Split(s, ",")
96✔
972
                for _, t := range tokens {
192✔
973
                        t = strings.Trim(t, " \t")
96✔
974
                        if strings.EqualFold(t, value) {
192✔
975
                                return true
96✔
976
                        }
96✔
977
                }
978
        }
979
        return false
×
980
}
981

982
func wsPMCExtensionSupport(header http.Header, checkPMCOnly bool) (bool, bool) {
18✔
983
        for _, extensionList := range header["Sec-Websocket-Extensions"] {
28✔
984
                extensions := strings.Split(extensionList, ",")
10✔
985
                for _, extension := range extensions {
20✔
986
                        extension = strings.Trim(extension, " \t")
10✔
987
                        params := strings.Split(extension, ";")
10✔
988
                        for i, p := range params {
20✔
989
                                p = strings.Trim(p, " \t")
10✔
990
                                if strings.EqualFold(p, wsPMCExtension) {
20✔
991
                                        if checkPMCOnly {
15✔
992
                                                return true, false
5✔
993
                                        }
5✔
994
                                        var snc bool
5✔
995
                                        var cnc bool
5✔
996
                                        for j := i + 1; j < len(params); j++ {
15✔
997
                                                p = params[j]
10✔
998
                                                p = strings.Trim(p, " \t")
10✔
999
                                                if strings.EqualFold(p, wsPMCSrvNoCtx) {
15✔
1000
                                                        snc = true
5✔
1001
                                                } else if strings.EqualFold(p, wsPMCCliNoCtx) {
15✔
1002
                                                        cnc = true
5✔
1003
                                                }
5✔
1004
                                                if snc && cnc {
15✔
1005
                                                        return true, true
5✔
1006
                                                }
5✔
1007
                                        }
1008
                                        return true, false
×
1009
                                }
1010
                        }
1011
                }
1012
        }
1013
        return false, false
8✔
1014
}
1015

1016
// Send an HTTP error with the given `status` to the given http response writer `w`.
1017
// Return an error created based on the `reason` string.
1018
func wsReturnHTTPError(w http.ResponseWriter, r *http.Request, status int, reason string) error {
×
1019
        err := fmt.Errorf("%s - websocket handshake error: %s", r.RemoteAddr, reason)
×
1020
        w.Header().Set("Sec-Websocket-Version", "13")
×
1021
        http.Error(w, http.StatusText(status), status)
×
1022
        return err
×
1023
}
×
1024

1025
// If the server is configured to accept any origin, then this function returns
1026
// `nil` without checking if the Origin is present and valid. This is also
1027
// the case if the request does not have the Origin header.
1028
// Otherwise, this will check that the Origin matches the same origin or
1029
// any origin in the allowed list.
1030
func (w *srvWebsocket) checkOrigin(r *http.Request) error {
32✔
1031
        w.mu.RLock()
32✔
1032
        checkSame := w.sameOrigin
32✔
1033
        listEmpty := len(w.allowedOrigins) == 0
32✔
1034
        w.mu.RUnlock()
32✔
1035
        if !checkSame && listEmpty {
64✔
1036
                return nil
32✔
1037
        }
32✔
1038
        origin := r.Header.Get("Origin")
×
1039
        if origin == _EMPTY_ {
×
1040
                origin = r.Header.Get("Sec-Websocket-Origin")
×
1041
        }
×
1042
        // If the header is not present, we will accept.
1043
        // From https://datatracker.ietf.org/doc/html/rfc6455#section-1.6
1044
        // "Naturally, when the WebSocket Protocol is used by a dedicated client
1045
        // directly (i.e., not from a web page through a web browser), the origin
1046
        // model is not useful, as the client can provide any arbitrary origin string."
1047
        if origin == _EMPTY_ {
×
1048
                return nil
×
1049
        }
×
1050
        u, err := url.ParseRequestURI(origin)
×
1051
        if err != nil {
×
1052
                return err
×
1053
        }
×
1054
        oh, op, err := wsGetHostAndPort(u.Scheme == "https", u.Host)
×
1055
        if err != nil {
×
1056
                return err
×
1057
        }
×
1058
        // If checking same origin, compare with the http's request's Host.
1059
        if checkSame {
×
1060
                rh, rp, err := wsGetHostAndPort(r.TLS != nil, r.Host)
×
1061
                if err != nil {
×
1062
                        return err
×
1063
                }
×
1064
                rs := "http"
×
1065
                if r.TLS != nil {
×
1066
                        rs = "https"
×
1067
                }
×
1068
                if oh != rh || op != rp || !strings.EqualFold(u.Scheme, rs) {
×
1069
                        return errors.New("not same origin")
×
1070
                }
×
1071
                // I guess it is possible to have cases where one wants to check
1072
                // same origin, but also that the origin is in the allowed list.
1073
                // So continue with the next check.
1074
        }
1075
        if !listEmpty {
×
1076
                w.mu.RLock()
×
1077
                origins := w.allowedOrigins[oh]
×
1078
                w.mu.RUnlock()
×
1079
                var allowed bool
×
1080
                for _, ao := range origins {
×
1081
                        if u.Scheme == ao.scheme && op == ao.port {
×
1082
                                allowed = true
×
1083
                                break
×
1084
                        }
1085
                }
1086
                if !allowed {
×
1087
                        return errors.New("not in the allowed list")
×
1088
                }
×
1089
        }
1090
        return nil
×
1091
}
1092

1093
func wsGetHostAndPort(tls bool, hostport string) (string, string, error) {
4✔
1094
        host, port, err := net.SplitHostPort(hostport)
4✔
1095
        if err != nil {
8✔
1096
                // If error is missing port, then use defaults based on the scheme
4✔
1097
                if ae, ok := err.(*net.AddrError); ok && strings.Contains(ae.Err, "missing port") {
8✔
1098
                        err = nil
4✔
1099
                        host = hostport
4✔
1100
                        if tls {
8✔
1101
                                port = "443"
4✔
1102
                        } else {
4✔
1103
                                port = "80"
×
1104
                        }
×
1105
                }
1106
        }
1107
        return strings.ToLower(host), port, err
4✔
1108
}
1109

1110
// Concatenate the key sent by the client with the GUID, then computes the SHA1 hash
1111
// and returns it as a based64 encoded string.
1112
func wsAcceptKey(key string) string {
61✔
1113
        h := sha1.New()
61✔
1114
        h.Write([]byte(key))
61✔
1115
        h.Write(wsGUID)
61✔
1116
        return base64.StdEncoding.EncodeToString(h.Sum(nil))
61✔
1117
}
61✔
1118

1119
func wsMakeChallengeKey() (string, error) {
46✔
1120
        p := make([]byte, 16)
46✔
1121
        if _, err := io.ReadFull(crand.Reader, p); err != nil {
46✔
1122
                return _EMPTY_, err
×
1123
        }
×
1124
        return base64.StdEncoding.EncodeToString(p), nil
46✔
1125
}
1126

1127
// Validate the websocket related options.
1128
func validateWebsocketOptions(o *Options) error {
7,834✔
1129
        wo := &o.Websocket
7,834✔
1130
        // If no port is defined, we don't care about other options
7,834✔
1131
        if wo.Port == 0 {
15,550✔
1132
                return nil
7,716✔
1133
        }
7,716✔
1134
        // Enforce TLS... unless NoTLS is set to true.
1135
        if wo.TLSConfig == nil && !wo.NoTLS {
118✔
1136
                return errors.New("websocket requires TLS configuration")
×
1137
        }
×
1138
        // Make sure that allowed origins, if specified, can be parsed.
1139
        for _, ao := range wo.AllowedOrigins {
120✔
1140
                u, err := url.ParseRequestURI(ao)
2✔
1141
                if err != nil {
2✔
1142
                        return fmt.Errorf("unable to parse allowed origin: %v", err)
×
1143
                }
×
1144
                if u.Scheme != "http" && u.Scheme != "https" {
2✔
1145
                        return fmt.Errorf("unable to parse allowed origin %q: allowed origins must be absolute URLs with http or https scheme", ao)
×
1146
                }
×
1147
                if u.Host == _EMPTY_ {
2✔
1148
                        return fmt.Errorf("unable to parse allowed origin %q: host is required", ao)
×
1149
                }
×
1150
                if _, _, err := wsGetHostAndPort(u.Scheme == "https", u.Host); err != nil {
2✔
1151
                        return fmt.Errorf("unable to parse allowed origin: %v", err)
×
1152
                }
×
1153
        }
1154
        // If there is a NoAuthUser, we need to have Users defined and
1155
        // the user to be present.
1156
        if wo.NoAuthUser != _EMPTY_ {
120✔
1157
                if err := validateNoAuthUser(o, wo.NoAuthUser); err != nil {
2✔
1158
                        return err
×
1159
                }
×
1160
        }
1161
        // Token/Username not possible if there are users/nkeys
1162
        if len(o.Users) > 0 || len(o.Nkeys) > 0 {
205✔
1163
                if wo.Username != _EMPTY_ {
87✔
1164
                        return fmt.Errorf("websocket authentication username not compatible with presence of users/nkeys")
×
1165
                }
×
1166
                if wo.Token != _EMPTY_ {
87✔
1167
                        return fmt.Errorf("websocket authentication token not compatible with presence of users/nkeys")
×
1168
                }
×
1169
        }
1170
        // Using JWT requires Trusted Keys
1171
        if wo.JWTCookie != _EMPTY_ {
119✔
1172
                if len(o.TrustedOperators) == 0 && len(o.TrustedKeys) == 0 {
1✔
1173
                        return fmt.Errorf("trusted operators or trusted keys configuration is required for JWT authentication via cookie %q", wo.JWTCookie)
×
1174
                }
×
1175
        }
1176
        if err := validatePinnedCerts(wo.TLSPinnedCerts); err != nil {
118✔
1177
                return fmt.Errorf("websocket: %v", err)
×
1178
        }
×
1179

1180
        // Check for invalid headers here.
1181
        for key := range wo.Headers {
118✔
1182
                k := strings.ToLower(key)
×
1183
                switch k {
×
1184
                case "host",
1185
                        "content-length",
1186
                        "connection",
1187
                        "upgrade",
1188
                        "nats-no-masking":
×
1189
                        return fmt.Errorf("websocket: invalid header %q not allowed", key)
×
1190
                }
1191

1192
                if strings.HasPrefix(k, "sec-websocket-") {
×
1193
                        return fmt.Errorf("websocket: invalid header %q, \"Sec-WebSocket-\" prefix not allowed", key)
×
1194
                }
×
1195
        }
1196

1197
        return nil
118✔
1198
}
1199

1200
// Creates or updates the existing map
1201
func (s *Server) wsSetOriginOptions(o *WebsocketOpts) {
117✔
1202
        ws := &s.websocket
117✔
1203
        ws.mu.Lock()
117✔
1204
        defer ws.mu.Unlock()
117✔
1205
        // Copy over the option's same origin boolean
117✔
1206
        ws.sameOrigin = o.SameOrigin
117✔
1207
        // Reset the map. Will help for config reload if/when we support it.
117✔
1208
        ws.allowedOrigins = nil
117✔
1209
        if o.AllowedOrigins == nil {
233✔
1210
                return
116✔
1211
        }
116✔
1212
        for _, ao := range o.AllowedOrigins {
3✔
1213
                // We have previously checked (during options validation) that the urls
2✔
1214
                // are parseable, but if we get an error, report and skip.
2✔
1215
                u, err := url.ParseRequestURI(ao)
2✔
1216
                if err != nil {
2✔
1217
                        s.Errorf("error parsing allowed origin: %v", err)
×
1218
                        continue
×
1219
                }
1220
                h, p, _ := wsGetHostAndPort(u.Scheme == "https", u.Host)
2✔
1221
                if ws.allowedOrigins == nil {
3✔
1222
                        ws.allowedOrigins = make(map[string][]*allowedOrigin, len(o.AllowedOrigins))
1✔
1223
                }
1✔
1224
                ws.allowedOrigins[h] = append(ws.allowedOrigins[h], &allowedOrigin{scheme: u.Scheme, port: p})
2✔
1225
        }
1226
}
1227

1228
// Calculate the raw headers for websocket upgrade response.
1229
func (s *Server) wsSetHeadersOptions(o *WebsocketOpts) {
117✔
1230
        var sb strings.Builder
117✔
1231
        for k, v := range o.Headers {
117✔
1232
                sb.WriteString(k)
×
1233
                sb.WriteString(": ")
×
1234
                sb.WriteString(v)
×
1235
                sb.WriteString(_CRLF_)
×
1236
        }
×
1237
        ws := &s.websocket
117✔
1238
        ws.mu.Lock()
117✔
1239
        defer ws.mu.Unlock()
117✔
1240
        ws.rawHeaders = sb.String()
117✔
1241
}
1242

1243
// Given the websocket options, we check if any auth configuration
1244
// has been provided. If so, possibly create users/nkey users and
1245
// store them in s.websocket.users/nkeys.
1246
// Also update a boolean that indicates if auth is required for
1247
// websocket clients.
1248
// Server lock is held on entry.
1249
func (s *Server) wsConfigAuth(opts *WebsocketOpts) {
7,820✔
1250
        ws := &s.websocket
7,820✔
1251
        // If any of those is specified, we consider that there is an override.
7,820✔
1252
        ws.authOverride = opts.Username != _EMPTY_ || opts.Token != _EMPTY_ || opts.NoAuthUser != _EMPTY_
7,820✔
1253
}
7,820✔
1254

1255
func (s *Server) startWebsocketServer() {
117✔
1256
        if s.isShuttingDown() {
117✔
1257
                return
×
1258
        }
×
1259

1260
        sopts := s.getOpts()
117✔
1261
        o := &sopts.Websocket
117✔
1262

117✔
1263
        s.wsSetOriginOptions(o)
117✔
1264
        s.wsSetHeadersOptions(o)
117✔
1265

117✔
1266
        var hl net.Listener
117✔
1267
        var proto string
117✔
1268
        var err error
117✔
1269

117✔
1270
        port := o.Port
117✔
1271
        if port == -1 {
233✔
1272
                port = 0
116✔
1273
        }
116✔
1274
        hp := net.JoinHostPort(o.Host, strconv.Itoa(port))
117✔
1275

117✔
1276
        // We are enforcing (when validating the options) the use of TLS, but the
117✔
1277
        // code was originally supporting both modes. The reason for TLS only is
117✔
1278
        // that we expect users to send JWTs with bearer tokens and we want to
117✔
1279
        // avoid the possibility of it being "intercepted".
117✔
1280

117✔
1281
        s.mu.Lock()
117✔
1282
        // Do not check o.NoTLS here. If a TLS configuration is available, use it,
117✔
1283
        // regardless of NoTLS. If we don't have a TLS config, it means that the
117✔
1284
        // user has configured NoTLS because otherwise the server would have failed
117✔
1285
        // to start due to options validation.
117✔
1286
        var config *tls.Config
117✔
1287
        if o.TLSConfig != nil {
129✔
1288
                proto = wsSchemePrefixTLS
12✔
1289
                config = o.TLSConfig.Clone()
12✔
1290
                config.GetConfigForClient = s.wsGetTLSConfig
12✔
1291
        } else {
117✔
1292
                proto = wsSchemePrefix
105✔
1293
        }
105✔
1294
        hl, err = natsListen("tcp", hp)
117✔
1295
        s.websocket.listenerErr = err
117✔
1296
        if err != nil {
117✔
1297
                s.mu.Unlock()
×
1298
                s.Fatalf("Unable to listen for websocket connections: %v", err)
×
1299
                return
×
1300
        }
×
1301
        if config != nil {
129✔
1302
                hl = tls.NewListener(hl, config)
12✔
1303
        }
12✔
1304
        if port == 0 {
233✔
1305
                o.Port = hl.Addr().(*net.TCPAddr).Port
116✔
1306
        }
116✔
1307
        s.Noticef("Listening for websocket clients on %s://%s:%d", proto, o.Host, o.Port)
117✔
1308
        if proto == wsSchemePrefix {
222✔
1309
                s.Warnf("Websocket not configured with TLS. DO NOT USE IN PRODUCTION!")
105✔
1310
        }
105✔
1311

1312
        // These 3 are immutable and will be accessed without lock by the client
1313
        // when generating/sending the INFO protocols.
1314
        s.websocket.tls = proto == wsSchemePrefixTLS
117✔
1315
        s.websocket.host, s.websocket.port = o.Host, o.Port
117✔
1316

117✔
1317
        // This will be updated when/if the cluster changes.
117✔
1318
        s.websocket.connectURLs, err = s.getConnectURLs(o.Advertise, o.Host, o.Port)
117✔
1319
        if err != nil {
117✔
1320
                s.Fatalf("Unable to get websocket connect URLs: %v", err)
×
1321
                hl.Close()
×
1322
                s.mu.Unlock()
×
1323
                return
×
1324
        }
×
1325
        hasLeaf := sopts.LeafNode.Port != 0
117✔
1326
        mux := http.NewServeMux()
117✔
1327
        mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
149✔
1328
                res, err := s.wsUpgrade(w, r)
32✔
1329
                if err != nil {
32✔
1330
                        s.Errorf(err.Error())
×
1331
                        return
×
1332
                }
×
1333
                switch res.kind {
32✔
1334
                case CLIENT:
3✔
1335
                        s.createWSClient(res.conn, res.ws)
3✔
1336
                case MQTT:
×
1337
                        s.createMQTTClient(res.conn, res.ws)
×
1338
                case LEAF:
29✔
1339
                        if !hasLeaf {
29✔
1340
                                s.Errorf("Not configured to accept leaf node connections")
×
1341
                                // Silently close for now. If we want to send an error back, we would
×
1342
                                // need to create the leafnode client anyway, so that is handling websocket
×
1343
                                // frames, then send the error to the remote.
×
1344
                                res.conn.Close()
×
1345
                                return
×
1346
                        }
×
1347
                        s.createLeafNode(res.conn, nil, nil, res.ws)
29✔
1348
                }
1349
        })
1350
        hs := &http.Server{
117✔
1351
                Addr:        hp,
117✔
1352
                Handler:     mux,
117✔
1353
                ReadTimeout: o.HandshakeTimeout,
117✔
1354
                ErrorLog:    log.New(&captureHTTPServerLog{s, "websocket: "}, _EMPTY_, 0),
117✔
1355
        }
117✔
1356
        s.websocket.mu.Lock()
117✔
1357
        s.websocket.server = hs
117✔
1358
        s.websocket.listener = hl
117✔
1359
        s.websocket.mu.Unlock()
117✔
1360
        go func() {
234✔
1361
                if err := hs.Serve(hl); err != http.ErrServerClosed {
117✔
1362
                        s.Fatalf("websocket listener error: %v", err)
×
1363
                }
×
1364
                if s.isLameDuckMode() {
117✔
1365
                        // Signal that we are not accepting new clients
×
1366
                        s.ldmCh <- true
×
1367
                        // Now wait for the Shutdown...
×
1368
                        <-s.quitCh
×
1369
                        return
×
1370
                }
×
1371
                s.done <- true
117✔
1372
        }()
1373
        s.mu.Unlock()
117✔
1374
}
1375

1376
// The TLS configuration is passed to the listener when the websocket
1377
// "server" is setup. That prevents TLS configuration updates on reload
1378
// from being used. By setting this function in tls.Config.GetConfigForClient
1379
// we instruct the TLS handshake to ask for the tls configuration to be
1380
// used for a specific client. We don't care which client, we always use
1381
// the same TLS configuration.
1382
func (s *Server) wsGetTLSConfig(_ *tls.ClientHelloInfo) (*tls.Config, error) {
11✔
1383
        opts := s.getOpts()
11✔
1384
        return opts.Websocket.TLSConfig, nil
11✔
1385
}
11✔
1386

1387
// This is similar to createClient() but has some modifications
1388
// specific to handle websocket clients.
1389
// The comments have been kept to minimum to reduce code size.
1390
// Check createClient() for more details.
1391
func (s *Server) createWSClient(conn net.Conn, ws *websocket) *client {
3✔
1392
        opts := s.getOpts()
3✔
1393

3✔
1394
        maxPay := int32(opts.MaxPayload)
3✔
1395
        maxSubs := int32(opts.MaxSubs)
3✔
1396
        if maxSubs == 0 {
6✔
1397
                maxSubs = -1
3✔
1398
        }
3✔
1399
        now := time.Now().UTC()
3✔
1400

3✔
1401
        c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now, ws: ws}
3✔
1402

3✔
1403
        c.registerWithAccount(s.globalAccount())
3✔
1404

3✔
1405
        var info Info
3✔
1406
        var authRequired bool
3✔
1407

3✔
1408
        s.mu.Lock()
3✔
1409
        info = s.copyInfo()
3✔
1410
        // Check auth, override if applicable.
3✔
1411
        if !info.AuthRequired {
3✔
1412
                // Set info.AuthRequired since this is what is sent to the client.
×
1413
                info.AuthRequired = s.websocket.authOverride
×
1414
        }
×
1415
        if s.nonceRequired() {
5✔
1416
                var raw [nonceLen]byte
2✔
1417
                nonce := raw[:]
2✔
1418
                s.generateNonce(nonce)
2✔
1419
                info.Nonce = string(nonce)
2✔
1420
        }
2✔
1421
        c.nonce = []byte(info.Nonce)
3✔
1422
        authRequired = info.AuthRequired
3✔
1423

3✔
1424
        s.totalClients++
3✔
1425
        s.mu.Unlock()
3✔
1426

3✔
1427
        c.mu.Lock()
3✔
1428
        if authRequired {
6✔
1429
                c.flags.set(expectConnect)
3✔
1430
        }
3✔
1431
        c.initClient()
3✔
1432
        c.Debugf("Client connection created")
3✔
1433
        c.sendProtoNow(c.generateClientInfoJSON(info, true))
3✔
1434
        c.mu.Unlock()
3✔
1435

3✔
1436
        s.mu.Lock()
3✔
1437
        if !s.isRunning() || s.ldm {
3✔
1438
                if s.isShuttingDown() {
×
1439
                        conn.Close()
×
1440
                }
×
1441
                s.mu.Unlock()
×
1442
                return c
×
1443
        }
1444

1445
        if opts.MaxConn < 0 || (opts.MaxConn > 0 && len(s.clients) >= opts.MaxConn) {
3✔
1446
                s.mu.Unlock()
×
1447
                c.maxConnExceeded()
×
1448
                return nil
×
1449
        }
×
1450
        s.clients[c.cid] = c
3✔
1451
        s.mu.Unlock()
3✔
1452

3✔
1453
        c.mu.Lock()
3✔
1454
        // Websocket clients do TLS in the websocket http server.
3✔
1455
        // So no TLS initiation here...
3✔
1456
        if _, ok := conn.(*tls.Conn); ok {
4✔
1457
                c.flags.set(handshakeComplete)
1✔
1458
        }
1✔
1459

1460
        if c.isClosed() {
3✔
1461
                c.mu.Unlock()
×
1462
                c.closeConnection(WriteError)
×
1463
                return nil
×
1464
        }
×
1465

1466
        if authRequired {
6✔
1467
                timeout := opts.AuthTimeout
3✔
1468
                // Possibly override with Websocket specific value.
3✔
1469
                if opts.Websocket.AuthTimeout != 0 {
3✔
1470
                        timeout = opts.Websocket.AuthTimeout
×
1471
                }
×
1472
                c.setAuthTimer(secondsToDuration(timeout))
3✔
1473
        }
1474

1475
        c.setPingTimer()
3✔
1476

3✔
1477
        s.startGoRoutine(func() { c.readLoop(nil) })
6✔
1478
        s.startGoRoutine(func() { c.writeLoop() })
6✔
1479

1480
        c.mu.Unlock()
3✔
1481

3✔
1482
        return c
3✔
1483
}
1484

1485
func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
348✔
1486
        nb := c.out.nb
348✔
1487
        var mfs int
348✔
1488
        var usz int
348✔
1489
        if c.ws.browser {
348✔
1490
                mfs = wsFrameSizeForBrowsers
×
1491
        }
×
1492
        mask := c.ws.maskwrite
348✔
1493
        // Start with possible already framed buffers (that we could have
348✔
1494
        // got from partials or control messages such as ws pings or pongs).
348✔
1495
        bufs := c.ws.frames
348✔
1496
        compress := c.ws.compress
348✔
1497
        if compress && len(nb) > 0 {
403✔
1498
                // First, make sure we don't compress for very small cumulative buffers.
55✔
1499
                for _, b := range nb {
124✔
1500
                        usz += len(b)
69✔
1501
                }
69✔
1502
                if usz <= wsCompressThreshold {
71✔
1503
                        compress = false
16✔
1504
                        if cp := c.ws.compressor; cp != nil {
32✔
1505
                                cp.Reset(nil)
16✔
1506
                        }
16✔
1507
                }
1508
        }
1509
        if compress && len(nb) > 0 {
387✔
1510
                // Overwrite mfs if this connection does not support fragmented compressed frames.
39✔
1511
                if mfs > 0 && c.ws.nocompfrag {
39✔
1512
                        mfs = 0
×
1513
                }
×
1514
                buf := bytes.NewBuffer(nbPoolGet(usz))
39✔
1515
                cp := c.ws.compressor
39✔
1516
                if cp == nil {
49✔
1517
                        c.ws.compressor, _ = flate.NewWriter(buf, flate.BestSpeed)
10✔
1518
                        cp = c.ws.compressor
10✔
1519
                } else {
39✔
1520
                        cp.Reset(buf)
29✔
1521
                }
29✔
1522
                var csz int
39✔
1523
                for i, b := range nb {
92✔
1524
                        for len(b) > 0 {
106✔
1525
                                n, err := cp.Write(b)
53✔
1526
                                if err != nil {
53✔
1527
                                        // Whatever this error is, it'll be handled by the cp.Flush()
×
1528
                                        // call below, as the same error will be returned there.
×
1529
                                        // Let the outer loop return all the buffers back to the pool
×
1530
                                        // and fall through naturally.
×
1531
                                        break
×
1532
                                }
1533
                                b = b[n:]
53✔
1534
                        }
1535
                        // Use original slice since capacity will change to zero
1536
                        // in the loop after consuming the buffer, which will make
1537
                        // nbPoolPut discard it.
1538
                        nbPoolPut(nb[i])
53✔
1539
                }
1540
                if err := cp.Flush(); err != nil {
39✔
1541
                        c.Errorf("Error during compression: %v", err)
×
1542
                        c.markConnAsClosed(WriteError)
×
1543
                        cp.Reset(nil)
×
1544
                        return nil, 0
×
1545
                }
×
1546
                b := buf.Bytes()
39✔
1547
                p := b[:len(b)-4]
39✔
1548
                if mfs > 0 && len(p) > mfs {
39✔
1549
                        for first, final := true, false; len(p) > 0; first = false {
×
1550
                                lp := len(p)
×
1551
                                if lp > mfs {
×
1552
                                        lp = mfs
×
1553
                                } else {
×
1554
                                        final = true
×
1555
                                }
×
1556
                                // Only the first frame should be marked as compressed, so pass
1557
                                // `first` for the compressed boolean.
1558
                                fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize]
×
1559
                                n, key := wsFillFrameHeader(fh, mask, first, final, first, wsBinaryMessage, lp)
×
1560
                                if mask {
×
1561
                                        wsMaskBuf(key, p[:lp])
×
1562
                                }
×
1563
                                bufs = append(bufs, fh[:n], p[:lp])
×
1564
                                csz += n + lp
×
1565
                                p = p[lp:]
×
1566
                        }
1567
                } else {
39✔
1568
                        ol := len(p)
39✔
1569
                        h, key := wsCreateFrameHeader(mask, true, wsBinaryMessage, ol)
39✔
1570
                        if mask {
51✔
1571
                                wsMaskBuf(key, p)
12✔
1572
                        }
12✔
1573
                        if ol > 0 {
78✔
1574
                                bufs = append(bufs, h, p)
39✔
1575
                        }
39✔
1576
                        csz = len(h) + ol
39✔
1577
                }
1578
                // Make sure that the compressor no longer holds a reference to
1579
                // the bytes.Buffer, so that the underlying memory gets cleaned
1580
                // up after flushOutbound/flushAndClose. For this to be safe, we
1581
                // always cp.Reset(...) before reusing the compressor again.
1582
                cp.Reset(nil)
39✔
1583
                // Add to pb the compressed data size (including headers), but
39✔
1584
                // remove the original uncompressed data size that was added
39✔
1585
                // during the queueing.
39✔
1586
                c.out.pb += int64(csz) - int64(usz)
39✔
1587
                c.ws.fs += int64(csz)
39✔
1588
        } else if len(nb) > 0 {
560✔
1589
                var total int
251✔
1590
                if mfs > 0 {
251✔
1591
                        // We are limiting the frame size.
×
1592
                        startFrame := func() int {
×
1593
                                bufs = append(bufs, nbPoolGet(wsMaxFrameHeaderSize))
×
1594
                                return len(bufs) - 1
×
1595
                        }
×
1596
                        endFrame := func(idx, size int) {
×
1597
                                bufs[idx] = bufs[idx][:wsMaxFrameHeaderSize]
×
1598
                                n, key := wsFillFrameHeader(bufs[idx], mask, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, wsBinaryMessage, size)
×
1599
                                bufs[idx] = bufs[idx][:n]
×
1600
                                c.out.pb += int64(n)
×
1601
                                c.ws.fs += int64(n + size)
×
1602
                                if mask {
×
1603
                                        wsMaskBufs(key, bufs[idx+1:])
×
1604
                                }
×
1605
                        }
1606

1607
                        fhIdx := startFrame()
×
1608
                        for i := 0; i < len(nb); i++ {
×
1609
                                b := nb[i]
×
1610
                                if total+len(b) <= mfs {
×
1611
                                        buf := nbPoolGet(len(b))
×
1612
                                        bufs = append(bufs, append(buf, b...))
×
1613
                                        total += len(b)
×
1614
                                        nbPoolPut(nb[i])
×
1615
                                        continue
×
1616
                                }
1617
                                for len(b) > 0 {
×
1618
                                        endStart := total != 0
×
1619
                                        if endStart {
×
1620
                                                endFrame(fhIdx, total)
×
1621
                                        }
×
1622
                                        total = len(b)
×
1623
                                        if total >= mfs {
×
1624
                                                total = mfs
×
1625
                                        }
×
1626
                                        if endStart {
×
1627
                                                fhIdx = startFrame()
×
1628
                                        }
×
1629
                                        buf := nbPoolGet(total)
×
1630
                                        bufs = append(bufs, append(buf, b[:total]...))
×
1631
                                        b = b[total:]
×
1632
                                }
1633
                                nbPoolPut(nb[i]) // No longer needed as copied into smaller frames.
×
1634
                        }
1635
                        if total > 0 {
×
1636
                                endFrame(fhIdx, total)
×
1637
                        }
×
1638
                } else {
251✔
1639
                        // If there is no limit on the frame size, create a single frame for
251✔
1640
                        // all pending buffers.
251✔
1641
                        for _, b := range nb {
551✔
1642
                                total += len(b)
300✔
1643
                        }
300✔
1644
                        wsfh, key := wsCreateFrameHeader(mask, false, wsBinaryMessage, total)
251✔
1645
                        c.out.pb += int64(len(wsfh))
251✔
1646
                        bufs = append(bufs, wsfh)
251✔
1647
                        idx := len(bufs)
251✔
1648
                        bufs = append(bufs, nb...)
251✔
1649
                        if mask {
326✔
1650
                                wsMaskBufs(key, bufs[idx:])
75✔
1651
                        }
75✔
1652
                        c.ws.fs += int64(len(wsfh) + total)
251✔
1653
                }
1654
        }
1655
        if len(c.ws.closeMsg) > 0 {
409✔
1656
                bufs = append(bufs, c.ws.closeMsg)
61✔
1657
                c.ws.fs += int64(len(c.ws.closeMsg))
61✔
1658
                c.ws.closeMsg = nil
61✔
1659
                c.ws.compressor = nil
61✔
1660
        }
61✔
1661
        c.ws.frames = nil
348✔
1662
        return bufs, c.ws.fs
348✔
1663
}
1664

1665
func isWSURL(u *url.URL) bool {
2,742✔
1666
        return strings.HasPrefix(strings.ToLower(u.Scheme), wsSchemePrefix)
2,742✔
1667
}
2,742✔
1668

1669
func isWSSURL(u *url.URL) bool {
1,800✔
1670
        return strings.HasPrefix(strings.ToLower(u.Scheme), wsSchemePrefixTLS)
1,800✔
1671
}
1,800✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc