• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

supabase / gotrue / 8337863587

19 Mar 2024 05:08AM UTC coverage: 64.957% (-0.3%) from 65.241%
8337863587

Pull #1474

github

J0
fix: apply suggestions
Pull Request #1474: feat: add custom sms hook

75 of 172 new or added lines in 12 files covered. (43.6%)

25 existing lines in 2 files now uncovered.

7993 of 12305 relevant lines covered (64.96%)

59.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

55.09
/internal/api/hooks.go
1
package api
2

3
import (
4
        "bytes"
5
        "context"
6
        "encoding/json"
7
        "fmt"
8
        "io"
9
        "net"
10
        "net/http"
11
        "strings"
12
        "time"
13

14
        "github.com/gofrs/uuid"
15
        "github.com/supabase/auth/internal/observability"
16

17
        "github.com/supabase/auth/internal/conf"
18
        "github.com/supabase/auth/internal/crypto"
19

20
        "github.com/sirupsen/logrus"
21
        "github.com/supabase/auth/internal/hooks"
22

23
        "github.com/supabase/auth/internal/storage"
24
)
25

26
const (
27
        DefaultHTTPHookTimeout  = 5 * time.Second
28
        DefaultHTTPHookRetries  = 3
29
        HTTPHookBackoffDuration = 2 * time.Second
30
        PayloadLimit            = 20 * 1024 // 20KB
31
)
32

33
func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) {
11✔
34
        db := a.db.WithContext(ctx)
11✔
35

11✔
36
        request, err := json.Marshal(input)
11✔
37
        if err != nil {
11✔
UNCOV
38
                panic(err)
×
39
        }
40

41
        var response []byte
11✔
42
        invokeHookFunc := func(tx *storage.Connection) error {
22✔
43
                // We rely on Postgres timeouts to ensure the function doesn't overrun
11✔
44
                if terr := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)).Exec(); terr != nil {
11✔
UNCOV
45
                        return terr
×
46
                }
×
47

48
                if terr := tx.RawQuery(fmt.Sprintf("select %s(?);", name), request).First(&response); terr != nil {
13✔
49
                        return terr
2✔
50
                }
2✔
51

52
                // reset the timeout
53
                if terr := tx.RawQuery("set local statement_timeout TO default;").Exec(); terr != nil {
9✔
UNCOV
54
                        return terr
×
55
                }
×
56

57
                return nil
9✔
58
        }
59

60
        if tx != nil {
16✔
61
                if err := invokeHookFunc(tx); err != nil {
5✔
UNCOV
62
                        return nil, err
×
63
                }
×
64
        } else {
6✔
65
                if err := db.Transaction(invokeHookFunc); err != nil {
8✔
66
                        return nil, err
2✔
67
                }
2✔
68
        }
69

70
        if err := json.Unmarshal(response, output); err != nil {
9✔
UNCOV
71
                return response, err
×
72
        }
×
73

74
        return response, nil
9✔
75
}
76

77
func readBodyWithLimit(rsp *http.Response) ([]byte, error) {
2✔
78
        defer rsp.Body.Close()
2✔
79

2✔
80
        limitedReader := io.LimitedReader{R: rsp.Body, N: PayloadLimit}
2✔
81

2✔
82
        body, err := io.ReadAll(&limitedReader)
2✔
83
        if err != nil {
2✔
NEW
84
                return nil, err
×
NEW
85
        }
×
86

87
        if limitedReader.N <= 0 {
2✔
NEW
88
                // Attempt to read one more byte to check if we're exactly at the limit or over
×
NEW
89
                _, err := rsp.Body.Read(make([]byte, 1))
×
NEW
90
                if err == nil {
×
NEW
91
                        // If we could read more, then the payload was too large
×
NEW
92
                        return nil, fmt.Errorf("payload too large")
×
NEW
93
                }
×
94
        }
95

96
        return body, nil
2✔
97
}
98

99
func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
3✔
100
        client := http.Client{
3✔
101
                Timeout: DefaultHTTPHookTimeout,
3✔
102
        }
3✔
103
        log := observability.GetLogEntry(r)
3✔
104
        requestURL := hookConfig.URI
3✔
105
        hookLog := log.WithFields(logrus.Fields{
3✔
106
                "component": "auth_hook",
3✔
107
                "url":       requestURL,
3✔
108
        })
3✔
109

3✔
110
        inputPayload, err := json.Marshal(input)
3✔
111
        if err != nil {
3✔
NEW
112
                return nil, err
×
NEW
113
        }
×
114
        start := time.Now()
3✔
115
        for i := 0; i < DefaultHTTPHookRetries; i++ {
7✔
116
                hookLog.Infof("invocation attempt: %d", i)
4✔
117
                if time.Since(start) > time.Duration(i+1)*DefaultHTTPHookTimeout {
4✔
NEW
118
                        return []byte{}, unprocessableEntityError(ErrorHookTimeout, "failed to reach hook within timeout")
×
NEW
119
                }
×
120
                msgID := uuid.Must(uuid.NewV4())
4✔
121
                currentTime := time.Now()
4✔
122
                signatureList, err := crypto.GenerateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload)
4✔
123
                if err != nil {
4✔
NEW
124
                        return nil, err
×
NEW
125
                }
×
126

127
                req, err := http.NewRequest(http.MethodPost, requestURL, bytes.NewBuffer(inputPayload))
4✔
128
                if err != nil {
4✔
NEW
129
                        panic("Failed to make requst object")
×
130
                }
131

132
                req.Header.Set("Content-Type", "application/json")
4✔
133
                req.Header.Set("webhook-id", msgID.String())
4✔
134
                req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix()))
4✔
135
                req.Header.Set("webhook-signature", strings.Join(signatureList, ", "))
4✔
136

4✔
137
                rsp, err := client.Do(req)
4✔
138
                if err != nil {
4✔
NEW
139
                        if terr, ok := err.(net.Error); ok && terr.Timeout() || i < DefaultHTTPHookRetries-1 {
×
NEW
140
                                hookLog.Errorf("Request timed out for attempt %d with err %s", i, err)
×
NEW
141
                                time.Sleep(HTTPHookBackoffDuration)
×
NEW
142
                                continue
×
NEW
143
                        } else if i == DefaultHTTPHookRetries-1 {
×
NEW
144
                                return nil, unprocessableEntityError(ErrorHookTimeout, "Failed to reach hook within allotted interval")
×
NEW
145
                        } else {
×
NEW
146
                                return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err)
×
NEW
147
                        }
×
148
                }
149

150
                switch rsp.StatusCode {
4✔
151
                case http.StatusOK, http.StatusNoContent, http.StatusAccepted:
2✔
152
                        if rsp.Body == nil {
2✔
NEW
153
                                return nil, nil
×
NEW
154
                        }
×
155
                        body, err := readBodyWithLimit(rsp)
2✔
156
                        if err != nil {
2✔
NEW
157
                                return nil, err
×
NEW
158
                        }
×
159
                        return body, nil
2✔
160
                case http.StatusTooManyRequests, http.StatusServiceUnavailable:
1✔
161
                        retryAfterHeader := rsp.Header.Get("retry-after")
1✔
162
                        // Check for truthy values to allow for flexibility to swtich to time duration
1✔
163
                        if retryAfterHeader != "" {
2✔
164
                                continue
1✔
165
                        }
NEW
166
                        return []byte{}, internalServerError("Service currently unavailable")
×
NEW
167
                case http.StatusBadRequest:
×
NEW
168
                        return nil, badRequestError(ErrorCodeValidationFailed, "Invalid payload sent to hook")
×
NEW
169
                case http.StatusUnauthorized:
×
NEW
170
                        return []byte{}, httpError(http.StatusUnauthorized, ErrorCodeNoAuthorization, "Hook requires authorizaition token")
×
171
                default:
1✔
172
                        return []byte{}, internalServerError("Error executing Hook")
1✔
173
                }
174
        }
NEW
175
        return nil, internalServerError("error executing hook")
×
176
}
177

NEW
178
func (a *API) invokeHTTPHook(r *http.Request, input, output any, hookURI string) error {
×
NEW
179
        switch input.(type) {
×
NEW
180
        case *hooks.CustomSMSProviderInput:
×
NEW
181
                hookOutput, ok := output.(*hooks.CustomSMSProviderOutput)
×
NEW
182
                if !ok {
×
NEW
183
                        panic("output should be *hooks.CustomSMSProviderOutput")
×
184
                }
NEW
185
                var response []byte
×
NEW
186
                var err error
×
NEW
187

×
NEW
188
                if response, err = a.runHTTPHook(r, a.config.Hook.CustomSMSProvider, input, output); err != nil {
×
NEW
189
                        return internalServerError("Error invoking custom SMS provider hook.").WithInternalError(err)
×
NEW
190
                }
×
NEW
191
                if err != nil {
×
NEW
192
                        return err
×
NEW
193
                }
×
194

NEW
195
                if err := json.Unmarshal(response, hookOutput); err != nil {
×
NEW
196
                        return internalServerError("Error unmarshaling custom SMS provider hook output.").WithInternalError(err)
×
NEW
197
                }
×
198

NEW
199
        default:
×
NEW
200
                panic("unknown HTTP hook type")
×
201
        }
NEW
202
        return nil
×
203
}
204

205
// invokePostgresHook invokes the hook code. tx can be nil, in which case a new
206
// transaction is opened. If calling invokeHook within a transaction, always
207
// pass the current transaction, as pool-exhaustion deadlocks are very easy to
208
// trigger.
209
func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection, input, output any, hookURI string) error {
11✔
210
        config := a.config
11✔
211
        // Switch based on hook type
11✔
212
        switch input.(type) {
11✔
213
        case *hooks.MFAVerificationAttemptInput:
4✔
214
                hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput)
4✔
215
                if !ok {
4✔
216
                        panic("output should be *hooks.MFAVerificationAttemptOutput")
×
217
                }
218

219
                if _, err := a.runPostgresHook(ctx, conn, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil {
6✔
220
                        return internalServerError("Error invoking MFA verification hook.").WithInternalError(err)
2✔
221
                }
2✔
222

223
                if hookOutput.IsError() {
2✔
224
                        httpCode := hookOutput.HookError.HTTPCode
×
225

×
UNCOV
226
                        if httpCode == 0 {
×
227
                                httpCode = http.StatusInternalServerError
×
228
                        }
×
229

230
                        httpError := &HTTPError{
×
231
                                HTTPStatus: httpCode,
×
232
                                Message:    hookOutput.HookError.Message,
×
UNCOV
233
                        }
×
UNCOV
234

×
UNCOV
235
                        return httpError.WithInternalError(&hookOutput.HookError)
×
236
                }
237

238
                return nil
2✔
239
        case *hooks.PasswordVerificationAttemptInput:
2✔
240
                hookOutput, ok := output.(*hooks.PasswordVerificationAttemptOutput)
2✔
241
                if !ok {
2✔
242
                        panic("output should be *hooks.PasswordVerificationAttemptOutput")
×
243
                }
244

245
                if _, err := a.runPostgresHook(ctx, conn, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil {
2✔
UNCOV
246
                        return internalServerError("Error invoking password verification hook.").WithInternalError(err)
×
247
                }
×
248

249
                if hookOutput.IsError() {
2✔
250
                        httpCode := hookOutput.HookError.HTTPCode
×
251

×
UNCOV
252
                        if httpCode == 0 {
×
253
                                httpCode = http.StatusInternalServerError
×
254
                        }
×
255

256
                        httpError := &HTTPError{
×
257
                                HTTPStatus: httpCode,
×
258
                                Message:    hookOutput.HookError.Message,
×
UNCOV
259
                        }
×
UNCOV
260

×
UNCOV
261
                        return httpError.WithInternalError(&hookOutput.HookError)
×
262
                }
263

264
                return nil
2✔
265
        case *hooks.CustomAccessTokenInput:
5✔
266
                hookOutput, ok := output.(*hooks.CustomAccessTokenOutput)
5✔
267
                if !ok {
5✔
268
                        panic("output should be *hooks.CustomAccessTokenOutput")
×
269
                }
270

271
                if _, err := a.runPostgresHook(ctx, conn, config.Hook.CustomAccessToken.HookName, input, output); err != nil {
5✔
UNCOV
272
                        return internalServerError("Error invoking access token hook.").WithInternalError(err)
×
UNCOV
273
                }
×
274

275
                if hookOutput.IsError() {
6✔
276
                        httpCode := hookOutput.HookError.HTTPCode
1✔
277

1✔
278
                        if httpCode == 0 {
1✔
UNCOV
279
                                httpCode = http.StatusInternalServerError
×
UNCOV
280
                        }
×
281

282
                        httpError := &HTTPError{
1✔
283
                                HTTPStatus: httpCode,
1✔
284
                                Message:    hookOutput.HookError.Message,
1✔
285
                        }
1✔
286

1✔
287
                        return httpError.WithInternalError(&hookOutput.HookError)
1✔
288
                }
289
                if err := validateTokenClaims(hookOutput.Claims); err != nil {
5✔
290
                        httpCode := hookOutput.HookError.HTTPCode
1✔
291

1✔
292
                        if httpCode == 0 {
2✔
293
                                httpCode = http.StatusInternalServerError
1✔
294
                        }
1✔
295

296
                        httpError := &HTTPError{
1✔
297
                                HTTPStatus: httpCode,
1✔
298
                                Message:    err.Error(),
1✔
299
                        }
1✔
300

1✔
301
                        return httpError
1✔
302
                }
303
                return nil
3✔
304

UNCOV
305
        default:
×
NEW
UNCOV
306
                panic("unknown Postgres hook input type")
×
307
        }
308
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc