• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

supabase / gotrue / 8135412125

04 Mar 2024 03:34AM UTC coverage: 65.142% (+0.1%) from 65.009%
8135412125

push

github

web-flow
fix: refactor request params to use generics (#1464)

## What kind of change does this PR introduce?
* Introduce a new method `retrieveRequestParams` which makes use of
generics to parse a request
* This will help to simplify parsing a request from:
```go

params := RequestParams{}
body, err := getBodyBytes(r)
if err != nil {
  return nil, badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, &params); err != nil {
  return nil, badRequestError("Could not decode request params: %v", err)
}
```
to 
```go
params := &Request{}
err := retrieveRequestParams(req, params)
```

## TODO
- [x] Add type constraint instead of using `any`

48 of 69 new or added lines in 19 files covered. (69.57%)

19 existing lines in 14 files now uncovered.

7806 of 11983 relevant lines covered (65.14%)

59.29 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

66.05
/internal/api/ssoadmin.go
1
package api
2

3
import (
4
        "context"
5
        "io"
6
        "net/http"
7
        "net/url"
8
        "unicode/utf8"
9

10
        "github.com/crewjam/saml"
11
        "github.com/crewjam/saml/samlsp"
12
        "github.com/go-chi/chi"
13
        "github.com/gofrs/uuid"
14
        "github.com/supabase/auth/internal/models"
15
        "github.com/supabase/auth/internal/observability"
16
        "github.com/supabase/auth/internal/storage"
17
        "github.com/supabase/auth/internal/utilities"
18
)
19

20
// loadSSOProvider looks for an idp_id parameter in the URL route and loads the SSO provider
21
// with that ID (or resource ID) and adds it to the context.
22
func (a *API) loadSSOProvider(w http.ResponseWriter, r *http.Request) (context.Context, error) {
13✔
23
        ctx := r.Context()
13✔
24
        db := a.db.WithContext(ctx)
13✔
25

13✔
26
        idpParam := chi.URLParam(r, "idp_id")
13✔
27

13✔
28
        idpID, err := uuid.FromString(idpParam)
13✔
29
        if err != nil {
14✔
30
                // idpParam is not UUIDv4
1✔
31
                return nil, notFoundError("SSO Identity Provider not found")
1✔
32
        }
1✔
33

34
        // idpParam is a UUIDv4
35
        provider, err := models.FindSSOProviderByID(db, idpID)
12✔
36
        if err != nil {
14✔
37
                if models.IsNotFoundError(err) {
4✔
38
                        return nil, notFoundError("SSO Identity Provider not found")
2✔
39
                } else {
2✔
40
                        return nil, internalServerError("Database error finding SSO Identity Provider").WithInternalError(err)
×
41
                }
×
42
        }
43

44
        observability.LogEntrySetField(r, "sso_provider_id", provider.ID.String())
10✔
45

10✔
46
        return withSSOProvider(r.Context(), provider), nil
10✔
47
}
48

49
// adminSSOProvidersList lists all SAML SSO Identity Providers in the system. Does
50
// not deal with pagination at this time.
51
func (a *API) adminSSOProvidersList(w http.ResponseWriter, r *http.Request) error {
6✔
52
        ctx := r.Context()
6✔
53
        db := a.db.WithContext(ctx)
6✔
54

6✔
55
        providers, err := models.FindAllSAMLProviders(db)
6✔
56
        if err != nil {
6✔
57
                return err
×
58
        }
×
59

60
        for i := range providers {
21✔
61
                // remove metadata XML so that the returned JSON is not ginormous
15✔
62
                providers[i].SAMLProvider.MetadataXML = ""
15✔
63
        }
15✔
64

65
        return sendJSON(w, http.StatusOK, map[string]interface{}{
6✔
66
                "items": providers,
6✔
67
        })
6✔
68
}
69

70
type CreateSSOProviderParams struct {
71
        Type string `json:"type"`
72

73
        MetadataURL      string                      `json:"metadata_url"`
74
        MetadataXML      string                      `json:"metadata_xml"`
75
        Domains          []string                    `json:"domains"`
76
        AttributeMapping models.SAMLAttributeMapping `json:"attribute_mapping"`
77
}
78

79
func (p *CreateSSOProviderParams) validate(forUpdate bool) error {
22✔
80
        if !forUpdate && p.Type != "saml" {
24✔
81
                return badRequestError("Only 'saml' supported for SSO provider type")
2✔
82
        } else if p.MetadataURL != "" && p.MetadataXML != "" {
23✔
83
                return badRequestError("Only one of metadata_xml or metadata_url needs to be set")
1✔
84
        } else if !forUpdate && p.MetadataURL == "" && p.MetadataXML == "" {
21✔
85
                return badRequestError("Either metadata_xml or metadata_url must be set")
1✔
86
        } else if p.MetadataURL != "" {
21✔
87
                metadataURL, err := url.ParseRequestURI(p.MetadataURL)
2✔
88
                if err != nil {
3✔
89
                        return badRequestError("metadata_url is not a valid URL")
1✔
90
                }
1✔
91

92
                if metadataURL.Scheme != "https" {
2✔
93
                        return badRequestError("metadata_url is not a HTTPS URL")
1✔
94
                }
1✔
95
        }
96

97
        // TODO validate p.AttributeMapping
98
        // TODO validate domains
99

100
        return nil
16✔
101
}
102

103
func (p *CreateSSOProviderParams) metadata(ctx context.Context) ([]byte, *saml.EntityDescriptor, error) {
13✔
104
        var rawMetadata []byte
13✔
105
        var err error
13✔
106

13✔
107
        if p.MetadataXML != "" {
26✔
108
                rawMetadata = []byte(p.MetadataXML)
13✔
109
        } else if p.MetadataURL != "" {
13✔
110
                rawMetadata, err = fetchSAMLMetadata(ctx, p.MetadataURL)
×
111
                if err != nil {
×
112
                        return nil, nil, err
×
113
                }
×
114
        } else {
×
115
                // impossible situation if you called validate() prior
×
116
                return nil, nil, nil
×
117
        }
×
118

119
        metadata, err := parseSAMLMetadata(rawMetadata)
13✔
120
        if err != nil {
13✔
121
                return nil, nil, err
×
122
        }
×
123

124
        return rawMetadata, metadata, nil
13✔
125
}
126

127
func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) {
16✔
128
        if !utf8.Valid(rawMetadata) {
16✔
129
                return nil, badRequestError("SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time")
×
130
        }
×
131

132
        metadata, err := samlsp.ParseMetadata(rawMetadata)
16✔
133
        if err != nil {
16✔
134
                return nil, err
×
135
        }
×
136

137
        if metadata.EntityID == "" {
16✔
138
                return nil, badRequestError("SAML Metadata does not contain an EntityID")
×
139
        }
×
140

141
        if len(metadata.IDPSSODescriptors) < 1 {
16✔
142
                return nil, badRequestError("SAML Metadata does not contain any IDPSSODescriptor")
×
143
        }
×
144

145
        if len(metadata.IDPSSODescriptors) > 1 {
16✔
146
                return nil, badRequestError("SAML Metadata contains multiple IDPSSODescriptors")
×
147
        }
×
148

149
        return metadata, nil
16✔
150
}
151

152
func fetchSAMLMetadata(ctx context.Context, url string) ([]byte, error) {
×
153
        req, err := http.NewRequest(http.MethodGet, url, nil)
×
154
        if err != nil {
×
155
                return nil, badRequestError("Unable to create a request to metadata_url").WithInternalError(err)
×
156
        }
×
157

158
        req = req.WithContext(ctx)
×
159

×
160
        req.Header.Set("Accept", "application/xml;charset=UTF-8")
×
161
        req.Header.Set("Accept-Charset", "UTF-8")
×
162

×
163
        resp, err := http.DefaultClient.Do(req)
×
164
        if err != nil {
×
165
                return nil, err
×
166
        }
×
167

168
        defer utilities.SafeClose(resp.Body)
×
169
        if resp.StatusCode != http.StatusOK {
×
170
                return nil, badRequestError("HTTP %v error fetching SAML Metadata from URL '%s'", resp.StatusCode, url)
×
171
        }
×
172

173
        data, err := io.ReadAll(resp.Body)
×
174
        if err != nil {
×
175
                return nil, err
×
176
        }
×
177

178
        return data, nil
×
179
}
180

181
// adminSSOProvidersCreate creates a new SAML Identity Provider in the system.
182
func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) error {
18✔
183
        ctx := r.Context()
18✔
184
        db := a.db.WithContext(ctx)
18✔
185

18✔
186
        params := &CreateSSOProviderParams{}
18✔
187
        if err := retrieveRequestParams(r, params); err != nil {
18✔
NEW
188
                return err
×
UNCOV
189
        }
×
190

191
        if err := params.validate(false /* <- forUpdate */); err != nil {
24✔
192
                return err
6✔
193
        }
6✔
194

195
        rawMetadata, metadata, err := params.metadata(ctx)
12✔
196
        if err != nil {
12✔
197
                return err
×
198
        }
×
199

200
        existingProvider, err := models.FindSAMLProviderByEntityID(db, metadata.EntityID)
12✔
201
        if err != nil && !models.IsNotFoundError(err) {
12✔
202
                return err
×
203
        }
×
204
        if existingProvider != nil {
13✔
205
                return badRequestError("SAML Identity Provider with this EntityID (%s) already exists", metadata.EntityID)
1✔
206
        }
1✔
207

208
        provider := &models.SSOProvider{
11✔
209
                // TODO handle Name, Description, Attribute Mapping
11✔
210
                SAMLProvider: models.SAMLProvider{
11✔
211
                        EntityID:    metadata.EntityID,
11✔
212
                        MetadataXML: string(rawMetadata),
11✔
213
                },
11✔
214
        }
11✔
215

11✔
216
        if params.MetadataURL != "" {
11✔
217
                provider.SAMLProvider.MetadataURL = &params.MetadataURL
×
218
        }
×
219

220
        provider.SAMLProvider.AttributeMapping = params.AttributeMapping
11✔
221

11✔
222
        for _, domain := range params.Domains {
15✔
223
                existingProvider, err := models.FindSSOProviderByDomain(db, domain)
4✔
224
                if err != nil && !models.IsNotFoundError(err) {
4✔
225
                        return err
×
226
                }
×
227
                if existingProvider != nil {
5✔
228
                        return badRequestError("SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String())
1✔
229
                }
1✔
230

231
                provider.SSODomains = append(provider.SSODomains, models.SSODomain{
3✔
232
                        Domain: domain,
3✔
233
                })
3✔
234
        }
235

236
        if err := db.Transaction(func(tx *storage.Connection) error {
20✔
237
                if terr := tx.Eager().Create(provider); terr != nil {
10✔
238
                        return terr
×
239
                }
×
240

241
                return tx.Eager().Load(provider)
10✔
242
        }); err != nil {
×
243
                return err
×
244
        }
×
245

246
        return sendJSON(w, http.StatusCreated, provider)
10✔
247
}
248

249
// adminSSOProvidersGet returns an existing SAML Identity Provider in the system.
250
func (a *API) adminSSOProvidersGet(w http.ResponseWriter, r *http.Request) error {
5✔
251
        provider := getSSOProvider(r.Context())
5✔
252

5✔
253
        return sendJSON(w, http.StatusOK, provider)
5✔
254
}
5✔
255

256
// adminSSOProvidersUpdate updates a provider with the provided diff values.
257
func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) error {
4✔
258
        ctx := r.Context()
4✔
259
        db := a.db.WithContext(ctx)
4✔
260

4✔
261
        params := &CreateSSOProviderParams{}
4✔
262
        if err := retrieveRequestParams(r, params); err != nil {
4✔
NEW
263
                return err
×
UNCOV
264
        }
×
265

266
        if err := params.validate(true /* <- forUpdate */); err != nil {
4✔
267
                return err
×
268
        }
×
269

270
        modified := false
4✔
271
        updateSAMLProvider := false
4✔
272

4✔
273
        provider := getSSOProvider(ctx)
4✔
274

4✔
275
        if params.MetadataXML != "" || params.MetadataURL != "" {
5✔
276
                // metadata is being updated
1✔
277
                rawMetadata, metadata, err := params.metadata(ctx)
1✔
278
                if err != nil {
1✔
279
                        return err
×
280
                }
×
281

282
                if provider.SAMLProvider.EntityID != metadata.EntityID {
2✔
283
                        return badRequestError("SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID)
1✔
284
                }
1✔
285

286
                if params.MetadataURL != "" {
×
287
                        provider.SAMLProvider.MetadataURL = &params.MetadataURL
×
288
                }
×
289

290
                provider.SAMLProvider.MetadataXML = string(rawMetadata)
×
291
                updateSAMLProvider = true
×
292
                modified = true
×
293
        }
294

295
        // domains are being "updated" only when params.Domains is not nil, if
296
        // it was nil (but not `[]`) then the caller is expecting not to modify
297
        // the domains
298
        updateDomains := params.Domains != nil
3✔
299

3✔
300
        var createDomains, deleteDomains []models.SSODomain
3✔
301
        keepDomains := make(map[string]bool)
3✔
302

3✔
303
        for _, domain := range params.Domains {
6✔
304
                existingProvider, err := models.FindSSOProviderByDomain(db, domain)
3✔
305
                if err != nil && !models.IsNotFoundError(err) {
3✔
306
                        return err
×
307
                }
×
308
                if existingProvider != nil {
5✔
309
                        if existingProvider.ID == provider.ID {
3✔
310
                                keepDomains[domain] = true
1✔
311
                        } else {
2✔
312
                                return badRequestError("SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String())
1✔
313
                        }
1✔
314
                } else {
1✔
315
                        modified = true
1✔
316
                        createDomains = append(createDomains, models.SSODomain{
1✔
317
                                Domain:        domain,
1✔
318
                                SSOProviderID: provider.ID,
1✔
319
                        })
1✔
320
                }
1✔
321
        }
322

323
        if updateDomains {
3✔
324
                for i, domain := range provider.SSODomains {
2✔
325
                        if !keepDomains[domain.Domain] {
1✔
326
                                modified = true
×
327
                                deleteDomains = append(deleteDomains, provider.SSODomains[i])
×
328
                        }
×
329
                }
330
        }
331

332
        updateAttributeMapping := !provider.SAMLProvider.AttributeMapping.Equal(&params.AttributeMapping)
2✔
333
        if updateAttributeMapping {
3✔
334
                modified = true
1✔
335
                provider.SAMLProvider.AttributeMapping = params.AttributeMapping
1✔
336
        }
1✔
337

338
        if modified {
4✔
339
                if err := db.Transaction(func(tx *storage.Connection) error {
4✔
340
                        if terr := tx.Eager().Update(provider); terr != nil {
2✔
341
                                return terr
×
342
                        }
×
343

344
                        if updateDomains {
3✔
345
                                if terr := tx.Destroy(deleteDomains); terr != nil {
1✔
346
                                        return terr
×
347
                                }
×
348

349
                                if terr := tx.Eager().Create(createDomains); terr != nil {
1✔
350
                                        return terr
×
351
                                }
×
352
                        }
353

354
                        if updateAttributeMapping || updateSAMLProvider {
3✔
355
                                if terr := tx.Eager().Update(&provider.SAMLProvider); terr != nil {
1✔
356
                                        return terr
×
357
                                }
×
358
                        }
359

360
                        return tx.Eager().Load(provider)
2✔
361
                }); err != nil {
×
362
                        return unprocessableEntityError("Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err)
×
363
                }
×
364
        }
365

366
        return sendJSON(w, http.StatusOK, provider)
2✔
367
}
368

369
// adminSSOProvidersDelete deletes a SAML identity provider.
370
func (a *API) adminSSOProvidersDelete(w http.ResponseWriter, r *http.Request) error {
1✔
371
        ctx := r.Context()
1✔
372
        db := a.db.WithContext(ctx)
1✔
373

1✔
374
        provider := getSSOProvider(ctx)
1✔
375

1✔
376
        if err := db.Transaction(func(tx *storage.Connection) error {
2✔
377
                return tx.Eager().Destroy(provider)
1✔
378
        }); err != nil {
1✔
379
                return err
×
380
        }
×
381

382
        return sendJSON(w, http.StatusOK, provider)
1✔
383
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc