• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pomerium / pomerium / 19149648094

06 Nov 2025 08:57PM UTC coverage: 56.177% (+0.1%) from 56.051%
19149648094

push

github

web-flow
ssh: initial implementation of reverse tunnel EDS (#5915)

This implements the reverse tunnel Endpoint Discovery Service endpoint. 

There are still more tests to be written but those will be added in a
follow-up.

246 of 308 new or added lines in 7 files covered. (79.87%)

12 existing lines in 5 files now uncovered.

28467 of 50674 relevant lines covered (56.18%)

96.58 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.08
/pkg/ssh/stream.go
1
package ssh
2

3
import (
4
        "context"
5
        "iter"
6
        "sync"
7
        "sync/atomic"
8
        "time"
9

10
        corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
11
        envoy_config_endpoint_v3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3"
12
        gossh "golang.org/x/crypto/ssh"
13
        "google.golang.org/grpc/codes"
14
        "google.golang.org/grpc/status"
15
        "google.golang.org/protobuf/types/known/anypb"
16

17
        extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
18
        "github.com/pomerium/pomerium/config"
19
        "github.com/pomerium/pomerium/internal/log"
20
        "github.com/pomerium/pomerium/pkg/grpc/databroker"
21
        "github.com/pomerium/pomerium/pkg/protoutil"
22
        "github.com/pomerium/pomerium/pkg/slices"
23
        "github.com/pomerium/pomerium/pkg/ssh/portforward"
24
)
25

26
const (
27
        MethodPublicKey           = "publickey"
28
        MethodKeyboardInteractive = "keyboard-interactive"
29

30
        ChannelTypeSession     = "session"
31
        ChannelTypeDirectTcpip = "direct-tcpip"
32

33
        ServiceConnection = "ssh-connection"
34
)
35

36
type KeyboardInteractiveQuerier interface {
37
        // Prompts the client and returns their responses to the given prompts.
38
        Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error)
39
}
40

41
type AuthMethodResponse[T any] struct {
42
        Allow                    *T
43
        RequireAdditionalMethods []string
44
}
45

46
type (
47
        PublicKeyAuthMethodResponse           = AuthMethodResponse[extensions_ssh.PublicKeyAllowResponse]
48
        KeyboardInteractiveAuthMethodResponse = AuthMethodResponse[extensions_ssh.KeyboardInteractiveAllowResponse]
49
)
50

51
//go:generate go run go.uber.org/mock/mockgen -typed -destination ./mock/mock_auth_interface.go . AuthInterface
52

53
type AuthInterface interface {
54
        HandlePublicKeyMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (PublicKeyAuthMethodResponse, error)
55
        HandleKeyboardInteractiveMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.KeyboardInteractiveMethodRequest, querier KeyboardInteractiveQuerier) (KeyboardInteractiveAuthMethodResponse, error)
56
        EvaluateDelayed(ctx context.Context, info StreamAuthInfo) error
57
        FormatSession(ctx context.Context, info StreamAuthInfo) ([]byte, error)
58
        DeleteSession(ctx context.Context, info StreamAuthInfo) error
59
        GetDataBrokerServiceClient() databroker.DataBrokerServiceClient
60
}
61

62
type ClusterStatsListener interface {
63
        HandleClusterStatsUpdate(*envoy_config_endpoint_v3.ClusterStats)
64
}
65

66
type EndpointDiscoveryInterface interface {
67
        UpdateClusterEndpoints(added map[string]portforward.RoutePortForwardInfo, removed map[string]struct{})
68
}
69

70
type AuthMethodValue[T any] struct {
71
        attempted bool
72
        Value     *T
73
}
74

75
func (v *AuthMethodValue[T]) Update(value *T) {
146✔
76
        v.attempted = true
146✔
77
        v.Value = value
146✔
78
}
146✔
79

80
func (v *AuthMethodValue[T]) IsValid() bool {
168✔
81
        if v.attempted {
278✔
82
                // method was attempted - valid iff there is a value
110✔
83
                return v.Value != nil
110✔
84
        }
110✔
85
        return true // method was not attempted - valid
58✔
86
}
87

88
type StreamAuthInfo struct {
89
        Username                   *string
90
        Hostname                   *string
91
        StreamID                   uint64
92
        SourceAddress              string
93
        ChannelType                string
94
        PublicKeyFingerprintSha256 []byte
95
        PublicKeyAllow             AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]
96
        KeyboardInteractiveAllow   AuthMethodValue[extensions_ssh.KeyboardInteractiveAllowResponse]
97
        InitialAuthComplete        bool
98
}
99

100
func (i *StreamAuthInfo) allMethodsValid() bool {
84✔
101
        return i.PublicKeyAllow.IsValid() && i.KeyboardInteractiveAllow.IsValid()
84✔
102
}
84✔
103

104
type StreamState struct {
105
        StreamAuthInfo
106
        RemainingUnauthenticatedMethods []string
107
        DownstreamChannelInfo           *extensions_ssh.SSHDownstreamChannelInfo
108
}
109

110
type TUIDefaultMode int
111

112
const (
113
        TUIModeInternalCLI TUIDefaultMode = iota
114
        TUIModeTunnelStatus
115
)
116

117
// StreamHandler handles a single SSH stream
118
type StreamHandler struct {
119
        auth       AuthInterface
120
        discovery  EndpointDiscoveryInterface
121
        config     *config.Config
122
        downstream *extensions_ssh.DownstreamConnectEvent
123
        writeC     chan *extensions_ssh.ServerMessage
124
        readC      chan *extensions_ssh.ClientMessage
125
        reauthC    chan struct{}
126
        terminateC chan error
127

128
        state        *StreamState
129
        close        func()
130
        portForwards *portforward.Manager
131

132
        expectingInternalChannel bool
133
        internalSession          atomic.Pointer[ChannelHandler]
134

135
        tuiDefaultModeLock sync.Mutex
136
        tuiDefaultMode     TUIDefaultMode
137
}
138

139
var _ StreamHandlerInterface = (*StreamHandler)(nil)
140

141
func NewStreamHandler(
142
        auth AuthInterface,
143
        discovery EndpointDiscoveryInterface,
144
        cfg *config.Config,
145
        downstream *extensions_ssh.DownstreamConnectEvent,
146
        onClosed func(),
147
) *StreamHandler {
178✔
148
        writeC := make(chan *extensions_ssh.ServerMessage, 32)
178✔
149
        sh := &StreamHandler{
178✔
150
                auth:       auth,
178✔
151
                discovery:  discovery,
178✔
152
                config:     cfg,
178✔
153
                downstream: downstream,
178✔
154
                writeC:     make(chan *extensions_ssh.ServerMessage, 32),
178✔
155
                readC:      make(chan *extensions_ssh.ClientMessage, 32),
178✔
156
                reauthC:    make(chan struct{}),
178✔
157
                terminateC: make(chan error, 1),
178✔
158
                close: func() {
343✔
159
                        onClosed()
165✔
160
                        close(writeC)
165✔
161
                },
165✔
162
        }
163
        sh.portForwards = portforward.NewManager(cfg, sh)
178✔
164
        sh.portForwards.AddUpdateListener(sh)
178✔
165
        return sh
178✔
166
}
167

168
// EvaluateRoute implements portforward.RouteEvaluator.
169
func (sh *StreamHandler) EvaluateRoute(_ portforward.RouteInfo) error {
100✔
170
        // Temporary stub - this is implemented separately
100✔
171
        return nil
100✔
172
}
100✔
173

174
// OnClusterEndpointsUpdated implements portforward.UpdateListener.
175
func (sh *StreamHandler) OnClusterEndpointsUpdated(added map[string]portforward.RoutePortForwardInfo, removed map[string]struct{}) {
284✔
176
        sh.discovery.UpdateClusterEndpoints(added, removed)
284✔
177
}
284✔
178

179
// OnPermissionsUpdated implements portforward.UpdateListener.
180
func (sh *StreamHandler) OnPermissionsUpdated(_ []portforward.Permission) {
238✔
181
}
238✔
182

183
// OnRoutesUpdated implements portforward.UpdateListener.
184
func (sh *StreamHandler) OnRoutesUpdated(_ []portforward.RouteInfo) {
224✔
185
}
224✔
186

187
func (sh *StreamHandler) Terminate(err error) {
11✔
188
        sh.terminateC <- err
11✔
189
}
11✔
190

191
func (sh *StreamHandler) Close() {
165✔
192
        sh.close()
165✔
193
}
165✔
194

195
func (sh *StreamHandler) IsExpectingInternalChannel() bool {
98✔
196
        return sh.expectingInternalChannel
98✔
197
}
98✔
198

199
func (sh *StreamHandler) ReadC() chan<- *extensions_ssh.ClientMessage {
632✔
200
        return sh.readC
632✔
201
}
632✔
202

203
func (sh *StreamHandler) WriteC() <-chan *extensions_ssh.ServerMessage {
314✔
204
        return sh.writeC
314✔
205
}
314✔
206

207
// Reauth blocks until authorization policy is reevaluated.
208
func (sh *StreamHandler) Reauth() {
46✔
209
        sh.reauthC <- struct{}{}
46✔
210
}
46✔
211

212
func (sh *StreamHandler) periodicReauth() (cancel func()) {
176✔
213
        t := time.NewTicker(1 * time.Minute)
176✔
214
        go func() {
352✔
215
                for range t.C {
176✔
216
                        sh.Reauth()
×
217
                }
×
218
        }()
219
        return t.Stop
176✔
220
}
221

222
// Prompt implements KeyboardInteractiveQuerier.
223
func (sh *StreamHandler) Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error) {
34✔
224
        sh.sendInfoPrompts(prompts)
34✔
225
        select {
34✔
226
        case <-ctx.Done():
2✔
227
                return nil, context.Cause(ctx)
2✔
228
        case err := <-sh.terminateC:
×
229
                return nil, err
×
230
        case req := <-sh.readC:
32✔
231
                switch msg := req.Message.(type) {
32✔
232
                case *extensions_ssh.ClientMessage_InfoResponse:
30✔
233
                        if msg.InfoResponse.Method != MethodKeyboardInteractive {
32✔
234
                                return nil, status.Errorf(codes.Internal, "received invalid info response")
2✔
235
                        }
2✔
236
                        r, _ := msg.InfoResponse.Response.UnmarshalNew()
28✔
237
                        respInfo, ok := r.(*extensions_ssh.KeyboardInteractiveInfoPromptResponses)
28✔
238
                        if !ok {
30✔
239
                                return nil, status.Errorf(codes.InvalidArgument, "received invalid prompt response")
2✔
240
                        }
2✔
241
                        return respInfo, nil
26✔
242
                default:
2✔
243
                        return nil, status.Errorf(codes.InvalidArgument, "received invalid message, expecting info response")
2✔
244
                }
245
        }
246
}
247

248
func (sh *StreamHandler) Run(ctx context.Context) error {
178✔
249
        if sh.state != nil {
180✔
250
                panic("Run called twice")
2✔
251
        }
252
        sh.state = &StreamState{
176✔
253
                RemainingUnauthenticatedMethods: []string{MethodPublicKey},
176✔
254
                StreamAuthInfo: StreamAuthInfo{
176✔
255
                        StreamID:      sh.downstream.StreamId,
176✔
256
                        SourceAddress: sh.downstream.SourceAddress.GetSocketAddress().GetAddress(),
176✔
257
                },
176✔
258
        }
176✔
259
        cancelReauth := sh.periodicReauth()
176✔
260
        defer cancelReauth()
176✔
261
        for {
626✔
262
                select {
450✔
263
                case <-ctx.Done():
119✔
264
                        return context.Cause(ctx)
119✔
265
                case <-sh.reauthC:
46✔
266
                        if err := sh.reauth(ctx); err != nil {
52✔
267
                                return err
6✔
268
                        }
6✔
269
                case err := <-sh.terminateC:
11✔
270
                        return err
11✔
271
                case req := <-sh.readC:
274✔
272
                        switch req := req.Message.(type) {
274✔
273
                        case *extensions_ssh.ClientMessage_Event:
32✔
274
                                switch event := req.Event.Event.(type) {
32✔
275
                                case *extensions_ssh.StreamEvent_DownstreamConnected:
2✔
276
                                        // this was already received as the first message in the stream
2✔
277
                                        return status.Errorf(codes.Internal, "received duplicate downstream connected event")
2✔
278
                                case *extensions_ssh.StreamEvent_UpstreamConnected:
26✔
279
                                        log.Ctx(ctx).Debug().
26✔
280
                                                Msg("ssh: upstream connected")
26✔
281
                                case *extensions_ssh.StreamEvent_DownstreamDisconnected:
2✔
282
                                        log.Ctx(ctx).Debug().
2✔
283
                                                Uint64("stream-id", sh.downstream.StreamId).
2✔
284
                                                Str("reason", event.DownstreamDisconnected.Reason).
2✔
285
                                                Msg("ssh: downstream disconnected")
2✔
286
                                case nil:
2✔
287
                                        return status.Errorf(codes.Internal, "received invalid event")
2✔
288
                                }
289
                        case *extensions_ssh.ClientMessage_AuthRequest:
180✔
290
                                if err := sh.handleAuthRequest(ctx, req.AuthRequest); err != nil {
214✔
291
                                        return err
34✔
292
                                }
34✔
293
                        case *extensions_ssh.ClientMessage_GlobalRequest:
60✔
294
                                if err := sh.handleGlobalRequest(ctx, req.GlobalRequest); err != nil {
60✔
NEW
295
                                        return err
×
NEW
296
                                }
×
297
                        default:
2✔
298
                                return status.Errorf(codes.Internal, "received invalid client message type %#T", req)
2✔
299
                        }
300
                }
301
        }
302
}
303

304
func (sh *StreamHandler) handleGlobalRequest(ctx context.Context, globalRequest *extensions_ssh.GlobalRequest) error {
60✔
305
        sh.tuiDefaultModeLock.Lock()
60✔
306
        defer sh.tuiDefaultModeLock.Unlock()
60✔
307
        switch request := globalRequest.Request.(type) {
60✔
308
        case *extensions_ssh.GlobalRequest_TcpipForwardRequest:
40✔
309
                reqHost := request.TcpipForwardRequest.RemoteAddress
40✔
310
                reqPort := request.TcpipForwardRequest.RemotePort
40✔
311
                log.Ctx(ctx).Debug().
40✔
312
                        Uint64("stream-id", sh.state.StreamID).
40✔
313
                        Str("host", reqHost).
40✔
314
                        Msg("got tcpip-forward request")
40✔
315

40✔
316
                serverPort, err := sh.portForwards.AddPermission(reqHost, reqPort)
40✔
317
                if err != nil {
40✔
NEW
318
                        sh.sendGlobalRequestResponse(&extensions_ssh.GlobalRequestResponse{
×
NEW
319
                                Success:      false,
×
NEW
320
                                DebugMessage: err.Error(),
×
NEW
321
                        })
×
NEW
322
                        return nil
×
NEW
323
                }
×
324

325
                log.Ctx(ctx).Debug().
40✔
326
                        Uint64("stream-id", sh.state.StreamID).
40✔
327
                        Msg("sending global request success")
40✔
328

40✔
329
                // https://datatracker.ietf.org/doc/html/rfc4254#section-7.1
40✔
330
                if globalRequest.WantReply && reqPort == 0 {
40✔
NEW
331
                        sh.sendGlobalRequestResponse(&extensions_ssh.GlobalRequestResponse{
×
NEW
332
                                Success: true,
×
NEW
333
                                Response: &extensions_ssh.GlobalRequestResponse_TcpipForwardResponse{
×
NEW
334
                                        TcpipForwardResponse: &extensions_ssh.TcpipForwardResponse{
×
NEW
335
                                                ServerPort: serverPort.Value,
×
NEW
336
                                        },
×
NEW
337
                                },
×
NEW
338
                        })
×
NEW
339
                }
×
340

341
                sh.tuiDefaultMode = TUIModeTunnelStatus
40✔
342
                return nil
40✔
343
        case *extensions_ssh.GlobalRequest_CancelTcpipForwardRequest:
20✔
344
                err := sh.portForwards.RemovePermission(
20✔
345
                        request.CancelTcpipForwardRequest.RemoteAddress,
20✔
346
                        request.CancelTcpipForwardRequest.RemotePort)
20✔
347
                if err != nil {
20✔
NEW
348
                        sh.sendGlobalRequestResponse(&extensions_ssh.GlobalRequestResponse{
×
NEW
349
                                Success:      false,
×
NEW
350
                                DebugMessage: err.Error(),
×
NEW
351
                        })
×
352
                } else if globalRequest.WantReply {
20✔
NEW
353
                        sh.sendGlobalRequestResponse(&extensions_ssh.GlobalRequestResponse{
×
NEW
354
                                Success: true,
×
NEW
355
                        })
×
NEW
356
                }
×
357
                return nil
20✔
NEW
358
        default:
×
NEW
359
                return status.Errorf(codes.Unimplemented, "received unknown global request")
×
360
        }
361
}
362

NEW
363
func (sh *StreamHandler) sendGlobalRequestResponse(response *extensions_ssh.GlobalRequestResponse) {
×
NEW
364
        sh.writeC <- &extensions_ssh.ServerMessage{
×
NEW
365
                Message: &extensions_ssh.ServerMessage_GlobalRequestResponse{
×
NEW
366
                        GlobalRequestResponse: response,
×
NEW
367
                },
×
NEW
368
        }
×
NEW
369
}
×
370

371
func (sh *StreamHandler) ServeChannel(
372
        stream extensions_ssh.StreamManagement_ServeChannelServer,
373
        metadata *extensions_ssh.FilterMetadata,
374
) error {
64✔
375
        // The first channel message on this stream should be a ChannelOpen
64✔
376
        channelOpen, err := stream.Recv()
64✔
377
        if err != nil {
67✔
378
                return err
3✔
379
        }
3✔
380
        rawMsg, ok := channelOpen.GetMessage().(*extensions_ssh.ChannelMessage_RawBytes)
61✔
381
        if !ok {
63✔
382
                return status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen")
2✔
383
        }
2✔
384
        var msg ChannelOpenMsg
59✔
385
        if err := gossh.Unmarshal(rawMsg.RawBytes.GetValue(), &msg); err != nil {
61✔
386
                return status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen")
2✔
387
        }
2✔
388

389
        sh.state.DownstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{
57✔
390
                ChannelType:               msg.ChanType,
57✔
391
                DownstreamChannelId:       msg.PeersID,
57✔
392
                InternalUpstreamChannelId: metadata.ChannelId,
57✔
393
                InitialWindowSize:         msg.PeersWindow,
57✔
394
                MaxPacketSize:             msg.MaxPacketSize,
57✔
395
        }
57✔
396
        sh.state.ChannelType = msg.ChanType
57✔
397
        channel := NewChannelImpl(sh, stream, sh.state.DownstreamChannelInfo)
57✔
398
        switch msg.ChanType {
57✔
399
        case ChannelTypeSession:
41✔
400
                ch := NewChannelHandler(channel, sh.config)
41✔
401
                if !sh.internalSession.CompareAndSwap(nil, ch) {
41✔
402
                        return channel.SendMessage(ChannelOpenFailureMsg{
×
403
                                PeersID: sh.state.DownstreamChannelInfo.DownstreamChannelId,
×
404
                                Reason:  Prohibited,
×
405
                                Message: "multiple concurrent internal session channels not supported",
×
406
                        })
×
407
                }
×
408
                if err := channel.SendMessage(ChannelOpenConfirmMsg{
41✔
409
                        PeersID:       sh.state.DownstreamChannelInfo.DownstreamChannelId,
41✔
410
                        MyID:          sh.state.DownstreamChannelInfo.InternalUpstreamChannelId,
41✔
411
                        MyWindow:      ChannelWindowSize,
41✔
412
                        MaxPacketSize: ChannelMaxPacket,
41✔
413
                }); err != nil {
41✔
414
                        return err
×
415
                }
×
416
                var mode TUIDefaultMode
41✔
417
                sh.tuiDefaultModeLock.Lock()
41✔
418
                mode = sh.tuiDefaultMode
41✔
419
                sh.tuiDefaultModeLock.Unlock()
41✔
420

41✔
421
                err := ch.Run(stream.Context(), mode)
41✔
422
                sh.internalSession.Store(nil)
41✔
423
                return err
41✔
424
        case ChannelTypeDirectTcpip:
14✔
425
                if !sh.config.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHAllowDirectTcpip) {
18✔
426
                        return status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled")
4✔
427
                }
4✔
428
                var subMsg ChannelOpenDirectMsg
10✔
429
                if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
11✔
430
                        return err
1✔
431
                }
1✔
432
                action, err := sh.PrepareHandoff(stream.Context(), subMsg.DestAddr, nil)
9✔
433
                if err != nil {
11✔
434
                        return err
2✔
435
                }
2✔
436
                return channel.SendControlAction(action)
7✔
437
        default:
2✔
438
                return status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: %s", msg.ChanType)
2✔
439
        }
440
}
441

442
func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_ssh.AuthenticationRequest) error {
180✔
443
        if req.Protocol != "ssh" {
182✔
444
                return status.Errorf(codes.InvalidArgument, "invalid protocol: %s", req.Protocol)
2✔
445
        }
2✔
446
        if req.Service != ServiceConnection {
180✔
447
                return status.Errorf(codes.InvalidArgument, "invalid service: %s", req.Service)
2✔
448
        }
2✔
449
        if !slices.Contains(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod) {
182✔
450
                return status.Errorf(codes.InvalidArgument, "unexpected auth method: %s", req.AuthMethod)
6✔
451
        }
6✔
452

453
        if sh.state.Username == nil {
280✔
454
                if req.Username == "" {
112✔
455
                        return status.Errorf(codes.InvalidArgument, "username missing")
2✔
456
                }
2✔
457
                sh.state.Username = &req.Username
108✔
458
        } else if *sh.state.Username != req.Username {
62✔
459
                return status.Errorf(codes.InvalidArgument, "inconsistent username")
2✔
460
        }
2✔
461
        if sh.state.Hostname == nil {
274✔
462
                sh.state.Hostname = &req.Hostname
108✔
463
        } else if *sh.state.Hostname != req.Hostname {
170✔
464
                return status.Errorf(codes.InvalidArgument, "inconsistent hostname")
4✔
465
        }
4✔
466

467
        updateMethods := func(add []string) {
308✔
468
                sh.state.RemainingUnauthenticatedMethods = slices.Remove(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod)
146✔
469
                sh.state.RemainingUnauthenticatedMethods = append(sh.state.RemainingUnauthenticatedMethods, add...)
146✔
470
        }
146✔
471
        log.Ctx(ctx).Debug().
162✔
472
                Str("method", req.AuthMethod).
162✔
473
                Str("username", *sh.state.Username).
162✔
474
                Str("hostname", *sh.state.Hostname).
162✔
475
                Msg("ssh: handling auth request")
162✔
476

162✔
477
        var partial bool
162✔
478
        switch req.AuthMethod {
162✔
479
        case MethodPublicKey:
124✔
480
                methodReq, _ := req.MethodRequest.UnmarshalNew()
124✔
481
                pubkeyReq, ok := methodReq.(*extensions_ssh.PublicKeyMethodRequest)
124✔
482
                if !ok {
126✔
483
                        return status.Errorf(codes.InvalidArgument, "invalid public key method request type")
2✔
484
                }
2✔
485
                response, err := sh.auth.HandlePublicKeyMethodRequest(ctx, sh.state.StreamAuthInfo, pubkeyReq)
122✔
486
                if err != nil {
124✔
487
                        return err
2✔
488
                } else if response.Allow != nil {
220✔
489
                        partial = true
98✔
490
                        sh.state.PublicKeyFingerprintSha256 = pubkeyReq.PublicKeyFingerprintSha256
98✔
491
                }
98✔
492
                sh.state.PublicKeyAllow.Update(response.Allow)
120✔
493
                updateMethods(response.RequireAdditionalMethods)
120✔
494
        case MethodKeyboardInteractive:
36✔
495
                methodReq, _ := req.MethodRequest.UnmarshalNew()
36✔
496
                kbiReq, ok := methodReq.(*extensions_ssh.KeyboardInteractiveMethodRequest)
36✔
497
                if !ok {
38✔
498
                        return status.Errorf(codes.InvalidArgument, "invalid keyboard-interactive method request type")
2✔
499
                }
2✔
500
                response, err := sh.auth.HandleKeyboardInteractiveMethodRequest(ctx, sh.state.StreamAuthInfo, kbiReq, sh)
34✔
501
                if err != nil {
42✔
502
                        return err
8✔
503
                }
8✔
504
                partial = response.Allow != nil
26✔
505
                sh.state.KeyboardInteractiveAllow.Update(response.Allow)
26✔
506
                updateMethods(response.RequireAdditionalMethods)
26✔
507
        default:
2✔
508
                return status.Errorf(codes.Internal, "bug: server requested an unsupported auth method %q", req.AuthMethod)
2✔
509
        }
510
        log.Ctx(ctx).Debug().
146✔
511
                Str("method", req.AuthMethod).
146✔
512
                Bool("partial", partial).
146✔
513
                Strs("methods-remaining", sh.state.RemainingUnauthenticatedMethods).
146✔
514
                Msg("ssh: auth request complete")
146✔
515

146✔
516
        if len(sh.state.RemainingUnauthenticatedMethods) == 0 && sh.state.allMethodsValid() {
230✔
517
                // if there are no methods remaining, the user is allowed if all attempted
84✔
518
                // methods have a valid response in the state
84✔
519
                sh.state.InitialAuthComplete = true
84✔
520
                log.Ctx(ctx).Debug().Msg("ssh: all methods valid, sending allow response")
84✔
521
                sh.sendAllowResponse()
84✔
522
        } else {
146✔
523
                log.Ctx(ctx).Debug().Msg("ssh: unauthenticated methods remain, sending deny response")
62✔
524
                sh.sendDenyResponseWithRemainingMethods(partial)
62✔
525
        }
62✔
526
        return nil
146✔
527
}
528

529
func (sh *StreamHandler) reauth(ctx context.Context) error {
46✔
530
        if !sh.state.InitialAuthComplete {
86✔
531
                return nil
40✔
532
        }
40✔
533
        return sh.auth.EvaluateDelayed(ctx, sh.state.StreamAuthInfo)
6✔
534
}
535

536
func (sh *StreamHandler) PrepareHandoff(ctx context.Context, hostname string, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo) (*extensions_ssh.SSHChannelControlAction, error) {
10✔
537
        if hostname == "" {
11✔
538
                return nil, status.Errorf(codes.PermissionDenied, "invalid hostname")
1✔
539
        }
1✔
540
        if sh.state.Hostname == nil {
9✔
541
                panic("bug: PrepareHandoff called but state is missing a hostname")
×
542
        }
543
        if *sh.state.Hostname != "" {
9✔
544
                panic("bug: PrepareHandoff called but previous hostname is not empty")
×
545
        }
546
        *sh.state.Hostname = hostname
9✔
547
        err := sh.auth.EvaluateDelayed(ctx, sh.state.StreamAuthInfo)
9✔
548
        if err != nil {
10✔
549
                return nil, status.Error(codes.PermissionDenied, err.Error())
1✔
550
        }
1✔
551
        log.Ctx(ctx).Debug().
8✔
552
                Str("hostname", *sh.state.Hostname).
8✔
553
                Str("username", *sh.state.Username).
8✔
554
                Msg("ssh: initiating handoff to upstream")
8✔
555
        upstreamAllow := sh.buildUpstreamAllowResponse()
8✔
556
        action := &extensions_ssh.SSHChannelControlAction{
8✔
557
                Action: &extensions_ssh.SSHChannelControlAction_HandOff{
8✔
558
                        HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
8✔
559
                                DownstreamChannelInfo: sh.state.DownstreamChannelInfo,
8✔
560
                                DownstreamPtyInfo:     ptyInfo,
8✔
561
                                UpstreamAuth:          upstreamAllow,
8✔
562
                        },
8✔
563
                },
8✔
564
        }
8✔
565
        return action, nil
8✔
566
}
567

568
func (sh *StreamHandler) FormatSession(ctx context.Context) ([]byte, error) {
14✔
569
        return sh.auth.FormatSession(ctx, sh.state.StreamAuthInfo)
14✔
570
}
14✔
571

572
func (sh *StreamHandler) DeleteSession(ctx context.Context) error {
12✔
573
        return sh.auth.DeleteSession(ctx, sh.state.StreamAuthInfo)
12✔
574
}
12✔
575

576
func (sh *StreamHandler) AllSSHRoutes() iter.Seq[*config.Policy] {
6✔
577
        return func(yield func(*config.Policy) bool) {
12✔
578
                for route := range sh.config.Options.GetAllPolicies() {
30✔
579
                        if route.IsSSH() {
34✔
580
                                if !yield(route) {
12✔
581
                                        return
2✔
582
                                }
2✔
583
                        }
584
                }
585
        }
586
}
587

588
// DownstreamChannelID implements StreamHandlerInterface.
589
func (sh *StreamHandler) DownstreamChannelID() uint32 {
212✔
590
        return sh.state.DownstreamChannelInfo.DownstreamChannelId
212✔
591
}
212✔
592

593
// Hostname implements StreamHandlerInterface.
594
func (sh *StreamHandler) Hostname() *string {
42✔
595
        return sh.state.Hostname
42✔
596
}
42✔
597

598
// Username implements StreamHandlerInterface.
599
func (sh *StreamHandler) Username() *string {
79✔
600
        return sh.state.Username
79✔
601
}
79✔
602

603
func (sh *StreamHandler) sendDenyResponseWithRemainingMethods(partial bool) {
62✔
604
        sh.writeC <- &extensions_ssh.ServerMessage{
62✔
605
                Message: &extensions_ssh.ServerMessage_AuthResponse{
62✔
606
                        AuthResponse: &extensions_ssh.AuthenticationResponse{
62✔
607
                                Response: &extensions_ssh.AuthenticationResponse_Deny{
62✔
608
                                        Deny: &extensions_ssh.DenyResponse{
62✔
609
                                                Partial: partial,
62✔
610
                                                Methods: sh.state.RemainingUnauthenticatedMethods,
62✔
611
                                        },
62✔
612
                                },
62✔
613
                        },
62✔
614
                },
62✔
615
        }
62✔
616
}
62✔
617

618
func (sh *StreamHandler) sendAllowResponse() {
84✔
619
        var allow *extensions_ssh.AllowResponse
84✔
620
        if *sh.state.Hostname == "" {
144✔
621
                sh.expectingInternalChannel = true
60✔
622
                allow = sh.buildInternalAllowResponse()
60✔
623
        } else {
84✔
624
                allow = sh.buildUpstreamAllowResponse()
24✔
625
        }
24✔
626

627
        sh.writeC <- &extensions_ssh.ServerMessage{
84✔
628
                Message: &extensions_ssh.ServerMessage_AuthResponse{
84✔
629
                        AuthResponse: &extensions_ssh.AuthenticationResponse{
84✔
630
                                Response: &extensions_ssh.AuthenticationResponse_Allow{
84✔
631
                                        Allow: allow,
84✔
632
                                },
84✔
633
                        },
84✔
634
                },
84✔
635
        }
84✔
636
}
637

638
func (sh *StreamHandler) sendInfoPrompts(prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) {
34✔
639
        sh.writeC <- &extensions_ssh.ServerMessage{
34✔
640
                Message: &extensions_ssh.ServerMessage_AuthResponse{
34✔
641
                        AuthResponse: &extensions_ssh.AuthenticationResponse{
34✔
642
                                Response: &extensions_ssh.AuthenticationResponse_InfoRequest{
34✔
643
                                        InfoRequest: &extensions_ssh.InfoRequest{
34✔
644
                                                Method:  MethodKeyboardInteractive,
34✔
645
                                                Request: protoutil.NewAny(prompts),
34✔
646
                                        },
34✔
647
                                },
34✔
648
                        },
34✔
649
                },
34✔
650
        }
34✔
651
}
34✔
652

653
func (sh *StreamHandler) buildUpstreamAllowResponse() *extensions_ssh.AllowResponse {
32✔
654
        var allowedMethods []*extensions_ssh.AllowedMethod
32✔
655
        if value := sh.state.PublicKeyAllow.Value; value != nil {
64✔
656
                allowedMethods = append(allowedMethods, &extensions_ssh.AllowedMethod{
32✔
657
                        Method:     MethodPublicKey,
32✔
658
                        MethodData: protoutil.NewAny(value),
32✔
659
                })
32✔
660
        }
32✔
661
        if value := sh.state.KeyboardInteractiveAllow.Value; value != nil {
46✔
662
                allowedMethods = append(allowedMethods, &extensions_ssh.AllowedMethod{
14✔
663
                        Method:     MethodKeyboardInteractive,
14✔
664
                        MethodData: protoutil.NewAny(value),
14✔
665
                })
14✔
666
        }
14✔
667
        return &extensions_ssh.AllowResponse{
32✔
668
                Username: *sh.state.Username,
32✔
669
                Target: &extensions_ssh.AllowResponse_Upstream{
32✔
670
                        Upstream: &extensions_ssh.UpstreamTarget{
32✔
671
                                Hostname:       *sh.state.Hostname,
32✔
672
                                DirectTcpip:    sh.state.ChannelType == ChannelTypeDirectTcpip,
32✔
673
                                AllowedMethods: allowedMethods,
32✔
674
                        },
32✔
675
                },
32✔
676
        }
32✔
677
}
678

679
func (sh *StreamHandler) buildInternalAllowResponse() *extensions_ssh.AllowResponse {
60✔
680
        return &extensions_ssh.AllowResponse{
60✔
681
                Username: *sh.state.Username,
60✔
682
                Target: &extensions_ssh.AllowResponse_Internal{
60✔
683
                        Internal: &extensions_ssh.InternalTarget{
60✔
684
                                SetMetadata: &corev3.Metadata{
60✔
685
                                        TypedFilterMetadata: map[string]*anypb.Any{
60✔
686
                                                "com.pomerium.ssh": protoutil.NewAny(&extensions_ssh.FilterMetadata{
60✔
687
                                                        StreamId: sh.downstream.StreamId,
60✔
688
                                                }),
60✔
689
                                        },
60✔
690
                                },
60✔
691
                        },
60✔
692
                },
60✔
693
        }
60✔
694
}
60✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc