• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

go-pkgz / pool / 13210176460

08 Feb 2025 12:12AM UTC coverage: 90.421% (-1.9%) from 92.337%
13210176460

push

github

umputun
Reorganize "Install and update" section in README.md

236 of 261 relevant lines covered (90.42%)

2105.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.5
/pool.go
1
package pool
2

3
import (
4
        "context"
5
        "fmt"
6
        "hash/fnv"
7
        "math/rand"
8
        "sync"
9

10
        "golang.org/x/sync/errgroup"
11

12
        "github.com/go-pkgz/pool/metrics"
13
)
14

15
// WorkerGroup is a simple case of flow with a single stage only running a common function in workers pool.
16
// IN type if for input (submitted) records, OUT type is for output records in case if worker function should
17
// return some values.
18
type WorkerGroup[T any] struct {
19
        poolSize  int // number of workers (goroutines)
20
        batchSize int // size of batch sends to workers
21

22
        chunkFn         func(T) string // worker selector function
23
        workerChanSize  int            // size of worker channels
24
        worker          Worker[T]      // worker function
25
        workerMaker     WorkerMaker[T] // worker maker function
26
        completeFn      CompleteFn[T]  // completion callback function, called by each worker on completion
27
        continueOnError bool           // don't terminate on first error
28

29
        buf        [][]T             // batch buffers for workers
30
        workersCh  []chan []T        // workers input channels
31
        workerCtxs []context.Context // store worker contexts
32

33
        eg  *errgroup.Group
34
        err struct {
35
                sync.Once
36
                ch chan struct{}
37
        }
38

39
        activated bool
40
        ctx       context.Context
41
}
42

43
// Worker is the interface that wraps the Submit method.
44
type Worker[T any] interface {
45
        Do(ctx context.Context, v T) error
46
}
47

48
// WorkerFunc is an adapter to allow the use of ordinary functions as Workers.
49
type WorkerFunc[T any] func(ctx context.Context, v T) error
50

51
// Do calls f(ctx, v).
52
func (f WorkerFunc[T]) Do(ctx context.Context, v T) error { return f(ctx, v) }
20,206✔
53

54
// WorkerMaker is a function that returns a new Worker
55
type WorkerMaker[T any] func() Worker[T]
56

57
// CompleteFn called (optionally) on worker completion
58
type CompleteFn[T any] func(ctx context.Context, id int, worker Worker[T]) error
59

60
// Send func called by worker code to publish results
61
type Send[T any] func(val T) error
62

63
// New creates a worker pool with a shared, stateless worker.
64
// Size defines the number of goroutines (workers) processing requests.
65
func New[T any](size int, worker Worker[T], opts ...Option[T]) (*WorkerGroup[T], error) {
32✔
66
        if size < 1 {
32✔
67
                size = 1
×
68
        }
×
69
        if worker == nil {
32✔
70
                return nil, fmt.Errorf("worker cannot be nil")
×
71
        }
×
72

73
        res := &WorkerGroup[T]{
32✔
74
                poolSize:       size,
32✔
75
                workersCh:      make([]chan []T, size),
32✔
76
                buf:            make([][]T, size),
32✔
77
                workerCtxs:     make([]context.Context, size),
32✔
78
                worker:         worker,
32✔
79
                batchSize:      1,
32✔
80
                workerChanSize: 1,
32✔
81
                ctx:            context.Background(),
32✔
82
        }
32✔
83
        res.err.ch = make(chan struct{})
32✔
84

32✔
85
        // apply all options
32✔
86
        for _, opt := range opts {
51✔
87
                opt(res)
19✔
88
        }
19✔
89

90
        // initialize worker's channels and batch buffers
91
        for id := 0; id < size; id++ {
85✔
92
                res.workersCh[id] = make(chan []T, res.workerChanSize)
53✔
93
                if res.batchSize > 1 {
60✔
94
                        res.buf[id] = make([]T, 0, size)
7✔
95
                }
7✔
96
        }
97

98
        return res, nil
32✔
99
}
100

101
// NewStateful creates a worker pool with a separate worker instance for each goroutine.
102
// Size defines number of goroutines (workers) processing requests.
103
// Maker function is called for each goroutine to create a new worker instance.
104
func NewStateful[T any](size int, maker func() Worker[T], opts ...Option[T]) (*WorkerGroup[T], error) {
3✔
105
        if size < 1 {
3✔
106
                size = 1
×
107
        }
×
108
        if maker == nil {
3✔
109
                return nil, fmt.Errorf("worker maker cannot be nil")
×
110
        }
×
111

112
        res := &WorkerGroup[T]{
3✔
113
                poolSize:       size,
3✔
114
                workersCh:      make([]chan []T, size),
3✔
115
                buf:            make([][]T, size),
3✔
116
                workerCtxs:     make([]context.Context, size),
3✔
117
                workerMaker:    maker,
3✔
118
                batchSize:      1,
3✔
119
                workerChanSize: 1,
3✔
120
                ctx:            context.Background(),
3✔
121
        }
3✔
122
        res.err.ch = make(chan struct{})
3✔
123

3✔
124
        // apply all options
3✔
125
        for _, opt := range opts {
5✔
126
                opt(res)
2✔
127
        }
2✔
128

129
        // initialize worker's channels and batch buffers
130
        for id := 0; id < size; id++ {
9✔
131
                res.workersCh[id] = make(chan []T, res.workerChanSize)
6✔
132
                if res.batchSize > 1 {
8✔
133
                        res.buf[id] = make([]T, 0, size)
2✔
134
                }
2✔
135
        }
136

137
        return res, nil
3✔
138
}
139

140
// Submit record to pool, can be blocked if worker channels are full
141
func (p *WorkerGroup[T]) Submit(v T) {
20,200✔
142
        // randomize distribution by default
20,200✔
143
        id := rand.Intn(p.poolSize) //nolint:gosec // no need for secure random here, just distribution
20,200✔
144
        if p.chunkFn != nil {
30,214✔
145
                // chunked distribution
10,014✔
146
                h := fnv.New32a()
10,014✔
147
                _, _ = h.Write([]byte(p.chunkFn(v)))
10,014✔
148
                id = int(h.Sum32()) % p.poolSize
10,014✔
149
        }
10,014✔
150

151
        if p.batchSize <= 1 {
40,380✔
152
                // skip all buffering if batch size is 1 or less
20,180✔
153
                p.workersCh[id] <- append([]T{}, v)
20,180✔
154
                return
20,180✔
155
        }
20,180✔
156

157
        if !p.continueOnError {
40✔
158
                select {
20✔
159
                case <-p.err.ch: // closed due to worker error
×
160
                        return
×
161
                default:
20✔
162
                }
163
        }
164

165
        p.buf[id] = append(p.buf[id], v)   // add to batch buffer
20✔
166
        if len(p.buf[id]) >= p.batchSize { // submit buffer to workers
27✔
167
                // commit copy to workers
7✔
168
                cp := make([]T, len(p.buf[id]))
7✔
169
                copy(cp, p.buf[id])
7✔
170
                p.workersCh[id] <- cp
7✔
171
                p.buf[id] = p.buf[id][:0] // reset size, keep capacity
7✔
172
        }
7✔
173
}
174

175
// Go activates worker pool, closes cursor on completion
176
func (p *WorkerGroup[T]) Go(ctx context.Context) error {
32✔
177
        if p.activated {
32✔
178
                return fmt.Errorf("workers poll already activated")
×
179
        }
×
180
        defer func() { p.activated = true }()
64✔
181

182
        // Create errgroup first
183
        var egCtx context.Context
32✔
184
        p.eg, egCtx = errgroup.WithContext(ctx)
32✔
185
        p.ctx = egCtx
32✔
186

32✔
187
        // start all goroutines
32✔
188
        for i := 0; i < p.poolSize; i++ {
86✔
189
                workerCtx := metrics.Make(metrics.WithWorkerID(egCtx, i))
54✔
190
                p.workerCtxs[i] = workerCtx
54✔
191
                p.eg.Go(p.workerProc(workerCtx, i, p.workersCh[i]))
54✔
192
        }
54✔
193

194
        return nil
32✔
195
}
196

197
// workerProc is a worker goroutine function, reads from the input channel and processes records
198
func (p *WorkerGroup[T]) workerProc(wCtx context.Context, id int, inCh chan []T) func() error {
54✔
199
        return func() error {
108✔
200
                var lastErr error
54✔
201
                var totalErrs int
54✔
202

54✔
203
                m := metrics.Get(wCtx)
54✔
204

54✔
205
                // get worker instance based on mode
54✔
206
                var worker Worker[T]
54✔
207
                if p.worker != nil {
104✔
208
                        worker = p.worker // use shared worker for stateless mode
50✔
209
                } else {
54✔
210
                        worker = p.workerMaker() // create new worker instance for stateful mode
4✔
211
                }
4✔
212

213
                // track initialization time
214
                initEndTmr := m.StartTimer(metrics.DurationInit)
54✔
215
                initEndTmr()
54✔
216

54✔
217
                for {
20,295✔
218
                        select {
20,241✔
219
                        case vv, ok := <-inCh:
20,235✔
220
                                if !ok { // input channel closed
20,283✔
221
                                        wrapEndTmr := m.StartTimer(metrics.DurationWrap)
48✔
222
                                        e := p.finalizeWorker(wCtx, id, worker)
48✔
223
                                        wrapEndTmr()
48✔
224
                                        if e != nil {
48✔
225
                                                return e
×
226
                                        }
×
227
                                        if lastErr != nil {
55✔
228
                                                return fmt.Errorf("total errors: %d, last error: %w", totalErrs, lastErr)
7✔
229
                                        }
7✔
230
                                        return nil
41✔
231
                                }
232

233
                                // track wait time - from when item was received till processing starts
234
                                waitEndTmr := m.StartTimer(metrics.DurationWait)
20,187✔
235

20,187✔
236
                                // read from the input slice
20,187✔
237
                                for _, v := range vv {
40,384✔
238
                                        // even if not continue on error has to read from input channel all it has
20,197✔
239
                                        if lastErr != nil && !p.continueOnError {
20,198✔
240
                                                m.Inc(metrics.CountDropped)
1✔
241
                                                continue
1✔
242
                                        }
243

244
                                        if err := worker.Do(wCtx, v); err != nil {
20,207✔
245
                                                m.Inc(metrics.CountErrors)
11✔
246
                                                e := fmt.Errorf("worker %d failed: %w", id, err)
11✔
247
                                                if !p.continueOnError {
15✔
248
                                                        // close err.ch once. indicates to Submit what all other records can be ignored
4✔
249
                                                        p.err.Do(func() { close(p.err.ch) })
8✔
250
                                                }
251
                                                totalErrs++
11✔
252
                                                lastErr = e // errors allowed to continue, capture the last error only
11✔
253
                                        }
254
                                }
255
                                waitEndTmr()
20,187✔
256

257
                        case <-p.ctx.Done(): // parent context, passed by caller
2✔
258
                                return p.ctx.Err()
2✔
259

260
                        case <-wCtx.Done(): // worker context from errgroup
×
261
                                // triggered by other worker, kill only if errors not allowed
×
262
                                if !p.continueOnError {
×
263
                                        return wCtx.Err()
×
264
                                }
×
265
                        }
266
                }
267
        }
268
}
269

270
// finWorker worker flushes records left in buffer to workers, called once for each worker
271
// if completeFn allowed, will be called as well
272
func (p *WorkerGroup[T]) finalizeWorker(ctx context.Context, id int, worker Worker[T]) (err error) {
53✔
273
        // process all requests left in the not submitted yet buffer
53✔
274
        for _, v := range p.buf[id] {
63✔
275
                if e := worker.Do(ctx, v); e != nil {
14✔
276
                        if !p.continueOnError {
7✔
277
                                return fmt.Errorf("worker %d failed in finalizer: %w", id, e)
3✔
278
                        }
3✔
279
                }
280
        }
281

282
        // call completeFn for given worker id
283
        if p.completeFn != nil {
53✔
284
                if e := p.completeFn(ctx, id, worker); e != nil {
4✔
285
                        err = fmt.Errorf("complete func for %d failed: %w", id, e)
1✔
286
                }
1✔
287
        }
288

289
        return err
50✔
290
}
291

292
// Close pool. Has to be called by consumer as the indication of "all records submitted".
293
// The call is blocking till all processing completed by workers. After this call poll can't be reused.
294
// Returns an error if any happened during the run
295
func (p *WorkerGroup[T]) Close(ctx context.Context) (err error) {
27✔
296
        for _, ch := range p.workersCh {
76✔
297
                close(ch)
49✔
298
        }
49✔
299

300
        doneCh := make(chan error)
27✔
301
        go func() {
54✔
302
                doneCh <- p.eg.Wait()
27✔
303
        }()
27✔
304

305
        for {
54✔
306
                select {
27✔
307
                case err := <-doneCh:
27✔
308
                        return err
27✔
309
                case <-ctx.Done():
×
310
                        return ctx.Err()
×
311
                }
312
        }
313
}
314

315
// Wait till workers completed and result channel closed. This can be used instead of the cursor
316
// in case if the result channel can be ignored and the goal is to wait for the completion.
317
func (p *WorkerGroup[T]) Wait(ctx context.Context) (err error) {
2✔
318
        doneCh := make(chan error)
2✔
319
        go func() {
4✔
320
                doneCh <- p.eg.Wait()
2✔
321
        }()
2✔
322

323
        for {
4✔
324
                select {
2✔
325
                case err := <-doneCh:
2✔
326
                        return err
2✔
327
                case <-ctx.Done():
×
328
                        return ctx.Err()
×
329
                }
330
        }
331
}
332

333
// Metrics returns combined metrics from all workers
334
func (p *WorkerGroup[T]) Metrics() *metrics.Value {
5✔
335
        values := make([]*metrics.Value, p.poolSize)
5✔
336
        for i := 0; i < p.poolSize; i++ {
15✔
337
                values[i] = metrics.Get(p.workerCtxs[i])
10✔
338
        }
10✔
339
        return metrics.Aggregate(values...)
5✔
340
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc