• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

go-phorce / dolly / 4269215184

25 Feb 2023 08:53AM UTC coverage: 85.275% (-0.009%) from 85.284%
4269215184

Pull #226

github

GitHub
Bump golang.org/x/sys from 0.0.0-20210615035016-665e8c7367d1 to 0.1.0
Pull Request #226: Bump golang.org/x/sys from 0.0.0-20210615035016-665e8c7367d1 to 0.1.0

9492 of 11131 relevant lines covered (85.28%)

7585.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.68
/rest/tlsconfig/reloader.go
1
package tlsconfig
2

3
import (
4
        "crypto/tls"
5
        "crypto/x509"
6
        "os"
7
        "path"
8
        "sync"
9
        "sync/atomic"
10
        "time"
11

12
        "github.com/pkg/errors"
13
)
14

15
// Wrap time.Tick so we can override it in tests.
16
var makeTicker = func(interval time.Duration) (func(), <-chan time.Time) {
4✔
17
        t := time.NewTicker(interval)
4✔
18
        return t.Stop, t.C
4✔
19
}
4✔
20

21
// OnReloadFunc is a callback to handle cert reload
22
type OnReloadFunc func(pair *tls.Certificate)
23

24
// KeypairReloader keeps necessary info to provide reloaded certificate
25
type KeypairReloader struct {
26
        label          string
27
        lock           sync.RWMutex
28
        loadedAt       time.Time
29
        count          uint32
30
        keypair        *tls.Certificate
31
        certPath       string
32
        certModifiedAt time.Time
33
        keyPath        string
34
        keyModifiedAt  time.Time
35
        inProgress     bool
36
        stopChan       chan<- struct{}
37
        closed         bool
38
        handlers       []OnReloadFunc
39
}
40

41
// NewKeypairReloader return an instance of the TLS cert loader
42
func NewKeypairReloader(certPath, keyPath string, checkInterval time.Duration) (*KeypairReloader, error) {
4✔
43
        return NewKeypairReloaderWithLabel("", certPath, keyPath, checkInterval)
4✔
44
}
4✔
45

46
// NewKeypairReloaderWithLabel return an instance of the TLS cert loader
47
func NewKeypairReloaderWithLabel(label, certPath, keyPath string, checkInterval time.Duration) (*KeypairReloader, error) {
4✔
48
        if label == "" {
8✔
49
                label = path.Base(certPath)
4✔
50
        }
4✔
51

52
        result := &KeypairReloader{
4✔
53
                label:    label,
4✔
54
                certPath: certPath,
4✔
55
                keyPath:  keyPath,
4✔
56
                stopChan: make(chan struct{}),
4✔
57
        }
4✔
58

4✔
59
        logger.Infof("label=%s, status=started", label)
4✔
60

4✔
61
        err := result.Reload()
4✔
62
        if err != nil {
4✔
63
                return nil, errors.WithStack(err)
×
64
        }
×
65

66
        stopChan := make(chan struct{})
4✔
67
        tickerStop, tickChan := makeTicker(checkInterval)
4✔
68
        go func() {
8✔
69
                for {
24✔
70
                        select {
20✔
71
                        case <-stopChan:
4✔
72
                                tickerStop()
4✔
73
                                logger.Infof("status=closed, label=%s, count=%d", result.label, result.LoadedCount())
4✔
74
                                return
4✔
75
                        case <-tickChan:
16✔
76
                                modified := false
16✔
77
                                fi, err := os.Stat(certPath)
16✔
78
                                if err == nil {
32✔
79
                                        modified = fi.ModTime().After(result.certModifiedAt)
16✔
80
                                } else {
16✔
81
                                        logger.Warningf("reason=stat, label=%s, file=%q, err=[%v]", result.label, certPath, err)
×
82
                                }
×
83
                                if !modified {
29✔
84
                                        fi, err = os.Stat(keyPath)
13✔
85
                                        if err == nil {
26✔
86
                                                modified = fi.ModTime().After(result.keyModifiedAt)
13✔
87
                                        } else {
13✔
88
                                                logger.Warningf("reason=stat, label=%s, file=%q, err=[%v]", result.label, keyPath, err)
×
89
                                        }
×
90
                                }
91
                                // reload on modified, or force to reload each hour
92
                                if modified || result.loadedAt.Add(1*time.Hour).Before(time.Now().UTC()) {
19✔
93
                                        err := result.Reload()
3✔
94
                                        if err != nil {
3✔
95
                                                logger.Errorf("label=%s, err=[%+v]", result.label, err)
×
96
                                        }
×
97
                                }
98
                        }
99
                }
100
        }()
101
        result.stopChan = stopChan
4✔
102
        return result, nil
4✔
103
}
104

105
// OnReload allows to add OnReloadFunc handler
106
func (k *KeypairReloader) OnReload(f OnReloadFunc) *KeypairReloader {
3✔
107
        k.lock.Lock()
3✔
108
        defer k.lock.Unlock()
3✔
109

3✔
110
        if f != nil {
6✔
111
                k.handlers = append(k.handlers, f)
3✔
112
        }
3✔
113
        return k
3✔
114
}
115

116
// Reload will explicitly load TLS certs from the disk
117
func (k *KeypairReloader) Reload() error {
17✔
118
        k.lock.Lock()
17✔
119
        if k.inProgress {
17✔
120
                k.lock.Unlock()
×
121
                return nil
×
122
        }
×
123

124
        k.inProgress = true
17✔
125
        defer func() {
34✔
126
                k.inProgress = false
17✔
127
                k.lock.Unlock()
17✔
128
        }()
17✔
129

130
        oldModifiedAt := k.certModifiedAt
17✔
131

17✔
132
        var newCert tls.Certificate
17✔
133
        var err error
17✔
134

17✔
135
        for i := 0; i < 3; i++ {
34✔
136
                // sleep a little as notification occurs right after process starts writing the file,
17✔
137
                // so it needs to finish writing the file
17✔
138
                time.Sleep(100 * time.Millisecond)
17✔
139
                newCert, err = tls.LoadX509KeyPair(k.certPath, k.keyPath)
17✔
140
                if err == nil {
34✔
141
                        break
17✔
142
                }
143
                logger.Warningf("reason=LoadX509KeyPair, label=%s, file=%q, err=[%v]", k.label, k.certPath, err)
×
144
        }
145
        if err != nil {
17✔
146
                return errors.WithMessagef(err, "count: %d", k.count)
×
147
        }
×
148

149
        atomic.AddUint32(&k.count, 1)
17✔
150
        k.loadedAt = time.Now().UTC()
17✔
151

17✔
152
        certFileInfo, err := os.Stat(k.certPath)
17✔
153
        if err == nil {
34✔
154
                k.certModifiedAt = certFileInfo.ModTime()
17✔
155
        } else {
17✔
156
                logger.Warningf("reason=stat, label=%s, file=%q, err=[%v]", k.label, k.certPath, err)
×
157
        }
×
158

159
        keyFileInfo, err := os.Stat(k.keyPath)
17✔
160
        if err == nil {
34✔
161
                k.keyModifiedAt = keyFileInfo.ModTime()
17✔
162
        } else {
17✔
163
                logger.Warningf("reason=stat, label=%s, file=%q, err=[%v]", k.label, k.keyPath, err)
×
164
        }
×
165

166
        logger.Noticef("label=%s, count=%d, cert=%q, modifiedAt=%q",
17✔
167
                k.label, k.count, k.certPath, k.certModifiedAt.Format(time.RFC3339))
17✔
168

17✔
169
        k.keypair = &newCert
17✔
170
        keypair := k.tlsCert()
17✔
171

17✔
172
        if oldModifiedAt != k.certModifiedAt {
24✔
173
                // execute notifications outside of the lock
7✔
174
                for _, h := range k.handlers {
10✔
175
                        go h(keypair)
3✔
176
                }
3✔
177
        }
178

179
        return nil
17✔
180
}
181

182
func (k *KeypairReloader) tlsCert() *tls.Certificate {
20✔
183
        var err error
20✔
184
        kp := k.keypair
20✔
185
        if kp.Leaf == nil && len(kp.Certificate) > 0 {
37✔
186
                kp.Leaf, err = x509.ParseCertificate(kp.Certificate[0])
17✔
187
                if err != nil {
17✔
188
                        logger.Warningf("reason=ParseCertificate, label=%s, err=[%v]", k.label, err)
×
189
                }
×
190
        }
191

192
        if kp.Leaf != nil && kp.Leaf.NotAfter.Add(1*time.Hour).Before(time.Now().UTC()) {
20✔
193
                logger.Warningf("label=%s, count=%d, cert=%q, expires=%q",
×
194
                        k.label, k.count, k.certPath, kp.Leaf.NotAfter.Format(time.RFC3339))
×
195
        }
×
196
        return kp
20✔
197
}
198

199
// GetKeypairFunc is a callback for TLSConfig to provide TLS certificate and key pair for Server
200
func (k *KeypairReloader) GetKeypairFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
3✔
201
        return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
9✔
202
                return k.tlsCert(), nil
6✔
203
        }
6✔
204
}
205

206
// GetClientCertificateFunc is a callback for TLSConfig to provide TLS certificate and key pair for Client
207
func (k *KeypairReloader) GetClientCertificateFunc() func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
2✔
208
        return func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
3✔
209
                return k.tlsCert(), nil
1✔
210
        }
1✔
211
}
212

213
// Keypair returns current pair
214
func (k *KeypairReloader) Keypair() *tls.Certificate {
1✔
215
        if k == nil {
1✔
216
                return nil
×
217
        }
×
218
        k.lock.RLock()
1✔
219
        defer k.lock.RUnlock()
1✔
220

1✔
221
        return k.tlsCert()
1✔
222
}
223

224
// CertAndKeyFiles returns cert and key files
225
func (k *KeypairReloader) CertAndKeyFiles() (string, string) {
1✔
226
        if k == nil {
1✔
227
                return "", ""
×
228
        }
×
229
        k.lock.RLock()
1✔
230
        defer k.lock.RUnlock()
1✔
231

1✔
232
        return k.certPath, k.keyPath
1✔
233
}
234

235
// LoadedAt return the last time when the pair was loaded
236
func (k *KeypairReloader) LoadedAt() time.Time {
3✔
237
        k.lock.RLock()
3✔
238
        defer k.lock.RUnlock()
3✔
239

3✔
240
        return k.loadedAt
3✔
241
}
3✔
242

243
// LoadedCount returns the number of times the pair was loaded from disk
244
func (k *KeypairReloader) LoadedCount() uint32 {
9✔
245
        return atomic.LoadUint32(&k.count)
9✔
246
}
9✔
247

248
// Close will close the reloader and release its resources
249
func (k *KeypairReloader) Close() error {
4✔
250
        if k == nil {
4✔
251
                return nil
×
252
        }
×
253

254
        k.lock.RLock()
4✔
255
        defer k.lock.RUnlock()
4✔
256

4✔
257
        if k.closed {
4✔
258
                return errors.New("already closed")
×
259
        }
×
260

261
        logger.Infof("label=%s, count=%d, cert=%q, key=%q", k.label, k.count, k.certPath, k.keyPath)
4✔
262

4✔
263
        k.closed = true
4✔
264
        k.stopChan <- struct{}{}
4✔
265

4✔
266
        return nil
4✔
267
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc