• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

supabase / gotrue / 8135412125

04 Mar 2024 03:34AM UTC coverage: 65.142% (+0.1%) from 65.009%
8135412125

push

github

web-flow
fix: refactor request params to use generics (#1464)

## What kind of change does this PR introduce?
* Introduce a new method `retrieveRequestParams` which makes use of
generics to parse a request
* This will help to simplify parsing a request from:
```go

params := RequestParams{}
body, err := getBodyBytes(r)
if err != nil {
  return nil, badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, &params); err != nil {
  return nil, badRequestError("Could not decode request params: %v", err)
}
```
to 
```go
params := &Request{}
err := retrieveRequestParams(req, params)
```

## TODO
- [x] Add type constraint instead of using `any`

48 of 69 new or added lines in 19 files covered. (69.57%)

19 existing lines in 14 files now uncovered.

7806 of 11983 relevant lines covered (65.14%)

59.29 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/internal/api/token_oidc.go
1
package api
2

3
import (
4
        "context"
5
        "crypto/sha256"
6
        "fmt"
7
        "net/http"
8

9
        "github.com/coreos/go-oidc/v3/oidc"
10
        "github.com/supabase/auth/internal/api/provider"
11
        "github.com/supabase/auth/internal/conf"
12
        "github.com/supabase/auth/internal/models"
13
        "github.com/supabase/auth/internal/observability"
14
        "github.com/supabase/auth/internal/storage"
15
)
16

17
// IdTokenGrantParams are the parameters the IdTokenGrant method accepts
18
type IdTokenGrantParams struct {
19
        IdToken     string `json:"id_token"`
20
        AccessToken string `json:"access_token"`
21
        Nonce       string `json:"nonce"`
22
        Provider    string `json:"provider"`
23
        ClientID    string `json:"client_id"`
24
        Issuer      string `json:"issuer"`
25
}
26

27
func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.GlobalConfiguration, r *http.Request) (*oidc.Provider, *conf.OAuthProviderConfiguration, string, []string, error) {
×
28
        log := observability.GetLogEntry(r)
×
29

×
30
        var cfg *conf.OAuthProviderConfiguration
×
31
        var issuer string
×
32
        var providerType string
×
33
        var acceptableClientIDs []string
×
34

×
35
        switch true {
×
36
        case p.Provider == "apple" || p.Issuer == provider.IssuerApple:
×
37
                cfg = &config.External.Apple
×
38
                providerType = "apple"
×
39
                issuer = provider.IssuerApple
×
40
                acceptableClientIDs = append(acceptableClientIDs, config.External.Apple.ClientID...)
×
41

×
42
                if config.External.IosBundleId != "" {
×
43
                        acceptableClientIDs = append(acceptableClientIDs, config.External.IosBundleId)
×
44
                }
×
45

46
        case p.Provider == "google" || p.Issuer == provider.IssuerGoogle:
×
47
                cfg = &config.External.Google
×
48
                providerType = "google"
×
49
                issuer = provider.IssuerGoogle
×
50
                acceptableClientIDs = append(acceptableClientIDs, config.External.Google.ClientID...)
×
51

52
        case p.Provider == "azure" || provider.IsAzureIssuer(p.Issuer):
×
53
                issuer = p.Issuer
×
54
                if issuer == "" || !provider.IsAzureIssuer(issuer) {
×
55
                        detectedIssuer, err := provider.DetectAzureIDTokenIssuer(ctx, p.IdToken)
×
56
                        if err != nil {
×
57
                                return nil, nil, "", nil, badRequestError("Unable to detect issuer in ID token for Azure provider").WithInternalError(err)
×
58
                        }
×
59
                        issuer = detectedIssuer
×
60
                }
61
                cfg = &config.External.Azure
×
62
                providerType = "azure"
×
63
                acceptableClientIDs = append(acceptableClientIDs, config.External.Azure.ClientID...)
×
64

65
        case p.Provider == "facebook" || p.Issuer == provider.IssuerFacebook:
×
66
                cfg = &config.External.Facebook
×
67
                providerType = "facebook"
×
68
                issuer = provider.IssuerFacebook
×
69
                acceptableClientIDs = append(acceptableClientIDs, config.External.Facebook.ClientID...)
×
70

71
        case p.Provider == "keycloak" || (config.External.Keycloak.Enabled && config.External.Keycloak.URL != "" && p.Issuer == config.External.Keycloak.URL):
×
72
                cfg = &config.External.Keycloak
×
73
                providerType = "keycloak"
×
74
                issuer = config.External.Keycloak.URL
×
75
                acceptableClientIDs = append(acceptableClientIDs, config.External.Keycloak.ClientID...)
×
76

77
        default:
×
78
                log.WithField("issuer", p.Issuer).WithField("client_id", p.ClientID).Warn("Use of POST /token with arbitrary issuer and client_id is deprecated for security reasons. Please switch to using the API with provider only!")
×
79

×
80
                allowed := false
×
81
                for _, allowedIssuer := range config.External.AllowedIdTokenIssuers {
×
82
                        if p.Issuer == allowedIssuer {
×
83
                                allowed = true
×
84
                                providerType = allowedIssuer
×
85
                                acceptableClientIDs = []string{p.ClientID}
×
86
                                issuer = allowedIssuer
×
87
                                break
×
88
                        }
89
                }
90

91
                if !allowed {
×
92
                        return nil, nil, "", nil, badRequestError(fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider))
×
93
                }
×
94
        }
95

96
        if cfg != nil && !cfg.Enabled {
×
97
                return nil, nil, "", nil, badRequestError(fmt.Sprintf("Provider (issuer %q) is not enabled", issuer))
×
98
        }
×
99

100
        oidcProvider, err := oidc.NewProvider(ctx, issuer)
×
101
        if err != nil {
×
102
                return nil, nil, "", nil, err
×
103
        }
×
104

105
        return oidcProvider, cfg, providerType, acceptableClientIDs, nil
×
106
}
107

108
// IdTokenGrant implements the id_token grant type flow
109
func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
×
110
        log := observability.GetLogEntry(r)
×
111

×
112
        db := a.db.WithContext(ctx)
×
113
        config := a.config
×
114

×
115
        params := &IdTokenGrantParams{}
×
NEW
116
        if err := retrieveRequestParams(r, params); err != nil {
×
NEW
117
                return err
×
UNCOV
118
        }
×
119

120
        if params.IdToken == "" {
×
121
                return oauthError("invalid request", "id_token required")
×
122
        }
×
123

124
        if params.Provider == "" && (params.ClientID == "" || params.Issuer == "") {
×
125
                return oauthError("invalid request", "provider or client_id and issuer required")
×
126
        }
×
127

128
        oidcProvider, oauthConfig, providerType, acceptableClientIDs, err := params.getProvider(ctx, config, r)
×
129
        if err != nil {
×
130
                return err
×
131
        }
×
132

133
        idToken, userData, err := provider.ParseIDToken(ctx, oidcProvider, nil, params.IdToken, provider.ParseIDTokenOptions{
×
134
                SkipAccessTokenCheck: params.AccessToken == "",
×
135
                AccessToken:          params.AccessToken,
×
136
        })
×
137
        if err != nil {
×
138
                return oauthError("invalid request", "Bad ID token").WithInternalError(err)
×
139
        }
×
140

141
        userData.Metadata.EmailVerified = false
×
142
        for _, email := range userData.Emails {
×
143
                if email.Primary {
×
144
                        userData.Metadata.Email = email.Email
×
145
                        userData.Metadata.EmailVerified = email.Verified
×
146
                        break
×
147
                } else {
×
148
                        userData.Metadata.Email = email.Email
×
149
                        userData.Metadata.EmailVerified = email.Verified
×
150
                }
×
151
        }
152

153
        if idToken.Subject == "" {
×
154
                return oauthError("invalid request", "Missing sub claim in id_token")
×
155
        }
×
156

157
        correctAudience := false
×
158
        for _, clientID := range acceptableClientIDs {
×
159
                if clientID == "" {
×
160
                        continue
×
161
                }
162

163
                for _, aud := range idToken.Audience {
×
164
                        if aud == clientID {
×
165
                                correctAudience = true
×
166
                                break
×
167
                        }
168
                }
169

170
                if correctAudience {
×
171
                        break
×
172
                }
173
        }
174

175
        if !correctAudience {
×
176
                return oauthError("invalid request", "Unacceptable audience in id_token")
×
177
        }
×
178

179
        if !oauthConfig.SkipNonceCheck {
×
180
                tokenHasNonce := idToken.Nonce != ""
×
181
                paramsHasNonce := params.Nonce != ""
×
182

×
183
                if tokenHasNonce != paramsHasNonce {
×
184
                        return oauthError("invalid request", "Passed nonce and nonce in id_token should either both exist or not.")
×
185
                } else if tokenHasNonce && paramsHasNonce {
×
186
                        // verify nonce to mitigate replay attacks
×
187
                        hash := fmt.Sprintf("%x", sha256.Sum256([]byte(params.Nonce)))
×
188
                        if hash != idToken.Nonce {
×
189
                                return oauthError("invalid nonce", "Nonces mismatch")
×
190
                        }
×
191
                }
192
        }
193

194
        if params.AccessToken == "" {
×
195
                if idToken.AccessTokenHash != "" {
×
196
                        log.Warn("ID token has a at_hash claim, but no access_token parameter was provided. In future versions, access_token will be mandatory as it's security best practice.")
×
197
                }
×
198
        } else {
×
199
                if idToken.AccessTokenHash == "" {
×
200
                        log.Info("ID token does not have a at_hash claim, access_token parameter is unused.")
×
201
                }
×
202
        }
203

204
        var token *AccessTokenResponse
×
205
        var grantParams models.GrantParams
×
206

×
207
        grantParams.FillGrantParams(r)
×
208

×
209
        if err := db.Transaction(func(tx *storage.Connection) error {
×
210
                var user *models.User
×
211
                var terr error
×
212

×
213
                user, terr = a.createAccountFromExternalIdentity(tx, r, userData, providerType)
×
214
                if terr != nil {
×
215
                        return terr
×
216
                }
×
217

218
                token, terr = a.issueRefreshToken(ctx, tx, user, models.OAuth, grantParams)
×
219
                if terr != nil {
×
220
                        return terr
×
221
                }
×
222

223
                return nil
×
224
        }); err != nil {
×
225
                switch err.(type) {
×
226
                case *storage.CommitWithError:
×
227
                        return err
×
228
                default:
×
229
                        return oauthError("server_error", "Internal Server Error").WithInternalError(err)
×
230
                }
231
        }
232

233
        return sendJSON(w, http.StatusOK, token)
×
234
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc