• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

m-lab / locate / 1505

10 Feb 2025 08:04PM UTC coverage: 95.335% (-1.7%) from 97.013%
1505

push

travis-pro

web-flow
Implement rate limiting based on IP+UA (#211)

* Implement rate limiting

* Use separate Redis instance for rate limiting

* Fix flag name

* Remove annotations

* Set rate-limit-max to 40

* Update comments.

* Update comments and benchmark

* Reuse the same tooManyRequests error

* Fix indent in cloudbuild.yaml

* Add RATE_LIMIT_REDIS_ADDRESS to the mlab-ns template too

* Do not set an error status code on rate limit yet

* Address review comments

* Update comment

* Address review comments

42 of 74 new or added lines in 2 files covered. (56.76%)

4 existing lines in 1 file now uncovered.

1921 of 2015 relevant lines covered (95.33%)

1.06 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.3
/handler/handler.go
1
// Package handler provides a client and handlers for responding to locate
2
// requests.
3
package handler
4

5
import (
6
        "bytes"
7
        "context"
8
        "encoding/json"
9
        "errors"
10
        "fmt"
11
        "html/template"
12
        "math/rand"
13
        "net/http"
14
        "net/url"
15
        "path"
16
        "strconv"
17
        "strings"
18
        "time"
19

20
        "github.com/google/uuid"
21
        log "github.com/sirupsen/logrus"
22
        "gopkg.in/square/go-jose.v2/jwt"
23

24
        "github.com/m-lab/go/rtx"
25
        v2 "github.com/m-lab/locate/api/v2"
26
        "github.com/m-lab/locate/clientgeo"
27
        "github.com/m-lab/locate/heartbeat"
28
        "github.com/m-lab/locate/limits"
29
        "github.com/m-lab/locate/metrics"
30
        "github.com/m-lab/locate/siteinfo"
31
        "github.com/m-lab/locate/static"
32
        prom "github.com/prometheus/client_golang/api/prometheus/v1"
33
        "github.com/prometheus/common/model"
34
)
35

36
var (
37
        errFailedToLookupClient = errors.New("Failed to look up client location")
38
        tooManyRequests         = "Too many periodic requests. Please contact support@measurementlab.net."
39
)
40

41
// Signer defines how access tokens are signed.
42
type Signer interface {
43
        Sign(cl jwt.Claims) (string, error)
44
}
45

46
// Client contains state needed for xyz.
47
type Client struct {
48
        Signer
49
        project string
50
        LocatorV2
51
        ClientLocator
52
        PrometheusClient
53
        targetTmpl  *template.Template
54
        agentLimits limits.Agents
55
        ipLimiter   *limits.RateLimiter
56
}
57

58
// LocatorV2 defines how the Nearest handler requests machines nearest to the
59
// client.
60
type LocatorV2 interface {
61
        Nearest(service string, lat, lon float64, opts *heartbeat.NearestOptions) (*heartbeat.TargetInfo, error)
62
        heartbeat.StatusTracker
63
}
64

65
// ClientLocator defines the interfeace for looking up the client geo location.
66
type ClientLocator interface {
67
        Locate(req *http.Request) (*clientgeo.Location, error)
68
}
69

70
// PrometheusClient defines the interface to query Prometheus.
71
type PrometheusClient interface {
72
        Query(ctx context.Context, query string, ts time.Time, opts ...prom.Option) (model.Value, prom.Warnings, error)
73
}
74

75
type paramOpts struct {
76
        raw       url.Values
77
        version   string
78
        ranks     map[string]int
79
        svcParams map[string]float64
80
}
81

82
func init() {
1✔
83
        log.SetFormatter(&log.JSONFormatter{})
1✔
84
        log.SetLevel(log.InfoLevel)
1✔
85
}
1✔
86

87
// NewClient creates a new client.
88
func NewClient(project string, private Signer, locatorV2 LocatorV2, client ClientLocator,
89
        prom PrometheusClient, lmts limits.Agents, limiter *limits.RateLimiter) *Client {
1✔
90
        return &Client{
1✔
91
                Signer:           private,
1✔
92
                project:          project,
1✔
93
                LocatorV2:        locatorV2,
1✔
94
                ClientLocator:    client,
1✔
95
                PrometheusClient: prom,
1✔
96
                targetTmpl:       template.Must(template.New("name").Parse("{{.Hostname}}{{.Ports}}")),
1✔
97
                agentLimits:      lmts,
1✔
98
                ipLimiter:        limiter,
1✔
99
        }
1✔
100
}
1✔
101

102
// NewClientDirect creates a new client with a target template using only the target machine.
103
func NewClientDirect(project string, private Signer, locatorV2 LocatorV2, client ClientLocator, prom PrometheusClient) *Client {
1✔
104
        return &Client{
1✔
105
                Signer:           private,
1✔
106
                project:          project,
1✔
107
                LocatorV2:        locatorV2,
1✔
108
                ClientLocator:    client,
1✔
109
                PrometheusClient: prom,
1✔
110
                // Useful for the locatetest package when running a local server.
1✔
111
                targetTmpl: template.Must(template.New("name").Parse("{{.Hostname}}{{.Ports}}")),
1✔
112
        }
1✔
113
}
1✔
114

115
func extraParams(hostname string, index int, p paramOpts) url.Values {
1✔
116
        v := url.Values{}
1✔
117
        // Add client parameters.
1✔
118
        for key := range p.raw {
2✔
119
                if strings.HasPrefix(key, "client_") {
2✔
120
                        // note: we only use the first value.
1✔
121
                        v.Set(key, p.raw.Get(key))
1✔
122
                }
1✔
123

124
                val, ok := p.svcParams[key]
1✔
125
                if ok && rand.Float64() < val {
2✔
126
                        v.Set(key, p.raw.Get(key))
1✔
127
                }
1✔
128
        }
129

130
        // Add Locate Service version.
131
        v.Set("locate_version", p.version)
1✔
132

1✔
133
        // Add metro rank.
1✔
134
        rank, ok := p.ranks[hostname]
1✔
135
        if ok {
2✔
136
                v.Set("metro_rank", strconv.Itoa(rank))
1✔
137
        }
1✔
138

139
        // Add result index.
140
        v.Set("index", strconv.Itoa(index))
1✔
141

1✔
142
        return v
1✔
143
}
144

145
// Nearest uses an implementation of the LocatorV2 interface to look up
146
// nearest servers.
147
func (c *Client) Nearest(rw http.ResponseWriter, req *http.Request) {
1✔
148
        req.ParseForm()
1✔
149
        result := v2.NearestResult{}
1✔
150
        setHeaders(rw)
1✔
151

1✔
152
        if c.limitRequest(time.Now().UTC(), req) {
2✔
153
                result.Error = v2.NewError("client", tooManyRequests, http.StatusTooManyRequests)
1✔
154
                writeResult(rw, result.Error.Status, &result)
1✔
155
                metrics.RequestsTotal.WithLabelValues("nearest", "request limit", http.StatusText(result.Error.Status)).Inc()
1✔
156
                return
1✔
157
        }
1✔
158

159
        // Check rate limit for IP and UA.
160
        if c.ipLimiter != nil {
1✔
NEW
161
                // Get the IP address from the request. X-Forwarded-For is guaranteed to
×
NEW
162
                // be set by AppEngine.
×
NEW
163
                ip := req.Header.Get("X-Forwarded-For")
×
NEW
164
                ips := strings.Split(ip, ",")
×
NEW
165
                if len(ips) > 0 {
×
NEW
166
                        ip = strings.TrimSpace(ips[0])
×
NEW
167
                }
×
NEW
168
                if ip != "" {
×
NEW
169
                        // An empty UA is technically possible. In this case, the key will be
×
NEW
170
                        // "ip:" and the rate limiting will be based on the IP address only.
×
NEW
171
                        ua := req.Header.Get("User-Agent")
×
NEW
172
                        limited, err := c.ipLimiter.IsLimited(ip, ua)
×
NEW
173
                        if err != nil {
×
NEW
174
                                // Log error but don't block request (fail open).
×
NEW
175
                                // TODO: Add tests for this path.
×
NEW
176
                                log.Printf("Rate limiter error: %v", err)
×
NEW
177
                        } else if limited {
×
NEW
178
                                metrics.RequestsTotal.WithLabelValues("nearest", "rate limit",
×
NEW
179
                                        http.StatusText(result.Error.Status)).Inc()
×
NEW
180
                                // For now, we only log the rate limit exceeded message.
×
NEW
181
                                // TODO: Actually block the request and return an appropriate HTTP error
×
NEW
182
                                // code and message.
×
NEW
183
                                log.Printf("Rate limit exceeded for IP %s and UA %s", ip, ua)
×
NEW
184
                        }
×
NEW
185
                } else {
×
NEW
186
                        // This should never happen if Locate is deployed on AppEngine.
×
NEW
187
                        log.Println("Cannot find IP address for rate limiting.")
×
NEW
188
                }
×
189
        }
190

191
        experiment, service := getExperimentAndService(req.URL.Path)
1✔
192

1✔
193
        // Look up client location.
1✔
194
        loc, err := c.checkClientLocation(rw, req)
1✔
195
        if err != nil {
2✔
196
                status := http.StatusServiceUnavailable
1✔
197
                result.Error = v2.NewError("nearest", "Failed to lookup nearest machines", status)
1✔
198
                writeResult(rw, result.Error.Status, &result)
1✔
199
                metrics.RequestsTotal.WithLabelValues("nearest", "client location",
1✔
200
                        http.StatusText(result.Error.Status)).Inc()
1✔
201
                return
1✔
202
        }
1✔
203

204
        // Parse client location.
205
        lat, errLat := strconv.ParseFloat(loc.Latitude, 64)
1✔
206
        lon, errLon := strconv.ParseFloat(loc.Longitude, 64)
1✔
207
        if errLat != nil || errLon != nil {
2✔
208
                result.Error = v2.NewError("client", errFailedToLookupClient.Error(), http.StatusInternalServerError)
1✔
209
                writeResult(rw, result.Error.Status, &result)
1✔
210
                metrics.RequestsTotal.WithLabelValues("nearest", "parse client location",
1✔
211
                        http.StatusText(result.Error.Status)).Inc()
1✔
212
                return
1✔
213
        }
1✔
214

215
        // Find the nearest targets using the client parameters.
216
        q := req.URL.Query()
1✔
217
        t := q.Get("machine-type")
1✔
218
        country := req.Header.Get("X-AppEngine-Country")
1✔
219
        sites := q["site"]
1✔
220
        org := q.Get("org")
1✔
221
        strict := false
1✔
222
        if qsStrict, err := strconv.ParseBool(q.Get("strict")); err == nil {
1✔
223
                strict = qsStrict
×
224
        }
×
225
        // If strict, override the country from the AppEngine header with the one in
226
        // the querystring.
227
        if strict {
1✔
228
                country = q.Get("country")
×
229
        }
×
230
        opts := &heartbeat.NearestOptions{Type: t, Country: country, Sites: sites, Org: org, Strict: strict}
1✔
231
        targetInfo, err := c.LocatorV2.Nearest(service, lat, lon, opts)
1✔
232
        if err != nil {
2✔
233
                result.Error = v2.NewError("nearest", "Failed to lookup nearest machines", http.StatusInternalServerError)
1✔
234
                writeResult(rw, result.Error.Status, &result)
1✔
235
                metrics.RequestsTotal.WithLabelValues("nearest", "server location",
1✔
236
                        http.StatusText(result.Error.Status)).Inc()
1✔
237
                return
1✔
238
        }
1✔
239

240
        pOpts := paramOpts{
1✔
241
                raw:       req.Form,
1✔
242
                version:   "v2",
1✔
243
                ranks:     targetInfo.Ranks,
1✔
244
                svcParams: static.ServiceParams,
1✔
245
        }
1✔
246
        // Populate target URLs and write out response.
1✔
247
        c.populateURLs(targetInfo.Targets, targetInfo.URLs, experiment, pOpts)
1✔
248
        result.Results = targetInfo.Targets
1✔
249
        writeResult(rw, http.StatusOK, &result)
1✔
250
        metrics.RequestsTotal.WithLabelValues("nearest", "success", http.StatusText(http.StatusOK)).Inc()
1✔
251
}
252

253
// Live is a minimal handler to indicate that the server is operating at all.
254
func (c *Client) Live(rw http.ResponseWriter, req *http.Request) {
1✔
255
        fmt.Fprintf(rw, "ok")
1✔
256
}
1✔
257

258
// Ready reports whether the server is working as expected and ready to serve requests.
259
func (c *Client) Ready(rw http.ResponseWriter, req *http.Request) {
1✔
260
        if c.LocatorV2.Ready() {
2✔
261
                fmt.Fprintf(rw, "ok")
1✔
262
        } else {
2✔
263
                rw.WriteHeader(http.StatusInternalServerError)
1✔
264
                fmt.Fprintf(rw, "not ready")
1✔
265
        }
1✔
266
}
267

268
// Registrations returns information about registered machines. There are 3
269
// supported query parameters:
270
//
271
// * format - defines the format of the returned JSON
272
// * org - limits results to only records for the given organization
273
// * exp - limits results to only records for the given experiment (e.g., ndt)
274
func (c *Client) Registrations(rw http.ResponseWriter, req *http.Request) {
1✔
275
        var err error
1✔
276
        var result interface{}
1✔
277

1✔
278
        q := req.URL.Query()
1✔
279
        format := q.Get("format")
1✔
280

1✔
281
        switch format {
1✔
282
        default:
1✔
283
                result, err = siteinfo.Machines(c.LocatorV2.Instances(), q)
1✔
284
        }
285

286
        if err != nil {
2✔
287
                v2Error := v2.NewError("siteinfo", err.Error(), http.StatusInternalServerError)
1✔
288
                writeResult(rw, http.StatusInternalServerError, v2Error)
1✔
289
                return
1✔
290
        }
1✔
291

292
        writeResult(rw, http.StatusOK, result)
1✔
293
}
294

295
// checkClientLocation looks up the client location and copies the location
296
// headers to the response writer.
297
func (c *Client) checkClientLocation(rw http.ResponseWriter, req *http.Request) (*clientgeo.Location, error) {
1✔
298
        // Lookup the client location using the client request.
1✔
299
        loc, err := c.Locate(req)
1✔
300
        if err != nil {
2✔
301
                return nil, errFailedToLookupClient
1✔
302
        }
1✔
303

304
        // Copy location headers to response writer.
305
        for key := range loc.Headers {
2✔
306
                rw.Header().Set(key, loc.Headers.Get(key))
1✔
307
        }
1✔
308

309
        return loc, nil
1✔
310
}
311

312
// populateURLs populates each set of URLs using the target configuration.
313
func (c *Client) populateURLs(targets []v2.Target, ports static.Ports, exp string, pOpts paramOpts) {
1✔
314
        for i, target := range targets {
2✔
315
                token := c.getAccessToken(target.Machine, exp)
1✔
316
                params := extraParams(target.Machine, i, pOpts)
1✔
317
                targets[i].URLs = c.getURLs(ports, target.Hostname, token, params)
1✔
318
        }
1✔
319
}
320

321
// getAccessToken allocates a new access token using the given machine name as
322
// the intended audience and the subject as the target service.
323
func (c *Client) getAccessToken(machine, subject string) string {
1✔
324
        // Create the token. The same access token is reused for every URL of a
1✔
325
        // target port.
1✔
326
        // A uuid is added to the claims so that each new token is unique.
1✔
327
        cl := jwt.Claims{
1✔
328
                Issuer:   static.IssuerLocate,
1✔
329
                Subject:  subject,
1✔
330
                Audience: jwt.Audience{machine},
1✔
331
                Expiry:   jwt.NewNumericDate(time.Now().Add(time.Minute)),
1✔
332
                ID:       uuid.NewString(),
1✔
333
        }
1✔
334
        token, err := c.Sign(cl)
1✔
335
        // Sign errors can only happen due to a misconfiguration of the key.
1✔
336
        // A good config will remain good.
1✔
337
        rtx.PanicOnError(err, "signing claims has failed")
1✔
338
        return token
1✔
339
}
1✔
340

341
// getURLs creates URLs for the named experiment, running on the named machine
342
// for each given port. Every URL will include an `access_token=` parameter,
343
// authorizing the measurement.
344
func (c *Client) getURLs(ports static.Ports, hostname, token string, extra url.Values) map[string]string {
1✔
345
        urls := map[string]string{}
1✔
346
        // For each port config, prepare the target url with access_token and
1✔
347
        // complete host field.
1✔
348
        for _, target := range ports {
2✔
349
                name := target.String()
1✔
350
                params := url.Values{}
1✔
351
                params.Set("access_token", token)
1✔
352
                for key := range extra {
2✔
353
                        // note: we only use the first value.
1✔
354
                        params.Set(key, extra.Get(key))
1✔
355
                }
1✔
356
                target.RawQuery = params.Encode()
1✔
357

1✔
358
                host := &bytes.Buffer{}
1✔
359
                err := c.targetTmpl.Execute(host, map[string]string{
1✔
360
                        "Hostname": hostname,
1✔
361
                        "Ports":    target.Host, // from URL template, so typically just the ":port".
1✔
362
                })
1✔
363
                rtx.PanicOnError(err, "bad template evaluation")
1✔
364
                target.Host = host.String()
1✔
365
                urls[name] = target.String()
1✔
366
        }
367
        return urls
1✔
368
}
369

370
// limitRequest determines whether a client request should be rate-limited.
371
func (c *Client) limitRequest(now time.Time, req *http.Request) bool {
1✔
372
        agent := req.Header.Get("User-Agent")
1✔
373
        l, ok := c.agentLimits[agent]
1✔
374
        if !ok {
2✔
375
                // No limit defined for user agent.
1✔
376
                return false
1✔
377
        }
1✔
378
        return l.IsLimited(now)
1✔
379
}
380

381
// setHeaders sets the response headers for "nearest" requests.
382
func setHeaders(rw http.ResponseWriter) {
1✔
383
        // Set CORS policy to allow third-party websites to use returned resources.
1✔
384
        rw.Header().Set("Content-Type", "application/json")
1✔
385
        rw.Header().Set("Access-Control-Allow-Origin", "*")
1✔
386
        // Prevent caching of result.
1✔
387
        // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control
1✔
388
        rw.Header().Set("Cache-Control", "no-store")
1✔
389
}
1✔
390

391
// writeResult marshals the result and writes the result to the response writer.
392
func writeResult(rw http.ResponseWriter, status int, result interface{}) {
1✔
393
        b, err := json.MarshalIndent(result, "", "  ")
1✔
394
        // Errors are only possible when marshalling incompatible types, like functions.
1✔
395
        rtx.PanicOnError(err, "Failed to format result")
1✔
396
        rw.WriteHeader(status)
1✔
397
        rw.Write(b)
1✔
398
}
1✔
399

400
// getExperimentAndService takes an http request path and extracts the last two
401
// fields. For correct requests (e.g. "/v2/nearest/ndt/ndt5"), this will be the
402
// experiment name (e.g. "ndt") and the datatype (e.g. "ndt5").
403
func getExperimentAndService(p string) (string, string) {
1✔
404
        datatype := path.Base(p)
1✔
405
        experiment := path.Base(path.Dir(p))
1✔
406
        return experiment, experiment + "/" + datatype
1✔
407
}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc