• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

m-lab / autojoin / 20352158914

18 Dec 2025 08:43PM UTC coverage: 89.695% (-0.1%) from 89.803%
20352158914

Pull #78

github

nkinkade
Ensures that service is comptible with machine

If the requested service is "ndt7" or "ndt7_client", then the Autojoin API
should not return, for example, Wehe servers.

Prometheus (well, gcp-service-discovery) queries the Autojoin API for
script-exporter targets for ndt7 e2e testing. The list endpoint should not
return machines running experiments that are not compatible with the requested
service.
Pull Request #78: Ensures that service is compatible with machine

4 of 6 new or added lines in 1 file covered. (66.67%)

1 existing line in 1 file now uncovered.

1323 of 1475 relevant lines covered (89.69%)

0.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.81
/handler/handler.go
1
package handler
2

3
import (
4
        "context"
5
        "encoding/json"
6
        "errors"
7
        "fmt"
8
        "log"
9
        "net"
10
        "net/http"
11
        "regexp"
12
        "strconv"
13
        "strings"
14

15
        "github.com/Masterminds/semver/v3"
16
        v0 "github.com/m-lab/autojoin/api/v0"
17
        "github.com/m-lab/autojoin/iata"
18
        "github.com/m-lab/autojoin/internal/adminx"
19
        "github.com/m-lab/autojoin/internal/dnsname"
20
        "github.com/m-lab/autojoin/internal/dnsx"
21
        "github.com/m-lab/autojoin/internal/dnsx/dnsiface"
22
        "github.com/m-lab/autojoin/internal/register"
23
        "github.com/m-lab/gcp-service-discovery/discovery"
24
        "github.com/m-lab/go/host"
25
        "github.com/m-lab/go/rtx"
26
        v2 "github.com/m-lab/locate/api/v2"
27
        "github.com/m-lab/uuid-annotator/annotator"
28
        "github.com/oschwald/geoip2-golang"
29
)
30

31
var (
32
        errLocationNotFound = errors.New("location not found")
33
        errLocationFormat   = errors.New("location could not be parsed")
34

35
        // Regex patterns with proper anchoring for security
36
        validName   = regexp.MustCompile(`^[a-zA-Z0-9]+$`)
37
        validUplink = regexp.MustCompile(`^[0-9]+g$`)
38
)
39

40
// Server maintains shared state for the server.
41
type Server struct {
42
        Project    string
43
        Iata       IataFinder
44
        Maxmind    MaxmindFinder
45
        ASN        ASNFinder
46
        DNS        dnsiface.Service
47
        minVersion *semver.Version
48

49
        sm         ServiceAccountSecretManager
50
        dnsTracker DNSTracker
51
        dsm        Datastore
52
}
53

54
// ASNFinder is an interface used by the Server to manage ASN information.
55
type ASNFinder interface {
56
        AnnotateIP(src string) *annotator.Network
57
        Reload(ctx context.Context)
58
}
59

60
// MaxmindFinder is an interface used by the Server to manage Maxmind information.
61
type MaxmindFinder interface {
62
        City(ip net.IP) (*geoip2.City, error)
63
        Reload(ctx context.Context) error
64
}
65

66
// IataFinder is an interface used by the Server to manage IATA information.
67
type IataFinder interface {
68
        Lookup(country string, lat, lon float64) (string, error)
69
        Find(iata string) (iata.Row, error)
70
        Load(ctx context.Context) error
71
}
72

73
type DNSTracker interface {
74
        Update(string, []string) error
75
        Delete(string) error
76
        List() ([]string, [][]string, error)
77
}
78

79
// ServiceAccountSecretManager is an interface used by the server to allocate service account keys.
80
type ServiceAccountSecretManager interface {
81
        LoadOrCreateKey(ctx context.Context, org string) (string, error)
82
}
83

84
type Datastore interface {
85
        GetOrganization(ctx context.Context, name string) (*adminx.Organization, error)
86
}
87

88
// NewServer creates a new Server instance for request handling.
89
func NewServer(project string, finder IataFinder, maxmind MaxmindFinder, asn ASNFinder,
90
        ds dnsiface.Service, tracker DNSTracker, sm ServiceAccountSecretManager, dsm Datastore,
91
        minVersion string) *Server {
1✔
92
        v, err := semver.NewVersion(minVersion)
1✔
93
        rtx.Must(err, "invalid minimum version")
1✔
94
        return &Server{
1✔
95
                Project:    project,
1✔
96
                Iata:       finder,
1✔
97
                Maxmind:    maxmind,
1✔
98
                ASN:        asn,
1✔
99
                DNS:        ds,
1✔
100
                sm:         sm,
1✔
101
                minVersion: v,
1✔
102

1✔
103
                dnsTracker: tracker,
1✔
104
                dsm:        dsm,
1✔
105
        }
1✔
106
}
1✔
107

108
// Reload reloads all resources used by the Server.
109
func (s *Server) Reload(ctx context.Context) {
1✔
110
        s.Iata.Load(ctx)
1✔
111
        s.Maxmind.Reload(ctx)
1✔
112
}
1✔
113

114
// Lookup is a handler used to find the nearest IATA given client IP or lat/lon metadata.
115
func (s *Server) Lookup(rw http.ResponseWriter, req *http.Request) {
1✔
116
        resp := v0.LookupResponse{}
1✔
117
        country, err := s.getCountry(req)
1✔
118
        if country == "" || err != nil {
2✔
119
                resp.Error = &v2.Error{
1✔
120
                        Type:   "?country=<country>",
1✔
121
                        Title:  "could not determine country from request",
1✔
122
                        Status: http.StatusBadRequest,
1✔
123
                }
1✔
124
                rw.WriteHeader(resp.Error.Status)
1✔
125
                writeResponse(rw, resp)
1✔
126
                return
1✔
127
        }
1✔
128
        lat, lon, err := s.getLocation(req)
1✔
129
        if err != nil {
2✔
130
                resp.Error = &v2.Error{
1✔
131
                        Type:   "?lat=<lat>&lon=<lon>",
1✔
132
                        Title:  "could not determine lat/lon from request",
1✔
133
                        Status: http.StatusBadRequest,
1✔
134
                }
1✔
135
                rw.WriteHeader(resp.Error.Status)
1✔
136
                writeResponse(rw, resp)
1✔
137
                return
1✔
138
        }
1✔
139
        code, err := s.Iata.Lookup(country, lat, lon)
1✔
140
        if err != nil {
2✔
141
                resp.Error = &v2.Error{
1✔
142
                        Type:   "internal error",
1✔
143
                        Title:  "could not determine iata from request",
1✔
144
                        Status: http.StatusInternalServerError,
1✔
145
                }
1✔
146
                rw.WriteHeader(resp.Error.Status)
1✔
147
                writeResponse(rw, resp)
1✔
148
                return
1✔
149
        }
1✔
150
        resp.Lookup = &v0.Lookup{
1✔
151
                IATA: code,
1✔
152
        }
1✔
153
        writeResponse(rw, resp)
1✔
154
}
155

156
// Register handler is used by autonodes to register their hostname with M-Lab
157
// on startup and receive additional needed configuration metadata.
158
func (s *Server) Register(rw http.ResponseWriter, req *http.Request) {
1✔
159
        // All replies, errors and successes, should be json.
1✔
160
        rw.Header().Set("Content-Type", "application/json")
1✔
161

1✔
162
        resp := v0.RegisterResponse{}
1✔
163

1✔
164
        // Check version first.
1✔
165
        versionStr := req.URL.Query().Get("version")
1✔
166
        // If no version is provided, default to v0.0.0. This allows existing clients
1✔
167
        // that do not provide the version yet to keep working until a minVersion is set.
1✔
168
        if versionStr == "" {
2✔
169
                versionStr = "v0.0.0"
1✔
170
        }
1✔
171

172
        // Parse the provided version.
173
        clientVersion, err := semver.NewVersion(versionStr)
1✔
174
        if err != nil {
2✔
175
                resp.Error = &v2.Error{
1✔
176
                        Type:   "version.invalid",
1✔
177
                        Title:  "invalid version format - must be semantic version (e.g. v1.2.3)",
1✔
178
                        Detail: err.Error(),
1✔
179
                        Status: http.StatusBadRequest,
1✔
180
                }
1✔
181
                rw.WriteHeader(resp.Error.Status)
1✔
182
                writeResponse(rw, resp)
1✔
183
                return
1✔
184
        }
1✔
185

186
        if clientVersion.LessThan(s.minVersion) {
2✔
187
                resp.Error = &v2.Error{
1✔
188
                        Type: "version.outdated",
1✔
189
                        Title: fmt.Sprintf("version %s is below minimum required version %s",
1✔
190
                                clientVersion.String(), s.minVersion.String()),
1✔
191
                        Status: http.StatusForbidden,
1✔
192
                }
1✔
193
                rw.WriteHeader(resp.Error.Status)
1✔
194
                writeResponse(rw, resp)
1✔
195
                return
1✔
196
        }
1✔
197

198
        param := &register.Params{Project: s.Project}
1✔
199
        param.Service = req.URL.Query().Get("service")
1✔
200
        if !isValidName(param.Service) {
2✔
201
                resp.Error = &v2.Error{
1✔
202
                        Type:   "?service=<service>",
1✔
203
                        Title:  "could not determine service from request",
1✔
204
                        Status: http.StatusBadRequest,
1✔
205
                }
1✔
206
                rw.WriteHeader(resp.Error.Status)
1✔
207
                writeResponse(rw, resp)
1✔
208
                return
1✔
209
        }
1✔
210

211
        // Get the organization from the context.
212
        org, ok := req.Context().Value(orgContextKey).(string)
1✔
213
        if !ok {
1✔
214
                resp.Error = &v2.Error{
×
215
                        Type:   "auth.context",
×
216
                        Title:  "missing organization in context",
×
217
                        Status: http.StatusInternalServerError,
×
218
                }
×
219
                rw.WriteHeader(resp.Error.Status)
×
220
                writeResponse(rw, resp)
×
221
                return
×
222
        }
×
223
        param.Org = org
1✔
224
        param.IPv6 = checkIP(req.URL.Query().Get("ipv6")) // optional.
1✔
225
        param.IPv4 = checkIP(getClientIP(req))
1✔
226
        ip := net.ParseIP(param.IPv4)
1✔
227
        if ip == nil || ip.To4() == nil {
2✔
228
                resp.Error = &v2.Error{
1✔
229
                        Type:   "?ipv4=<ipv4>",
1✔
230
                        Title:  "could not determine client ipv4 from request",
1✔
231
                        Status: http.StatusBadRequest,
1✔
232
                }
1✔
233
                rw.WriteHeader(resp.Error.Status)
1✔
234
                writeResponse(rw, resp)
1✔
235
                return
1✔
236
        }
1✔
237
        param.Type = req.URL.Query().Get("type")
1✔
238
        if !isValidType(param.Type) {
2✔
239
                resp.Error = &v2.Error{
1✔
240
                        Type:   "?type=<type>",
1✔
241
                        Title:  "invalid machine type from request",
1✔
242
                        Status: http.StatusBadRequest,
1✔
243
                }
1✔
244
                rw.WriteHeader(resp.Error.Status)
1✔
245
                writeResponse(rw, resp)
1✔
246
                return
1✔
247
        }
1✔
248
        param.Uplink = req.URL.Query().Get("uplink")
1✔
249
        if !isValidUplink(param.Uplink) {
2✔
250
                resp.Error = &v2.Error{
1✔
251
                        Type:   "?uplink=<uplink>",
1✔
252
                        Title:  "invalid uplink speed from request",
1✔
253
                        Status: http.StatusBadRequest,
1✔
254
                }
1✔
255
                rw.WriteHeader(resp.Error.Status)
1✔
256
                writeResponse(rw, resp)
1✔
257
                return
1✔
258
        }
1✔
259
        iata := getClientIata(req)
1✔
260
        if iata == "" {
1✔
261
                resp.Error = &v2.Error{
×
262
                        Type:   "?iata=<iata>",
×
263
                        Title:  "could not determine iata from request",
×
264
                        Status: http.StatusBadRequest,
×
265
                }
×
266
                rw.WriteHeader(resp.Error.Status)
×
267
                writeResponse(rw, resp)
×
268
                return
×
269
        }
×
270
        row, err := s.Iata.Find(iata)
1✔
271
        if err != nil {
2✔
272
                resp.Error = &v2.Error{
1✔
273
                        Type:   "iata.find",
1✔
274
                        Title:  "could not find given iata in dataset",
1✔
275
                        Status: http.StatusInternalServerError,
1✔
276
                }
1✔
277
                rw.WriteHeader(resp.Error.Status)
1✔
278
                writeResponse(rw, resp)
1✔
279
                return
1✔
280
        }
1✔
281
        param.Metro = row
1✔
282
        record, err := s.Maxmind.City(ip)
1✔
283
        if err != nil {
2✔
284
                resp.Error = &v2.Error{
1✔
285
                        Type:   "maxmind.city",
1✔
286
                        Title:  "could not find city metadata from ip",
1✔
287
                        Status: http.StatusInternalServerError,
1✔
288
                }
1✔
289
                rw.WriteHeader(resp.Error.Status)
1✔
290
                writeResponse(rw, resp)
1✔
291
                return
1✔
292
        }
1✔
293
        param.Geo = record
1✔
294
        param.Network = s.ASN.AnnotateIP(param.IPv4)
1✔
295

1✔
296
        // Get the organization probability multiplier.
1✔
297
        orgEntity, err := s.dsm.GetOrganization(req.Context(), param.Org)
1✔
298
        orgMultiplier := 1.0
1✔
299
        if err == nil && orgEntity != nil && orgEntity.ProbabilityMultiplier != nil {
2✔
300
                orgMultiplier = *orgEntity.ProbabilityMultiplier
1✔
301
        }
1✔
302
        // Assign the probability by multiplying the org multiplier with the
303
        // probability requested by the client.
304
        param.Probability = getProbability(req) * orgMultiplier
1✔
305
        r := register.CreateRegisterResponse(param)
1✔
306

1✔
307
        key, err := s.sm.LoadOrCreateKey(req.Context(), param.Org)
1✔
308
        if err != nil {
2✔
309
                resp.Error = &v2.Error{
1✔
310
                        Type:   "load.serviceaccount.key",
1✔
311
                        Title:  "could not load service account key for node",
1✔
312
                        Status: http.StatusInternalServerError,
1✔
313
                }
1✔
314
                log.Println("loading service account key failure:", err)
1✔
315
                rw.WriteHeader(resp.Error.Status)
1✔
316
                writeResponse(rw, resp)
1✔
317
                return
1✔
318
        }
1✔
319
        r.Registration.Credentials = &v0.Credentials{
1✔
320
                ServiceAccountKey: key,
1✔
321
        }
1✔
322

1✔
323
        // Register the hostname under the organization zone.
1✔
324
        m := dnsx.NewManager(s.DNS, s.Project, dnsname.OrgZone(param.Org, s.Project))
1✔
325
        _, err = m.Register(req.Context(), r.Registration.Hostname+".", param.IPv4, param.IPv6)
1✔
326
        if err != nil {
2✔
327
                resp.Error = &v2.Error{
1✔
328
                        Type:   "dns.register",
1✔
329
                        Title:  "could not register dynamic hostname",
1✔
330
                        Status: http.StatusInternalServerError,
1✔
331
                }
1✔
332
                log.Println("dns register failure:", err)
1✔
333
                rw.WriteHeader(resp.Error.Status)
1✔
334
                writeResponse(rw, resp)
1✔
335
                return
1✔
336
        }
1✔
337

338
        // Add the hostname to the DNS tracker.
339
        err = s.dnsTracker.Update(r.Registration.Hostname, getPorts(req))
1✔
340
        if err != nil {
2✔
341
                resp.Error = &v2.Error{
1✔
342
                        Type:   "tracker.gc",
1✔
343
                        Title:  "could not update DNS tracker",
1✔
344
                        Status: http.StatusInternalServerError,
1✔
345
                }
1✔
346
                log.Println("dns gc update failure:", err)
1✔
347
                rw.WriteHeader(resp.Error.Status)
1✔
348
                writeResponse(rw, resp)
1✔
349
                return
1✔
350
        }
1✔
351

352
        b, _ := json.MarshalIndent(r, "", " ")
1✔
353
        rw.Write(b)
1✔
354
}
355

356
// Delete handler is used by operators to delete a previously registered
357
// hostname from DNS.
358
func (s *Server) Delete(rw http.ResponseWriter, req *http.Request) {
1✔
359
        // All replies, errors and successes, should be json.
1✔
360
        rw.Header().Set("Content-Type", "application/json")
1✔
361

1✔
362
        resp := v0.DeleteResponse{}
1✔
363
        hostname := req.URL.Query().Get("hostname")
1✔
364
        name, err := host.Parse(hostname)
1✔
365
        if err != nil {
2✔
366
                resp.Error = &v2.Error{
1✔
367
                        Type:   "dns.delete",
1✔
368
                        Title:  "failed to parse hostname",
1✔
369
                        Detail: err.Error(),
1✔
370
                        Status: http.StatusBadRequest,
1✔
371
                }
1✔
372
                log.Println("dns delete (parse) failure:", err)
1✔
373
                rw.WriteHeader(resp.Error.Status)
1✔
374
                writeResponse(rw, resp)
1✔
375
                return
1✔
376
        }
1✔
377

378
        m := dnsx.NewManager(s.DNS, s.Project, dnsname.OrgZone(name.Org, s.Project))
1✔
379
        _, err = m.Delete(req.Context(), name.StringAll()+".")
1✔
380
        if err != nil {
2✔
381
                resp.Error = &v2.Error{
1✔
382
                        Type:   "dns.delete",
1✔
383
                        Title:  "failed to delete hostname",
1✔
384
                        Detail: err.Error(),
1✔
385
                        Status: http.StatusInternalServerError,
1✔
386
                }
1✔
387
                log.Println("dns delete failure:", err)
1✔
388
                rw.WriteHeader(resp.Error.Status)
1✔
389
                writeResponse(rw, resp)
1✔
390
                return
1✔
391
        }
1✔
392

393
        err = s.dnsTracker.Delete(name.StringAll())
1✔
394
        if err != nil {
2✔
395
                resp.Error = &v2.Error{
1✔
396
                        Type:   "tracker.gc",
1✔
397
                        Title:  "failed to delete hostname from DNS tracker",
1✔
398
                        Detail: err.Error(),
1✔
399
                        Status: http.StatusInternalServerError,
1✔
400
                }
1✔
401
                log.Println("dns gc delete failure:", err)
1✔
402
                rw.WriteHeader(resp.Error.Status)
1✔
403
                writeResponse(rw, resp)
1✔
404
                return
1✔
405
        }
1✔
406

407
        b, err := json.MarshalIndent(resp, "", " ")
1✔
408
        rtx.Must(err, "failed to marshal DNS delete response")
1✔
409
        rw.Write(b)
1✔
410
}
411

412
// List handler is used by monitoring to generate a list of known, active
413
// hostnames previously registered with the Autojoin API.
414
func (s *Server) List(rw http.ResponseWriter, req *http.Request) {
1✔
415
        // Set CORS policy to allow third-party websites to use returned resources.
1✔
416
        rw.Header().Set("Content-Type", "application/json")
1✔
417
        rw.Header().Set("Access-Control-Allow-Origin", "*")
1✔
418
        rw.Header().Set("Cache-Control", "no-store") // Prevent caching of result.
1✔
419

1✔
420
        configs := []discovery.StaticConfig{}
1✔
421
        resp := v0.ListResponse{}
1✔
422
        hosts, ports, err := s.dnsTracker.List()
1✔
423
        if err != nil {
2✔
424
                resp.Error = &v2.Error{
1✔
425
                        Type:   "list",
1✔
426
                        Title:  "failed to list node records",
1✔
427
                        Detail: err.Error(),
1✔
428
                        Status: http.StatusInternalServerError,
1✔
429
                }
1✔
430
                log.Println("list failure:", err)
1✔
431
                rw.WriteHeader(resp.Error.Status)
1✔
432
                writeResponse(rw, resp)
1✔
433
                return
1✔
434
        }
1✔
435

436
        org := req.URL.Query().Get("org")
1✔
437
        format := req.URL.Query().Get("format")
1✔
438
        service := req.URL.Query().Get("service")
1✔
439
        sites := map[string]bool{}
1✔
440

1✔
441
        // Create a prometheus StaticConfig for each known host.
1✔
442
        for i := range hosts {
2✔
443
                h, err := host.Parse(hosts[i])
1✔
444
                if err != nil {
2✔
445
                        continue
1✔
446
                }
447
                if org != "" && org != h.Org {
2✔
448
                        // Skip hosts that are not part of the given org.
1✔
449
                        continue
1✔
450
                }
451
                if service != "" && !strings.HasPrefix(service, h.Service) {
1✔
NEW
452
                        // Skip hosts that are not part of the requested service
×
NEW
453
                        continue
×
454
                }
455
                sites[h.Site] = true
1✔
456
                if format == "script-exporter" {
2✔
457
                        // NOTE: do not assign any ports for script exporter.
1✔
458
                        ports[i] = []string{""}
1✔
459
                } else {
2✔
460
                        // Convert port strings to ":<port>".
1✔
461
                        p := []string{}
1✔
462
                        for j := range ports[i] {
2✔
463
                                p = append(p, ":"+ports[i][j])
1✔
464
                        }
1✔
465
                        ports[i] = p
1✔
466
                }
467
                for _, port := range ports[i] {
2✔
468
                        labels := map[string]string{
1✔
469
                                "machine":    hosts[i],
1✔
470
                                "type":       "virtual",
1✔
471
                                "deployment": "byos",
1✔
472
                                "managed":    "none",
1✔
473
                                "org":        h.Org,
1✔
474
                        }
1✔
475
                        if service != "" {
2✔
476
                                labels["service"] = service
1✔
477
                        }
1✔
478
                        // We create one record per host to add a unique "machine" label to each one.
479
                        configs = append(configs, discovery.StaticConfig{
1✔
480
                                Targets: []string{hosts[i] + port},
1✔
481
                                Labels:  labels,
1✔
482
                        })
1✔
483
                }
484
        }
485

486
        var results interface{}
1✔
487
        switch format {
1✔
488
        case "script-exporter":
1✔
489
                fallthrough
1✔
490
        case "blackbox":
1✔
491
                fallthrough
1✔
492
        case "prometheus":
1✔
493
                results = configs
1✔
494
        case "servers":
1✔
495
                resp.Servers = hosts
1✔
496
                results = resp
1✔
497
        case "sites":
1✔
498
                for k := range sites {
2✔
499
                        resp.Sites = append(resp.Sites, k)
1✔
500
                }
1✔
501
                results = resp
1✔
502
        default:
1✔
503
                resp.Servers = hosts
1✔
504
                results = resp
1✔
505
        }
506
        // Generate as JSON; the list may be empty.
507
        b, err := json.MarshalIndent(results, "", " ")
1✔
508
        rtx.Must(err, "failed to marshal DNS delete response")
1✔
509
        rw.Write(b)
1✔
510
}
511

512
// Live reports whether the system is live.
513
func (s *Server) Live(rw http.ResponseWriter, req *http.Request) {
1✔
514
        fmt.Fprintf(rw, "ok")
1✔
515
}
1✔
516

517
// Ready reports whether the server is ready.
518
func (s *Server) Ready(rw http.ResponseWriter, req *http.Request) {
1✔
519
        fmt.Fprintf(rw, "ok")
1✔
520
}
1✔
521

522
func getClientIata(req *http.Request) string {
1✔
523
        iata := req.URL.Query().Get("iata")
1✔
524
        if iata != "" && len(iata) == 3 && isValidName(iata) {
2✔
525
                return strings.ToLower(iata)
1✔
526
        }
1✔
UNCOV
527
        return ""
×
528
}
529

530
func isValidName(s string) bool {
1✔
531
        if s == "" {
2✔
532
                return false
1✔
533
        }
1✔
534
        if len(s) > 10 {
2✔
535
                return false
1✔
536
        }
1✔
537
        return validName.MatchString(s)
1✔
538
}
539

540
func isValidType(s string) bool {
1✔
541
        switch s {
1✔
542
        case "physical", "virtual":
1✔
543
                return true
1✔
544
        default:
1✔
545
                return false
1✔
546
        }
547
}
548

549
func isValidUplink(s string) bool {
1✔
550
        // Validate uplink speed specification: numbers followed by "g".
1✔
551
        // Using anchored regex to prevent injection attacks.
1✔
552
        return validUplink.MatchString(s)
1✔
553
}
1✔
554

555
func (s *Server) getCountry(req *http.Request) (string, error) {
1✔
556
        c := req.URL.Query().Get("country")
1✔
557
        if c != "" {
2✔
558
                return c, nil
1✔
559
        }
1✔
560
        c = req.Header.Get("X-AppEngine-Country")
1✔
561
        if c != "" {
2✔
562
                return c, nil
1✔
563
        }
1✔
564
        record, err := s.Maxmind.City(net.ParseIP(getClientIP(req)))
1✔
565
        if err != nil {
2✔
566
                return "", err
1✔
567
        }
1✔
568
        return record.Country.IsoCode, nil
1✔
569
}
570

571
func rawLatLon(req *http.Request) (string, string, error) {
1✔
572
        lat := req.URL.Query().Get("lat")
1✔
573
        lon := req.URL.Query().Get("lon")
1✔
574
        if lat != "" && lon != "" {
2✔
575
                return lat, lon, nil
1✔
576
        }
1✔
577
        latlon := req.Header.Get("X-AppEngine-CityLatLong")
1✔
578
        if latlon != "0.000000,0.000000" {
2✔
579
                fields := strings.Split(latlon, ",")
1✔
580
                if len(fields) == 2 {
2✔
581
                        return fields[0], fields[1], nil
1✔
582
                }
1✔
583
        }
584
        return "", "", errLocationNotFound
1✔
585
}
586

587
func (s *Server) getLocation(req *http.Request) (float64, float64, error) {
1✔
588
        rlat, rlon, err := rawLatLon(req)
1✔
589
        if err == nil {
2✔
590
                lat, errLat := strconv.ParseFloat(rlat, 64)
1✔
591
                lon, errLon := strconv.ParseFloat(rlon, 64)
1✔
592
                if errLat != nil || errLon != nil {
2✔
593
                        return 0, 0, errLocationFormat
1✔
594
                }
1✔
595
                return lat, lon, nil
1✔
596
        }
597
        // Fall back to lookup with request IP.
598
        record, err := s.Maxmind.City(net.ParseIP(getClientIP(req)))
1✔
599
        if err != nil {
2✔
600
                return 0, 0, err
1✔
601
        }
1✔
602
        return record.Location.Latitude, record.Location.Longitude, nil
1✔
603
}
604

605
func writeResponse(rw http.ResponseWriter, resp interface{}) {
1✔
606
        b, err := json.MarshalIndent(resp, "", "  ")
1✔
607
        // NOTE: marshal can only fail on incompatible types, like functions. The
1✔
608
        // panic will be caught by the http server handler.
1✔
609
        rtx.PanicOnError(err, "failed to marshal response")
1✔
610
        rw.Write(b)
1✔
611
}
1✔
612

613
func checkIP(ip string) string {
1✔
614
        if net.ParseIP(ip) != nil {
2✔
615
                return ip
1✔
616
        }
1✔
617
        return ""
1✔
618
}
619

620
func getClientIP(req *http.Request) string {
1✔
621
        // Use given IP parameter.
1✔
622
        rawip := req.URL.Query().Get("ipv4")
1✔
623
        if rawip != "" {
2✔
624
                return rawip
1✔
625
        }
1✔
626
        // Use AppEngine's forwarded client address.
627
        fwdIPs := strings.Split(req.Header.Get("X-Forwarded-For"), ", ")
1✔
628
        if fwdIPs[0] != "" {
2✔
629
                return fwdIPs[0]
1✔
630
        }
1✔
631
        // Use remote client address.
632
        hip, _, _ := net.SplitHostPort(req.RemoteAddr)
1✔
633
        return hip
1✔
634
}
635

636
func getProbability(req *http.Request) float64 {
1✔
637
        prob := req.URL.Query().Get("probability")
1✔
638
        if prob == "" {
2✔
639
                return 1.0
1✔
640
        }
1✔
641
        p, err := strconv.ParseFloat(prob, 64)
1✔
642
        if err != nil {
2✔
643
                return 1.0
1✔
644
        }
1✔
645
        return p
1✔
646
}
647

648
func getPorts(req *http.Request) []string {
1✔
649
        result := []string{}
1✔
650
        ports := req.URL.Query()["ports"]
1✔
651
        for _, port := range ports {
2✔
652
                // Verify this is a valid port number in the valid range.
1✔
653
                portNum, err := strconv.ParseInt(port, 10, 64)
1✔
654
                if err != nil {
2✔
655
                        // Skip if not a valid number.
1✔
656
                        continue
1✔
657
                }
658
                // Validate port range (1-65535)
659
                if portNum < 1 || portNum > 65535 {
2✔
660
                        // Skip ports outside valid range.
1✔
661
                        continue
1✔
662
                }
663
                result = append(result, port)
1✔
664
        }
665
        if len(result) == 0 {
2✔
666
                return []string{"9990"} // default port
1✔
667
        }
1✔
668
        return result
1✔
669
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc