• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

fogfish / swarm / 11426008059

20 Oct 2024 11:40AM UTC coverage: 60.333% (-0.2%) from 60.498%
11426008059

push

github

web-flow
(fix): pass IOContext from WebSocket (#106)

22 of 42 new or added lines in 4 files covered. (52.38%)

1 existing line in 1 file now uncovered.

1016 of 1684 relevant lines covered (60.33%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/broker/websocket/lambda/auth/auth.go
1
//
2
// Copyright (C) 2021 - 2022 Dmitry Kolesnikov
3
//
4
// This file may be modified and distributed under the terms
5
// of the Apache License Version 2.0. See the LICENSE file for details.
6
// https://github.com/fogfish/swarm
7
//
8

9
package main
10

11
import (
12
        "context"
13
        "crypto/sha256"
14
        "crypto/subtle"
15
        "encoding/base64"
16
        "errors"
17
        "log/slog"
18
        "net/url"
19
        "os"
20
        "strings"
21
        "time"
22

23
        "github.com/auth0/go-jwt-middleware/v2/jwks"
24
        "github.com/auth0/go-jwt-middleware/v2/validator"
25
        "github.com/aws/aws-lambda-go/events"
26
        "github.com/aws/aws-lambda-go/lambda"
27
        _ "github.com/fogfish/logger/v3"
28
)
29

30
func main() {
×
31
        basic, err := NewAuthBasic()
×
32
        if err != nil {
×
33
                slog.Warn("Basic Auth disabled.")
×
34
                basic = nil
×
35
        }
×
36

37
        jwt, err := NewAuthJWT()
×
38
        if err != nil {
×
39
                slog.Warn("JWT Auth disabled.")
×
40
                jwt = nil
×
41
        }
×
42

43
        lambda.Start(
×
44
                func(evt events.APIGatewayV2CustomAuthorizerV1Request) (events.APIGatewayCustomAuthorizerResponse, error) {
×
45
                        tkn, has := evt.QueryStringParameters["apikey"]
×
46
                        if !has || len(tkn) == 0 {
×
47
                                return None, ErrForbidden
×
48
                        }
×
49

50
                        if jwt != nil && strings.HasPrefix(tkn, "ey") {
×
51
                                principal, context, err := jwt.Validate(tkn)
×
52
                                if err != nil {
×
53
                                        return None, ErrForbidden
×
54
                                }
×
55

56
                                return AccessPolicy(principal, evt.MethodArn, context), nil
×
57
                        }
58

59
                        if basic != nil {
×
NEW
60
                                scope := evt.QueryStringParameters["scope"]
×
NEW
61
                                principal, context, err := basic.Validate(tkn, scope)
×
62
                                if err != nil {
×
63
                                        return None, ErrForbidden
×
64
                                }
×
65

66
                                return AccessPolicy(principal, evt.MethodArn, context), nil
×
67
                        }
68

69
                        return None, ErrForbidden
×
70
                },
71
        )
72

73
}
74

75
var (
76
        None         = events.APIGatewayCustomAuthorizerResponse{}
77
        ErrForbidden = errors.New("forbidden")
78
)
79

80
//------------------------------------------------------------------------------
81

82
// Grant the access to WebSocket with the policy
83
func AccessPolicy(principal, method string, context map[string]any) events.APIGatewayCustomAuthorizerResponse {
×
84
        return events.APIGatewayCustomAuthorizerResponse{
×
85
                PrincipalID: principal,
×
86
                PolicyDocument: events.APIGatewayCustomAuthorizerPolicy{
×
87
                        Version: "2012-10-17",
×
88
                        Statement: []events.IAMPolicyStatement{
×
89
                                {
×
90
                                        Action:   []string{"execute-api:*"},
×
91
                                        Effect:   "Allow",
×
92
                                        Resource: []string{method},
×
93
                                },
×
94
                        },
×
95
                },
×
96
                Context: context,
×
97
        }
×
98
}
×
99

100
//------------------------------------------------------------------------------
101

102
type AuthBasic struct {
103
        access, secret string
104
        scope          []string
105
}
106

107
func NewAuthBasic() (*AuthBasic, error) {
×
108
        access := os.Getenv("CONFIG_SWARM_WS_AUTHORIZER_ACCESS")
×
109
        secret := os.Getenv("CONFIG_SWARM_WS_AUTHORIZER_SECRET")
×
NEW
110
        scope := os.Getenv("CONFIG_SWARM_WS_AUTHORIZER_SCOPE")
×
111

×
112
        if access == "" || secret == "" {
×
113
                return nil, errors.New("basic auth is not configured")
×
114
        }
×
115

116
        return &AuthBasic{
×
117
                access: access,
×
118
                secret: secret,
×
NEW
119
                scope:  strings.Split(scope, " "),
×
UNCOV
120
        }, nil
×
121
}
122

NEW
123
func (auth *AuthBasic) Validate(apikey, scope string) (string, map[string]any, error) {
×
124
        c, err := base64.RawStdEncoding.DecodeString(apikey)
×
125
        if err != nil {
×
126
                return "", nil, ErrForbidden
×
127
        }
×
128

129
        access, secret, ok := strings.Cut(string(c), ":")
×
130
        if !ok {
×
131
                return "", nil, ErrForbidden
×
132
        }
×
133

NEW
134
        seq, err := url.QueryUnescape(scope)
×
NEW
135
        if err != nil {
×
NEW
136
                return "", nil, ErrForbidden
×
NEW
137
        }
×
NEW
138
        for _, sid := range strings.Split(seq, " ") {
×
NEW
139
                has := false
×
NEW
140
                for _, allowed := range auth.scope {
×
NEW
141
                        if allowed == sid {
×
NEW
142
                                has = true
×
NEW
143
                        }
×
144
                }
NEW
145
                if !has {
×
NEW
146
                        return "", nil, ErrForbidden
×
NEW
147
                }
×
148
        }
149

150
        gaccess := sha256.Sum256([]byte(access))
×
151
        gsecret := sha256.Sum256([]byte(secret))
×
152
        haccess := sha256.Sum256([]byte(auth.access))
×
153
        hsecret := sha256.Sum256([]byte(auth.secret))
×
154

×
155
        accessMatch := (subtle.ConstantTimeCompare(gaccess[:], haccess[:]) == 1)
×
156
        secretMatch := (subtle.ConstantTimeCompare(gsecret[:], hsecret[:]) == 1)
×
157

×
158
        if accessMatch && secretMatch {
×
NEW
159
                return access, map[string]any{"auth": "basic", "sub": access, "scope": scope}, nil
×
160
        }
×
161

162
        return "", nil, ErrForbidden
×
163
}
164

165
//------------------------------------------------------------------------------
166

167
type AuthJWT struct {
168
        *validator.Validator
169
}
170

171
type Claims struct {
172
        Scope string `json:"scope"`
173
}
174

175
func (c Claims) Validate(ctx context.Context) error { return nil }
×
176

177
func NewAuthJWT() (*AuthJWT, error) {
×
178
        iss := os.Getenv("CONFIG_SWARM_WS_AUTHORIZER_ISS")
×
179
        aud := os.Getenv("CONFIG_SWARM_WS_AUTHORIZER_AUD")
×
180

×
181
        if iss == "" || aud == "" {
×
182
                return nil, errors.New("jwt auth is not configured")
×
183
        }
×
184

185
        issuer, err := url.Parse(iss)
×
186
        if err != nil {
×
187
                return nil, err
×
188
        }
×
189

190
        provider := jwks.NewCachingProvider(issuer, 5*time.Minute)
×
191

×
192
        auth, err := validator.New(
×
193
                provider.KeyFunc,
×
194
                validator.RS256,
×
195
                iss,
×
196
                []string{aud},
×
197
                validator.WithCustomClaims(func() validator.CustomClaims { return &Claims{} }),
×
198
                validator.WithAllowedClockSkew(time.Minute),
199
        )
200
        if err != nil {
×
201
                return nil, err
×
202
        }
×
203

204
        return &AuthJWT{Validator: auth}, nil
×
205
}
206

207
func (auth *AuthJWT) Validate(token string) (string, map[string]any, error) {
×
208
        claims, err := auth.ValidateToken(context.Background(), token)
×
209
        if err != nil {
×
210
                return "", nil, ErrForbidden
×
211
        }
×
212

213
        switch c := claims.(type) {
×
214
        case *validator.ValidatedClaims:
×
215
                ctx := map[string]any{
×
216
                        "iss": c.RegisteredClaims.Issuer,
×
217
                        "sub": c.RegisteredClaims.Subject,
×
218
                        // "aud": c.RegisteredClaims.Audience,
×
219
                        "exp":   c.RegisteredClaims.Expiry,
×
220
                        "nbf":   c.RegisteredClaims.NotBefore,
×
221
                        "iat":   c.RegisteredClaims.IssuedAt,
×
222
                        "scope": c.CustomClaims.(*Claims).Scope,
×
223
                        "auth":  "jwt",
×
224
                }
×
225

×
226
                return c.RegisteredClaims.Subject, ctx, nil
×
227
        default:
×
228
                return "", nil, ErrForbidden
×
229
        }
230
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc