• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

supabase / gotrue / 8316299588

17 Mar 2024 02:55PM UTC coverage: 64.923% (-0.3%) from 65.241%
8316299588

Pull #1474

github

J0
fix: remove unneeded if check
Pull Request #1474: feat: add custom sms hook

87 of 197 new or added lines in 13 files covered. (44.16%)

72 existing lines in 3 files now uncovered.

8005 of 12330 relevant lines covered (64.92%)

59.5 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

55.04
/internal/api/hooks.go
1
package api
2

3
import (
4
        "bytes"
5
        "context"
6
        "encoding/json"
7
        "fmt"
8
        "io"
9
        "net"
10
        "net/http"
11
        "net/http/httptrace"
12
        "strings"
13
        "time"
14

15
        "github.com/gofrs/uuid"
16
        "github.com/supabase/auth/internal/observability"
17

18
        "github.com/supabase/auth/internal/conf"
19
        "github.com/supabase/auth/internal/crypto"
20

21
        "github.com/sirupsen/logrus"
22
        "github.com/supabase/auth/internal/hooks"
23

24
        "github.com/supabase/auth/internal/storage"
25
)
26

27
const (
28
        DefaultHTTPHookTimeout  = 5 * time.Second
29
        DefaultHTTPHookRetries  = 3
30
        HTTPHookBackoffDuration = 2 * time.Second
31
)
32

33
func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) {
11✔
34
        db := a.db.WithContext(ctx)
11✔
35

11✔
36
        request, err := json.Marshal(input)
11✔
37
        if err != nil {
11✔
UNCOV
38
                panic(err)
×
39
        }
40

41
        var response []byte
11✔
42
        invokeHookFunc := func(tx *storage.Connection) error {
22✔
43
                // We rely on Postgres timeouts to ensure the function doesn't overrun
11✔
44
                if terr := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)).Exec(); terr != nil {
11✔
UNCOV
45
                        return terr
×
46
                }
×
47

48
                if terr := tx.RawQuery(fmt.Sprintf("select %s(?);", name), request).First(&response); terr != nil {
13✔
49
                        return terr
2✔
50
                }
2✔
51

52
                // reset the timeout
53
                if terr := tx.RawQuery("set local statement_timeout TO default;").Exec(); terr != nil {
9✔
UNCOV
54
                        return terr
×
55
                }
×
56

57
                return nil
9✔
58
        }
59

60
        if tx != nil {
16✔
61
                if err := invokeHookFunc(tx); err != nil {
5✔
UNCOV
62
                        return nil, err
×
63
                }
×
64
        } else {
6✔
65
                if err := db.Transaction(invokeHookFunc); err != nil {
8✔
66
                        return nil, err
2✔
67
                }
2✔
68
        }
69

70
        if err := json.Unmarshal(response, output); err != nil {
9✔
UNCOV
71
                return response, err
×
72
        }
×
73

74
        return response, nil
9✔
75
}
76

77
func readBodyWithLimit(rsp *http.Response) ([]byte, error) {
2✔
78
        defer rsp.Body.Close()
2✔
79

2✔
80
        const limit = 20 * 1024 // 20KB
2✔
81
        limitedReader := io.LimitedReader{R: rsp.Body, N: limit}
2✔
82

2✔
83
        body, err := io.ReadAll(&limitedReader)
2✔
84
        if err != nil {
2✔
NEW
85
                return nil, err
×
NEW
86
        }
×
87

88
        if limitedReader.N <= 0 {
2✔
NEW
89
                // Attempt to read one more byte to check if we're exactly at the limit or over
×
NEW
90
                _, err := rsp.Body.Read(make([]byte, 1))
×
NEW
91
                if err == nil {
×
NEW
92
                        // If we could read more, then the payload was too large
×
NEW
93
                        return nil, fmt.Errorf("payload too large")
×
NEW
94
                }
×
95
        }
96

97
        return body, nil
2✔
98
}
99

100
func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
3✔
101
        client := http.Client{
3✔
102
                Timeout: DefaultHTTPHookTimeout,
3✔
103
        }
3✔
104
        log := observability.GetLogEntry(r)
3✔
105
        requestURL := hookConfig.URI
3✔
106
        hookLog := log.WithFields(logrus.Fields{
3✔
107
                "component": "auth_hook",
3✔
108
                "url":       requestURL,
3✔
109
        })
3✔
110

3✔
111
        inputPayload, err := json.Marshal(input)
3✔
112
        if err != nil {
3✔
NEW
113
                return nil, err
×
NEW
114
        }
×
115
        start := time.Now()
3✔
116
        for i := 0; i < DefaultHTTPHookRetries; i++ {
7✔
117
                hookLog.Infof("invocation attempt: %d", i)
4✔
118
                if time.Since(start) > time.Duration(i+1)*DefaultHTTPHookTimeout {
4✔
NEW
119
                        return []byte{}, gatewayTimeoutError(ErrorHookTimeout, "failed to reach hook within timeout")
×
NEW
120
                }
×
121
                msgID := uuid.Must(uuid.NewV4())
4✔
122
                currentTime := time.Now()
4✔
123
                signatureList, err := crypto.GenerateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload)
4✔
124
                if err != nil {
4✔
NEW
125
                        return nil, err
×
NEW
126
                }
×
127

128
                req, err := http.NewRequest(http.MethodPost, requestURL, bytes.NewBuffer(inputPayload))
4✔
129
                if err != nil {
4✔
NEW
130
                        return nil, internalServerError("Failed to make request object").WithInternalError(err)
×
NEW
131
                }
×
132

133
                req.Header.Set("Content-Type", "application/json")
4✔
134
                req.Header.Set("webhook-id", msgID.String())
4✔
135
                req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix()))
4✔
136
                req.Header.Set("webhook-signature", strings.Join(signatureList, ", "))
4✔
137

4✔
138
                watcher, req := watchForConnection(req)
4✔
139
                rsp, err := client.Do(req)
4✔
140

4✔
141
                if err != nil {
4✔
NEW
142
                        if terr, ok := err.(net.Error); ok && terr.Timeout() {
×
NEW
143
                                hookLog.Errorf("Request timed out for attempt %d with err %s", i, err)
×
NEW
144
                                time.Sleep(HTTPHookBackoffDuration)
×
NEW
145
                                continue
×
NEW
146
                        } else if !watcher.gotConn && i < DefaultHTTPHookRetries-1 {
×
NEW
147
                                hookLog.Errorf("Failed to establish a connection on attempt %d with err %s", i, err)
×
NEW
148
                                time.Sleep(HTTPHookBackoffDuration)
×
NEW
149
                                continue
×
NEW
150
                        } else if i == DefaultHTTPHookRetries-1 {
×
NEW
151
                                return nil, gatewayTimeoutError(ErrorHookTimeout, "Failed to reach hook within allotted interval")
×
NEW
152

×
NEW
153
                        } else {
×
NEW
154
                                return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err)
×
NEW
155
                        }
×
156
                }
157

158
                switch rsp.StatusCode {
4✔
159
                case http.StatusOK, http.StatusNoContent, http.StatusAccepted:
2✔
160
                        if rsp.Body == nil {
2✔
NEW
161
                                return nil, nil
×
NEW
162
                        }
×
163
                        body, err := readBodyWithLimit(rsp)
2✔
164
                        if err != nil {
2✔
NEW
165
                                return nil, err
×
NEW
166
                        }
×
167
                        return body, nil
2✔
168
                case http.StatusTooManyRequests, http.StatusServiceUnavailable:
1✔
169
                        retryAfterHeader := rsp.Header.Get("retry-after")
1✔
170
                        // Check for truthy values to allow for flexibility to swtich to time duration
1✔
171
                        if retryAfterHeader != "" {
2✔
172
                                continue
1✔
173
                        }
NEW
174
                        return []byte{}, internalServerError("Service currently unavailable")
×
NEW
175
                case http.StatusBadRequest:
×
NEW
176
                        return nil, badRequestError(ErrorCodeValidationFailed, "Invalid payload sent to hook")
×
NEW
177
                case http.StatusUnauthorized:
×
NEW
178
                        return []byte{}, httpError(http.StatusUnauthorized, ErrorCodeNoAuthorization, "Hook requires authorizaition token")
×
179
                default:
1✔
180
                        return []byte{}, internalServerError("Error executing Hook")
1✔
181
                }
182
        }
NEW
183
        return nil, internalServerError("error executing hook")
×
184
}
185

186
func watchForConnection(req *http.Request) (*connectionWatcher, *http.Request) {
4✔
187
        w := new(connectionWatcher)
4✔
188
        t := &httptrace.ClientTrace{
4✔
189
                GotConn: w.GotConn,
4✔
190
        }
4✔
191

4✔
192
        req = req.WithContext(httptrace.WithClientTrace(req.Context(), t))
4✔
193
        return w, req
4✔
194
}
4✔
195

196
type connectionWatcher struct {
197
        gotConn bool
198
}
199

NEW
200
func (c *connectionWatcher) GotConn(_ httptrace.GotConnInfo) {
×
NEW
201
        c.gotConn = true
×
NEW
202
}
×
203

NEW
204
func (a *API) invokeHTTPHook(r *http.Request, input, output any, hookURI string) error {
×
NEW
UNCOV
205
        switch input.(type) {
×
NEW
206
        case *hooks.CustomSMSProviderInput:
×
NEW
UNCOV
207
                hookOutput, ok := output.(*hooks.CustomSMSProviderOutput)
×
NEW
208
                if !ok {
×
NEW
UNCOV
209
                        panic("output should be *hooks.CustomSMSProviderOutput")
×
210
                }
NEW
UNCOV
211
                var response []byte
×
NEW
UNCOV
212
                var err error
×
NEW
213

×
NEW
UNCOV
214
                if response, err = a.runHTTPHook(r, a.config.Hook.CustomSMSProvider, input, output); err != nil {
×
NEW
UNCOV
215
                        return internalServerError("Error invoking custom SMS provider hook.").WithInternalError(err)
×
NEW
216
                }
×
NEW
UNCOV
217
                if err != nil {
×
NEW
UNCOV
218
                        return err
×
NEW
UNCOV
219
                }
×
220

NEW
221
                if err := json.Unmarshal(response, hookOutput); err != nil {
×
NEW
222
                        return internalServerError("Error unmarshaling custom SMS provider hook output.").WithInternalError(err)
×
NEW
223
                }
×
NEW
224
                fmt.Printf("%v", hookOutput)
×
225

NEW
UNCOV
226
        default:
×
NEW
227
                panic("unknown HTTP hook type")
×
228
        }
NEW
229
        return nil
×
230
}
231

232
// invokePostgresHook invokes the hook code. tx can be nil, in which case a new
233
// transaction is opened. If calling invokeHook within a transaction, always
234
// pass the current transaction, as pool-exhaustion deadlocks are very easy to
235
// trigger.
236
func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection, input, output any, hookURI string) error {
11✔
237
        config := a.config
11✔
238
        // Switch based on hook type
11✔
239
        switch input.(type) {
11✔
240
        case *hooks.MFAVerificationAttemptInput:
4✔
241
                hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput)
4✔
242
                if !ok {
4✔
243
                        panic("output should be *hooks.MFAVerificationAttemptOutput")
×
244
                }
245

246
                if _, err := a.runPostgresHook(ctx, conn, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil {
6✔
247
                        return internalServerError("Error invoking MFA verification hook.").WithInternalError(err)
2✔
248
                }
2✔
249

250
                if hookOutput.IsError() {
2✔
251
                        httpCode := hookOutput.HookError.HTTPCode
×
UNCOV
252

×
253
                        if httpCode == 0 {
×
254
                                httpCode = http.StatusInternalServerError
×
255
                        }
×
256

257
                        httpError := &HTTPError{
×
258
                                HTTPStatus: httpCode,
×
UNCOV
259
                                Message:    hookOutput.HookError.Message,
×
UNCOV
260
                        }
×
UNCOV
261

×
UNCOV
262
                        return httpError.WithInternalError(&hookOutput.HookError)
×
263
                }
264

265
                return nil
2✔
266
        case *hooks.PasswordVerificationAttemptInput:
2✔
267
                hookOutput, ok := output.(*hooks.PasswordVerificationAttemptOutput)
2✔
268
                if !ok {
2✔
269
                        panic("output should be *hooks.PasswordVerificationAttemptOutput")
×
270
                }
271

272
                if _, err := a.runPostgresHook(ctx, conn, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil {
2✔
UNCOV
273
                        return internalServerError("Error invoking password verification hook.").WithInternalError(err)
×
UNCOV
274
                }
×
275

276
                if hookOutput.IsError() {
2✔
277
                        httpCode := hookOutput.HookError.HTTPCode
×
UNCOV
278

×
UNCOV
279
                        if httpCode == 0 {
×
UNCOV
280
                                httpCode = http.StatusInternalServerError
×
UNCOV
281
                        }
×
282

UNCOV
283
                        httpError := &HTTPError{
×
UNCOV
284
                                HTTPStatus: httpCode,
×
UNCOV
285
                                Message:    hookOutput.HookError.Message,
×
UNCOV
286
                        }
×
UNCOV
287

×
UNCOV
288
                        return httpError.WithInternalError(&hookOutput.HookError)
×
289
                }
290

291
                return nil
2✔
292
        case *hooks.CustomAccessTokenInput:
5✔
293
                hookOutput, ok := output.(*hooks.CustomAccessTokenOutput)
5✔
294
                if !ok {
5✔
UNCOV
295
                        panic("output should be *hooks.CustomAccessTokenOutput")
×
296
                }
297

298
                if _, err := a.runPostgresHook(ctx, conn, config.Hook.CustomAccessToken.HookName, input, output); err != nil {
5✔
UNCOV
299
                        return internalServerError("Error invoking access token hook.").WithInternalError(err)
×
UNCOV
300
                }
×
301

302
                if hookOutput.IsError() {
6✔
303
                        httpCode := hookOutput.HookError.HTTPCode
1✔
304

1✔
305
                        if httpCode == 0 {
1✔
UNCOV
306
                                httpCode = http.StatusInternalServerError
×
UNCOV
307
                        }
×
308

309
                        httpError := &HTTPError{
1✔
310
                                HTTPStatus: httpCode,
1✔
311
                                Message:    hookOutput.HookError.Message,
1✔
312
                        }
1✔
313

1✔
314
                        return httpError.WithInternalError(&hookOutput.HookError)
1✔
315
                }
316
                if err := validateTokenClaims(hookOutput.Claims); err != nil {
5✔
317
                        httpCode := hookOutput.HookError.HTTPCode
1✔
318

1✔
319
                        if httpCode == 0 {
2✔
320
                                httpCode = http.StatusInternalServerError
1✔
321
                        }
1✔
322

323
                        httpError := &HTTPError{
1✔
324
                                HTTPStatus: httpCode,
1✔
325
                                Message:    err.Error(),
1✔
326
                        }
1✔
327

1✔
328
                        return httpError
1✔
329
                }
330
                return nil
3✔
331

UNCOV
332
        default:
×
NEW
UNCOV
333
                panic("unknown Postgres hook input type")
×
334
        }
335
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc