• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pomerium / pomerium / 19177304437

07 Nov 2025 06:16PM UTC coverage: 56.093% (-0.02%) from 56.112%
19177304437

push

github

web-flow
ssh: upstream tunnel auth stubs (#5919)

Temporary patch to set up upstream tunnel auth. The actual policy
evaluation is not yet implemented.

13 of 37 new or added lines in 4 files covered. (35.14%)

14 existing lines in 3 files now uncovered.

28519 of 50842 relevant lines covered (56.09%)

96.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.02
/pkg/ssh/stream.go
1
package ssh
2

3
import (
4
        "context"
5
        "iter"
6
        "sync"
7
        "sync/atomic"
8
        "time"
9

10
        corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
11
        envoy_config_endpoint_v3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3"
12
        gossh "golang.org/x/crypto/ssh"
13
        "google.golang.org/grpc/codes"
14
        "google.golang.org/grpc/status"
15
        "google.golang.org/protobuf/types/known/anypb"
16

17
        extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
18
        "github.com/pomerium/pomerium/config"
19
        "github.com/pomerium/pomerium/internal/log"
20
        "github.com/pomerium/pomerium/pkg/grpc/databroker"
21
        "github.com/pomerium/pomerium/pkg/protoutil"
22
        "github.com/pomerium/pomerium/pkg/slices"
23
        "github.com/pomerium/pomerium/pkg/ssh/portforward"
24
)
25

26
const (
27
        MethodPublicKey           = "publickey"
28
        MethodKeyboardInteractive = "keyboard-interactive"
29

30
        ChannelTypeSession     = "session"
31
        ChannelTypeDirectTcpip = "direct-tcpip"
32

33
        ServiceConnection = "ssh-connection"
34
)
35

36
type KeyboardInteractiveQuerier interface {
37
        // Prompts the client and returns their responses to the given prompts.
38
        Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error)
39
}
40

41
type AuthMethodResponse[T any] struct {
42
        Allow                    *T
43
        RequireAdditionalMethods []string
44
}
45

46
type (
47
        PublicKeyAuthMethodResponse           = AuthMethodResponse[extensions_ssh.PublicKeyAllowResponse]
48
        KeyboardInteractiveAuthMethodResponse = AuthMethodResponse[extensions_ssh.KeyboardInteractiveAllowResponse]
49
)
50

51
//go:generate go run go.uber.org/mock/mockgen -typed -destination ./mock/mock_auth_interface.go . AuthInterface
52

53
type AuthInterface interface {
54
        HandlePublicKeyMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (PublicKeyAuthMethodResponse, error)
55
        HandleKeyboardInteractiveMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.KeyboardInteractiveMethodRequest, querier KeyboardInteractiveQuerier) (KeyboardInteractiveAuthMethodResponse, error)
56
        EvaluateDelayed(ctx context.Context, info StreamAuthInfo) error
57
        EvaluatePortForward(ctx context.Context, info StreamAuthInfo, portForwardInfo portforward.RouteInfo) error
58
        FormatSession(ctx context.Context, info StreamAuthInfo) ([]byte, error)
59
        DeleteSession(ctx context.Context, info StreamAuthInfo) error
60
        GetDataBrokerServiceClient() databroker.DataBrokerServiceClient
61
}
62

63
type ClusterStatsListener interface {
64
        HandleClusterStatsUpdate(*envoy_config_endpoint_v3.ClusterStats)
65
}
66

67
type EndpointDiscoveryInterface interface {
68
        UpdateClusterEndpoints(added map[string]portforward.RoutePortForwardInfo, removed map[string]struct{})
69
}
70

71
type AuthMethodValue[T any] struct {
72
        attempted bool
73
        Value     *T
74
}
75

76
func (v *AuthMethodValue[T]) Update(value *T) {
176✔
77
        v.attempted = true
176✔
78
        v.Value = value
176✔
79
}
176✔
80

81
func (v *AuthMethodValue[T]) IsValid() bool {
228✔
82
        if v.attempted {
368✔
83
                // method was attempted - valid iff there is a value
140✔
84
                return v.Value != nil
140✔
85
        }
140✔
86
        return true // method was not attempted - valid
88✔
87
}
88

89
type StreamAuthInfo struct {
90
        Username                   *string
91
        Hostname                   *string
92
        StreamID                   uint64
93
        SourceAddress              string
94
        ChannelType                string
95
        PublicKeyFingerprintSha256 []byte
96
        PublicKeyAllow             AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]
97
        KeyboardInteractiveAllow   AuthMethodValue[extensions_ssh.KeyboardInteractiveAllowResponse]
98
        InitialAuthComplete        bool
99
}
100

101
func (i *StreamAuthInfo) allMethodsValid() bool {
114✔
102
        return i.PublicKeyAllow.IsValid() && i.KeyboardInteractiveAllow.IsValid()
114✔
103
}
114✔
104

105
type StreamState struct {
106
        StreamAuthInfo
107
        RemainingUnauthenticatedMethods []string
108
        DownstreamChannelInfo           *extensions_ssh.SSHDownstreamChannelInfo
109
}
110

111
type TUIDefaultMode int
112

113
const (
114
        TUIModeInternalCLI TUIDefaultMode = iota
115
        TUIModeTunnelStatus
116
)
117

118
// StreamHandler handles a single SSH stream
119
type StreamHandler struct {
120
        auth       AuthInterface
121
        discovery  EndpointDiscoveryInterface
122
        config     *config.Config
123
        downstream *extensions_ssh.DownstreamConnectEvent
124
        writeC     chan *extensions_ssh.ServerMessage
125
        readC      chan *extensions_ssh.ClientMessage
126
        reauthC    chan struct{}
127
        terminateC chan error
128

129
        state        *StreamState
130
        close        func()
131
        portForwards *portforward.Manager
132

133
        expectingInternalChannel bool
134
        internalSession          atomic.Pointer[ChannelHandler]
135

136
        tuiDefaultModeLock sync.Mutex
137
        tuiDefaultMode     TUIDefaultMode
138
}
139

140
var _ StreamHandlerInterface = (*StreamHandler)(nil)
141

142
func NewStreamHandler(
143
        auth AuthInterface,
144
        discovery EndpointDiscoveryInterface,
145
        cfg *config.Config,
146
        downstream *extensions_ssh.DownstreamConnectEvent,
147
        onClosed func(),
148
) *StreamHandler {
178✔
149
        writeC := make(chan *extensions_ssh.ServerMessage, 32)
178✔
150
        sh := &StreamHandler{
178✔
151
                auth:       auth,
178✔
152
                discovery:  discovery,
178✔
153
                config:     cfg,
178✔
154
                downstream: downstream,
178✔
155
                writeC:     make(chan *extensions_ssh.ServerMessage, 32),
178✔
156
                readC:      make(chan *extensions_ssh.ClientMessage, 32),
178✔
157
                reauthC:    make(chan struct{}),
178✔
158
                terminateC: make(chan error, 1),
178✔
159
                close: func() {
343✔
160
                        onClosed()
165✔
161
                        close(writeC)
165✔
162
                },
165✔
163
        }
164
        return sh
178✔
165
}
166

167
// EvaluateRoute implements portforward.RouteEvaluator.
168
func (sh *StreamHandler) EvaluateRoute(ctx context.Context, info portforward.RouteInfo) error {
100✔
169
        return sh.auth.EvaluatePortForward(ctx, sh.state.StreamAuthInfo, info)
100✔
170
}
100✔
171

172
// OnClusterEndpointsUpdated implements portforward.UpdateListener.
173
func (sh *StreamHandler) OnClusterEndpointsUpdated(added map[string]portforward.RoutePortForwardInfo, removed map[string]struct{}) {
220✔
174
        sh.discovery.UpdateClusterEndpoints(added, removed)
220✔
175
}
220✔
176

177
// OnPermissionsUpdated implements portforward.UpdateListener.
178
func (sh *StreamHandler) OnPermissionsUpdated(_ []portforward.Permission) {
174✔
179
}
174✔
180

181
// OnRoutesUpdated implements portforward.UpdateListener.
182
func (sh *StreamHandler) OnRoutesUpdated(_ []portforward.RouteInfo) {
160✔
183
}
160✔
184

185
func (sh *StreamHandler) Terminate(err error) {
11✔
186
        sh.terminateC <- err
11✔
187
}
11✔
188

189
func (sh *StreamHandler) Close() {
165✔
190
        sh.close()
165✔
191
}
165✔
192

193
func (sh *StreamHandler) IsExpectingInternalChannel() bool {
98✔
194
        return sh.expectingInternalChannel
98✔
195
}
98✔
196

197
func (sh *StreamHandler) ReadC() chan<- *extensions_ssh.ClientMessage {
653✔
198
        return sh.readC
653✔
199
}
653✔
200

201
func (sh *StreamHandler) WriteC() <-chan *extensions_ssh.ServerMessage {
314✔
202
        return sh.writeC
314✔
203
}
314✔
204

205
// Reauth blocks until authorization policy is reevaluated.
206
func (sh *StreamHandler) Reauth() {
46✔
207
        sh.reauthC <- struct{}{}
46✔
208
}
46✔
209

210
func (sh *StreamHandler) periodicReauth() (cancel func()) {
176✔
211
        t := time.NewTicker(1 * time.Minute)
176✔
212
        go func() {
352✔
213
                for range t.C {
176✔
214
                        sh.Reauth()
×
215
                }
×
216
        }()
217
        return t.Stop
176✔
218
}
219

220
// Prompt implements KeyboardInteractiveQuerier.
221
func (sh *StreamHandler) Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error) {
34✔
222
        sh.sendInfoPrompts(prompts)
34✔
223
        select {
34✔
224
        case <-ctx.Done():
2✔
225
                return nil, context.Cause(ctx)
2✔
226
        case err := <-sh.terminateC:
×
227
                return nil, err
×
228
        case req := <-sh.readC:
32✔
229
                switch msg := req.Message.(type) {
32✔
230
                case *extensions_ssh.ClientMessage_InfoResponse:
30✔
231
                        if msg.InfoResponse.Method != MethodKeyboardInteractive {
32✔
232
                                return nil, status.Errorf(codes.Internal, "received invalid info response")
2✔
233
                        }
2✔
234
                        r, _ := msg.InfoResponse.Response.UnmarshalNew()
28✔
235
                        respInfo, ok := r.(*extensions_ssh.KeyboardInteractiveInfoPromptResponses)
28✔
236
                        if !ok {
30✔
237
                                return nil, status.Errorf(codes.InvalidArgument, "received invalid prompt response")
2✔
238
                        }
2✔
239
                        return respInfo, nil
26✔
240
                default:
2✔
241
                        return nil, status.Errorf(codes.InvalidArgument, "received invalid message, expecting info response")
2✔
242
                }
243
        }
244
}
245

246
func (sh *StreamHandler) Run(ctx context.Context) error {
178✔
247
        if sh.state != nil {
180✔
248
                panic("Run called twice")
2✔
249
        }
250
        sh.state = &StreamState{
176✔
251
                RemainingUnauthenticatedMethods: []string{MethodPublicKey},
176✔
252
                StreamAuthInfo: StreamAuthInfo{
176✔
253
                        StreamID:      sh.downstream.StreamId,
176✔
254
                        SourceAddress: sh.downstream.SourceAddress.GetSocketAddress().GetAddress(),
176✔
255
                },
176✔
256
        }
176✔
257
        cancelReauth := sh.periodicReauth()
176✔
258
        defer cancelReauth()
176✔
259
        for {
656✔
260
                select {
480✔
261
                case <-ctx.Done():
119✔
262
                        return context.Cause(ctx)
119✔
263
                case <-sh.reauthC:
46✔
264
                        if err := sh.reauth(ctx); err != nil {
52✔
265
                                return err
6✔
266
                        }
6✔
267
                case err := <-sh.terminateC:
11✔
268
                        return err
11✔
269
                case req := <-sh.readC:
304✔
270
                        switch req := req.Message.(type) {
304✔
271
                        case *extensions_ssh.ClientMessage_Event:
32✔
272
                                switch event := req.Event.Event.(type) {
32✔
273
                                case *extensions_ssh.StreamEvent_DownstreamConnected:
2✔
274
                                        // this was already received as the first message in the stream
2✔
275
                                        return status.Errorf(codes.Internal, "received duplicate downstream connected event")
2✔
276
                                case *extensions_ssh.StreamEvent_UpstreamConnected:
26✔
277
                                        log.Ctx(ctx).Debug().
26✔
278
                                                Msg("ssh: upstream connected")
26✔
279
                                case *extensions_ssh.StreamEvent_DownstreamDisconnected:
2✔
280
                                        log.Ctx(ctx).Debug().
2✔
281
                                                Uint64("stream-id", sh.downstream.StreamId).
2✔
282
                                                Str("reason", event.DownstreamDisconnected.Reason).
2✔
283
                                                Msg("ssh: downstream disconnected")
2✔
284
                                case nil:
2✔
285
                                        return status.Errorf(codes.Internal, "received invalid event")
2✔
286
                                }
287
                        case *extensions_ssh.ClientMessage_AuthRequest:
210✔
288
                                if err := sh.handleAuthRequest(ctx, req.AuthRequest); err != nil {
244✔
289
                                        return err
34✔
290
                                }
34✔
291
                        case *extensions_ssh.ClientMessage_GlobalRequest:
60✔
292
                                if err := sh.handleGlobalRequest(ctx, req.GlobalRequest); err != nil {
60✔
293
                                        return err
×
294
                                }
×
295
                        default:
2✔
296
                                return status.Errorf(codes.Internal, "received invalid client message type %#T", req)
2✔
297
                        }
298
                }
299
        }
300
}
301

302
func (sh *StreamHandler) handleGlobalRequest(ctx context.Context, globalRequest *extensions_ssh.GlobalRequest) error {
60✔
303
        sh.tuiDefaultModeLock.Lock()
60✔
304
        defer sh.tuiDefaultModeLock.Unlock()
60✔
305
        switch request := globalRequest.Request.(type) {
60✔
306
        case *extensions_ssh.GlobalRequest_TcpipForwardRequest:
40✔
307
                if sh.portForwards == nil {
40✔
NEW
308
                        return status.Errorf(codes.InvalidArgument, "cannot request port-forward before auth is complete")
×
NEW
309
                }
×
310
                reqHost := request.TcpipForwardRequest.RemoteAddress
40✔
311
                reqPort := request.TcpipForwardRequest.RemotePort
40✔
312
                log.Ctx(ctx).Debug().
40✔
313
                        Uint64("stream-id", sh.state.StreamID).
40✔
314
                        Str("host", reqHost).
40✔
315
                        Msg("got tcpip-forward request")
40✔
316

40✔
317
                serverPort, err := sh.portForwards.AddPermission(reqHost, reqPort)
40✔
318
                if err != nil {
40✔
319
                        sh.sendGlobalRequestResponse(&extensions_ssh.GlobalRequestResponse{
×
320
                                Success:      false,
×
321
                                DebugMessage: err.Error(),
×
322
                        })
×
323
                        return nil
×
324
                }
×
325

326
                log.Ctx(ctx).Debug().
40✔
327
                        Uint64("stream-id", sh.state.StreamID).
40✔
328
                        Msg("sending global request success")
40✔
329

40✔
330
                // https://datatracker.ietf.org/doc/html/rfc4254#section-7.1
40✔
331
                if globalRequest.WantReply && reqPort == 0 {
40✔
332
                        sh.sendGlobalRequestResponse(&extensions_ssh.GlobalRequestResponse{
×
333
                                Success: true,
×
334
                                Response: &extensions_ssh.GlobalRequestResponse_TcpipForwardResponse{
×
335
                                        TcpipForwardResponse: &extensions_ssh.TcpipForwardResponse{
×
336
                                                ServerPort: serverPort.Value,
×
337
                                        },
×
338
                                },
×
339
                        })
×
340
                }
×
341

342
                sh.tuiDefaultMode = TUIModeTunnelStatus
40✔
343
                return nil
40✔
344
        case *extensions_ssh.GlobalRequest_CancelTcpipForwardRequest:
20✔
345
                if sh.portForwards == nil {
20✔
NEW
346
                        return status.Errorf(codes.InvalidArgument, "cannot request port-forward before auth is complete")
×
NEW
347
                }
×
348
                err := sh.portForwards.RemovePermission(
20✔
349
                        request.CancelTcpipForwardRequest.RemoteAddress,
20✔
350
                        request.CancelTcpipForwardRequest.RemotePort)
20✔
351
                if err != nil {
20✔
352
                        sh.sendGlobalRequestResponse(&extensions_ssh.GlobalRequestResponse{
×
353
                                Success:      false,
×
354
                                DebugMessage: err.Error(),
×
355
                        })
×
356
                } else if globalRequest.WantReply {
20✔
357
                        sh.sendGlobalRequestResponse(&extensions_ssh.GlobalRequestResponse{
×
358
                                Success: true,
×
359
                        })
×
360
                }
×
361
                return nil
20✔
362
        default:
×
363
                return status.Errorf(codes.Unimplemented, "received unknown global request")
×
364
        }
365
}
366

367
func (sh *StreamHandler) sendGlobalRequestResponse(response *extensions_ssh.GlobalRequestResponse) {
×
368
        sh.writeC <- &extensions_ssh.ServerMessage{
×
369
                Message: &extensions_ssh.ServerMessage_GlobalRequestResponse{
×
370
                        GlobalRequestResponse: response,
×
371
                },
×
372
        }
×
373
}
×
374

375
func (sh *StreamHandler) ServeChannel(
376
        stream extensions_ssh.StreamManagement_ServeChannelServer,
377
        metadata *extensions_ssh.FilterMetadata,
378
) error {
64✔
379
        // The first channel message on this stream should be a ChannelOpen
64✔
380
        channelOpen, err := stream.Recv()
64✔
381
        if err != nil {
67✔
382
                return err
3✔
383
        }
3✔
384
        rawMsg, ok := channelOpen.GetMessage().(*extensions_ssh.ChannelMessage_RawBytes)
61✔
385
        if !ok {
63✔
386
                return status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen")
2✔
387
        }
2✔
388
        var msg ChannelOpenMsg
59✔
389
        if err := gossh.Unmarshal(rawMsg.RawBytes.GetValue(), &msg); err != nil {
61✔
390
                return status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen")
2✔
391
        }
2✔
392

393
        sh.state.DownstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{
57✔
394
                ChannelType:               msg.ChanType,
57✔
395
                DownstreamChannelId:       msg.PeersID,
57✔
396
                InternalUpstreamChannelId: metadata.ChannelId,
57✔
397
                InitialWindowSize:         msg.PeersWindow,
57✔
398
                MaxPacketSize:             msg.MaxPacketSize,
57✔
399
        }
57✔
400
        sh.state.ChannelType = msg.ChanType
57✔
401
        channel := NewChannelImpl(sh, stream, sh.state.DownstreamChannelInfo)
57✔
402
        switch msg.ChanType {
57✔
403
        case ChannelTypeSession:
41✔
404
                ch := NewChannelHandler(channel, sh.config)
41✔
405
                if !sh.internalSession.CompareAndSwap(nil, ch) {
41✔
406
                        return channel.SendMessage(ChannelOpenFailureMsg{
×
407
                                PeersID: sh.state.DownstreamChannelInfo.DownstreamChannelId,
×
408
                                Reason:  Prohibited,
×
409
                                Message: "multiple concurrent internal session channels not supported",
×
410
                        })
×
411
                }
×
412
                if err := channel.SendMessage(ChannelOpenConfirmMsg{
41✔
413
                        PeersID:       sh.state.DownstreamChannelInfo.DownstreamChannelId,
41✔
414
                        MyID:          sh.state.DownstreamChannelInfo.InternalUpstreamChannelId,
41✔
415
                        MyWindow:      ChannelWindowSize,
41✔
416
                        MaxPacketSize: ChannelMaxPacket,
41✔
417
                }); err != nil {
41✔
418
                        return err
×
419
                }
×
420
                var mode TUIDefaultMode
41✔
421
                sh.tuiDefaultModeLock.Lock()
41✔
422
                mode = sh.tuiDefaultMode
41✔
423
                sh.tuiDefaultModeLock.Unlock()
41✔
424

41✔
425
                err := ch.Run(stream.Context(), mode)
41✔
426
                sh.internalSession.Store(nil)
41✔
427
                return err
41✔
428
        case ChannelTypeDirectTcpip:
14✔
429
                if !sh.config.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHAllowDirectTcpip) {
18✔
430
                        return status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled")
4✔
431
                }
4✔
432
                var subMsg ChannelOpenDirectMsg
10✔
433
                if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
11✔
434
                        return err
1✔
435
                }
1✔
436
                action, err := sh.PrepareHandoff(stream.Context(), subMsg.DestAddr, nil)
9✔
437
                if err != nil {
11✔
438
                        return err
2✔
439
                }
2✔
440
                return channel.SendControlAction(action)
7✔
441
        default:
2✔
442
                return status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: %s", msg.ChanType)
2✔
443
        }
444
}
445

446
func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_ssh.AuthenticationRequest) error {
210✔
447
        if req.Protocol != "ssh" {
212✔
448
                return status.Errorf(codes.InvalidArgument, "invalid protocol: %s", req.Protocol)
2✔
449
        }
2✔
450
        if req.Service != ServiceConnection {
210✔
451
                return status.Errorf(codes.InvalidArgument, "invalid service: %s", req.Service)
2✔
452
        }
2✔
453
        if !slices.Contains(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod) {
212✔
454
                return status.Errorf(codes.InvalidArgument, "unexpected auth method: %s", req.AuthMethod)
6✔
455
        }
6✔
456

457
        if sh.state.Username == nil {
340✔
458
                if req.Username == "" {
142✔
459
                        return status.Errorf(codes.InvalidArgument, "username missing")
2✔
460
                }
2✔
461
                sh.state.Username = &req.Username
138✔
462
        } else if *sh.state.Username != req.Username {
62✔
463
                return status.Errorf(codes.InvalidArgument, "inconsistent username")
2✔
464
        }
2✔
465
        if sh.state.Hostname == nil {
334✔
466
                sh.state.Hostname = &req.Hostname
138✔
467
        } else if *sh.state.Hostname != req.Hostname {
200✔
468
                return status.Errorf(codes.InvalidArgument, "inconsistent hostname")
4✔
469
        }
4✔
470

471
        updateMethods := func(add []string) {
368✔
472
                sh.state.RemainingUnauthenticatedMethods = slices.Remove(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod)
176✔
473
                sh.state.RemainingUnauthenticatedMethods = append(sh.state.RemainingUnauthenticatedMethods, add...)
176✔
474
        }
176✔
475
        log.Ctx(ctx).Debug().
192✔
476
                Str("method", req.AuthMethod).
192✔
477
                Str("username", *sh.state.Username).
192✔
478
                Str("hostname", *sh.state.Hostname).
192✔
479
                Msg("ssh: handling auth request")
192✔
480

192✔
481
        var partial bool
192✔
482
        switch req.AuthMethod {
192✔
483
        case MethodPublicKey:
154✔
484
                methodReq, _ := req.MethodRequest.UnmarshalNew()
154✔
485
                pubkeyReq, ok := methodReq.(*extensions_ssh.PublicKeyMethodRequest)
154✔
486
                if !ok {
156✔
487
                        return status.Errorf(codes.InvalidArgument, "invalid public key method request type")
2✔
488
                }
2✔
489
                response, err := sh.auth.HandlePublicKeyMethodRequest(ctx, sh.state.StreamAuthInfo, pubkeyReq)
152✔
490
                if err != nil {
154✔
491
                        return err
2✔
492
                } else if response.Allow != nil {
280✔
493
                        partial = true
128✔
494
                        sh.state.PublicKeyFingerprintSha256 = pubkeyReq.PublicKeyFingerprintSha256
128✔
495
                }
128✔
496
                sh.state.PublicKeyAllow.Update(response.Allow)
150✔
497
                updateMethods(response.RequireAdditionalMethods)
150✔
498
        case MethodKeyboardInteractive:
36✔
499
                methodReq, _ := req.MethodRequest.UnmarshalNew()
36✔
500
                kbiReq, ok := methodReq.(*extensions_ssh.KeyboardInteractiveMethodRequest)
36✔
501
                if !ok {
38✔
502
                        return status.Errorf(codes.InvalidArgument, "invalid keyboard-interactive method request type")
2✔
503
                }
2✔
504
                response, err := sh.auth.HandleKeyboardInteractiveMethodRequest(ctx, sh.state.StreamAuthInfo, kbiReq, sh)
34✔
505
                if err != nil {
42✔
506
                        return err
8✔
507
                }
8✔
508
                partial = response.Allow != nil
26✔
509
                sh.state.KeyboardInteractiveAllow.Update(response.Allow)
26✔
510
                updateMethods(response.RequireAdditionalMethods)
26✔
511
        default:
2✔
512
                return status.Errorf(codes.Internal, "bug: server requested an unsupported auth method %q", req.AuthMethod)
2✔
513
        }
514
        log.Ctx(ctx).Debug().
176✔
515
                Str("method", req.AuthMethod).
176✔
516
                Bool("partial", partial).
176✔
517
                Strs("methods-remaining", sh.state.RemainingUnauthenticatedMethods).
176✔
518
                Msg("ssh: auth request complete")
176✔
519

176✔
520
        if len(sh.state.RemainingUnauthenticatedMethods) == 0 && sh.state.allMethodsValid() {
290✔
521
                // If there are no methods remaining, the user is allowed if all attempted
114✔
522
                // methods have a valid response in the state
114✔
523
                sh.state.InitialAuthComplete = true
114✔
524
                // Initialize the port forward manager
114✔
525
                sh.portForwards = portforward.NewManager(ctx, sh)
114✔
526
                sh.portForwards.OnConfigUpdate(sh.config)
114✔
527
                sh.portForwards.AddUpdateListener(sh)
114✔
528

114✔
529
                log.Ctx(ctx).Debug().Msg("ssh: all methods valid, sending allow response")
114✔
530
                sh.sendAllowResponse()
114✔
531
        } else {
176✔
532
                log.Ctx(ctx).Debug().Msg("ssh: unauthenticated methods remain, sending deny response")
62✔
533
                sh.sendDenyResponseWithRemainingMethods(partial)
62✔
534
        }
62✔
535
        return nil
176✔
536
}
537

538
func (sh *StreamHandler) reauth(ctx context.Context) error {
46✔
539
        if !sh.state.InitialAuthComplete {
46✔
UNCOV
540
                return nil
×
UNCOV
541
        }
×
542
        return sh.auth.EvaluateDelayed(ctx, sh.state.StreamAuthInfo)
46✔
543
}
544

545
func (sh *StreamHandler) PrepareHandoff(ctx context.Context, hostname string, ptyInfo *extensions_ssh.SSHDownstreamPTYInfo) (*extensions_ssh.SSHChannelControlAction, error) {
10✔
546
        if hostname == "" {
11✔
547
                return nil, status.Errorf(codes.PermissionDenied, "invalid hostname")
1✔
548
        }
1✔
549
        if sh.state.Hostname == nil {
9✔
550
                panic("bug: PrepareHandoff called but state is missing a hostname")
×
551
        }
552
        if *sh.state.Hostname != "" {
9✔
553
                panic("bug: PrepareHandoff called but previous hostname is not empty")
×
554
        }
555
        *sh.state.Hostname = hostname
9✔
556
        err := sh.auth.EvaluateDelayed(ctx, sh.state.StreamAuthInfo)
9✔
557
        if err != nil {
10✔
558
                return nil, status.Error(codes.PermissionDenied, err.Error())
1✔
559
        }
1✔
560
        log.Ctx(ctx).Debug().
8✔
561
                Str("hostname", *sh.state.Hostname).
8✔
562
                Str("username", *sh.state.Username).
8✔
563
                Msg("ssh: initiating handoff to upstream")
8✔
564
        upstreamAllow := sh.buildUpstreamAllowResponse()
8✔
565
        action := &extensions_ssh.SSHChannelControlAction{
8✔
566
                Action: &extensions_ssh.SSHChannelControlAction_HandOff{
8✔
567
                        HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
8✔
568
                                DownstreamChannelInfo: sh.state.DownstreamChannelInfo,
8✔
569
                                DownstreamPtyInfo:     ptyInfo,
8✔
570
                                UpstreamAuth:          upstreamAllow,
8✔
571
                        },
8✔
572
                },
8✔
573
        }
8✔
574
        return action, nil
8✔
575
}
576

577
func (sh *StreamHandler) FormatSession(ctx context.Context) ([]byte, error) {
14✔
578
        return sh.auth.FormatSession(ctx, sh.state.StreamAuthInfo)
14✔
579
}
14✔
580

581
func (sh *StreamHandler) DeleteSession(ctx context.Context) error {
12✔
582
        return sh.auth.DeleteSession(ctx, sh.state.StreamAuthInfo)
12✔
583
}
12✔
584

585
func (sh *StreamHandler) AllSSHRoutes() iter.Seq[*config.Policy] {
6✔
586
        return func(yield func(*config.Policy) bool) {
12✔
587
                for route := range sh.config.Options.GetAllPolicies() {
30✔
588
                        if route.IsSSH() {
34✔
589
                                if !yield(route) {
12✔
590
                                        return
2✔
591
                                }
2✔
592
                        }
593
                }
594
        }
595
}
596

597
// DownstreamChannelID implements StreamHandlerInterface.
598
func (sh *StreamHandler) DownstreamChannelID() uint32 {
212✔
599
        return sh.state.DownstreamChannelInfo.DownstreamChannelId
212✔
600
}
212✔
601

602
// Hostname implements StreamHandlerInterface.
603
func (sh *StreamHandler) Hostname() *string {
42✔
604
        return sh.state.Hostname
42✔
605
}
42✔
606

607
// Username implements StreamHandlerInterface.
608
func (sh *StreamHandler) Username() *string {
79✔
609
        return sh.state.Username
79✔
610
}
79✔
611

612
func (sh *StreamHandler) sendDenyResponseWithRemainingMethods(partial bool) {
62✔
613
        sh.writeC <- &extensions_ssh.ServerMessage{
62✔
614
                Message: &extensions_ssh.ServerMessage_AuthResponse{
62✔
615
                        AuthResponse: &extensions_ssh.AuthenticationResponse{
62✔
616
                                Response: &extensions_ssh.AuthenticationResponse_Deny{
62✔
617
                                        Deny: &extensions_ssh.DenyResponse{
62✔
618
                                                Partial: partial,
62✔
619
                                                Methods: sh.state.RemainingUnauthenticatedMethods,
62✔
620
                                        },
62✔
621
                                },
62✔
622
                        },
62✔
623
                },
62✔
624
        }
62✔
625
}
62✔
626

627
func (sh *StreamHandler) sendAllowResponse() {
114✔
628
        var allow *extensions_ssh.AllowResponse
114✔
629
        if *sh.state.Hostname == "" {
204✔
630
                sh.expectingInternalChannel = true
90✔
631
                allow = sh.buildInternalAllowResponse()
90✔
632
        } else {
114✔
633
                allow = sh.buildUpstreamAllowResponse()
24✔
634
        }
24✔
635

636
        sh.writeC <- &extensions_ssh.ServerMessage{
114✔
637
                Message: &extensions_ssh.ServerMessage_AuthResponse{
114✔
638
                        AuthResponse: &extensions_ssh.AuthenticationResponse{
114✔
639
                                Response: &extensions_ssh.AuthenticationResponse_Allow{
114✔
640
                                        Allow: allow,
114✔
641
                                },
114✔
642
                        },
114✔
643
                },
114✔
644
        }
114✔
645
}
646

647
func (sh *StreamHandler) sendInfoPrompts(prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) {
34✔
648
        sh.writeC <- &extensions_ssh.ServerMessage{
34✔
649
                Message: &extensions_ssh.ServerMessage_AuthResponse{
34✔
650
                        AuthResponse: &extensions_ssh.AuthenticationResponse{
34✔
651
                                Response: &extensions_ssh.AuthenticationResponse_InfoRequest{
34✔
652
                                        InfoRequest: &extensions_ssh.InfoRequest{
34✔
653
                                                Method:  MethodKeyboardInteractive,
34✔
654
                                                Request: protoutil.NewAny(prompts),
34✔
655
                                        },
34✔
656
                                },
34✔
657
                        },
34✔
658
                },
34✔
659
        }
34✔
660
}
34✔
661

662
func (sh *StreamHandler) buildUpstreamAllowResponse() *extensions_ssh.AllowResponse {
32✔
663
        var allowedMethods []*extensions_ssh.AllowedMethod
32✔
664
        if value := sh.state.PublicKeyAllow.Value; value != nil {
64✔
665
                allowedMethods = append(allowedMethods, &extensions_ssh.AllowedMethod{
32✔
666
                        Method:     MethodPublicKey,
32✔
667
                        MethodData: protoutil.NewAny(value),
32✔
668
                })
32✔
669
        }
32✔
670
        if value := sh.state.KeyboardInteractiveAllow.Value; value != nil {
46✔
671
                allowedMethods = append(allowedMethods, &extensions_ssh.AllowedMethod{
14✔
672
                        Method:     MethodKeyboardInteractive,
14✔
673
                        MethodData: protoutil.NewAny(value),
14✔
674
                })
14✔
675
        }
14✔
676
        return &extensions_ssh.AllowResponse{
32✔
677
                Username: *sh.state.Username,
32✔
678
                Target: &extensions_ssh.AllowResponse_Upstream{
32✔
679
                        Upstream: &extensions_ssh.UpstreamTarget{
32✔
680
                                Hostname:       *sh.state.Hostname,
32✔
681
                                DirectTcpip:    sh.state.ChannelType == ChannelTypeDirectTcpip,
32✔
682
                                AllowedMethods: allowedMethods,
32✔
683
                        },
32✔
684
                },
32✔
685
        }
32✔
686
}
687

688
func (sh *StreamHandler) buildInternalAllowResponse() *extensions_ssh.AllowResponse {
90✔
689
        return &extensions_ssh.AllowResponse{
90✔
690
                Username: *sh.state.Username,
90✔
691
                Target: &extensions_ssh.AllowResponse_Internal{
90✔
692
                        Internal: &extensions_ssh.InternalTarget{
90✔
693
                                SetMetadata: &corev3.Metadata{
90✔
694
                                        TypedFilterMetadata: map[string]*anypb.Any{
90✔
695
                                                "com.pomerium.ssh": protoutil.NewAny(&extensions_ssh.FilterMetadata{
90✔
696
                                                        StreamId: sh.downstream.StreamId,
90✔
697
                                                }),
90✔
698
                                        },
90✔
699
                                },
90✔
700
                        },
90✔
701
                },
90✔
702
        }
90✔
703
}
90✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc