• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

supabase / gotrue / 8135412125

04 Mar 2024 03:34AM UTC coverage: 65.142% (+0.1%) from 65.009%
8135412125

push

github

web-flow
fix: refactor request params to use generics (#1464)

## What kind of change does this PR introduce?
* Introduce a new method `retrieveRequestParams` which makes use of
generics to parse a request
* This will help to simplify parsing a request from:
```go

params := RequestParams{}
body, err := getBodyBytes(r)
if err != nil {
  return nil, badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, &params); err != nil {
  return nil, badRequestError("Could not decode request params: %v", err)
}
```
to 
```go
params := &Request{}
err := retrieveRequestParams(req, params)
```

## TODO
- [x] Add type constraint instead of using `any`

48 of 69 new or added lines in 19 files covered. (69.57%)

19 existing lines in 14 files now uncovered.

7806 of 11983 relevant lines covered (65.14%)

59.29 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

71.96
/internal/api/sso.go
1
package api
2

3
import (
4
        "net/http"
5

6
        "github.com/crewjam/saml"
7
        "github.com/gofrs/uuid"
8
        "github.com/supabase/auth/internal/models"
9
        "github.com/supabase/auth/internal/storage"
10
)
11

12
type SingleSignOnParams struct {
13
        ProviderID          uuid.UUID `json:"provider_id"`
14
        Domain              string    `json:"domain"`
15
        RedirectTo          string    `json:"redirect_to"`
16
        SkipHTTPRedirect    *bool     `json:"skip_http_redirect"`
17
        CodeChallenge       string    `json:"code_challenge"`
18
        CodeChallengeMethod string    `json:"code_challenge_method"`
19
}
20

21
type SingleSignOnResponse struct {
22
        URL string `json:"url"`
23
}
24

25
func (p *SingleSignOnParams) validate() (bool, error) {
6✔
26
        hasProviderID := p.ProviderID != uuid.Nil
6✔
27
        hasDomain := p.Domain != ""
6✔
28

6✔
29
        if hasProviderID && hasDomain {
6✔
30
                return hasProviderID, badRequestError("Only one of provider_id or domain supported")
×
31
        } else if !hasProviderID && !hasDomain {
6✔
32
                return hasProviderID, badRequestError("A provider_id or domain needs to be provided")
×
33
        }
×
34

35
        return hasProviderID, nil
6✔
36
}
37

38
// SingleSignOn handles the single-sign-on flow for a provided SSO domain or provider.
39
func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error {
6✔
40
        ctx := r.Context()
6✔
41
        db := a.db.WithContext(ctx)
6✔
42

6✔
43
        params := &SingleSignOnParams{}
6✔
44
        if err := retrieveRequestParams(r, params); err != nil {
6✔
NEW
45
                return err
×
UNCOV
46
        }
×
47

48
        var err error
6✔
49
        hasProviderID := false
6✔
50

6✔
51
        if hasProviderID, err = params.validate(); err != nil {
6✔
52
                return err
×
53
        }
×
54
        codeChallengeMethod := params.CodeChallengeMethod
6✔
55
        codeChallenge := params.CodeChallenge
6✔
56

6✔
57
        if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil {
6✔
58
                return err
×
59
        }
×
60
        flowType := getFlowFromChallenge(params.CodeChallenge)
6✔
61
        var flowStateID *uuid.UUID
6✔
62
        flowStateID = nil
6✔
63
        if flowType == models.PKCEFlow {
7✔
64
                codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod)
1✔
65
                if err != nil {
1✔
66
                        return err
×
67
                }
×
68
                flowState, err := models.NewFlowState(models.SSOSAML.String(), codeChallenge, codeChallengeMethodType, models.SSOSAML)
1✔
69
                if err != nil {
1✔
70
                        return err
×
71
                }
×
72
                if err := a.db.Create(flowState); err != nil {
1✔
73
                        return err
×
74
                }
×
75
                flowStateID = &flowState.ID
1✔
76
        }
77

78
        var ssoProvider *models.SSOProvider
6✔
79

6✔
80
        if hasProviderID {
9✔
81
                ssoProvider, err = models.FindSSOProviderByID(db, params.ProviderID)
3✔
82
                if models.IsNotFoundError(err) {
4✔
83
                        return notFoundError("No such SSO provider")
1✔
84
                } else if err != nil {
3✔
85
                        return internalServerError("Unable to find SSO provider by ID").WithInternalError(err)
×
86
                }
×
87
        } else {
3✔
88
                ssoProvider, err = models.FindSSOProviderByDomain(db, params.Domain)
3✔
89
                if models.IsNotFoundError(err) {
4✔
90
                        return notFoundError("No SSO provider assigned for this domain")
1✔
91
                } else if err != nil {
3✔
92
                        return internalServerError("Unable to find SSO provider by domain").WithInternalError(err)
×
93
                }
×
94
        }
95

96
        entityDescriptor, err := ssoProvider.SAMLProvider.EntityDescriptor()
4✔
97
        if err != nil {
4✔
98
                return internalServerError("Error parsing SAML Metadata for SAML provider").WithInternalError(err)
×
99
        }
×
100

101
        // TODO: fetch new metadata if validUntil < time.Now()
102

103
        serviceProvider := a.getSAMLServiceProvider(entityDescriptor, false /* <- idpInitiated */)
4✔
104

4✔
105
        authnRequest, err := serviceProvider.MakeAuthenticationRequest(
4✔
106
                serviceProvider.GetSSOBindingLocation(saml.HTTPRedirectBinding),
4✔
107
                saml.HTTPRedirectBinding,
4✔
108
                saml.HTTPPostBinding,
4✔
109
        )
4✔
110
        if err != nil {
4✔
111
                return internalServerError("Error creating SAML Authentication Request").WithInternalError(err)
×
112
        }
×
113

114
        relayState := models.SAMLRelayState{
4✔
115
                SSOProviderID: ssoProvider.ID,
4✔
116
                RequestID:     authnRequest.ID,
4✔
117
                RedirectTo:    params.RedirectTo,
4✔
118
                FlowStateID:   flowStateID,
4✔
119
        }
4✔
120

4✔
121
        if err := db.Transaction(func(tx *storage.Connection) error {
8✔
122
                if terr := tx.Create(&relayState); terr != nil {
4✔
123
                        return internalServerError("Error creating SAML relay state from sign up").WithInternalError(err)
×
124
                }
×
125

126
                return nil
4✔
127
        }); err != nil {
×
128
                return err
×
129
        }
×
130

131
        ssoRedirectURL, err := authnRequest.Redirect(relayState.ID.String(), serviceProvider)
4✔
132
        if err != nil {
4✔
133
                return internalServerError("Error creating SAML authentication request redirect URL").WithInternalError(err)
×
134
        }
×
135

136
        skipHTTPRedirect := false
4✔
137

4✔
138
        if params.SkipHTTPRedirect != nil {
5✔
139
                skipHTTPRedirect = *params.SkipHTTPRedirect
1✔
140
        }
1✔
141

142
        if skipHTTPRedirect {
5✔
143
                return sendJSON(w, http.StatusOK, SingleSignOnResponse{
1✔
144
                        URL: ssoRedirectURL.String(),
1✔
145
                })
1✔
146
        }
1✔
147

148
        http.Redirect(w, r, ssoRedirectURL.String(), http.StatusSeeOther)
3✔
149
        return nil
3✔
150
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc