• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

dguerri / k8s-ssh-gateway-controller / 26014014641

17 May 2026 06:33PM UTC coverage: 83.864% (-0.4%) from 84.239%
26014014641

push

github

dguerri
feat(ssh): add diagnostic goroutine to log SSH transport shutdown reason

Log the underlying SSH transport exit error via `ssh.Client.Wait()` in a
background goroutine. This surfaces SSH_MSG_DISCONNECT reason strings
sent by the server, which are otherwise hidden behind the generic EOF
returned to `SendRequest`, making connection failures easier to diagnose.

5 of 9 new or added lines in 1 file covered. (55.56%)

4 existing lines in 1 file now uncovered.

1580 of 1884 relevant lines covered (83.86%)

6.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.43
/ssh/ssh_tunnel_manager.go
1
package ssh
2

3
import (
4
        "context"
5
        "crypto/sha256"
6
        "encoding/base64"
7
        "fmt"
8
        "io"
9
        "log/slog"
10
        "net"
11
        "strconv"
12
        "strings"
13
        "sync"
14
        "time"
15

16
        "golang.org/x/crypto/ssh"
17
)
18

19
// sshClient defines the methods used by the SSH tunnel manager.
20
type sshClient interface {
21
        Listen(network, addr string) (net.Listener, error)
22
        SendRequest(string, bool, []byte) (bool, []byte, error)
23
        HandleChannelOpen(string) <-chan ssh.NewChannel
24
        Close() error
25
}
26

27
// sshDialFunc is a function type for establishing SSH connections.
28
type sshDialFunc func(network, addr string, cfg *ssh.ClientConfig) (sshClient, error)
29

30
// netDialFunc is a function type for establishing TCP connections.
31
type netDialFunc func(network string, address string) (net.Conn, error)
32

33
// Default dial functions - can be overridden via config for testing.
34
var (
35
        defaultSSHDial sshDialFunc = func(network, addr string, cfg *ssh.ClientConfig) (sshClient, error) {
×
36
                return ssh.Dial(network, addr, cfg)
×
37
        }
×
38
        defaultNetDial netDialFunc = func(network string, address string) (net.Conn, error) {
×
39
                return net.Dial(network, address)
×
40
        }
×
41
)
42

43
// Legacy global variables for backward compatibility with existing tests.
44
// Deprecated: Use SSHConnectionConfig.SSHDialFunc and NetDialFunc instead.
45
var (
46
        sshDial = defaultSSHDial
47
        netDial = defaultNetDial
48
)
49

50
// ExtractAddrFunc is a function type used to pass a callback that takes text
51
// returned to ssh server and returns a string containing a uri.
52
type ExtractAddrFunc func(string) ([]string, error)
53

54
// SSHConnectionConfig contains configuration for SSH connection.
55
type SSHConnectionConfig struct {
56
        RemoteAddrFunc             ExtractAddrFunc
57
        SSHDialFunc                sshDialFunc
58
        NetDialFunc                netDialFunc
59
        ServerAddress              string
60
        Username                   string
61
        HostKey                    string
62
        PrivateKey                 []byte
63
        ConnectTimeout             time.Duration
64
        FwdReqTimeout              time.Duration
65
        KeepAliveInterval          time.Duration
66
        AddressVerificationTimeout time.Duration // Optional: timeout for address verification (default: 30s)
67
}
68

69
// ForwardingConfig defines configuration for a single forwarding.
70
type ForwardingConfig struct {
71
        RemoteHost   string
72
        InternalHost string
73
        RemotePort   int
74
        InternalPort int
75
}
76

77
// SSHTunnelManager manages SSH tunnels and forwardings.
78
type SSHTunnelManager struct {
79
        externalCtx                context.Context
80
        client                     sshClient
81
        connectionCtx              context.Context
82
        signer                     ssh.Signer
83
        forwardings                map[string]*ForwardingConfig
84
        remoteAddrFunc             ExtractAddrFunc
85
        netDialFunc                netDialFunc
86
        sshDialFunc                sshDialFunc
87
        addrNotifications          map[string]chan []string
88
        assignedAddrs              map[string][]string
89
        connectionCancel           context.CancelFunc
90
        hostKey                    string
91
        sshUser                    string
92
        sshServerAddress           string
93
        fwdReqTimeout              time.Duration
94
        connTimeout                time.Duration
95
        keepAliveInterval          time.Duration
96
        addressVerificationTimeout time.Duration
97
        clientMu                   sync.RWMutex
98
        addrNotifMu                sync.RWMutex
99
        captureReady               chan struct{}
100
        proxyProtocol              int
101
        connected                  bool
102
}
103

104
// forwardingKey generates a unique key for the forwarding session based on remote host and port.
105
func forwardingKey(remoteHost string, remotePort int) string {
39✔
106
        return fmt.Sprintf("%s:%d", remoteHost, remotePort)
39✔
107
}
39✔
108

109
// NewSSHTunnelManager creates a new SSHTunnelManager with the specified configuration.
110
// externalCtx is used to cancel the connection and forwardings when the context is done.
111
func NewSSHTunnelManager(externalCtx context.Context, config *SSHConnectionConfig) (*SSHTunnelManager, error) {
27✔
112
        signer, err := ssh.ParsePrivateKey(config.PrivateKey)
27✔
113
        if err != nil {
28✔
114
                slog.With("function", "NewSSHTunnelManager").Error("unable to parse private key", "error", err)
1✔
115
                return nil, fmt.Errorf("unable to parse private key: %w", err)
1✔
116
        }
1✔
117

118
        // Use provided dial functions or fall back to defaults/globals
119
        sshDialFn := config.SSHDialFunc
26✔
120
        if sshDialFn == nil {
50✔
121
                sshDialFn = sshDial // Use global for backward compatibility
24✔
122
        }
24✔
123
        netDialFn := config.NetDialFunc
26✔
124
        if netDialFn == nil {
51✔
125
                netDialFn = netDial // Use global for backward compatibility
25✔
126
        }
25✔
127

128
        addrVerifyTimeout := config.AddressVerificationTimeout
26✔
129
        if addrVerifyTimeout == 0 {
46✔
130
                addrVerifyTimeout = addressVerificationTimeout
20✔
131
        }
20✔
132

133
        m := &SSHTunnelManager{
26✔
134
                sshServerAddress:           config.ServerAddress,
26✔
135
                sshUser:                    config.Username,
26✔
136
                hostKey:                    config.HostKey,
26✔
137
                signer:                     signer,
26✔
138
                connTimeout:                config.ConnectTimeout,
26✔
139
                fwdReqTimeout:              config.FwdReqTimeout,
26✔
140
                keepAliveInterval:          config.KeepAliveInterval,
26✔
141
                addressVerificationTimeout: addrVerifyTimeout,
26✔
142
                externalCtx:                externalCtx,
26✔
143
                connected:                  false,
26✔
144
                remoteAddrFunc:             config.RemoteAddrFunc,
26✔
145
                forwardings:                make(map[string]*ForwardingConfig),
26✔
146
                assignedAddrs:              make(map[string][]string),
26✔
147
                addrNotifications:          make(map[string]chan []string),
26✔
148
                sshDialFunc:                sshDialFn,
26✔
149
                netDialFunc:                netDialFn,
26✔
150
        }
26✔
151

26✔
152
        return m, nil
26✔
153
}
154

155
// Connect attempts to establish the SSH connection.
156
// It returns nil if already connected or if connection succeeds.
157
func (m *SSHTunnelManager) Connect() error {
24✔
158
        m.clientMu.Lock()
24✔
159
        defer m.clientMu.Unlock()
24✔
160

24✔
161
        if m.connected {
24✔
162
                return nil
×
163
        }
×
164

165
        if err := m.connectClient(); err != nil {
26✔
166
                return err
2✔
167
        }
2✔
168

169
        // Start background tasks
170
        go m.handleChannels()
22✔
171
        go m.monitorConnection()
22✔
172

22✔
173
        slog.With("function", "Connect").Info("ssh connection established")
22✔
174
        return nil
22✔
175
}
176

177
// monitorConnection monitors the SSH connection and sends keepalive requests.
178
// If a keepalive fails, it closes the connection.
179
func (m *SSHTunnelManager) monitorConnection() {
22✔
180
        ticker := time.NewTicker(m.keepAliveInterval)
22✔
181
        defer ticker.Stop()
22✔
182

22✔
183
        // Capture the context we are monitoring
22✔
184
        m.clientMu.RLock()
22✔
185
        ctx := m.connectionCtx
22✔
186
        m.clientMu.RUnlock()
22✔
187

22✔
188
        if ctx == nil {
22✔
189
                return
×
190
        }
×
191

192
        for {
45✔
193
                select {
23✔
194
                case <-m.externalCtx.Done():
12✔
195
                        return
12✔
196
                case <-ctx.Done():
9✔
197
                        return
9✔
198
                case <-ticker.C:
2✔
199
                        m.clientMu.RLock()
2✔
200
                        if !m.connected || m.client == nil || m.connectionCtx.Err() != nil {
2✔
201
                                m.clientMu.RUnlock()
×
202
                                return
×
203
                        }
×
204
                        client := m.client
2✔
205
                        m.clientMu.RUnlock()
2✔
206

2✔
207
                        // Send keepalive with timeout to avoid hanging indefinitely
2✔
208
                        type keepaliveResult struct {
2✔
209
                                err error
2✔
210
                        }
2✔
211
                        resultCh := make(chan keepaliveResult, 1)
2✔
212

2✔
213
                        go func() {
4✔
214
                                _, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
2✔
215
                                resultCh <- keepaliveResult{err: err}
2✔
216
                        }()
2✔
217

218
                        // Wait for keepalive response with timeout
219
                        keepaliveTimeout := m.keepAliveInterval * 2 // 2x keepalive interval
2✔
220
                        select {
2✔
221
                        case result := <-resultCh:
2✔
222
                                if result.err != nil {
3✔
223
                                        slog.With("function", "monitorConnection").Error("ssh keepalive failed, closing connection", "error", result.err)
1✔
224
                                        m.clientMu.Lock()
1✔
225
                                        m.closeClient()
1✔
226
                                        m.clientMu.Unlock()
1✔
227
                                        return
1✔
228
                                }
1✔
229
                                slog.With("function", "monitorConnection").Debug("ssh keepalive sent")
1✔
230
                        case <-time.After(keepaliveTimeout):
×
231
                                slog.With("function", "monitorConnection").Error("ssh keepalive timeout, closing connection", "timeout", keepaliveTimeout)
×
232
                                m.clientMu.Lock()
×
233
                                m.closeClient()
×
234
                                m.clientMu.Unlock()
×
235
                                return
×
236
                        }
237
                }
238
        }
239
}
240

241
// closeClient closes the current SSH client connection.
242
// Must be called with a write lock on m.clientMu
243
func (m *SSHTunnelManager) closeClient() {
4✔
244
        if !m.connected {
4✔
245
                return
×
246
        }
×
247
        m.connected = false
4✔
248

4✔
249
        if m.connectionCancel != nil {
8✔
250
                m.connectionCancel()
4✔
251
        }
4✔
252
        if m.client != nil {
8✔
253
                if err := m.client.Close(); err != nil {
4✔
254
                        slog.With("function", "disconnect").Error("failed to close SSH client", "error", err)
×
255
                }
×
256
        }
257
        m.client = nil
4✔
258
        m.connectionCancel = nil
4✔
259

4✔
260
        // Clear assigned addresses as they are invalid on disconnect
4✔
261
        m.addrNotifMu.Lock()
4✔
262
        m.assignedAddrs = make(map[string][]string)
4✔
263
        m.addrNotifMu.Unlock()
4✔
264

4✔
265
        // Clear forwardings as they need to be re-established
4✔
266
        m.forwardings = make(map[string]*ForwardingConfig)
4✔
267
}
268

269
type channelForwardMsg struct {
270
        addr  string
271
        rport uint32
272
}
273

274
type forwardedTCPPayload struct {
275
        Addr       string
276
        Port       uint32
277
        OriginAddr string
278
        OriginPort uint32
279
}
280

281
// StartForwarding starts a new forwarding based on the provided configuration.
282
func (m *SSHTunnelManager) StartForwarding(fwd ForwardingConfig) error {
17✔
283
        m.clientMu.Lock()
17✔
284
        defer m.clientMu.Unlock()
17✔
285

17✔
286
        if !m.connected || m.client == nil {
18✔
287
                slog.With("function", "StartForwarding").Warn("client not ready")
1✔
288
                return &ErrSSHClientNotReady{}
1✔
289
        }
1✔
290
        key := forwardingKey(fwd.RemoteHost, fwd.RemotePort)
16✔
291

16✔
292
        if _, exists := m.forwardings[key]; exists {
18✔
293
                err := &ErrSSHForwardingExists{Key: key}
2✔
294
                slog.With("function", "StartForwarding").Error(err.Error())
2✔
295
                return err
2✔
296
        }
2✔
297

298
        err := m.sendForwarding(&fwd, ForwardStart)
14✔
299
        if err != nil {
18✔
300
                slog.With("function", "StartForwarding").Error("failed to send forwarding request", "error", err)
4✔
301
                return err
4✔
302
        }
4✔
303

304
        // Store the forwarding session
305
        m.forwardings[key] = &fwd
10✔
306

10✔
307
        slog.With("function", "StartForwarding").Info("started forwarding", "key", key)
10✔
308

10✔
309
        return nil
10✔
310
}
311

312
// StopForwarding stops an existing forwarding based on the provided configuration.
313
func (m *SSHTunnelManager) StopForwarding(fwd *ForwardingConfig) error {
4✔
314
        m.clientMu.Lock()
4✔
315
        defer m.clientMu.Unlock()
4✔
316

4✔
317
        if !m.connected || m.client == nil {
5✔
318
                return &ErrSSHClientNotReady{}
1✔
319
        }
1✔
320
        key := forwardingKey(fwd.RemoteHost, fwd.RemotePort)
3✔
321
        forwardingSession, exists := m.forwardings[key]
3✔
322
        if !exists {
5✔
323
                err := &ErrSSHForwardingNotFound{Key: key}
2✔
324
                slog.With("function", "StopForwarding").Warn(err.Error())
2✔
325
                return err
2✔
326
        }
2✔
327

328
        // Cancel forwarding may still fail, but we should remove the forwardingSession anyway.
329
        delete(m.forwardings, key)
1✔
330

1✔
331
        // Clean up assigned addresses
1✔
332
        m.addrNotifMu.Lock()
1✔
333
        delete(m.assignedAddrs, key)
1✔
334
        m.addrNotifMu.Unlock()
1✔
335

1✔
336
        err := m.sendForwarding(forwardingSession, ForwardCancel)
1✔
337
        if err != nil {
1✔
338
                slog.With("function", "StopForwarding").Error("failed to send cancel request", "error", err)
×
339
                return err
×
340
        }
×
341

342
        slog.With("function", "StopForwarding").Info("stopped forwarding", "key", key)
1✔
343

1✔
344
        return nil
1✔
345
}
346

347
// GetAssignedAddresses returns the assigned addresses (URIs) for a forwarding configuration.
348
// Returns nil if no addresses have been assigned yet or the forwarding doesn't exist.
349
func (m *SSHTunnelManager) GetAssignedAddresses(remoteHost string, remotePort int) []string {
7✔
350
        key := forwardingKey(remoteHost, remotePort)
7✔
351
        m.addrNotifMu.RLock()
7✔
352
        defer m.addrNotifMu.RUnlock()
7✔
353

7✔
354
        if addrs, ok := m.assignedAddrs[key]; ok {
13✔
355
                // Return a copy to avoid race conditions
6✔
356
                result := make([]string, len(addrs))
6✔
357
                copy(result, addrs)
6✔
358
                return result
6✔
359
        }
6✔
360
        return nil
1✔
361
}
362

363
// IsConnected checks if the SSH client is connected.
364
func (m *SSHTunnelManager) IsConnected() bool {
4✔
365
        m.clientMu.RLock()
4✔
366
        defer m.clientMu.RUnlock()
4✔
367
        return m.connected
4✔
368
}
4✔
369

370
// SetProxyProtocol sets the PROXY protocol version (0=disabled, 1 or 2).
371
// If the version changes and the manager is connected, it disconnects so
372
// the next Connect() re-establishes the session with the new setting.
373
func (m *SSHTunnelManager) SetProxyProtocol(version int) {
4✔
374
        m.clientMu.Lock()
4✔
375
        defer m.clientMu.Unlock()
4✔
376

4✔
377
        if m.proxyProtocol == version {
5✔
378
                return
1✔
379
        }
1✔
380

381
        slog.With("function", "SetProxyProtocol").Info("proxy protocol version changed",
3✔
382
                "old", m.proxyProtocol, "new", version)
3✔
383
        m.proxyProtocol = version
3✔
384

3✔
385
        if m.connected {
4✔
386
                slog.With("function", "SetProxyProtocol").Info("disconnecting to apply new proxy protocol setting")
1✔
387
                m.closeClient()
1✔
388
        }
1✔
389
}
390

391
// GetProxyProtocol returns the current PROXY protocol version.
392
func (m *SSHTunnelManager) GetProxyProtocol() int {
4✔
393
        m.clientMu.RLock()
4✔
394
        defer m.clientMu.RUnlock()
4✔
395
        return m.proxyProtocol
4✔
396
}
4✔
397

398
// ForwardRequest is the type of forwarding request to send.
399
type ForwardRequest string
400

401
const (
402
        ForwardStart  ForwardRequest = "start"
403
        ForwardCancel ForwardRequest = "cancel"
404
)
405

406
// Constants for forwarding retry and verification logic
407
const (
408
        addressVerificationTimeout  = 30 * time.Second
409
        addrNotificationChannelSize = 5
410
)
411

412
// matchesRequestedHost checks if any of the extracted URIs match the requested hostname.
413
// For HTTP/HTTPS URIs, it checks if the hostname contains the requested host.
414
// For TCP URIs, it checks if the host:port matches.
415
func matchesRequestedHost(uris []string, requestedHost string, requestedPort int) bool {
10✔
416
        if requestedHost == "" || requestedHost == "0.0.0.0" {
12✔
417
                // No specific hostname requested, any assignment is fine
2✔
418
                return true
2✔
419
        }
2✔
420

421
        for _, uri := range uris {
17✔
422
                // For HTTP/HTTPS: check if hostname contains requested host
9✔
423
                // e.g., requested "dev", got "https://user-dev.tuns.sh" -> match
9✔
424
                if strings.HasPrefix(uri, "http://") || strings.HasPrefix(uri, "https://") {
15✔
425
                        if strings.Contains(uri, requestedHost) {
9✔
426
                                return true
3✔
427
                        }
3✔
428
                }
429
                // For TCP: check if it contains host:port
430
                // e.g., requested "example.com", got "tcp://example.com:8080" -> match
431
                expectedTCP := fmt.Sprintf("%s:%d", requestedHost, requestedPort)
6✔
432
                if strings.Contains(uri, expectedTCP) {
8✔
433
                        return true
2✔
434
                }
2✔
435
        }
436
        return false
3✔
437
}
438

439
// sendForwarding sends a request to the SSH server to start or cancel a TCP forwarding.
440
// 'req' controls the request: ForwardStart -> "tcpip-forward", ForwardCancel -> "cancel-tcpip-forward".
441
// For ForwardStart requests, it waits for the SSH server to report the assigned address and verifies
442
// it matches the requested hostname.
443
// It returns an error if the request fails or is denied by the server.
444
// Must be called with a lock on m.clientMu
445
func (m *SSHTunnelManager) sendForwarding(fwd *ForwardingConfig, req ForwardRequest) error {
16✔
446
        // Wait for the capture session (Shell) to be ready before sending the
16✔
447
        // tcpip-forward request. This ensures the server output (forwarding
16✔
448
        // addresses) will be captured. Without this, the server may output
16✔
449
        // addresses before the Shell session is listening, and they'd be lost.
16✔
450
        if req == ForwardStart && m.captureReady != nil {
30✔
451
                slog.With("function", "sendForwarding").Debug("waiting for capture session to be ready")
14✔
452
                select {
14✔
453
                case <-m.captureReady:
14✔
454
                        slog.With("function", "sendForwarding").Debug("capture session ready, proceeding")
14✔
455
                case <-m.connectionCtx.Done():
×
456
                        return fmt.Errorf("connection closed while waiting for capture session")
×
457
                }
458
        }
459

460
        // For ForwardStart with address collection enabled, register the notification
461
        // channel BEFORE sending the SSH request. The server can emit the assigned URI
462
        // on stdout concurrently with (or even before) returning the request response;
463
        // if we registered after sendForwardingOnce, notifyURIWaiters would have no
464
        // channel to send to and the URI would be dropped, leading to an empty
465
        // assignedAddrs while m.forwardings[key] is populated — a state that produces
466
        // an infinite "forwarding already exists, adopting it" reconcile loop.
467
        var notifCh chan []string
16✔
468
        var key string
16✔
469
        if req == ForwardStart && m.remoteAddrFunc != nil {
22✔
470
                key = forwardingKey(fwd.RemoteHost, fwd.RemotePort)
6✔
471
                notifCh = make(chan []string, addrNotificationChannelSize)
6✔
472
                m.addrNotifMu.Lock()
6✔
473
                m.addrNotifications[key] = notifCh
6✔
474
                m.addrNotifMu.Unlock()
6✔
475

6✔
476
                defer func() {
12✔
477
                        m.addrNotifMu.Lock()
6✔
478
                        delete(m.addrNotifications, key)
6✔
479
                        m.addrNotifMu.Unlock()
6✔
480
                        close(notifCh)
6✔
481
                }()
6✔
482
        }
483

484
        // Send the request once
485
        err := m.sendForwardingOnce(fwd, req)
16✔
486
        if err != nil {
18✔
487
                return err
2✔
488
        }
2✔
489

490
        // For ForwardCancel, we're done after successful request
491
        if req == ForwardCancel {
16✔
492
                return nil
2✔
493
        }
2✔
494

495
        if notifCh == nil {
18✔
496
                // No remoteAddrFunc provided, assume success
6✔
497
                return nil
6✔
498
        }
6✔
499

500
        return m.awaitAssignedAddresses(fwd, key, notifCh)
6✔
501
}
502

503
// awaitAssignedAddresses blocks until the SSH server reports the assigned URIs
504
// for a freshly-sent tcpip-forward request, or the verification timeout fires.
505
// On success it stores the URIs in m.assignedAddrs. For specific hostnames it
506
// verifies the returned URIs match the request and cancels the forwarding on
507
// mismatch or timeout.
508
// Must be called with a lock on m.clientMu.
509
func (m *SSHTunnelManager) awaitAssignedAddresses(fwd *ForwardingConfig, key string, notifCh <-chan []string) error {
6✔
510
        verifyCtx, verifyCancel := context.WithTimeout(m.externalCtx, m.addressVerificationTimeout)
6✔
511
        defer verifyCancel()
6✔
512

6✔
513
        needsVerification := fwd.RemoteHost != "" && fwd.RemoteHost != "0.0.0.0" && fwd.RemoteHost != "localhost"
6✔
514

6✔
515
        select {
6✔
516
        case <-verifyCtx.Done():
1✔
517
                return m.handleVerificationTimeout(fwd, needsVerification)
1✔
518
        case uris := <-notifCh:
5✔
519
                return m.handleAssignedURIs(fwd, key, uris, needsVerification)
5✔
520
        }
521
}
522

523
// handleVerificationTimeout reacts to the address-verification timeout. For
524
// hostname-specific requests it cancels the forwarding and returns an error;
525
// for wildcard/generic requests it logs and reports success (no specific
526
// address was required).
527
// Must be called with a lock on m.clientMu.
528
func (m *SSHTunnelManager) handleVerificationTimeout(fwd *ForwardingConfig, needsVerification bool) error {
1✔
529
        if needsVerification {
2✔
530
                slog.With("function", "sendForwarding").Error("timeout waiting for address verification, canceling forwarding",
1✔
531
                        "remote_host", fwd.RemoteHost, "remote_port", fwd.RemotePort)
1✔
532
                _ = m.sendForwardingOnce(fwd, ForwardCancel)
1✔
533
                return fmt.Errorf("timeout waiting for address verification for %s", fwd.RemoteHost)
1✔
534
        }
1✔
535
        slog.With("function", "sendForwarding").Warn("timeout waiting for address verification",
×
536
                "remote_host", fwd.RemoteHost, "remote_port", fwd.RemotePort)
×
537
        return nil
×
538
}
539

540
// handleAssignedURIs validates URIs returned by the SSH server. For specific
541
// hostnames it cancels the forwarding on hostname mismatch; otherwise it
542
// stores the URIs in m.assignedAddrs.
543
// Must be called with a lock on m.clientMu.
544
func (m *SSHTunnelManager) handleAssignedURIs(fwd *ForwardingConfig, key string, uris []string, needsVerification bool) error {
5✔
545
        if needsVerification && !matchesRequestedHost(uris, fwd.RemoteHost, fwd.RemotePort) {
6✔
546
                slog.With("function", "sendForwarding").Warn("wrong hostname assigned",
1✔
547
                        "requested_host", fwd.RemoteHost, "received_uris", uris)
1✔
548
                _ = m.sendForwardingOnce(fwd, ForwardCancel)
1✔
549
                return fmt.Errorf("wrong hostname assigned: %v", uris)
1✔
550
        }
1✔
551

552
        if needsVerification {
5✔
553
                slog.With("function", "sendForwarding").Info("verified correct hostname assigned",
1✔
554
                        "remote_host", fwd.RemoteHost, "uris", uris)
1✔
555
        } else {
4✔
556
                slog.With("function", "sendForwarding").Debug("storing assigned addresses for wildcard forwarding",
3✔
557
                        "remote_host", fwd.RemoteHost, "remote_port", fwd.RemotePort, "uris", uris)
3✔
558
        }
3✔
559

560
        m.addrNotifMu.Lock()
4✔
561
        m.assignedAddrs[key] = uris
4✔
562
        m.addrNotifMu.Unlock()
4✔
563
        return nil
4✔
564
}
565

566
// sendForwardingOnce sends a single forwarding request without retry logic.
567
// Must be called with a lock on m.clientMu
568
func (m *SSHTunnelManager) sendForwardingOnce(fwd *ForwardingConfig, req ForwardRequest) error {
18✔
569
        // Validate port range to prevent integer overflow
18✔
570
        if fwd.RemotePort < 0 || fwd.RemotePort > 65535 {
18✔
571
                return fmt.Errorf("invalid remote port %d: must be between 0 and 65535", fwd.RemotePort)
×
572
        }
×
573

574
        forwardMessage := channelForwardMsg{
18✔
575
                addr:  fwd.RemoteHost,
18✔
576
                rport: uint32(fwd.RemotePort), // #nosec G115 -- Port validated above
18✔
577
        }
18✔
578

18✔
579
        ctx, cancel := context.WithTimeout(m.externalCtx, m.fwdReqTimeout)
18✔
580
        defer cancel()
18✔
581

18✔
582
        var reqType string
18✔
583
        switch req {
18✔
584
        case ForwardStart:
14✔
585
                reqType = "tcpip-forward"
14✔
586
        case ForwardCancel:
4✔
587
                reqType = "cancel-tcpip-forward"
4✔
588
        default:
×
589
                return fmt.Errorf("ssh: unknown forwarding request type: %q", req)
×
590
        }
591

592
        slog.With("function", "sendForwardingOnce").Info("sending SSH request",
18✔
593
                "request_type", reqType,
18✔
594
                "remote_host", fwd.RemoteHost,
18✔
595
                "remote_port", fwd.RemotePort,
18✔
596
                "internal_host", fwd.InternalHost,
18✔
597
                "internal_port", fwd.InternalPort,
18✔
598
                "marshaled_addr", forwardMessage.addr,
18✔
599
                "marshaled_port", forwardMessage.rport)
18✔
600

18✔
601
        resCh := make(chan struct {
18✔
602
                err error
18✔
603
                ok  bool
18✔
604
        })
18✔
605

18✔
606
        go func() {
36✔
607
                ok, _, err := m.client.SendRequest(reqType, true, ssh.Marshal(&forwardMessage))
18✔
608
                select {
18✔
609
                case resCh <- struct {
610
                        err error
611
                        ok  bool
612
                }{err, ok}:
17✔
613
                case <-ctx.Done():
×
614
                }
615
        }()
616

617
        select {
18✔
618
        case <-ctx.Done():
1✔
619
                slog.With("function", "sendForwardingOnce").Error("request timed out",
1✔
620
                        "request", reqType, "remote_host", fwd.RemoteHost, "remote_port", fwd.RemotePort)
1✔
621
                return fmt.Errorf("ssh: %s request timed out", reqType)
1✔
622
        case res := <-resCh:
17✔
623
                if res.err != nil {
17✔
624
                        slog.With("function", "sendForwardingOnce").Error("request failed with error",
×
625
                                "request", reqType, "remote_host", fwd.RemoteHost, "remote_port", fwd.RemotePort, "error", res.err)
×
626
                        return res.err
×
627
                }
×
628
                if !res.ok {
18✔
629
                        slog.With("function", "sendForwardingOnce").Error("request denied by server",
1✔
630
                                "request", reqType, "remote_host", fwd.RemoteHost, "remote_port", fwd.RemotePort)
1✔
631
                        return fmt.Errorf("ssh: %s request denied by server", reqType)
1✔
632
                }
1✔
633
                slog.With("function", "sendForwardingOnce").Info("request accepted by server",
16✔
634
                        "request", reqType, "remote_host", fwd.RemoteHost, "remote_port", fwd.RemotePort)
16✔
635
                return nil
16✔
636
        }
637
}
638

639
// Stop gracefully shuts down the SSHTunnelManager.
640
func (m *SSHTunnelManager) Stop() {
1✔
641
        m.clientMu.Lock()
1✔
642
        defer m.clientMu.Unlock()
1✔
643

1✔
644
        if m.client != nil && m.connected {
2✔
645
                for key, forwardingSession := range m.forwardings {
2✔
646
                        if err := m.sendForwarding(forwardingSession, ForwardCancel); err != nil {
1✔
647
                                slog.With("function", "Stop").Error("failed to stop forwarding", "key", key, "error", err)
×
648
                        } else {
1✔
649
                                slog.With("function", "Stop").Info("stopped forwarding", "key", key)
1✔
650
                        }
1✔
651
                }
652
        }
653
        // Clear forwardings map
654
        m.forwardings = make(map[string]*ForwardingConfig)
1✔
655

1✔
656
        m.closeClient()
1✔
657

1✔
658
        slog.Info("ssh tunnel manager stopped, all forwardings and connections closed")
1✔
659
}
660

661
// handleChannels manages the lifecycle of a forwarding sessions
662
func (m *SSHTunnelManager) handleChannels() {
22✔
663
        m.clientMu.RLock()
22✔
664
        if m.client == nil {
22✔
UNCOV
665
                m.clientMu.RUnlock()
×
UNCOV
666
                slog.With("function", "handleChannels").Error("client is nil, cannot handle channels")
×
UNCOV
667
                return
×
UNCOV
668
        }
×
669
        tcpipChan := m.client.HandleChannelOpen("forwarded-tcpip")
22✔
670
        connectionCtx := m.connectionCtx
22✔
671
        m.clientMu.RUnlock()
22✔
672

22✔
673
        for {
49✔
674
                // slog.Debug("waiting for new channels")
27✔
675
                select {
27✔
676
                case <-connectionCtx.Done():
17✔
677
                        slog.With("function", "handleChannels").Debug("connection context canceled, stopping channel handling")
17✔
678
                        return
17✔
679
                case ch := <-tcpipChan:
10✔
680
                        if ch == nil {
15✔
681
                                slog.With("function", "handleChannels").Debug("received nil channel, stopping channel handling")
5✔
682
                                return
5✔
683
                        }
5✔
684

685
                        logger := slog.With(
5✔
686
                                slog.String("channel_type", ch.ChannelType()),
5✔
687
                                slog.String("extra_data", string(ch.ExtraData())),
5✔
688
                        )
5✔
689
                        logger.Debug("received new channel from SSH server")
5✔
690

5✔
691
                        switch channelType := ch.ChannelType(); channelType {
5✔
692
                        case "forwarded-tcpip":
5✔
693
                                var payload forwardedTCPPayload
5✔
694
                                if err := ssh.Unmarshal(ch.ExtraData(), &payload); err != nil {
5✔
695
                                        logger.Error("unable to parse forwarded-tcpip payload", slog.Any("error", err))
×
696
                                        _ = ch.Reject(ssh.ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) // #nosec G104 -- Best effort rejection in error path
×
697
                                        continue
×
698
                                }
699

700
                                logger.Debug("forwarded-tcpip channel opened", "remote_addr", payload.Addr, "remote_port", payload.Port, "origin_addr", payload.OriginAddr, "origin_port", payload.OriginPort)
5✔
701

5✔
702
                                key := forwardingKey(payload.Addr, int(payload.Port))
5✔
703
                                m.clientMu.RLock()
5✔
704
                                fwd, exists := m.forwardings[key]
5✔
705
                                m.clientMu.RUnlock()
5✔
706
                                if exists {
7✔
707
                                        go func(ch ssh.NewChannel) {
4✔
708
                                                remoteConn, reqs, acceptErr := ch.Accept()
2✔
709
                                                if acceptErr != nil {
2✔
710
                                                        logger.Error("failed to accept channel", "error", acceptErr)
×
711
                                                        return
×
712
                                                }
×
713
                                                logger.Debug("channel accepted")
2✔
714

2✔
715
                                                // Log all requests sent by the SSH server on this channel, and discard them.
2✔
716
                                                go func() {
4✔
717
                                                        for req := range reqs {
2✔
718
                                                                logger.Debug("received request on channel", "request_type", req.Type, "want_reply", req.WantReply, "payload", string(req.Payload))
×
719
                                                                if req.WantReply {
×
720
                                                                        _ = req.Reply(false, nil) // #nosec G104 -- Best effort reply, already in goroutine
×
721
                                                                }
×
722
                                                        }
723
                                                }()
724

725
                                                go func() {
4✔
726
                                                        defer func() {
4✔
727
                                                                _ = remoteConn.Close() // #nosec G104 -- Cleanup in defer, error logged below
2✔
728
                                                        }()
2✔
729

730
                                                        addr := net.JoinHostPort(fwd.InternalHost, strconv.Itoa(fwd.InternalPort))
2✔
731
                                                        localConn, localErr := m.netDialFunc("tcp", addr)
2✔
732
                                                        if localErr != nil {
2✔
733
                                                                logger.Error("failed to connect to local address", "error", localErr)
×
734
                                                                return
×
735
                                                        }
×
736

737
                                                        defer func() {
4✔
738
                                                                _ = localConn.Close() // #nosec G104 -- Cleanup in defer, error logged below
2✔
739
                                                        }()
2✔
740

741
                                                        // When either Copy direction finishes, fully close both
742
                                                        // connections so the peer goroutine's blocking Read returns and
743
                                                        // the SSH channel is torn down (not just half-closed). With only
744
                                                        // CloseWrite, an upstream SSH server like tuns.sh keeps the
745
                                                        // public-facing TCP connection open until the channel itself
746
                                                        // closes, which never happens if a client doesn't hang up first.
747
                                                        wg := &sync.WaitGroup{}
2✔
748
                                                        wg.Add(2)
2✔
749

2✔
750
                                                        go func() {
4✔
751
                                                                defer wg.Done()
2✔
752
                                                                n, err := io.Copy(remoteConn, localConn)
2✔
753
                                                                slog.With("function", "handleChannels").Debug("copied data from local to remote", "bytes", n, "error", err)
2✔
754
                                                                _ = remoteConn.Close() // #nosec G104 -- Best effort close
2✔
755
                                                                _ = localConn.Close()  // #nosec G104 -- Best effort close
2✔
756
                                                        }()
2✔
757

758
                                                        go func() {
4✔
759
                                                                defer wg.Done()
2✔
760
                                                                n, err := io.Copy(localConn, remoteConn)
2✔
761
                                                                slog.With("function", "handleChannels").Debug("copied data from remote to local", "bytes", n, "error", err)
2✔
762
                                                                _ = localConn.Close()  // #nosec G104 -- Best effort close
2✔
763
                                                                _ = remoteConn.Close() // #nosec G104 -- Best effort close
2✔
764
                                                        }()
2✔
765

766
                                                        wg.Wait()
2✔
767
                                                        slog.With("function", "handleChannels").Debug("channel closed")
2✔
768
                                                }()
769
                                        }(ch)
770
                                        logger.Debug("forwarding established", "key", key)
2✔
771
                                } else {
3✔
772
                                        logger.Warn("unable to find forwarding session")
3✔
773
                                        _ = ch.Reject(ssh.ConnectionFailed, "unable to find forwarding session") // #nosec G104 -- Best effort rejection
3✔
774
                                }
3✔
775
                        default:
×
776
                                logger.Warn("unknown channel type received", "channel_type", channelType)
×
777
                                _ = ch.Reject(ssh.UnknownChannelType, "unknown channel type") // #nosec G104 -- Best effort rejection
×
778
                        }
779
                }
780
        }
781
}
782

783
// connectClient establishes a new SSH client connection.
784
// Must be called with a write lock on m.clientMu
785
func (m *SSHTunnelManager) connectClient() error {
24✔
786
        config := &ssh.ClientConfig{
24✔
787
                User: m.sshUser,
24✔
788
                Auth: []ssh.AuthMethod{
24✔
789
                        ssh.PublicKeys(m.signer),
24✔
790
                        ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
24✔
791
                                // serveo.net and similar services use keyboard-interactive for authentication
×
792
                                // but don't actually require answers to questions
×
793
                                answers := make([]string, len(questions))
×
794
                                return answers, nil
×
795
                        }),
×
796
                },
797
                Timeout: m.connTimeout,
798
        }
799

800
        if m.hostKey != "" {
26✔
801
                config.HostKeyCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error {
4✔
802
                        actual := sha256.Sum256(key.Marshal())
2✔
803
                        actualFingerprint := "SHA256:" + base64.StdEncoding.EncodeToString(actual[:])
2✔
804

2✔
805
                        if actualFingerprint != m.hostKey {
3✔
806
                                slog.With("function", "connect").Error("host key verification failed", "expected", m.hostKey, "got", actualFingerprint)
1✔
807
                                return fmt.Errorf("host key verification failed: expected %s, got %s", m.hostKey, actualFingerprint)
1✔
808
                        }
1✔
809
                        return nil
1✔
810
                }
811
        } else {
22✔
812
                // Security: Using InsecureIgnoreHostKey is acceptable here because:
22✔
813
                // 1. User explicitly chose not to provide SSH_HOST_KEY environment variable
22✔
814
                // 2. This is typically used for development/testing environments
22✔
815
                // 3. We log a warning to make the user aware of the security implication
22✔
816
                // For production use, SSH_HOST_KEY should always be set
22✔
817
                slog.With("function", "connect").Warn("no host key provided, falling back to InsecureIgnoreHostKey")
22✔
818
                config.HostKeyCallback = ssh.InsecureIgnoreHostKey() // #nosec G106 -- Intentional fallback for dev/test
22✔
819
        }
22✔
820

821
        client, err := m.sshDialFunc("tcp", m.sshServerAddress, config)
24✔
822
        if err != nil {
26✔
823
                rerr := &ErrSSHConnectionFailed{Err: err}
2✔
824
                slog.With("function", "connect").Error(rerr.Error())
2✔
825
                return rerr
2✔
826
        }
2✔
827

828
        m.client = client
22✔
829
        m.connected = true
22✔
830
        m.connectionCtx, m.connectionCancel = context.WithCancel(m.externalCtx)
22✔
831

22✔
832
        // Diagnostic: log the underlying transport shutdown reason. ssh.Conn.Wait
22✔
833
        // blocks until the SSH transport loop exits and returns whatever error
22✔
834
        // caused it — including SSH_MSG_DISCONNECT reason strings from the server,
22✔
835
        // which are otherwise hidden behind the EOF surfaced to SendRequest.
22✔
836
        if real, ok := client.(*ssh.Client); ok {
22✔
NEW
837
                go func() {
×
NEW
838
                        err := real.Wait()
×
NEW
839
                        slog.With("function", "monitorClientWait").Info("ssh client transport shut down", "error", err)
×
NEW
840
                }()
×
841
        }
842

843
        // captureReady is closed when the capture session (Shell) is ready to receive
844
        // server output. Forwardings wait on this before sending tcpip-forward requests,
845
        // ensuring the Shell session is listening when the server outputs addresses.
846
        m.captureReady = make(chan struct{})
22✔
847
        if m.remoteAddrFunc == nil {
38✔
848
                slog.With("function", "connect").Debug("remoteAddrFunc is not set, skipping server output capture")
16✔
849
                close(m.captureReady)
16✔
850
        } else {
22✔
851
                slog.With("function", "connect").Debug("remoteAddrFunc is set, capturing server output")
6✔
852
                go m.captureServerOutput()
6✔
853
        }
6✔
854

855
        return nil
22✔
856
}
857

858
// captureServerOutput captures the output from the SSH server and processes it using the remoteAddrFunc.
859
// It reads from the server's stdout and stderr and applies the remoteAddrFunc to extract URIs from the output.
860
// This function runs in a goroutine and will stop when the connection context is done.
861
func (m *SSHTunnelManager) captureServerOutput() {
6✔
862
        // Ensure captureReady is always closed so forwardings don't block forever,
6✔
863
        // even if session setup fails.
6✔
864
        captureSignaled := false
6✔
865
        defer func() {
12✔
866
                if !captureSignaled {
12✔
867
                        close(m.captureReady)
6✔
868
                }
6✔
869
        }()
870

871
        session, err := m.createSSHSession()
6✔
872
        if err != nil {
12✔
873
                return
6✔
874
        }
6✔
875
        defer func() {
×
876
                if err := session.Close(); err != nil {
×
877
                        slog.With("function", "captureServerOutput").Error("failed to close session", "error", err)
×
878
                }
×
879
        }()
880

881
        stdout, err := session.StdoutPipe()
×
882
        if err != nil {
×
883
                slog.With("function", "captureServerOutput").Error("failed to get stdout pipe", "error", err)
×
884
                return
×
885
        }
×
886

887
        stderr, err := session.StderrPipe()
×
888
        if err != nil {
×
889
                slog.With("function", "captureServerOutput").Error("failed to get stderr pipe", "error", err)
×
890
                return
×
891
        }
×
892

893
        // Start readers and signal readiness BEFORE Start/Shell so that
894
        // tcpip-forward requests can proceed concurrently. The readers will
895
        // block on Read() until the session produces output.
896
        go m.readServerOutput(stdout, "stdout")
×
897
        go m.readServerOutput(stderr, "stderr")
×
898

×
899
        close(m.captureReady)
×
900
        captureSignaled = true
×
901
        slog.With("function", "captureServerOutput").Debug("capture session ready")
×
902

×
903
        if m.proxyProtocol > 0 {
×
904
                // Send proxy-protocol as an exec command on this same session.
×
905
                // Equivalent to `ssh server proxy-protocol=N`. The server outputs
×
906
                // forwarding addresses on this session's stdout/stderr.
×
907
                cmd := fmt.Sprintf("proxy-protocol=%d", m.proxyProtocol)
×
908
                slog.With("function", "captureServerOutput").Info("starting session with proxy protocol", "command", cmd)
×
909
                if err := session.Start(cmd); err != nil {
×
910
                        slog.With("function", "captureServerOutput").Error("failed to start proxy protocol session", "error", err)
×
911
                        return
×
912
                }
×
913
        } else {
×
914
                if err := session.Shell(); err != nil {
×
915
                        slog.With("function", "captureServerOutput").Error("failed to start remote session", "error", err)
×
916
                        return
×
917
                }
×
918
        }
919

920
        <-m.connectionCtx.Done()
×
921
        _ = session.Close() // #nosec G104 -- Cleanup on shutdown, error not actionable
×
922
}
923

924
// createSSHSession creates a new SSH session from the current client.
925
func (m *SSHTunnelManager) createSSHSession() (*ssh.Session, error) {
6✔
926
        realClient, ok := m.client.(*ssh.Client)
6✔
927
        if !ok {
12✔
928
                slog.With("function", "captureServerOutput").Warn("cannot capture server output, client is not *ssh.Client")
6✔
929
                return nil, fmt.Errorf("client is not *ssh.Client")
6✔
930
        }
6✔
931

932
        session, err := realClient.NewSession()
×
933
        if err != nil {
×
934
                slog.With("function", "captureServerOutput").Error("failed to create SSH session", "error", err)
×
935
                return nil, err
×
936
        }
×
937

938
        return session, nil
×
939
}
940

941
// readServerOutput reads and processes output from the SSH server.
942
func (m *SSHTunnelManager) readServerOutput(reader io.Reader, streamName string) {
×
943
        buf := make([]byte, 4096)
×
944
        for {
×
945
                select {
×
946
                case <-m.connectionCtx.Done():
×
947
                        slog.With("function", "captureServerOutput").Debug("connection context canceled, stopping server output capture")
×
948
                        return
×
949
                default:
×
950
                        n, err := reader.Read(buf)
×
951
                        if err != nil {
×
952
                                if err != io.EOF {
×
953
                                        slog.With("function", "captureServerOutput", "stream", streamName).Error("read error", "error", err)
×
954
                                }
×
955
                                return
×
956
                        }
957
                        if n > 0 {
×
958
                                m.processServerData(buf[:n], streamName)
×
959
                        }
×
960
                }
961
        }
962
}
963

964
// processServerData extracts URIs from server output and notifies waiting goroutines.
965
func (m *SSHTunnelManager) processServerData(data []byte, streamName string) {
1✔
966
        dataStr := string(data)
1✔
967
        slog.With("function", "captureServerOutput", "stream", streamName).Debug(dataStr)
1✔
968

1✔
969
        uris, err := m.remoteAddrFunc(dataStr)
1✔
970
        if err != nil {
1✔
971
                slog.With("function", "captureServerOutput", "stream", streamName).Error("failed to extract URIs from data", "error", err)
×
972
                return
×
973
        }
×
974

975
        if len(uris) == 0 {
1✔
976
                return
×
977
        }
×
978

979
        slog.With("function", "captureServerOutput", "stream", streamName).Debug("extracted URIs from server output", "uris_count", len(uris))
1✔
980
        for _, uri := range uris {
2✔
981
                slog.With("function", "captureServerOutput", "stream", streamName).Info("extracted URI from server output", "uri", uri)
1✔
982
        }
1✔
983

984
        m.notifyURIWaiters(uris)
1✔
985
}
986

987
// notifyURIWaiters sends extracted URIs to all registered notification channels.
988
func (m *SSHTunnelManager) notifyURIWaiters(uris []string) {
1✔
989
        // Copy channels under read lock to avoid holding lock during I/O
1✔
990
        m.addrNotifMu.RLock()
1✔
991
        channels := make(map[string]chan []string, len(m.addrNotifications))
1✔
992
        for key, ch := range m.addrNotifications {
2✔
993
                channels[key] = ch
1✔
994
        }
1✔
995
        m.addrNotifMu.RUnlock()
1✔
996

1✔
997
        // Send to channels outside the lock
1✔
998
        for key, ch := range channels {
2✔
999
                select {
1✔
1000
                case ch <- uris:
1✔
1001
                        slog.With("function", "captureServerOutput").Debug("notified waiter about addresses", "key", key, "uris", uris)
1✔
1002
                default:
×
1003
                        // Channel full or no receiver, skip
1004
                }
1005
        }
1006
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc