• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pomerium / pomerium / 21647993278

03 Feb 2026 09:10PM UTC coverage: 44.408% (-0.07%) from 44.475%
21647993278

push

github

web-flow
ssh tui: fix data races related to model updates (#6077)

This fixes several data races related to concurrent rendering and model
updating. Some model callbacks were changed to be propagated using
tea.Msg instead, and the top level model uses a new double-buffer
drawable to ensure View() is always called from the same thread as
Update.

0 of 248 new or added lines in 15 files covered. (0.0%)

33 existing lines in 11 files now uncovered.

31413 of 70737 relevant lines covered (44.41%)

115.64 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.18
/pkg/ssh/stream.go
1
package ssh
2

3
import (
4
        "context"
5
        "iter"
6
        "sync/atomic"
7
        "time"
8

9
        corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
10
        envoy_config_endpoint_v3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3"
11
        datav3 "github.com/envoyproxy/go-control-plane/envoy/data/core/v3"
12
        gossh "golang.org/x/crypto/ssh"
13
        "google.golang.org/grpc/codes"
14
        "google.golang.org/grpc/status"
15
        "google.golang.org/protobuf/types/known/anypb"
16

17
        extensions_ssh "github.com/pomerium/envoy-custom/api/extensions/filters/network/ssh"
18
        "github.com/pomerium/pomerium/config"
19
        "github.com/pomerium/pomerium/internal/log"
20
        "github.com/pomerium/pomerium/pkg/grpc/databroker"
21
        "github.com/pomerium/pomerium/pkg/grpc/session"
22
        "github.com/pomerium/pomerium/pkg/protoutil"
23
        "github.com/pomerium/pomerium/pkg/slices"
24
        "github.com/pomerium/pomerium/pkg/ssh/api"
25
        "github.com/pomerium/pomerium/pkg/ssh/cli"
26
        "github.com/pomerium/pomerium/pkg/ssh/models"
27
        "github.com/pomerium/pomerium/pkg/ssh/portforward"
28
)
29

30
const (
31
        MethodPublicKey           = "publickey"
32
        MethodKeyboardInteractive = "keyboard-interactive"
33

34
        ChannelTypeSession     = "session"
35
        ChannelTypeDirectTcpip = "direct-tcpip"
36

37
        ServiceConnection = "ssh-connection"
38
)
39

40
type KeyboardInteractiveQuerier interface {
41
        // Prompts the client and returns their responses to the given prompts.
42
        Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error)
43
}
44

45
type AuthMethodResponse[T any] struct {
46
        Allow                    *T
47
        RequireAdditionalMethods []string
48
}
49

50
type (
51
        PublicKeyAuthMethodResponse           = AuthMethodResponse[extensions_ssh.PublicKeyAllowResponse]
52
        KeyboardInteractiveAuthMethodResponse = AuthMethodResponse[extensions_ssh.KeyboardInteractiveAllowResponse]
53
)
54

55
//go:generate go tool -modfile ../../internal/tools/go.mod go.uber.org/mock/mockgen -typed -destination ./mock/mock_auth_interface.go . AuthInterface
56

57
type AuthInterface interface {
58
        HandlePublicKeyMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.PublicKeyMethodRequest) (PublicKeyAuthMethodResponse, error)
59
        HandleKeyboardInteractiveMethodRequest(ctx context.Context, info StreamAuthInfo, req *extensions_ssh.KeyboardInteractiveMethodRequest, querier KeyboardInteractiveQuerier) (KeyboardInteractiveAuthMethodResponse, error)
60
        EvaluateDelayed(ctx context.Context, info StreamAuthInfo) error
61
        GetSession(ctx context.Context, info StreamAuthInfo) (*session.Session, error)
62
        DeleteSession(ctx context.Context, info StreamAuthInfo) error
63
        GetDataBrokerServiceClient() databroker.DataBrokerServiceClient
64
}
65

66
type ClusterStatsListener interface {
67
        HandleClusterStatsUpdate(*envoy_config_endpoint_v3.ClusterStats)
68
}
69

70
type EndpointDiscoveryInterface interface {
71
        PortForwardManager() *portforward.Manager
72
        UpdateClusterEndpoints(added map[string]portforward.RoutePortForwardInfo, removed map[string]struct{})
73
}
74

75
type AuthMethodValue[T any] struct {
76
        attempted bool
77
        Value     *T
78
}
79

80
func (v *AuthMethodValue[T]) Update(value *T) {
163✔
81
        v.attempted = true
163✔
82
        v.Value = value
163✔
83
}
163✔
84

85
func (v *AuthMethodValue[T]) IsValid() bool {
176✔
86
        if v.attempted {
303✔
87
                // method was attempted - valid iff there is a value
127✔
88
                return v.Value != nil
127✔
89
        }
127✔
90
        return true // method was not attempted - valid
49✔
91
}
92

93
type StreamAuthInfo struct {
94
        Username                   *string
95
        Hostname                   *string
96
        StreamID                   uint64
97
        SourceAddress              string
98
        ChannelType                string
99
        PublicKeyFingerprintSha256 []byte
100
        PublicKeyAllow             AuthMethodValue[extensions_ssh.PublicKeyAllowResponse]
101
        KeyboardInteractiveAllow   AuthMethodValue[extensions_ssh.KeyboardInteractiveAllowResponse]
102
        InitialAuthComplete        bool
103
}
104

105
func (i *StreamAuthInfo) allMethodsValid() bool {
88✔
106
        return i.PublicKeyAllow.IsValid() && i.KeyboardInteractiveAllow.IsValid()
88✔
107
}
88✔
108

109
type StreamState struct {
110
        StreamAuthInfo
111
        RemainingUnauthenticatedMethods []string
112
        DownstreamChannelInfo           *extensions_ssh.SSHDownstreamChannelInfo
113
}
114

115
// StreamHandler handles a single SSH stream
116
type StreamHandler struct {
117
        auth       AuthInterface
118
        discovery  EndpointDiscoveryInterface
119
        cliCtrl    cli.InternalCLIController
120
        config     *config.Config
121
        downstream *extensions_ssh.DownstreamConnectEvent
122
        writeC     chan *extensions_ssh.ServerMessage
123
        readC      chan *extensions_ssh.ClientMessage
124
        reauthC    chan struct{}
125
        terminateC chan error
126

127
        state *StreamState
128
        close func()
129

130
        expectingInternalChannel bool
131
        internalSession          atomic.Pointer[ChannelHandler]
132

133
        // Internal data models
134
        channelModel    *models.ChannelModel
135
        routeModel      *models.RouteModel
136
        permissionModel *models.PermissionModel
137
}
138

139
// PermissionDataModel implements StreamHandlerInterface.
UNCOV
140
func (sh *StreamHandler) PermissionDataModel() *models.PermissionModel {
×
UNCOV
141
        return sh.permissionModel
×
UNCOV
142
}
×
143

144
// RouteDataModel implements StreamHandlerInterface.
UNCOV
145
func (sh *StreamHandler) RouteDataModel() *models.RouteModel {
×
UNCOV
146
        return sh.routeModel
×
UNCOV
147
}
×
148

149
// ChannelDataModel implements StreamHandlerInterface.
UNCOV
150
func (sh *StreamHandler) ChannelDataModel() *models.ChannelModel {
×
UNCOV
151
        return sh.channelModel
×
UNCOV
152
}
×
153

154
var _ api.StreamHandlerInterface = (*StreamHandler)(nil)
155

156
func NewStreamHandler(
157
        auth AuthInterface,
158
        discovery EndpointDiscoveryInterface,
159
        cliCtrl cli.InternalCLIController,
160
        cfg *config.Config,
161
        downstream *extensions_ssh.DownstreamConnectEvent,
162
        onClosed func(),
163
) *StreamHandler {
150✔
164
        writeC := make(chan *extensions_ssh.ServerMessage, 32)
150✔
165
        sh := &StreamHandler{
150✔
166
                auth:       auth,
150✔
167
                discovery:  discovery,
150✔
168
                cliCtrl:    cliCtrl,
150✔
169
                config:     cfg,
150✔
170
                downstream: downstream,
150✔
171
                writeC:     make(chan *extensions_ssh.ServerMessage, 32),
150✔
172
                readC:      make(chan *extensions_ssh.ClientMessage, 32),
150✔
173
                reauthC:    make(chan struct{}),
150✔
174
                terminateC: make(chan error, 1),
150✔
175
                close: func() {
289✔
176
                        onClosed()
139✔
177
                        close(writeC)
139✔
178
                },
139✔
179
                channelModel:    models.NewChannelModel(),
180
                routeModel:      models.NewRouteModel(cliCtrl.EventHandlers().RouteDataModelEventHandlers),
181
                permissionModel: models.NewPermissionModel(),
182
        }
183
        return sh
150✔
184
}
185

186
// OnClusterEndpointsUpdated implements portforward.UpdateListener.
187
func (sh *StreamHandler) OnClusterEndpointsUpdated(added map[string]portforward.RoutePortForwardInfo, removed map[string]struct{}) {
55✔
188
        sh.discovery.UpdateClusterEndpoints(added, removed)
55✔
189
        sh.routeModel.HandleClusterEndpointsUpdate(added, removed)
55✔
190
        sh.permissionModel.HandleClusterEndpointsUpdate(added, removed)
55✔
191
}
55✔
192

193
// OnPermissionsUpdated implements portforward.UpdateListener.
194
func (sh *StreamHandler) OnPermissionsUpdated(permissions []portforward.Permission) {
51✔
195
        sh.permissionModel.HandlePermissionsUpdate(permissions)
51✔
196
}
51✔
197

198
// OnRoutesUpdated implements portforward.UpdateListener.
199
func (sh *StreamHandler) OnRoutesUpdated(routes []portforward.RouteInfo) {
52✔
200
        sh.routeModel.HandleRoutesUpdate(routes)
52✔
201
}
52✔
202

203
func (sh *StreamHandler) OnClusterHealthUpdate(_ context.Context, event *datav3.HealthCheckEvent) {
×
204
        sh.routeModel.HandleClusterHealthUpdate(event)
×
205
}
×
206

207
func (sh *StreamHandler) Terminate(err error) {
11✔
208
        sh.terminateC <- err
11✔
209
}
11✔
210

211
func (sh *StreamHandler) Close() {
139✔
212
        sh.close()
139✔
213
}
139✔
214

215
func (sh *StreamHandler) IsExpectingInternalChannel() bool {
98✔
216
        return sh.expectingInternalChannel
98✔
217
}
98✔
218

219
func (sh *StreamHandler) ReadC() chan<- *extensions_ssh.ClientMessage {
509✔
220
        return sh.readC
509✔
221
}
509✔
222

223
func (sh *StreamHandler) WriteC() <-chan *extensions_ssh.ServerMessage {
348✔
224
        return sh.writeC
348✔
225
}
348✔
226

227
// Reauth blocks until authorization policy is reevaluated.
228
func (sh *StreamHandler) Reauth() {
10✔
229
        sh.reauthC <- struct{}{}
10✔
230
}
10✔
231

232
func (sh *StreamHandler) periodicReauth() (cancel func()) {
148✔
233
        t := time.NewTicker(1 * time.Minute)
148✔
234
        go func() {
296✔
235
                for range t.C {
148✔
236
                        sh.Reauth()
×
237
                }
×
238
        }()
239
        return t.Stop
148✔
240
}
241

242
// Prompt implements KeyboardInteractiveQuerier.
243
func (sh *StreamHandler) Prompt(ctx context.Context, prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) (*extensions_ssh.KeyboardInteractiveInfoPromptResponses, error) {
47✔
244
        sh.sendInfoPrompts(prompts)
47✔
245
        select {
47✔
246
        case <-ctx.Done():
2✔
247
                return nil, context.Cause(ctx)
2✔
248
        case err := <-sh.terminateC:
×
249
                return nil, err
×
250
        case req := <-sh.readC:
45✔
251
                switch msg := req.Message.(type) {
45✔
252
                case *extensions_ssh.ClientMessage_InfoResponse:
43✔
253
                        if msg.InfoResponse.Method != MethodKeyboardInteractive {
45✔
254
                                return nil, status.Errorf(codes.Internal, "received invalid info response")
2✔
255
                        }
2✔
256
                        r, _ := msg.InfoResponse.Response.UnmarshalNew()
41✔
257
                        respInfo, ok := r.(*extensions_ssh.KeyboardInteractiveInfoPromptResponses)
41✔
258
                        if !ok {
43✔
259
                                return nil, status.Errorf(codes.InvalidArgument, "received invalid prompt response")
2✔
260
                        }
2✔
261
                        return respInfo, nil
39✔
262
                default:
2✔
263
                        return nil, status.Errorf(codes.InvalidArgument, "received invalid message, expecting info response")
2✔
264
                }
265
        }
266
}
267

268
func (sh *StreamHandler) Run(ctx context.Context) error {
150✔
269
        if sh.state != nil {
152✔
270
                panic("Run called twice")
2✔
271
        }
272
        sh.state = &StreamState{
148✔
273
                RemainingUnauthenticatedMethods: []string{MethodPublicKey},
148✔
274
                StreamAuthInfo: StreamAuthInfo{
148✔
275
                        StreamID:      sh.downstream.StreamId,
148✔
276
                        SourceAddress: sh.downstream.SourceAddress.GetSocketAddress().GetAddress(),
148✔
277
                },
148✔
278
        }
148✔
279
        cancelReauth := sh.periodicReauth()
148✔
280
        defer cancelReauth()
148✔
281
        for {
498✔
282
                select {
350✔
283
                case <-ctx.Done():
91✔
284
                        return context.Cause(ctx)
91✔
285
                case <-sh.reauthC:
10✔
286
                        if err := sh.reauth(ctx); err != nil {
16✔
287
                                return err
6✔
288
                        }
6✔
289
                case err := <-sh.terminateC:
11✔
290
                        return err
11✔
291
                case req := <-sh.readC:
238✔
292
                        switch req := req.Message.(type) {
238✔
293
                        case *extensions_ssh.ClientMessage_Event:
33✔
294
                                switch event := req.Event.Event.(type) {
33✔
295
                                case *extensions_ssh.StreamEvent_DownstreamConnected:
2✔
296
                                        // this was already received as the first message in the stream
2✔
297
                                        return status.Errorf(codes.Internal, "received duplicate downstream connected event")
2✔
298
                                case *extensions_ssh.StreamEvent_UpstreamConnected:
27✔
299
                                        log.Ctx(ctx).Debug().
27✔
300
                                                Msg("ssh: upstream connected")
27✔
301
                                case *extensions_ssh.StreamEvent_DownstreamDisconnected:
2✔
302
                                        log.Ctx(ctx).Debug().
2✔
303
                                                Uint64("stream-id", sh.downstream.StreamId).
2✔
304
                                                Str("reason", event.DownstreamDisconnected.Reason).
2✔
305
                                                Msg("ssh: downstream disconnected")
2✔
306
                                case *extensions_ssh.StreamEvent_ChannelEvent:
×
307
                                        sh.handleChannelEvent(event.ChannelEvent)
×
308
                                case nil:
2✔
309
                                        return status.Errorf(codes.Internal, "received invalid event")
2✔
310
                                }
311
                        case *extensions_ssh.ClientMessage_AuthRequest:
197✔
312
                                if err := sh.handleAuthRequest(ctx, req.AuthRequest); err != nil {
231✔
313
                                        return err
34✔
314
                                }
34✔
315
                        case *extensions_ssh.ClientMessage_GlobalRequest:
6✔
316
                                if err := sh.handleGlobalRequest(ctx, req.GlobalRequest); err != nil {
6✔
317
                                        return err
×
318
                                }
×
319
                        default:
2✔
320
                                return status.Errorf(codes.Internal, "received invalid client message type %#T", req)
2✔
321
                        }
322
                }
323
        }
324
}
325

326
func (sh *StreamHandler) handleChannelEvent(event *extensions_ssh.ChannelEvent) {
×
327
        sh.channelModel.HandleEvent(event)
×
328
}
×
329

330
func (sh *StreamHandler) handleGlobalRequest(ctx context.Context, globalRequest *extensions_ssh.GlobalRequest) error {
6✔
331
        switch request := globalRequest.Request.(type) {
6✔
332
        case *extensions_ssh.GlobalRequest_TcpipForwardRequest:
4✔
333
                if !sh.state.InitialAuthComplete {
4✔
334
                        return status.Errorf(codes.InvalidArgument, "cannot request port-forward before auth is complete")
×
335
                }
×
336
                reqHost := request.TcpipForwardRequest.RemoteAddress
4✔
337
                reqPort := request.TcpipForwardRequest.RemotePort
4✔
338
                log.Ctx(ctx).Debug().
4✔
339
                        Uint64("stream-id", sh.state.StreamID).
4✔
340
                        Str("host", reqHost).
4✔
341
                        Msg("got tcpip-forward request")
4✔
342

4✔
343
                serverPort, err := sh.discovery.PortForwardManager().AddPermission(reqHost, reqPort)
4✔
344
                if err != nil {
4✔
345
                        log.Ctx(ctx).Debug().
×
346
                                Uint64("stream-id", sh.state.StreamID).
×
347
                                Err(err).
×
348
                                Msg("sending global request failure")
×
349
                        sh.sendGlobalRequestResponse(&extensions_ssh.GlobalRequestResponse{
×
350
                                Success:      false,
×
351
                                DebugMessage: err.Error(),
×
352
                        })
×
353
                        return nil
×
354
                }
×
355

356
                log.Ctx(ctx).Debug().
4✔
357
                        Uint64("stream-id", sh.state.StreamID).
4✔
358
                        Msg("sending global request success")
4✔
359

4✔
360
                // https://datatracker.ietf.org/doc/html/rfc4254#section-7.1
4✔
361
                if globalRequest.WantReply && reqPort == 0 {
4✔
362
                        sh.sendGlobalRequestResponse(&extensions_ssh.GlobalRequestResponse{
×
363
                                Success: true,
×
364
                                Response: &extensions_ssh.GlobalRequestResponse_TcpipForwardResponse{
×
365
                                        TcpipForwardResponse: &extensions_ssh.TcpipForwardResponse{
×
366
                                                ServerPort: serverPort.Value,
×
367
                                        },
×
368
                                },
×
369
                        })
×
370
                }
×
371

372
                return nil
4✔
373
        case *extensions_ssh.GlobalRequest_CancelTcpipForwardRequest:
2✔
374
                if !sh.state.InitialAuthComplete {
2✔
375
                        return status.Errorf(codes.InvalidArgument, "cannot request port-forward before auth is complete")
×
376
                }
×
377
                err := sh.discovery.PortForwardManager().RemovePermission(
2✔
378
                        request.CancelTcpipForwardRequest.RemoteAddress,
2✔
379
                        request.CancelTcpipForwardRequest.RemotePort)
2✔
380
                if err != nil {
2✔
381
                        sh.sendGlobalRequestResponse(&extensions_ssh.GlobalRequestResponse{
×
382
                                Success:      false,
×
383
                                DebugMessage: err.Error(),
×
384
                        })
×
385
                } else if globalRequest.WantReply {
2✔
386
                        sh.sendGlobalRequestResponse(&extensions_ssh.GlobalRequestResponse{
×
387
                                Success: true,
×
388
                        })
×
389
                }
×
390
                return nil
2✔
391
        default:
×
392
                return status.Errorf(codes.Unimplemented, "received unknown global request")
×
393
        }
394
}
395

396
func (sh *StreamHandler) sendGlobalRequestResponse(response *extensions_ssh.GlobalRequestResponse) {
×
397
        sh.writeC <- &extensions_ssh.ServerMessage{
×
398
                Message: &extensions_ssh.ServerMessage_GlobalRequestResponse{
×
399
                        GlobalRequestResponse: response,
×
400
                },
×
401
        }
×
402
}
×
403

404
func (sh *StreamHandler) ServeChannel(
405
        stream extensions_ssh.StreamManagement_ServeChannelServer,
406
        metadata *extensions_ssh.FilterMetadata,
407
) error {
64✔
408
        // The first channel message on this stream should be a ChannelOpen
64✔
409
        channelOpen, err := stream.Recv()
64✔
410
        if err != nil {
71✔
411
                return err
7✔
412
        }
7✔
413
        rawMsg, ok := channelOpen.GetMessage().(*extensions_ssh.ChannelMessage_RawBytes)
57✔
414
        if !ok {
59✔
415
                return status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen")
2✔
416
        }
2✔
417
        var msg ChannelOpenMsg
55✔
418
        if err := gossh.Unmarshal(rawMsg.RawBytes.GetValue(), &msg); err != nil {
57✔
419
                return status.Errorf(codes.InvalidArgument, "first channel message was not ChannelOpen")
2✔
420
        }
2✔
421

422
        sh.state.DownstreamChannelInfo = &extensions_ssh.SSHDownstreamChannelInfo{
53✔
423
                ChannelType:               msg.ChanType,
53✔
424
                DownstreamChannelId:       msg.PeersID,
53✔
425
                InternalUpstreamChannelId: metadata.ChannelId,
53✔
426
                InitialWindowSize:         msg.PeersWindow,
53✔
427
                MaxPacketSize:             msg.MaxPacketSize,
53✔
428
        }
53✔
429
        sh.state.ChannelType = msg.ChanType
53✔
430
        channel := NewChannelImpl(sh, stream, sh.state.DownstreamChannelInfo)
53✔
431
        switch msg.ChanType {
53✔
432
        case ChannelTypeSession:
37✔
433
                ch := NewChannelHandler(channel, sh.cliCtrl, sh.config)
37✔
434
                if !sh.internalSession.CompareAndSwap(nil, ch) {
37✔
435
                        return channel.SendMessage(ChannelOpenFailureMsg{
×
436
                                PeersID: sh.state.DownstreamChannelInfo.DownstreamChannelId,
×
437
                                Reason:  Prohibited,
×
438
                                Message: "multiple concurrent internal session channels not supported",
×
439
                        })
×
440
                }
×
441
                if err := channel.SendMessage(ChannelOpenConfirmMsg{
37✔
442
                        PeersID:       sh.state.DownstreamChannelInfo.DownstreamChannelId,
37✔
443
                        MyID:          sh.state.DownstreamChannelInfo.InternalUpstreamChannelId,
37✔
444
                        MyWindow:      ChannelWindowSize,
37✔
445
                        MaxPacketSize: ChannelMaxPacket,
37✔
446
                }); err != nil {
37✔
447
                        return err
×
448
                }
×
449

450
                err := ch.Run(stream.Context(), metadata.ModeHint)
37✔
451
                sh.internalSession.Store(nil)
37✔
452
                return err
37✔
453
        case ChannelTypeDirectTcpip:
14✔
454
                if !sh.config.Options.IsRuntimeFlagSet(config.RuntimeFlagSSHAllowDirectTcpip) {
18✔
455
                        return status.Errorf(codes.Unavailable, "direct-tcpip channels are not enabled")
4✔
456
                }
4✔
457
                var subMsg ChannelOpenDirectMsg
10✔
458
                if err := gossh.Unmarshal(msg.TypeSpecificData, &subMsg); err != nil {
11✔
459
                        return err
1✔
460
                }
1✔
461
                action, err := sh.PrepareHandoff(stream.Context(), subMsg.DestAddr, nil)
9✔
462
                if err != nil {
11✔
463
                        return err
2✔
464
                }
2✔
465
                return channel.SendControlAction(action)
7✔
466
        default:
2✔
467
                return status.Errorf(codes.InvalidArgument, "unexpected channel type in ChannelOpen message: %s", msg.ChanType)
2✔
468
        }
469
}
470

471
func (sh *StreamHandler) handleAuthRequest(ctx context.Context, req *extensions_ssh.AuthenticationRequest) error {
197✔
472
        if req.Protocol != "ssh" {
199✔
473
                return status.Errorf(codes.InvalidArgument, "invalid protocol: %s", req.Protocol)
2✔
474
        }
2✔
475
        if req.Service != ServiceConnection {
197✔
476
                return status.Errorf(codes.InvalidArgument, "invalid service: %s", req.Service)
2✔
477
        }
2✔
478
        if !slices.Contains(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod) {
199✔
479
                return status.Errorf(codes.InvalidArgument, "unexpected auth method: %s", req.AuthMethod)
6✔
480
        }
6✔
481

482
        if sh.state.Username == nil {
301✔
483
                if req.Username == "" {
116✔
484
                        return status.Errorf(codes.InvalidArgument, "username missing")
2✔
485
                }
2✔
486
                sh.state.Username = &req.Username
112✔
487
        } else if *sh.state.Username != req.Username {
75✔
488
                return status.Errorf(codes.InvalidArgument, "inconsistent username")
2✔
489
        }
2✔
490
        if sh.state.Hostname == nil {
295✔
491
                sh.state.Hostname = &req.Hostname
112✔
492
        } else if *sh.state.Hostname != req.Hostname {
187✔
493
                return status.Errorf(codes.InvalidArgument, "inconsistent hostname")
4✔
494
        }
4✔
495

496
        updateMethods := func(add []string) {
342✔
497
                sh.state.RemainingUnauthenticatedMethods = slices.Remove(sh.state.RemainingUnauthenticatedMethods, req.AuthMethod)
163✔
498
                sh.state.RemainingUnauthenticatedMethods = append(sh.state.RemainingUnauthenticatedMethods, add...)
163✔
499
        }
163✔
500
        log.Ctx(ctx).Debug().
179✔
501
                Str("method", req.AuthMethod).
179✔
502
                Str("username", *sh.state.Username).
179✔
503
                Str("hostname", *sh.state.Hostname).
179✔
504
                Msg("ssh: handling auth request")
179✔
505

179✔
506
        var partial bool
179✔
507
        switch req.AuthMethod {
179✔
508
        case MethodPublicKey:
128✔
509
                methodReq, _ := req.MethodRequest.UnmarshalNew()
128✔
510
                pubkeyReq, ok := methodReq.(*extensions_ssh.PublicKeyMethodRequest)
128✔
511
                if !ok {
130✔
512
                        return status.Errorf(codes.InvalidArgument, "invalid public key method request type")
2✔
513
                }
2✔
514
                response, err := sh.auth.HandlePublicKeyMethodRequest(ctx, sh.state.StreamAuthInfo, pubkeyReq)
126✔
515
                if err != nil {
128✔
516
                        return err
2✔
517
                } else if response.Allow != nil {
228✔
518
                        partial = true
102✔
519
                        sh.state.PublicKeyFingerprintSha256 = pubkeyReq.PublicKeyFingerprintSha256
102✔
520
                }
102✔
521
                sh.state.PublicKeyAllow.Update(response.Allow)
124✔
522
                updateMethods(response.RequireAdditionalMethods)
124✔
523
        case MethodKeyboardInteractive:
49✔
524
                methodReq, _ := req.MethodRequest.UnmarshalNew()
49✔
525
                kbiReq, ok := methodReq.(*extensions_ssh.KeyboardInteractiveMethodRequest)
49✔
526
                if !ok {
51✔
527
                        return status.Errorf(codes.InvalidArgument, "invalid keyboard-interactive method request type")
2✔
528
                }
2✔
529
                response, err := sh.auth.HandleKeyboardInteractiveMethodRequest(ctx, sh.state.StreamAuthInfo, kbiReq, sh)
47✔
530
                if err != nil {
55✔
531
                        return err
8✔
532
                }
8✔
533
                partial = response.Allow != nil
39✔
534
                sh.state.KeyboardInteractiveAllow.Update(response.Allow)
39✔
535
                updateMethods(response.RequireAdditionalMethods)
39✔
536
        default:
2✔
537
                return status.Errorf(codes.Internal, "bug: server requested an unsupported auth method %q", req.AuthMethod)
2✔
538
        }
539
        log.Ctx(ctx).Debug().
163✔
540
                Str("method", req.AuthMethod).
163✔
541
                Bool("partial", partial).
163✔
542
                Strs("methods-remaining", sh.state.RemainingUnauthenticatedMethods).
163✔
543
                Msg("ssh: auth request complete")
163✔
544

163✔
545
        if len(sh.state.RemainingUnauthenticatedMethods) == 0 && sh.state.allMethodsValid() {
251✔
546
                // If there are no methods remaining, the user is allowed if all attempted
88✔
547
                // methods have a valid response in the state
88✔
548
                sh.state.InitialAuthComplete = true
88✔
549
                log.Ctx(ctx).Debug().Msg("ssh: all methods valid, sending allow response")
88✔
550
                sh.sendAllowResponse()
88✔
551
        } else {
163✔
552
                log.Ctx(ctx).Debug().Msg("ssh: unauthenticated methods remain, sending deny response")
75✔
553
                sh.sendDenyResponseWithRemainingMethods(partial)
75✔
554
        }
75✔
555
        return nil
163✔
556
}
557

558
func (sh *StreamHandler) reauth(ctx context.Context) error {
10✔
559
        if !sh.state.InitialAuthComplete {
10✔
560
                return nil
×
561
        }
×
562
        return sh.auth.EvaluateDelayed(ctx, sh.state.StreamAuthInfo)
10✔
563
}
564

565
func (sh *StreamHandler) PrepareHandoff(ctx context.Context, hostname string, ptyInfo api.SSHPtyInfo) (*extensions_ssh.SSHChannelControlAction, error) {
9✔
566
        if hostname == "" {
10✔
567
                return nil, status.Errorf(codes.PermissionDenied, "invalid hostname")
1✔
568
        }
1✔
569
        if sh.state.Hostname == nil {
8✔
570
                panic("bug: PrepareHandoff called but state is missing a hostname")
×
571
        }
572
        if *sh.state.Hostname != "" {
8✔
573
                panic("bug: PrepareHandoff called but previous hostname is not empty")
×
574
        }
575
        *sh.state.Hostname = hostname
8✔
576
        err := sh.auth.EvaluateDelayed(ctx, sh.state.StreamAuthInfo)
8✔
577
        if err != nil {
9✔
578
                return nil, status.Error(codes.PermissionDenied, err.Error())
1✔
579
        }
1✔
580
        log.Ctx(ctx).Debug().
7✔
581
                Str("hostname", *sh.state.Hostname).
7✔
582
                Str("username", *sh.state.Username).
7✔
583
                Msg("ssh: initiating handoff to upstream")
7✔
584
        upstreamAllow := sh.buildUpstreamAllowResponse()
7✔
585
        var downstreamPtyInfo *extensions_ssh.SSHDownstreamPTYInfo
7✔
586
        if ptyInfo != nil {
7✔
587
                downstreamPtyInfo = &extensions_ssh.SSHDownstreamPTYInfo{
×
588
                        TermEnv:      ptyInfo.GetTermEnv(),
×
589
                        WidthColumns: ptyInfo.GetWidthColumns(),
×
590
                        HeightRows:   ptyInfo.GetHeightRows(),
×
591
                        WidthPx:      ptyInfo.GetWidthPx(),
×
592
                        HeightPx:     ptyInfo.GetHeightPx(),
×
593
                        Modes:        ptyInfo.GetModes(),
×
594
                }
×
595
        }
×
596
        action := &extensions_ssh.SSHChannelControlAction{
7✔
597
                Action: &extensions_ssh.SSHChannelControlAction_HandOff{
7✔
598
                        HandOff: &extensions_ssh.SSHChannelControlAction_HandOffUpstream{
7✔
599
                                DownstreamChannelInfo: sh.state.DownstreamChannelInfo,
7✔
600
                                DownstreamPtyInfo:     downstreamPtyInfo,
7✔
601
                                UpstreamAuth:          upstreamAllow,
7✔
602
                        },
7✔
603
                },
7✔
604
        }
7✔
605
        return action, nil
7✔
606
}
607

608
func (sh *StreamHandler) GetSession(ctx context.Context) (*session.Session, error) {
12✔
609
        return sh.auth.GetSession(ctx, sh.state.StreamAuthInfo)
12✔
610
}
12✔
611

612
func (sh *StreamHandler) DeleteSession(ctx context.Context) error {
12✔
613
        return sh.auth.DeleteSession(ctx, sh.state.StreamAuthInfo)
12✔
614
}
12✔
615

616
func (sh *StreamHandler) AllSSHRoutes() iter.Seq[*config.Policy] {
4✔
617
        return func(yield func(*config.Policy) bool) {
8✔
618
                for route := range sh.config.Options.GetAllPolicies() {
18✔
619
                        if route.IsSSH() {
20✔
620
                                if !yield(route) {
8✔
621
                                        return
2✔
622
                                }
2✔
623
                        }
624
                }
625
        }
626
}
627

628
// DownstreamChannelID implements StreamHandlerInterface.
629
func (sh *StreamHandler) DownstreamChannelID() uint32 {
189✔
630
        return sh.state.DownstreamChannelInfo.DownstreamChannelId
189✔
631
}
189✔
632

633
// DownstreamSourceAddress implements StreamHandlerInterface.
634
func (sh *StreamHandler) DownstreamSourceAddress() string {
×
635
        return sh.state.SourceAddress
×
636
}
×
637

638
// DownstreamPublicKeyFingerprint implements StreamHandlerInterface.
639
func (sh *StreamHandler) DownstreamPublicKeyFingerprint() []byte {
×
640
        return sh.state.PublicKeyFingerprintSha256
×
641
}
×
642

643
// Hostname implements StreamHandlerInterface.
644
func (sh *StreamHandler) Hostname() *string {
42✔
645
        return sh.state.Hostname
42✔
646
}
42✔
647

648
// Username implements StreamHandlerInterface.
649
func (sh *StreamHandler) Username() *string {
42✔
650
        return sh.state.Username
42✔
651
}
42✔
652

653
// PortForwardManager implements StreamHandlerInterface.
654
// This is used by internal channels to add additional update listeners.
655
func (sh *StreamHandler) PortForwardManager() *portforward.Manager {
×
656
        return sh.discovery.PortForwardManager()
×
657
}
×
658

659
func (sh *StreamHandler) sendDenyResponseWithRemainingMethods(partial bool) {
75✔
660
        sh.writeC <- &extensions_ssh.ServerMessage{
75✔
661
                Message: &extensions_ssh.ServerMessage_AuthResponse{
75✔
662
                        AuthResponse: &extensions_ssh.AuthenticationResponse{
75✔
663
                                Response: &extensions_ssh.AuthenticationResponse_Deny{
75✔
664
                                        Deny: &extensions_ssh.DenyResponse{
75✔
665
                                                Partial: partial,
75✔
666
                                                Methods: sh.state.RemainingUnauthenticatedMethods,
75✔
667
                                        },
75✔
668
                                },
75✔
669
                        },
75✔
670
                },
75✔
671
        }
75✔
672
}
75✔
673

674
func (sh *StreamHandler) sendAllowResponse() {
88✔
675
        var allow *extensions_ssh.AllowResponse
88✔
676
        if *sh.state.Hostname == "" {
151✔
677
                sh.expectingInternalChannel = true
63✔
678
                allow = sh.buildInternalAllowResponse()
63✔
679
        } else {
88✔
680
                allow = sh.buildUpstreamAllowResponse()
25✔
681
        }
25✔
682

683
        sh.writeC <- &extensions_ssh.ServerMessage{
88✔
684
                Message: &extensions_ssh.ServerMessage_AuthResponse{
88✔
685
                        AuthResponse: &extensions_ssh.AuthenticationResponse{
88✔
686
                                Response: &extensions_ssh.AuthenticationResponse_Allow{
88✔
687
                                        Allow: allow,
88✔
688
                                },
88✔
689
                        },
88✔
690
                },
88✔
691
        }
88✔
692
}
693

694
func (sh *StreamHandler) sendInfoPrompts(prompts *extensions_ssh.KeyboardInteractiveInfoPrompts) {
47✔
695
        sh.writeC <- &extensions_ssh.ServerMessage{
47✔
696
                Message: &extensions_ssh.ServerMessage_AuthResponse{
47✔
697
                        AuthResponse: &extensions_ssh.AuthenticationResponse{
47✔
698
                                Response: &extensions_ssh.AuthenticationResponse_InfoRequest{
47✔
699
                                        InfoRequest: &extensions_ssh.InfoRequest{
47✔
700
                                                Method:  MethodKeyboardInteractive,
47✔
701
                                                Request: protoutil.NewAny(prompts),
47✔
702
                                        },
47✔
703
                                },
47✔
704
                        },
47✔
705
                },
47✔
706
        }
47✔
707
}
47✔
708

709
func (sh *StreamHandler) buildUpstreamAllowResponse() *extensions_ssh.AllowResponse {
32✔
710
        var allowedMethods []*extensions_ssh.AllowedMethod
32✔
711
        if value := sh.state.PublicKeyAllow.Value; value != nil {
64✔
712
                allowedMethods = append(allowedMethods, &extensions_ssh.AllowedMethod{
32✔
713
                        Method:     MethodPublicKey,
32✔
714
                        MethodData: protoutil.NewAny(value),
32✔
715
                })
32✔
716
        }
32✔
717
        if value := sh.state.KeyboardInteractiveAllow.Value; value != nil {
59✔
718
                allowedMethods = append(allowedMethods, &extensions_ssh.AllowedMethod{
27✔
719
                        Method:     MethodKeyboardInteractive,
27✔
720
                        MethodData: protoutil.NewAny(value),
27✔
721
                })
27✔
722
        }
27✔
723
        return &extensions_ssh.AllowResponse{
32✔
724
                Username: *sh.state.Username,
32✔
725
                Target: &extensions_ssh.AllowResponse_Upstream{
32✔
726
                        Upstream: &extensions_ssh.UpstreamTarget{
32✔
727
                                Hostname:       *sh.state.Hostname,
32✔
728
                                DirectTcpip:    sh.state.ChannelType == ChannelTypeDirectTcpip,
32✔
729
                                AllowedMethods: allowedMethods,
32✔
730
                        },
32✔
731
                },
32✔
732
        }
32✔
733
}
734

735
func (sh *StreamHandler) buildInternalAllowResponse() *extensions_ssh.AllowResponse {
63✔
736
        return &extensions_ssh.AllowResponse{
63✔
737
                Username: *sh.state.Username,
63✔
738
                Target: &extensions_ssh.AllowResponse_Internal{
63✔
739
                        Internal: &extensions_ssh.InternalTarget{
63✔
740
                                SetMetadata: &corev3.Metadata{
63✔
741
                                        TypedFilterMetadata: map[string]*anypb.Any{
63✔
742
                                                "com.pomerium.ssh": protoutil.NewAny(&extensions_ssh.FilterMetadata{
63✔
743
                                                        StreamId: sh.downstream.StreamId,
63✔
744
                                                }),
63✔
745
                                        },
63✔
746
                                },
63✔
747
                        },
63✔
748
                },
63✔
749
        }
63✔
750
}
63✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc