• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

go-pkgz / pool / 13377913852

17 Feb 2025 08:13PM UTC coverage: 94.32% (+0.4%) from 93.95%
13377913852

push

github

umputun
update comment explaining thread-safety

548 of 581 relevant lines covered (94.32%)

171661.14 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.41
/pool.go
1
package pool
2

3
import (
4
        "context"
5
        "errors"
6
        "fmt"
7
        "hash/fnv"
8
        "math/rand"
9
        "sync"
10
        "sync/atomic"
11
        "time"
12

13
        "golang.org/x/sync/errgroup"
14

15
        "github.com/go-pkgz/pool/metrics"
16
)
17

18
// WorkerGroup represents a pool of workers processing items in parallel.
19
// Supports both direct item processing and batching modes.
20
type WorkerGroup[T any] struct {
21
        poolSize         int                 // number of workers (goroutines)
22
        workerChanSize   int                 // size of worker channels
23
        workerCompleteFn WorkerCompleteFn[T] // completion callback function, called by each worker on completion
24
        poolCompleteFn   GroupCompleteFn[T]  // pool-level completion callback, called once when all workers are done
25
        continueOnError  bool                // don't terminate on first error
26
        chunkFn          func(T) string      // worker selector function
27
        worker           Worker[T]           // worker function
28
        workerMaker      WorkerMaker[T]      // worker maker function
29

30
        metrics *metrics.Value // shared metrics
31

32
        workersCh     []chan T     // workers input channels
33
        sharedCh      chan T       // shared input channel for all workers
34
        activeWorkers atomic.Int32 // track number of active workers
35

36
        // batching support
37
        batchSize     int        // if > 0, accumulate items up to this size
38
        accumulators  [][]T      // per-worker accumulators for batching
39
        workerBatchCh []chan []T // per-worker batch channels
40
        sharedBatchCh chan []T   // shared batch channel
41

42
        eg        *errgroup.Group
43
        activated bool
44
        ctx       context.Context
45

46
        sendMu sync.Mutex
47
}
48

49
// Worker is the interface that wraps the Submit method.
50
type Worker[T any] interface {
51
        Do(ctx context.Context, v T) error
52
}
53

54
// WorkerFunc is an adapter to allow the use of ordinary functions as Workers.
55
type WorkerFunc[T any] func(ctx context.Context, v T) error
56

57
// Do calls f(ctx, v).
58
func (f WorkerFunc[T]) Do(ctx context.Context, v T) error { return f(ctx, v) }
4,023,690✔
59

60
// WorkerMaker is a function that returns a new Worker
61
type WorkerMaker[T any] func() Worker[T]
62

63
// WorkerCompleteFn called on worker completion
64
type WorkerCompleteFn[T any] func(ctx context.Context, id int, worker Worker[T]) error
65

66
// GroupCompleteFn called once when all workers are done
67
type GroupCompleteFn[T any] func(ctx context.Context) error
68

69
// Send func called by worker code to publish results
70
type Send[T any] func(val T) error
71

72
// New creates a worker pool with a shared worker instance.
73
// All goroutines share the same worker, suitable for stateless processing.
74
func New[T any](size int, worker Worker[T]) *WorkerGroup[T] {
63✔
75
        if size < 1 {
63✔
76
                size = 1
×
77
        }
×
78

79
        res := &WorkerGroup[T]{
63✔
80
                poolSize:       size,
63✔
81
                worker:         worker,
63✔
82
                workerChanSize: 1,
63✔
83
                batchSize:      10, // default batch size
63✔
84

63✔
85
                // initialize channels
63✔
86
                workersCh:     make([]chan T, size),
63✔
87
                sharedCh:      make(chan T, size),
63✔
88
                workerBatchCh: make([]chan []T, size),
63✔
89
                sharedBatchCh: make(chan []T, size),
63✔
90
                accumulators:  make([][]T, size),
63✔
91
        }
63✔
92

63✔
93
        // initialize worker's channels
63✔
94
        for i := range size {
191✔
95
                res.workersCh[i] = make(chan T, res.workerChanSize)
128✔
96
                res.workerBatchCh[i] = make(chan []T, res.workerChanSize)
128✔
97
        }
128✔
98

99
        return res
63✔
100
}
101

102
// NewStateful creates a worker pool where each goroutine gets its own worker instance.
103
// Suitable for operations requiring state (e.g., database connections).
104
func NewStateful[T any](size int, maker func() Worker[T]) *WorkerGroup[T] {
4✔
105
        if size < 1 {
4✔
106
                size = 1
×
107
        }
×
108

109
        res := &WorkerGroup[T]{
4✔
110
                poolSize:       size,
4✔
111
                workerMaker:    maker,
4✔
112
                workerChanSize: 1,
4✔
113
                batchSize:      10, // default batch size
4✔
114
                ctx:            context.Background(),
4✔
115

4✔
116
                // initialize channels
4✔
117
                workersCh:     make([]chan T, size),
4✔
118
                sharedCh:      make(chan T, size),
4✔
119
                workerBatchCh: make([]chan []T, size),
4✔
120
                sharedBatchCh: make(chan []T, size),
4✔
121
                accumulators:  make([][]T, size),
4✔
122
        }
4✔
123

4✔
124
        // initialize worker's channels
4✔
125
        for i := range size {
10✔
126
                res.workersCh[i] = make(chan T, res.workerChanSize)
6✔
127
                res.workerBatchCh[i] = make(chan []T, res.workerChanSize)
6✔
128
        }
6✔
129

130
        return res
4✔
131
}
132

133
// WithWorkerChanSize sets channel buffer size for each worker.
134
// Larger sizes can help with bursty workloads but increase memory usage.
135
// Default: 1
136
func (p *WorkerGroup[T]) WithWorkerChanSize(size int) *WorkerGroup[T] {
4✔
137
        p.workerChanSize = size
4✔
138
        if size < 1 {
4✔
139
                p.workerChanSize = 1
×
140
        }
×
141
        return p
4✔
142
}
143

144
// WithWorkerCompleteFn sets callback executed on worker completion.
145
// Useful for cleanup or finalization of worker resources.
146
// Default: none (disabled)
147
func (p *WorkerGroup[T]) WithWorkerCompleteFn(fn WorkerCompleteFn[T]) *WorkerGroup[T] {
3✔
148
        p.workerCompleteFn = fn
3✔
149
        return p
3✔
150
}
3✔
151

152
// WithPoolCompleteFn sets callback executed once when all workers are done
153
func (p *WorkerGroup[T]) WithPoolCompleteFn(fn GroupCompleteFn[T]) *WorkerGroup[T] {
9✔
154
        p.poolCompleteFn = fn
9✔
155
        return p
9✔
156
}
9✔
157

158
// WithChunkFn enables predictable item distribution.
159
// Items with the same key (returned by fn) are processed by the same worker.
160
// Useful for maintaining order within groups of related items.
161
// Default: none (random distribution)
162
func (p *WorkerGroup[T]) WithChunkFn(fn func(T) string) *WorkerGroup[T] {
6✔
163
        p.chunkFn = fn
6✔
164
        return p
6✔
165
}
6✔
166

167
// WithContinueOnError sets whether the pool should continue on error.
168
// Default: false
169
func (p *WorkerGroup[T]) WithContinueOnError() *WorkerGroup[T] {
8✔
170
        p.continueOnError = true
8✔
171
        return p
8✔
172
}
8✔
173

174
// WithBatchSize enables item batching with specified size.
175
// Items are accumulated until batch is full before processing.
176
// Set to 0 to disable batching.
177
// Default: 10
178
func (p *WorkerGroup[T]) WithBatchSize(size int) *WorkerGroup[T] {
21✔
179
        p.batchSize = size
21✔
180
        if size > 0 {
37✔
181
                // initialize accumulators with capacity
16✔
182
                for i := range p.poolSize {
58✔
183
                        p.accumulators[i] = make([]T, 0, size)
42✔
184
                }
42✔
185
        }
186
        return p
21✔
187
}
188

189
// Submit adds an item to the pool for processing. May block if worker channels are full.
190
// Not thread-safe, intended for use by the main thread ot a single producer's thread.
191
func (p *WorkerGroup[T]) Submit(v T) {
4,023,677✔
192
        // check context early
4,023,677✔
193
        select {
4,023,677✔
194
        case <-p.ctx.Done():
×
195
                return // don't submit if context is cancelled
×
196
        default:
4,023,677✔
197
        }
198

199
        if p.batchSize == 0 {
4,023,694✔
200
                // direct submission mode
17✔
201
                if p.chunkFn == nil {
30✔
202
                        p.sharedCh <- v
13✔
203
                        return
13✔
204
                }
13✔
205
                h := fnv.New32a()
4✔
206
                _, _ = h.Write([]byte(p.chunkFn(v)))
4✔
207
                id := int(h.Sum32()) % p.poolSize
4✔
208
                p.workersCh[id] <- v
4✔
209
                return
4✔
210
        }
211

212
        // batching mode
213
        var id int
4,023,660✔
214
        if p.chunkFn != nil {
5,033,680✔
215
                h := fnv.New32a()
1,010,020✔
216
                _, _ = h.Write([]byte(p.chunkFn(v)))
1,010,020✔
217
                id = int(h.Sum32()) % p.poolSize
1,010,020✔
218
        } else {
4,023,660✔
219
                id = rand.Intn(p.poolSize) //nolint:gosec // no need for secure random here
3,013,640✔
220
        }
3,013,640✔
221

222
        // add to accumulator
223
        p.accumulators[id] = append(p.accumulators[id], v)
4,023,660✔
224

4,023,660✔
225
        // check if we should flush
4,023,660✔
226
        var shouldFlush bool
4,023,660✔
227
        select {
4,023,660✔
228
        case <-p.ctx.Done():
×
229
                shouldFlush = true // always flush on context cancellation
×
230
        default:
4,023,660✔
231
                // in normal case, flush only when batch is full
4,023,660✔
232
                shouldFlush = len(p.accumulators[id]) >= p.batchSize
4,023,660✔
233
        }
234

235
        if shouldFlush && len(p.accumulators[id]) > 0 {
4,246,074✔
236
                if p.chunkFn == nil {
433,825✔
237
                        select {
211,411✔
238
                        case p.sharedBatchCh <- p.accumulators[id]:
211,411✔
239
                        case <-p.ctx.Done(): // handle case where channel send would block
×
240
                                return
×
241
                        }
242
                } else {
11,003✔
243
                        select {
11,003✔
244
                        case p.workerBatchCh[id] <- p.accumulators[id]:
11,003✔
245
                        case <-p.ctx.Done():
×
246
                                return
×
247
                        }
248
                }
249
                p.accumulators[id] = make([]T, 0, p.batchSize)
222,414✔
250
        }
251
}
252

253
// Send adds an item to the pool for processing.
254
// Safe for concurrent use, intended for worker-to-pool submissions or for use by multiple concurrent producers.
255
func (p *WorkerGroup[T]) Send(v T) {
1,200✔
256
        p.sendMu.Lock()
1,200✔
257
        defer p.sendMu.Unlock()
1,200✔
258
        p.Submit(v)
1,200✔
259
}
1,200✔
260

261
// Go activates the pool and starts worker goroutines.
262
// Must be called before submitting items.
263
func (p *WorkerGroup[T]) Go(ctx context.Context) error {
67✔
264
        if p.activated {
67✔
265
                return fmt.Errorf("workers poll already activated")
×
266
        }
×
267
        defer func() { p.activated = true }()
134✔
268

269
        var egCtx context.Context
67✔
270
        p.eg, egCtx = errgroup.WithContext(ctx)
67✔
271
        p.ctx = egCtx
67✔
272

67✔
273
        // create metrics context for all workers
67✔
274
        metricsCtx := metrics.Make(egCtx, p.poolSize)
67✔
275
        p.metrics = metrics.Get(metricsCtx)
67✔
276

67✔
277
        // set initial count
67✔
278
        p.activeWorkers.Store(int32(p.poolSize)) //nolint:gosec // no risk of overflow
67✔
279

67✔
280
        // start all goroutines (workers)
67✔
281
        for i := range p.poolSize {
201✔
282
                withWorkerIDctx := metrics.WithWorkerID(metricsCtx, i)
134✔
283
                workerCh := p.sharedCh
134✔
284
                batchCh := p.sharedBatchCh
134✔
285
                if p.chunkFn != nil {
152✔
286
                        workerCh = p.workersCh[i]
18✔
287
                        batchCh = p.workerBatchCh[i]
18✔
288
                }
18✔
289
                r := workerRequest[T]{inCh: workerCh, batchCh: batchCh, m: p.metrics, id: i}
134✔
290
                p.eg.Go(p.workerProc(withWorkerIDctx, r))
134✔
291
        }
292

293
        return nil
67✔
294
}
295

296
// workerRequest is a request to worker goroutine containing all necessary data
297
type workerRequest[T any] struct {
298
        inCh    <-chan T
299
        batchCh <-chan []T
300
        m       *metrics.Value
301
        id      int
302
}
303

304
// workerProc is a worker goroutine function, reads from the input channel and processes records
305
func (p *WorkerGroup[T]) workerProc(wCtx context.Context, r workerRequest[T]) func() error {
134✔
306
        return func() error {
268✔
307
                var lastErr error
134✔
308
                var totalErrs int
134✔
309

134✔
310
                initEndTmr := r.m.StartTimer(r.id, metrics.TimerInit)
134✔
311
                worker := p.worker
134✔
312
                if p.workerMaker != nil {
140✔
313
                        worker = p.workerMaker()
6✔
314
                }
6✔
315
                initEndTmr()
134✔
316

134✔
317
                lastActivity := time.Now()
134✔
318

134✔
319
                // processItem handles a single item with metrics
134✔
320
                processItem := func(v T) error {
151✔
321
                        waitTime := time.Since(lastActivity)
17✔
322
                        r.m.AddWaitTime(r.id, waitTime)
17✔
323
                        lastActivity = time.Now()
17✔
324

17✔
325
                        procEndTmr := r.m.StartTimer(r.id, metrics.TimerProc)
17✔
326
                        defer procEndTmr()
17✔
327

17✔
328
                        if err := worker.Do(wCtx, v); err != nil {
19✔
329
                                r.m.IncErrors(r.id)
2✔
330
                                totalErrs++
2✔
331
                                if !p.continueOnError {
4✔
332
                                        return fmt.Errorf("worker %d failed: %w", r.id, err)
2✔
333
                                }
2✔
334
                                lastErr = fmt.Errorf("worker %d failed: %w", r.id, err)
×
335
                                return nil // continue on error
×
336
                        }
337
                        r.m.IncProcessed(r.id)
15✔
338
                        return nil
15✔
339
                }
340

341
                // processBatch handles batch of items
342
                processBatch := func(items []T) error {
222,648✔
343
                        waitTime := time.Since(lastActivity)
222,514✔
344
                        r.m.AddWaitTime(r.id, waitTime)
222,514✔
345
                        lastActivity = time.Now()
222,514✔
346

222,514✔
347
                        procEndTmr := r.m.StartTimer(r.id, metrics.TimerProc)
222,514✔
348
                        defer procEndTmr()
222,514✔
349

222,514✔
350
                        for _, v := range items {
4,246,170✔
351
                                if err := worker.Do(wCtx, v); err != nil {
4,023,727✔
352
                                        r.m.IncErrors(r.id)
71✔
353
                                        totalErrs++
71✔
354
                                        if !p.continueOnError {
80✔
355
                                                return fmt.Errorf("worker %d failed: %w", r.id, err)
9✔
356
                                        }
9✔
357
                                        lastErr = fmt.Errorf("worker %d failed: %w", r.id, err)
62✔
358
                                        continue
62✔
359
                                }
360
                                r.m.IncProcessed(r.id)
4,023,585✔
361
                        }
362
                        return nil
222,505✔
363
                }
364

365
                // track if channels are closed
366
                normalClosed := false
134✔
367
                batchClosed := false
134✔
368

134✔
369
                // main processing loop
134✔
370
                for {
223,214✔
371
                        if normalClosed && batchClosed {
223,202✔
372
                                return p.finishWorker(wCtx, r.id, worker, lastErr, totalErrs)
122✔
373
                        }
122✔
374

375
                        select {
222,958✔
376
                        case <-wCtx.Done():
1✔
377
                                return p.finishWorker(wCtx, r.id, worker, wCtx.Err(), totalErrs)
1✔
378

379
                        case v, ok := <-r.inCh:
271✔
380
                                if !ok {
525✔
381
                                        normalClosed = true
254✔
382
                                        continue
254✔
383
                                }
384
                                if err := processItem(v); err != nil {
19✔
385
                                        return p.finishWorker(wCtx, r.id, worker, err, totalErrs)
2✔
386
                                }
2✔
387

388
                        case batch, ok := <-r.batchCh:
222,686✔
389
                                if !ok {
222,858✔
390
                                        batchClosed = true
172✔
391
                                        continue
172✔
392
                                }
393
                                if err := processBatch(batch); err != nil {
222,523✔
394
                                        return p.finishWorker(wCtx, r.id, worker, err, totalErrs)
9✔
395
                                }
9✔
396
                        }
397
                }
398
        }
399
}
400

401
// finishWorker handles worker completion logic
402
func (p *WorkerGroup[T]) finishWorker(ctx context.Context, id int, worker Worker[T], lastErr error, totalErrs int) error {
134✔
403
        // worker completion should be called only if we are continuing on error or no error
134✔
404
        if p.workerCompleteFn != nil && (lastErr == nil || p.continueOnError) {
138✔
405
                wrapFinTmr := p.metrics.StartTimer(id, metrics.TimerWrap)
4✔
406
                if e := p.workerCompleteFn(ctx, id, worker); e != nil {
5✔
407
                        if lastErr == nil {
2✔
408
                                lastErr = fmt.Errorf("complete worker func for %d failed: %w", id, e)
1✔
409
                        }
1✔
410
                }
411
                wrapFinTmr()
4✔
412
        }
413

414
        activeWorkers := p.activeWorkers.Add(-1)
134✔
415

134✔
416
        // pool completion should be called when this is the last worker
134✔
417
        // regardless of error state, except for context cancellation
134✔
418
        if activeWorkers == 0 && p.poolCompleteFn != nil && !errors.Is(lastErr, context.Canceled) {
142✔
419
                if e := p.poolCompleteFn(ctx); e != nil {
9✔
420
                        if lastErr == nil {
2✔
421
                                lastErr = fmt.Errorf("complete pool func for %d failed: %w", id, e)
1✔
422
                        }
1✔
423
                }
424
        }
425

426
        if lastErr != nil {
157✔
427
                return fmt.Errorf("total errors: %d, last error: %w", totalErrs, lastErr)
23✔
428
        }
23✔
429
        return nil
111✔
430
}
431

432
// Close pool. Has to be called by consumer as the indication of "all records submitted".
433
// The call is blocking till all processing completed by workers. After this call poll can't be reused.
434
// Returns an error if any happened during the run
435
func (p *WorkerGroup[T]) Close(ctx context.Context) error {
67✔
436
        // if context canceled, return immediately
67✔
437
        switch {
67✔
438
        case ctx.Err() != nil:
×
439
                return ctx.Err()
×
440
        default:
67✔
441
        }
442

443
        // flush any remaining items in accumulators
444
        if p.batchSize > 0 {
129✔
445
                for i, acc := range p.accumulators {
187✔
446
                        if len(acc) > 0 {
227✔
447
                                // ensure we flush any non-empty accumulator, regardless of size
102✔
448
                                if p.chunkFn == nil {
200✔
449
                                        p.sharedBatchCh <- acc
98✔
450
                                } else {
102✔
451
                                        p.workerBatchCh[i] <- acc
4✔
452
                                }
4✔
453
                                p.accumulators[i] = nil // help GC
102✔
454
                        }
455
                }
456
        }
457

458
        close(p.sharedCh)
67✔
459
        close(p.sharedBatchCh)
67✔
460
        for i := range p.poolSize {
201✔
461
                close(p.workersCh[i])
134✔
462
                close(p.workerBatchCh[i])
134✔
463
        }
134✔
464
        return p.eg.Wait()
67✔
465
}
466

467
// Wait till workers completed and the result channel closed.
468
func (p *WorkerGroup[T]) Wait(ctx context.Context) error {
6✔
469
        // if context canceled, return immediately
6✔
470
        switch {
6✔
471
        case ctx.Err() != nil:
×
472
                return ctx.Err()
×
473
        default:
6✔
474
        }
475
        return p.eg.Wait()
6✔
476
}
477

478
// Metrics returns combined metrics from all workers
479
func (p *WorkerGroup[T]) Metrics() *metrics.Value {
12✔
480
        return p.metrics
12✔
481
}
12✔
482

483
// Middleware wraps worker and adds functionality
484
type Middleware[T any] func(Worker[T]) Worker[T]
485

486
// Use applies middlewares to the worker group's worker. Middlewares are applied
487
// in the same order as they are provided, matching the HTTP middleware pattern in Go.
488
// The first middleware is the outermost wrapper, and the last middleware is the
489
// innermost wrapper closest to the original worker.
490
func (p *WorkerGroup[T]) Use(middlewares ...Middleware[T]) *WorkerGroup[T] {
7✔
491
        if len(middlewares) == 0 {
7✔
492
                return p
×
493
        }
×
494

495
        // if we have a worker maker (stateful), wrap it
496
        if p.workerMaker != nil {
8✔
497
                originalMaker := p.workerMaker
1✔
498
                p.workerMaker = func() Worker[T] {
2✔
499
                        worker := originalMaker()
1✔
500
                        // apply middlewares in order from last to first
1✔
501
                        // this makes first middleware outermost
1✔
502
                        wrapped := worker
1✔
503
                        for i := len(middlewares) - 1; i >= 0; i-- {
2✔
504
                                prev := wrapped
1✔
505
                                wrapped = middlewares[i](prev)
1✔
506
                        }
1✔
507
                        return wrapped
1✔
508
                }
509
                return p
1✔
510
        }
511

512
        // for stateless worker, just wrap it directly
513
        wrapped := p.worker
6✔
514
        for i := len(middlewares) - 1; i >= 0; i-- {
15✔
515
                prev := wrapped
9✔
516
                wrapped = middlewares[i](prev)
9✔
517
        }
9✔
518
        p.worker = wrapped
6✔
519
        return p
6✔
520
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc