• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

knights-analytics / hugot / 20024020731

08 Dec 2025 09:56AM UTC coverage: 55.81%. Remained the same
20024020731

push

github

riccardopinosio
add pre tokenization option to token classification pipeline

76 of 103 new or added lines in 1 file covered. (73.79%)

116 existing lines in 2 files now uncovered.

2584 of 4630 relevant lines covered (55.81%)

12236.29 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

58.61
/pipelines/tokenClassification.go
1
package pipelines
2

3
import (
4
        "errors"
5
        "fmt"
6
        "slices"
7
        "strings"
8
        "sync/atomic"
9
        "time"
10

11
        "github.com/knights-analytics/hugot/backends"
12
        "github.com/knights-analytics/hugot/options"
13
        "github.com/knights-analytics/hugot/util/safeconv"
14
        "github.com/knights-analytics/hugot/util/vectorutil"
15
)
16

17
// TokenClassificationPipeline is a go version of huggingface tokenClassificationPipeline.
18
// https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/token_classification.py
19
type TokenClassificationPipeline struct {
20
        *backends.BasePipeline
21
        IDLabelMap          map[int]string
22
        AggregationStrategy string
23
        IgnoreLabels        []string
24
        SplitWords          bool
25
}
26
type Entity struct {
27
        Entity    string
28
        Word      string
29
        Scores    []float32
30
        TokenID   []uint32
31
        Index     int
32
        Start     uint
33
        End       uint
34
        Score     float32
35
        IsSubword bool
36
}
37
type TokenClassificationOutput struct {
38
        Entities [][]Entity
39
}
40

UNCOV
41
func (t *TokenClassificationOutput) GetOutput() []any {
×
42
        out := make([]any, len(t.Entities))
×
43
        for i, entity := range t.Entities {
×
44
                out[i] = any(entity)
×
45
        }
×
46
        return out
×
47
}
48

49
// options
50

51
// WithSimpleAggregation sets the aggregation strategy for the token labels to simple
52
// It reproduces simple aggregation from the huggingface implementation.
53
func WithSimpleAggregation() backends.PipelineOption[*TokenClassificationPipeline] {
6✔
54
        return func(pipeline *TokenClassificationPipeline) error {
12✔
55
                pipeline.AggregationStrategy = "SIMPLE"
6✔
56
                return nil
6✔
57
        }
6✔
58
}
59

60
// WithAverageAggregation sets the aggregation strategy for the token labels to average
61
// It reproduces simple aggregation from the huggingface implementation.
UNCOV
62
func WithAverageAggregation() backends.PipelineOption[*TokenClassificationPipeline] {
×
63
        return func(pipeline *TokenClassificationPipeline) error {
×
64
                pipeline.AggregationStrategy = "AVERAGE"
×
65
                return nil
×
66
        }
×
67
}
68

69
// WithMaxAggregation sets the aggregation strategy for the token labels to Max
70
// It reproduces max aggregation from the huggingface implementation.
UNCOV
71
func WithMaxAggregation() backends.PipelineOption[*TokenClassificationPipeline] {
×
72
        return func(pipeline *TokenClassificationPipeline) error {
×
73
                pipeline.AggregationStrategy = "MAX"
×
74
                return nil
×
75
        }
×
76
}
77

78
// WithFirstAggregation sets the aggregation strategy for the token labels to first
79
// It reproduces first aggregation from the huggingface implementation.
UNCOV
80
func WithFirstAggregation() backends.PipelineOption[*TokenClassificationPipeline] {
×
81
        return func(pipeline *TokenClassificationPipeline) error {
×
82
                pipeline.AggregationStrategy = "FIRST"
×
83
                return nil
×
84
        }
×
85
}
86

87
// WithoutAggregation returns the token labels.
88
func WithoutAggregation() backends.PipelineOption[*TokenClassificationPipeline] {
1✔
89
        return func(pipeline *TokenClassificationPipeline) error {
2✔
90
                pipeline.AggregationStrategy = "NONE"
1✔
91
                return nil
1✔
92
        }
1✔
93
}
94

95
func WithIgnoreLabels(ignoreLabels []string) backends.PipelineOption[*TokenClassificationPipeline] {
6✔
96
        return func(pipeline *TokenClassificationPipeline) error {
12✔
97
                pipeline.IgnoreLabels = ignoreLabels
6✔
98
                return nil
6✔
99
        }
6✔
100
}
101

102
// WithSplitWords enables word-level alignment like Hugging Face's is_split_into_words.
103
func WithSplitWords() backends.PipelineOption[*TokenClassificationPipeline] {
1✔
104
        return func(pipeline *TokenClassificationPipeline) error {
2✔
105
                pipeline.SplitWords = true
1✔
106
                return nil
1✔
107
        }
1✔
108
}
109

110
// NewTokenClassificationPipeline Initializes a feature extraction pipeline.
111
func NewTokenClassificationPipeline(config backends.PipelineConfig[*TokenClassificationPipeline], s *options.Options, model *backends.Model) (*TokenClassificationPipeline, error) {
7✔
112
        defaultPipeline, err := backends.NewBasePipeline(config, s, model)
7✔
113
        if err != nil {
7✔
UNCOV
114
                return nil, err
×
UNCOV
115
        }
×
116
        pipeline := &TokenClassificationPipeline{BasePipeline: defaultPipeline}
7✔
117
        for _, o := range config.Options {
20✔
118
                err = o(pipeline)
13✔
119
                if err != nil {
13✔
UNCOV
120
                        return nil, err
×
UNCOV
121
                }
×
122
        }
123
        // Id label map
124
        pipeline.IDLabelMap = model.IDLabelMap
7✔
125
        // default strategies if not set
7✔
126
        if pipeline.AggregationStrategy == "" {
7✔
UNCOV
127
                pipeline.AggregationStrategy = "SIMPLE"
×
UNCOV
128
        }
×
129
        if len(pipeline.IgnoreLabels) == 0 {
8✔
130
                pipeline.IgnoreLabels = []string{"O"}
1✔
131
        }
1✔
132
        // Additional options needed for postprocessing
133
        backends.AllInputTokens(pipeline.BasePipeline)
7✔
134
        err = pipeline.Validate()
7✔
135
        if err != nil {
7✔
136
                return nil, err
×
137
        }
×
138
        return pipeline, nil
7✔
139
}
140

141
// INTERFACE IMPLEMENTATION
142

UNCOV
143
func (p *TokenClassificationPipeline) GetModel() *backends.Model {
×
UNCOV
144
        return p.Model
×
145
}
×
146

147
// GetMetadata returns metadata information about the pipeline, in particular:
148
// OutputInfo: names and dimensions of the output layer used for token classification.
UNCOV
149
func (p *TokenClassificationPipeline) GetMetadata() backends.PipelineMetadata {
×
UNCOV
150
        return backends.PipelineMetadata{
×
UNCOV
151
                OutputsInfo: []backends.OutputInfo{
×
152
                        {
×
153
                                Name:       p.Model.OutputsMeta[0].Name,
×
154
                                Dimensions: p.Model.OutputsMeta[0].Dimensions,
×
UNCOV
155
                        },
×
UNCOV
156
                },
×
UNCOV
157
        }
×
158
}
×
159

160
// GetStatistics returns the runtime statistics for the pipeline.
161
func (p *TokenClassificationPipeline) GetStatistics() backends.PipelineStatistics {
×
162
        statistics := backends.PipelineStatistics{}
×
163
        statistics.ComputeTokenizerStatistics(p.Model.Tokenizer.TokenizerTimings)
×
164
        statistics.ComputeOnnxStatistics(p.PipelineTimings)
×
165
        return statistics
×
166
}
×
167

168
// Validate checks that the pipeline is valid.
169
func (p *TokenClassificationPipeline) Validate() error {
9✔
170
        var validationErrors []error
9✔
171
        if p.Model.Tokenizer == nil {
9✔
172
                validationErrors = append(validationErrors, fmt.Errorf("token classification pipeline requires a tokenizer"))
×
173
        }
×
174
        outputDim := p.Model.OutputsMeta[0].Dimensions
9✔
175
        if len(outputDim) != 3 {
9✔
UNCOV
176
                validationErrors = append(validationErrors,
×
UNCOV
177
                        fmt.Errorf("output for token classification must be three dimensional (input, sequence, logits)"))
×
UNCOV
178
        }
×
179
        if outputDim[len(outputDim)-1] == -1 {
10✔
180
                validationErrors = append(validationErrors,
1✔
181
                        fmt.Errorf("logit dimension cannot be dynamic"))
1✔
182
        }
1✔
183
        if len(p.IDLabelMap) <= 0 {
10✔
184
                validationErrors = append(validationErrors, fmt.Errorf("p configuration invalid: length of id2label map for token classification p must be greater than zero"))
1✔
185
        }
1✔
186
        return errors.Join(validationErrors...)
9✔
187
}
188

189
// Preprocess tokenizes the input strings.
190
func (p *TokenClassificationPipeline) Preprocess(batch *backends.PipelineBatch, inputs []string) error {
6✔
191
        if p.SplitWords {
6✔
NEW
192
                return fmt.Errorf("split-words enabled: use RunWords/PreprocessWords for [][]string inputs")
×
NEW
193
        }
×
194
        start := time.Now()
6✔
195
        backends.TokenizeInputs(batch, p.Model.Tokenizer, inputs)
6✔
196
        atomic.AddUint64(&p.Model.Tokenizer.TokenizerTimings.NumCalls, 1)
6✔
197
        atomic.AddUint64(&p.Model.Tokenizer.TokenizerTimings.TotalNS, safeconv.DurationToU64(time.Since(start)))
6✔
198
        err := backends.CreateInputTensors(batch, p.Model, p.Runtime)
6✔
199
        return err
6✔
200
}
201

202
// PreprocessWords tokenizes pre-split words and maps tokens to word IDs via offsets.
203
func (p *TokenClassificationPipeline) PreprocessWords(batch *backends.PipelineBatch, inputs [][]string) error {
4✔
204
        start := time.Now()
4✔
205
        // Join words with single spaces to simulate pretokenized behavior
4✔
206
        joined := make([]string, len(inputs))
4✔
207
        wordBoundaries := make([][][2]uint, len(inputs))
4✔
208
        // local helper to convert non-negative int to uint safely
4✔
209
        toUintNonNeg := func(i int) uint {
41✔
210
                if i < 0 {
37✔
NEW
211
                        return 0
×
NEW
212
                }
×
213
                return uint(i)
37✔
214
        }
215
        for i, words := range inputs {
8✔
216
                joined[i] = strings.Join(words, " ")
4✔
217
                // compute boundaries in joined string
4✔
218
                var boundaries [][2]uint
4✔
219
                pos := 0
4✔
220
                for wIdx, w := range words {
23✔
221
                        startPos := pos
19✔
222
                        endPos := pos + len(w)
19✔
223
                        // clamp to non-negative and convert safely to uint
19✔
224
                        // ensure non-negative before converting to uint
19✔
225
                        if startPos < 0 {
19✔
NEW
226
                                startPos = 0
×
NEW
227
                        }
×
228
                        if endPos < 0 {
19✔
NEW
229
                                endPos = 0
×
NEW
230
                        }
×
231
                        boundaries = append(boundaries, [2]uint{toUintNonNeg(startPos), toUintNonNeg(endPos)})
19✔
232
                        // add one space after every word except last
19✔
233
                        if wIdx < len(words)-1 {
35✔
234
                                pos = endPos + 1
16✔
235
                        } else {
20✔
236
                                pos = endPos
4✔
237
                        }
4✔
238
                }
239
                wordBoundaries[i] = boundaries
4✔
240
        }
241
        backends.TokenizeInputs(batch, p.Model.Tokenizer, joined)
4✔
242
        atomic.AddUint64(&p.Model.Tokenizer.TokenizerTimings.NumCalls, 1)
4✔
243
        atomic.AddUint64(&p.Model.Tokenizer.TokenizerTimings.TotalNS, safeconv.DurationToU64(time.Since(start)))
4✔
244

4✔
245
        // Map token offsets to word indices
4✔
246
        for i := range batch.Input {
8✔
247
                input := batch.Input[i]
4✔
248
                boundaries := wordBoundaries[i]
4✔
249
                wordIDs := make([]int, len(input.Offsets))
4✔
250
                for t := range input.Offsets {
30✔
251
                        if input.SpecialTokensMask[t] > 0 {
33✔
252
                                wordIDs[t] = -1
7✔
253
                                continue
7✔
254
                        }
255
                        tokStart := input.Offsets[t][0]
20✔
256
                        tokEnd := input.Offsets[t][1]
20✔
257
                        id := -1
20✔
258
                        for w := range boundaries {
98✔
259
                                b := boundaries[w]
78✔
260
                                // assign if token lies within the word span
78✔
261
                                if tokStart >= b[0] && tokEnd <= b[1] {
98✔
262
                                        id = w
20✔
263
                                        break
20✔
264
                                }
265
                        }
266
                        wordIDs[t] = id
20✔
267
                }
268
                batch.Input[i].WordIDs = wordIDs
4✔
269
                // also set raw to joined string for offsets consistency
4✔
270
                batch.Input[i].Raw = joined[i]
4✔
271
        }
272
        return backends.CreateInputTensors(batch, p.Model, p.Runtime)
4✔
273
}
274

275
// Forward performs the forward inference of the pipeline.
276
func (p *TokenClassificationPipeline) Forward(batch *backends.PipelineBatch) error {
9✔
277
        start := time.Now()
9✔
278
        err := backends.RunSessionOnBatch(batch, p.BasePipeline)
9✔
279
        if err != nil {
9✔
UNCOV
280
                return err
×
UNCOV
281
        }
×
282
        atomic.AddUint64(&p.PipelineTimings.NumCalls, 1)
9✔
283
        atomic.AddUint64(&p.PipelineTimings.TotalNS, safeconv.DurationToU64(time.Since(start)))
9✔
284
        return nil
9✔
285
}
286

287
// Postprocess function for a token classification pipeline.
288
func (p *TokenClassificationPipeline) Postprocess(batch *backends.PipelineBatch) (*TokenClassificationOutput, error) {
9✔
289
        if batch.Size == 0 {
9✔
UNCOV
290
                return &TokenClassificationOutput{}, nil
×
UNCOV
291
        }
×
292
        output := batch.OutputValues[0]
9✔
293
        var outputCast [][][]float32
9✔
294
        switch v := output.(type) {
9✔
295
        case [][][]float32:
9✔
296
                for batchIndex, tokens := range v {
19✔
297
                        v[batchIndex] = make([][]float32, len(tokens))
10✔
298
                        for tokenIndex, tokenLogits := range tokens {
99✔
299
                                v[batchIndex][tokenIndex] = vectorutil.SoftMax(tokenLogits)
89✔
300
                        }
89✔
301
                }
302
                outputCast = v
9✔
UNCOV
303
        default:
×
UNCOV
304
                return nil, fmt.Errorf("expected 3D output, got type %T", output)
×
305
        }
306
        // now convert the logits to the predictions of actual entities
307
        classificationOutput := TokenClassificationOutput{
9✔
308
                Entities: make([][]Entity, batch.Size),
9✔
309
        }
9✔
310
        for i, input := range batch.Input {
19✔
311
                preEntities := p.GatherPreEntities(input, outputCast[i])
10✔
312
                entities, errAggregate := p.Aggregate(input, preEntities)
10✔
313
                if errAggregate != nil {
10✔
314
                        return nil, errAggregate
×
315
                }
×
316
                // Filter anything that is in ignore_labels
317
                var filteredEntities []Entity
10✔
318
                for _, e := range entities {
52✔
319
                        if !slices.Contains(p.IgnoreLabels, e.Entity) && e.Entity != "" {
62✔
320
                                filteredEntities = append(filteredEntities, e)
20✔
321
                        }
20✔
322
                }
323
                classificationOutput.Entities[i] = filteredEntities
10✔
324
        }
325
        return &classificationOutput, nil
9✔
326
}
327

328
// GatherPreEntities from batch of logits to list of pre-aggregated outputs.
329
func (p *TokenClassificationPipeline) GatherPreEntities(input backends.TokenizedInput, output [][]float32) []Entity {
10✔
330
        sentence := input.Raw
10✔
331
        var preEntities []Entity
10✔
332
        for j, tokenScores := range output {
99✔
333
                // filter out special tokens (skip them)
89✔
334
                if input.SpecialTokensMask[j] > 0.0 {
108✔
335
                        continue
19✔
336
                }
337
                // TODO: the python code uses id_to_token to get the token here which is a method on the rust tokenizer, check if it's better
338
                word := input.Tokens[j]
71✔
339
                tokenID := input.TokenIDs[j]
71✔
340
                // TODO: the determination of subword can probably be better done by exporting the words field from the tokenizer directly
71✔
341
                startInd := input.Offsets[j][0]
71✔
342
                endInd := input.Offsets[j][1]
71✔
343
                wordRef := sentence[startInd:endInd]
71✔
344
                isSubword := len(word) != len(wordRef)
71✔
345
                // In split-words mode, grouping will use offsets between tokens rather than IsSubword.
71✔
346
                // TODO: check for unknown token here, it's in the config and can be loaded and compared with the token
71✔
347
                // in that case set the subword as in the python code
71✔
348
                preEntities = append(preEntities, Entity{
71✔
349
                        Word:      word,
71✔
350
                        TokenID:   []uint32{tokenID},
71✔
351
                        Scores:    tokenScores,
71✔
352
                        Start:     startInd,
71✔
353
                        End:       endInd,
71✔
354
                        Index:     j,
71✔
355
                        IsSubword: isSubword,
71✔
356
                })
71✔
357
        }
358
        return preEntities
10✔
359
}
360

UNCOV
361
func (p *TokenClassificationPipeline) aggregateWord(entities []Entity) (Entity, error) {
×
UNCOV
362
        tokens := make([]uint32, len(entities))
×
UNCOV
363
        for i, e := range entities {
×
UNCOV
364
                tokens[i] = e.TokenID[0]
×
UNCOV
365
        }
×
366
        newEntity := Entity{}
×
367
        word, err := backends.Decode(tokens, p.Model.Tokenizer, true)
×
UNCOV
368
        if err != nil {
×
UNCOV
369
                return newEntity, err
×
UNCOV
370
        }
×
UNCOV
371
        var score float32
×
UNCOV
372
        var label string
×
UNCOV
373
        switch p.AggregationStrategy {
×
UNCOV
374
        case "AVERAGE":
×
UNCOV
375
                scores := make([][]float32, len(p.IDLabelMap))
×
376
                for _, e := range entities {
×
377
                        for i, score := range e.Scores {
×
UNCOV
378
                                scores[i] = append(scores[i], score)
×
UNCOV
379
                        }
×
380
                }
UNCOV
381
                averages := make([]float32, len(p.IDLabelMap))
×
UNCOV
382
                for i, s := range scores {
×
UNCOV
383
                        averages[i] = vectorutil.Mean(s)
×
UNCOV
384
                }
×
UNCOV
385
                entityIdx, maxScore, err := vectorutil.ArgMax(averages)
×
UNCOV
386
                if err != nil {
×
UNCOV
387
                        return newEntity, err
×
UNCOV
388
                }
×
389
                entityLabel, ok := p.IDLabelMap[entityIdx]
×
390
                if !ok {
×
UNCOV
391
                        return newEntity, fmt.Errorf("could not determine entity type for input %s, predicted entity index %d", word, entityIdx)
×
UNCOV
392
                }
×
UNCOV
393
                score = maxScore
×
UNCOV
394
                label = entityLabel
×
UNCOV
395
        case "MAX":
×
UNCOV
396
                var maxScore float32
×
UNCOV
397
                var maxIdx int
×
UNCOV
398
                for _, e := range entities {
×
UNCOV
399
                        idx, score, err := vectorutil.ArgMax(e.Scores)
×
400
                        if err != nil {
×
401
                                return newEntity, err
×
UNCOV
402
                        }
×
UNCOV
403
                        if score >= maxScore {
×
UNCOV
404
                                maxScore = score
×
UNCOV
405
                                maxIdx = idx
×
UNCOV
406
                        }
×
407
                }
UNCOV
408
                entityLabel, ok := p.IDLabelMap[maxIdx]
×
UNCOV
409
                if !ok {
×
UNCOV
410
                        return Entity{}, fmt.Errorf("could not determine entity type for input %s, predicted entity index %d", word, maxIdx)
×
UNCOV
411
                }
×
UNCOV
412
                score = maxScore
×
UNCOV
413
                label = entityLabel
×
UNCOV
414
        case "FIRST":
×
UNCOV
415
                entityIdx, maxScore, err := vectorutil.ArgMax(entities[0].Scores)
×
UNCOV
416
                if err != nil {
×
UNCOV
417
                        return newEntity, err
×
UNCOV
418
                }
×
UNCOV
419
                entityLabel, ok := p.IDLabelMap[entityIdx]
×
UNCOV
420
                if !ok {
×
UNCOV
421
                        return Entity{}, fmt.Errorf("could not determine entity type for input %s, predicted entity index %d", word, entityIdx)
×
UNCOV
422
                }
×
UNCOV
423
                score = maxScore
×
UNCOV
424
                label = entityLabel
×
UNCOV
425
        default:
×
UNCOV
426
                return Entity{}, fmt.Errorf("aggregation strategy %s not recognized", p.AggregationStrategy)
×
427
        }
UNCOV
428
        return Entity{
×
UNCOV
429
                Entity:  label,
×
UNCOV
430
                Score:   score,
×
UNCOV
431
                Word:    word,
×
UNCOV
432
                TokenID: tokens,
×
UNCOV
433
                Start:   entities[0].Start,
×
UNCOV
434
                End:     entities[len(entities)-1].End,
×
UNCOV
435
        }, nil
×
436
}
437

UNCOV
438
func (p *TokenClassificationPipeline) aggregateWords(entities []Entity) ([]Entity, error) {
×
UNCOV
439
        var wordGroup []Entity
×
UNCOV
440
        var wordEntities []Entity
×
UNCOV
441
        for _, entity := range entities {
×
UNCOV
442
                if len(wordGroup) == 0 {
×
UNCOV
443
                        wordGroup = []Entity{entity}
×
NEW
444
                        continue
×
445
                }
446
                // Default behavior: group by IsSubword boundaries
NEW
447
                groupBreak := !entity.IsSubword
×
NEW
448
                if p.SplitWords {
×
NEW
449
                        // In split-words mode, we group by contiguous tokens of the same word boundary since we pretokenized.
×
NEW
450
                        // Since preEntities don’t carry word IDs directly, simulate grouping by contiguous offsets:
×
NEW
451
                        // break group if there is a gap between previous End and current Start (space) or heuristic non-subword.
×
NEW
452
                        // TODO: eventually we should export word IDs from the tokenizer to avoid this heuristic but the rust tokenizer bindings don't expose this yet
×
NEW
453
                        // and we also use other tokenizers in go backend.
×
NEW
454
                        prev := wordGroup[len(wordGroup)-1]
×
NEW
455
                        // if there is a gap in offsets consider it a new word
×
NEW
456
                        groupBreak = entity.Start > prev.End
×
NEW
457
                }
×
NEW
458
                if groupBreak {
×
459
                        aggregated, err := p.aggregateWord(wordGroup)
×
460
                        if err != nil {
×
461
                                return nil, err
×
462
                        }
×
463
                        wordEntities = append(wordEntities, aggregated)
×
464
                        wordGroup = []Entity{entity}
×
NEW
465
                } else {
×
NEW
466
                        wordGroup = append(wordGroup, entity)
×
467
                }
×
468
        }
469
        if len(wordGroup) > 0 {
×
470
                aggregated, err := p.aggregateWord(wordGroup)
×
471
                if err != nil {
×
472
                        return nil, err
×
473
                }
×
474
                wordEntities = append(wordEntities, aggregated)
×
475
        }
476
        return wordEntities, nil
×
477
}
478

479
func (p *TokenClassificationPipeline) Aggregate(input backends.TokenizedInput, preEntities []Entity) ([]Entity, error) {
10✔
480
        entities := make([]Entity, len(preEntities))
10✔
481
        var aggregationError error
10✔
482
        if p.AggregationStrategy == "SIMPLE" || p.AggregationStrategy == "NONE" {
20✔
483
                for i, preEntity := range preEntities {
81✔
484
                        entityIdx, score, argMaxErr := vectorutil.ArgMax(preEntity.Scores)
71✔
485
                        if argMaxErr != nil {
71✔
486
                                return nil, argMaxErr
×
487
                        }
×
488
                        label, ok := p.IDLabelMap[entityIdx]
71✔
489
                        if !ok {
71✔
490
                                return nil, fmt.Errorf("could not determine entity type for input %s, predicted entity index %d", input.Raw, entityIdx)
×
491
                        }
×
492
                        entities[i] = Entity{
71✔
493
                                Entity:  label,
71✔
494
                                Score:   score,
71✔
495
                                Index:   preEntity.Index,
71✔
496
                                Word:    preEntity.Word,
71✔
497
                                TokenID: preEntity.TokenID,
71✔
498
                                Start:   preEntity.Start,
71✔
499
                                End:     preEntity.End,
71✔
500
                        }
71✔
501
                }
502
        } else {
×
503
                entities, aggregationError = p.aggregateWords(preEntities)
×
504
                if aggregationError != nil {
×
505
                        return nil, aggregationError
×
506
                }
×
507
        }
508
        if p.AggregationStrategy == "NONE" {
11✔
509
                return entities, nil
1✔
510
        }
1✔
511
        return p.GroupEntities(entities)
9✔
512
}
513

514
func (p *TokenClassificationPipeline) getTag(entityName string) (string, string) {
105✔
515
        var bi string
105✔
516
        var tag string
105✔
517
        if strings.HasPrefix(entityName, "B-") {
135✔
518
                bi = "B"
30✔
519
                tag = entityName[2:]
30✔
520
        } else if strings.HasPrefix(entityName, "I-") {
113✔
521
                bi = "I"
7✔
522
                tag = entityName[2:]
7✔
523
        } else {
77✔
524
                // defaulting to "I" if string is not in B- I- format
70✔
525
                bi = "I"
70✔
526
                tag = entityName
70✔
527
        }
70✔
528
        return bi, tag
105✔
529
}
530

531
func (p *TokenClassificationPipeline) groupSubEntities(entities []Entity) (Entity, error) {
32✔
532
        splits := strings.Split(entities[0].Entity, "-")
32✔
533
        var entityType string
32✔
534
        if len(splits) == 1 {
47✔
535
                entityType = splits[0]
15✔
536
        } else {
33✔
537
                entityType = strings.Join(splits[1:], "-")
18✔
538
        }
18✔
539
        scores := make([]float32, len(entities))
32✔
540
        tokens := make([]uint32, len(entities))
32✔
541
        for i, s := range entities {
93✔
542
                scores[i] = s.Score
61✔
543
                tokens = slices.Concat(tokens, s.TokenID)
61✔
544
        }
61✔
545
        score := vectorutil.Mean(scores)
32✔
546
        // note: here we directly appeal to the tokenizer decoder with the tokenIds
32✔
547
        // in the python code they pass the words to a token_to_string_method
32✔
548
        word, err := backends.Decode(tokens, p.Model.Tokenizer, true)
32✔
549
        if err != nil {
32✔
550
                return Entity{}, err
×
551
        }
×
552
        return Entity{
32✔
553
                Entity: entityType,
32✔
554
                Score:  score,
32✔
555
                Word:   word,
32✔
556
                Start:  entities[0].Start,
32✔
557
                End:    entities[len(entities)-1].End,
32✔
558
        }, nil
32✔
559
}
560

561
// GroupEntities group together adjacent tokens with the same entity predicted.
562
func (p *TokenClassificationPipeline) GroupEntities(entities []Entity) ([]Entity, error) {
9✔
563
        var entityGroups []Entity
9✔
564
        var currentGroupDisagg []Entity
9✔
565
        for _, e := range entities {
70✔
566
                if len(currentGroupDisagg) == 0 {
70✔
567
                        currentGroupDisagg = append(currentGroupDisagg, e)
9✔
568
                        continue
9✔
569
                }
570
                bi, tag := p.getTag(e.Entity)
53✔
571
                _, lastTag := p.getTag(currentGroupDisagg[len(currentGroupDisagg)-1].Entity)
53✔
572
                if tag == lastTag && bi != "B" {
83✔
573
                        currentGroupDisagg = append(currentGroupDisagg, e)
30✔
574
                } else {
54✔
575
                        // create the grouped entity
24✔
576
                        groupedEntity, err := p.groupSubEntities(currentGroupDisagg)
24✔
577
                        if err != nil {
24✔
UNCOV
578
                                return nil, err
×
UNCOV
579
                        }
×
580
                        entityGroups = append(entityGroups, groupedEntity)
24✔
581
                        currentGroupDisagg = []Entity{e}
24✔
582
                }
583
        }
584
        if len(currentGroupDisagg) > 0 {
18✔
585
                // last entity remaining
9✔
586
                groupedEntity, err := p.groupSubEntities(currentGroupDisagg)
9✔
587
                if err != nil {
9✔
UNCOV
588
                        return nil, err
×
UNCOV
589
                }
×
590
                entityGroups = append(entityGroups, groupedEntity)
9✔
591
        }
592
        return entityGroups, nil
9✔
593
}
594

595
// Run the pipeline on a string batch.
UNCOV
596
func (p *TokenClassificationPipeline) Run(inputs []string) (backends.PipelineBatchOutput, error) {
×
UNCOV
597
        return p.RunPipeline(inputs)
×
UNCOV
598
}
×
599

600
// RunPipeline is like Run but returns the concrete type rather than the interface.
601
func (p *TokenClassificationPipeline) RunPipeline(inputs []string) (*TokenClassificationOutput, error) {
6✔
602
        var runErrors []error
6✔
603
        batch := backends.NewBatch(len(inputs))
6✔
604
        defer func(*backends.PipelineBatch) {
12✔
605
                runErrors = append(runErrors, batch.Destroy())
6✔
606
        }(batch)
6✔
607
        runErrors = append(runErrors, p.Preprocess(batch, inputs))
6✔
608
        if e := errors.Join(runErrors...); e != nil {
6✔
UNCOV
609
                return nil, e
×
UNCOV
610
        }
×
611
        runErrors = append(runErrors, p.Forward(batch))
6✔
612
        if e := errors.Join(runErrors...); e != nil {
6✔
UNCOV
613
                return nil, e
×
UNCOV
614
        }
×
615
        result, postErr := p.Postprocess(batch)
6✔
616
        runErrors = append(runErrors, postErr)
6✔
617
        return result, errors.Join(runErrors...)
6✔
618
}
619

620
// RunWords runs the pipeline for pre-split word inputs.
621
// Each input is a slice of words representing a pretokenized sentence.
622
// This is particularly useful when the user wants to control tokenization because of special tokens,
623
// hashtags, or other domain-specific tokenization needs.
624
func (p *TokenClassificationPipeline) RunWords(inputs [][]string) (*TokenClassificationOutput, error) {
4✔
625
        var runErrors []error
4✔
626
        batch := backends.NewBatch(len(inputs))
4✔
627
        defer func(*backends.PipelineBatch) {
8✔
628
                runErrors = append(runErrors, batch.Destroy())
4✔
629
        }(batch)
4✔
630
        runErrors = append(runErrors, p.PreprocessWords(batch, inputs))
4✔
631
        if e := errors.Join(runErrors...); e != nil {
4✔
NEW
632
                return nil, e
×
NEW
633
        }
×
634
        runErrors = append(runErrors, p.Forward(batch))
4✔
635
        if e := errors.Join(runErrors...); e != nil {
4✔
NEW
636
                return nil, e
×
NEW
637
        }
×
638
        result, postErr := p.Postprocess(batch)
4✔
639
        runErrors = append(runErrors, postErr)
4✔
640
        return result, errors.Join(runErrors...)
4✔
641
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc