• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

knights-analytics / hugot / 20023537116

08 Dec 2025 09:38AM UTC coverage: 55.81% (+0.6%) from 55.172%
20023537116

push

github

RJKeevil
add pre tokenization option to token classification pipeline

76 of 103 new or added lines in 1 file covered. (73.79%)

455 existing lines in 14 files now uncovered.

2584 of 4630 relevant lines covered (55.81%)

12236.29 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

75.13
/pipelines/textClassification.go
1
package pipelines
2

3
import (
4
        "errors"
5
        "fmt"
6
        "sync/atomic"
7
        "time"
8

9
        "github.com/knights-analytics/hugot/backends"
10
        "github.com/knights-analytics/hugot/options"
11
        "github.com/knights-analytics/hugot/util/safeconv"
12
        "github.com/knights-analytics/hugot/util/vectorutil"
13
)
14

15
// types
16

17
type TextClassificationPipeline struct {
18
        *backends.BasePipeline
19
        AggregationFunctionName string
20
        ProblemType             string
21
        FixedPaddingLength      int
22
}
23

24
type ClassificationOutput struct {
25
        Label string
26
        Score float32
27
}
28

29
type TextClassificationOutput struct {
30
        ClassificationOutputs [][]ClassificationOutput
31
}
32

33
func (t *TextClassificationOutput) GetOutput() []any {
×
34
        out := make([]any, len(t.ClassificationOutputs))
×
35
        for i, classificationOutput := range t.ClassificationOutputs {
×
36
                out[i] = any(classificationOutput)
×
37
        }
×
38
        return out
×
39
}
40

41
// options
42

43
func WithSoftmax() backends.PipelineOption[*TextClassificationPipeline] {
1✔
44
        return func(pipeline *TextClassificationPipeline) error {
2✔
45
                pipeline.AggregationFunctionName = "SOFTMAX"
1✔
46
                return nil
1✔
47
        }
1✔
48
}
49

50
func WithSigmoid() backends.PipelineOption[*TextClassificationPipeline] {
1✔
51
        return func(pipeline *TextClassificationPipeline) error {
2✔
52
                pipeline.AggregationFunctionName = "SIGMOID"
1✔
53
                return nil
1✔
54
        }
1✔
55
}
56

57
func WithSingleLabel() backends.PipelineOption[*TextClassificationPipeline] {
1✔
58
        return func(pipeline *TextClassificationPipeline) error {
2✔
59
                pipeline.ProblemType = "singleLabel"
1✔
60
                return nil
1✔
61
        }
1✔
62
}
63

64
func WithMultiLabel() backends.PipelineOption[*TextClassificationPipeline] {
1✔
65
        return func(pipeline *TextClassificationPipeline) error {
2✔
66
                pipeline.ProblemType = "multiLabel"
1✔
67
                return nil
1✔
68
        }
1✔
69
}
70

71
func WithFixedPadding(fixedPaddingLength int) backends.PipelineOption[*TextClassificationPipeline] {
1✔
72
        return func(pipeline *TextClassificationPipeline) error {
2✔
73
                pipeline.FixedPaddingLength = fixedPaddingLength
1✔
74
                return nil
1✔
75
        }
1✔
76
}
77

78
// NewTextClassificationPipeline initializes a new text classification pipeline.
79
func NewTextClassificationPipeline(config backends.PipelineConfig[*TextClassificationPipeline], s *options.Options, model *backends.Model) (*TextClassificationPipeline, error) {
5✔
80
        defaultPipeline, err := backends.NewBasePipeline(config, s, model)
5✔
81
        if err != nil {
5✔
UNCOV
82
                return nil, err
×
83
        }
×
84

85
        pipeline := &TextClassificationPipeline{BasePipeline: defaultPipeline}
5✔
86
        for _, o := range config.Options {
11✔
87
                err = o(pipeline)
6✔
88
                if err != nil {
6✔
UNCOV
89
                        return nil, err
×
90
                }
×
91
        }
92

93
        if pipeline.ProblemType == "" {
8✔
94
                pipeline.ProblemType = "singleLabel"
3✔
95
        }
3✔
96
        if pipeline.AggregationFunctionName == "" {
8✔
97
                if pipeline.PipelineName == "singleLabel" {
3✔
UNCOV
98
                        pipeline.AggregationFunctionName = "SOFTMAX"
×
99
                } else {
3✔
100
                        pipeline.AggregationFunctionName = "SIGMOID"
3✔
101
                }
3✔
102
        }
103

104
        // validate
105
        err = pipeline.Validate()
5✔
106
        if err != nil {
5✔
UNCOV
107
                return nil, err
×
108
        }
×
109
        return pipeline, nil
5✔
110
}
111

112
// INTERFACE IMPLEMENTATION
113

UNCOV
114
func (p *TextClassificationPipeline) GetModel() *backends.Model {
×
115
        return p.Model
×
116
}
×
117

118
// GetMetadata returns metadata information about the pipeline, in particular:
119
// OutputInfo: names and dimensions of the output layer used for text classification.
UNCOV
120
func (p *TextClassificationPipeline) GetMetadata() backends.PipelineMetadata {
×
121
        return backends.PipelineMetadata{
×
122
                OutputsInfo: []backends.OutputInfo{
×
123
                        {
×
124
                                Name:       p.Model.OutputsMeta[0].Name,
×
125
                                Dimensions: p.Model.OutputsMeta[0].Dimensions,
×
126
                        },
×
127
                },
×
128
        }
×
129
}
×
130

131
// GetStatistics returns the runtime statistics for the pipeline.
132
func (p *TextClassificationPipeline) GetStatistics() backends.PipelineStatistics {
3✔
133
        statistics := backends.PipelineStatistics{}
3✔
134
        statistics.ComputeTokenizerStatistics(p.Model.Tokenizer.TokenizerTimings)
3✔
135
        statistics.ComputeOnnxStatistics(p.PipelineTimings)
3✔
136
        return statistics
3✔
137
}
3✔
138

139
// Validate checks that the pipeline is valid.
140
func (p *TextClassificationPipeline) Validate() error {
9✔
141
        var validationErrors []error
9✔
142

9✔
143
        if p.Model.Tokenizer == nil {
9✔
UNCOV
144
                validationErrors = append(validationErrors, fmt.Errorf("feature extraction pipeline requires a tokenizer"))
×
UNCOV
145
        }
×
146

147
        if len(p.Model.IDLabelMap) <= 0 {
12✔
148
                validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: length of id2label map for text classification pipeline must be greater than zero"))
3✔
149
        }
3✔
150

151
        outDims := p.Model.OutputsMeta[0].Dimensions
9✔
152
        if len(outDims) != 2 {
12✔
153
                validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: text classification must have 2 dimensional output"))
3✔
154
        }
3✔
155
        dynamicBatch := false
9✔
156
        for _, d := range outDims {
26✔
157
                if d == -1 {
28✔
158
                        if dynamicBatch {
14✔
159
                                validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: text classification must have max one dynamic dimensions (input)"))
3✔
160
                                break
3✔
161
                        }
162
                        dynamicBatch = true
9✔
163
                }
164
        }
165
        nLogits := int(outDims[len(outDims)-1])
9✔
166
        if len(p.Model.IDLabelMap) != nLogits {
14✔
167
                validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: length of id2label map does not match number of logits in output (%d)", nLogits))
5✔
168
        }
5✔
169
        return errors.Join(validationErrors...)
9✔
170
}
171

172
// Preprocess tokenizes the input strings.
173
func (p *TextClassificationPipeline) Preprocess(batch *backends.PipelineBatch, inputs []string) error {
3✔
174
        start := time.Now()
3✔
175
        backends.TokenizeInputs(batch, p.Model.Tokenizer, inputs)
3✔
176
        atomic.AddUint64(&p.Model.Tokenizer.TokenizerTimings.NumCalls, 1)
3✔
177

3✔
178
        if p.FixedPaddingLength > 0 {
4✔
179
                batch.MaxSequenceLength = p.FixedPaddingLength
1✔
180
        }
1✔
181

182
        atomic.AddUint64(&p.Model.Tokenizer.TokenizerTimings.TotalNS, safeconv.DurationToU64(time.Since(start)))
3✔
183
        err := backends.CreateInputTensors(batch, p.Model, p.Runtime)
3✔
184
        return err
3✔
185
}
186

187
func (p *TextClassificationPipeline) Forward(batch *backends.PipelineBatch) error {
3✔
188
        start := time.Now()
3✔
189
        err := backends.RunSessionOnBatch(batch, p.BasePipeline)
3✔
190
        if err != nil {
3✔
UNCOV
191
                return err
×
UNCOV
192
        }
×
193
        atomic.AddUint64(&p.PipelineTimings.NumCalls, 1)
3✔
194
        atomic.AddUint64(&p.PipelineTimings.TotalNS, safeconv.DurationToU64(time.Since(start)))
3✔
195
        return nil
3✔
196
}
197

198
func (p *TextClassificationPipeline) Postprocess(batch *backends.PipelineBatch) (*TextClassificationOutput, error) {
3✔
199
        var aggregationFunction func([]float32) []float32
3✔
200
        switch p.AggregationFunctionName {
3✔
201
        case "SIGMOID":
1✔
202
                aggregationFunction = vectorutil.Sigmoid
1✔
203
        case "SOFTMAX":
1✔
204
                aggregationFunction = vectorutil.SoftMax
1✔
UNCOV
205
        default:
×
UNCOV
206
                return nil, fmt.Errorf("aggregation function %s is not supported", p.AggregationFunctionName)
×
207
        }
208

209
        output := batch.OutputValues[0]
3✔
210
        var outputCast [][]float32
3✔
211
        switch v := output.(type) {
3✔
212
        case [][]float32:
3✔
213
                for i, logits := range v {
7✔
214
                        v[i] = aggregationFunction(logits)
4✔
215
                }
4✔
216
                outputCast = v
3✔
UNCOV
217
        default:
×
UNCOV
218
                return nil, fmt.Errorf("output is not 2D, expected batch size x logits, got %T", output)
×
219
        }
220

221
        batchClassificationOutputs := TextClassificationOutput{
3✔
222
                ClassificationOutputs: make([][]ClassificationOutput, batch.Size),
3✔
223
        }
3✔
224

3✔
225
        var err error
3✔
226

3✔
227
        for i := 0; i < batch.Size; i++ {
7✔
228
                switch p.ProblemType {
4✔
229
                case "singleLabel":
3✔
230
                        inputClassificationOutputs := make([]ClassificationOutput, 1)
3✔
231
                        index, value, errArgMax := vectorutil.ArgMax(outputCast[i])
3✔
232
                        if errArgMax != nil {
3✔
UNCOV
233
                                err = errArgMax
×
UNCOV
234
                                continue
×
235
                        }
236
                        class, ok := p.Model.IDLabelMap[index]
3✔
237
                        if !ok {
3✔
UNCOV
238
                                err = fmt.Errorf("class with index number %d not found in id label map", index)
×
UNCOV
239
                        }
×
240
                        inputClassificationOutputs[0] = ClassificationOutput{
3✔
241
                                Label: class,
3✔
242
                                Score: value,
3✔
243
                        }
3✔
244
                        batchClassificationOutputs.ClassificationOutputs[i] = inputClassificationOutputs
3✔
245
                case "multiLabel":
1✔
246
                        inputClassificationOutputs := make([]ClassificationOutput, len(p.Model.IDLabelMap))
1✔
247
                        for j := range outputCast[i] {
30✔
248
                                class, ok := p.Model.IDLabelMap[j]
29✔
249
                                if !ok {
29✔
UNCOV
250
                                        err = fmt.Errorf("class with index number %d not found in id label map", j)
×
UNCOV
251
                                }
×
252
                                inputClassificationOutputs[j] = ClassificationOutput{
29✔
253
                                        Label: class,
29✔
254
                                        Score: outputCast[i][j],
29✔
255
                                }
29✔
256
                        }
257
                        batchClassificationOutputs.ClassificationOutputs[i] = inputClassificationOutputs
1✔
258
                default:
×
259
                        err = fmt.Errorf("problem type %s not recognized", p.ProblemType)
×
260
                }
261
        }
262
        return &batchClassificationOutputs, err
3✔
263
}
264

265
// Run the pipeline on a string batch.
266
func (p *TextClassificationPipeline) Run(inputs []string) (backends.PipelineBatchOutput, error) {
×
267
        return p.RunPipeline(inputs)
×
UNCOV
268
}
×
269

270
func (p *TextClassificationPipeline) RunPipeline(inputs []string) (*TextClassificationOutput, error) {
3✔
271
        var runErrors []error
3✔
272
        batch := backends.NewBatch(len(inputs))
3✔
273
        defer func(*backends.PipelineBatch) {
6✔
274
                runErrors = append(runErrors, batch.Destroy())
3✔
275
        }(batch)
3✔
276

277
        runErrors = append(runErrors, p.Preprocess(batch, inputs))
3✔
278
        if e := errors.Join(runErrors...); e != nil {
3✔
UNCOV
279
                return nil, e
×
UNCOV
280
        }
×
281

282
        runErrors = append(runErrors, p.Forward(batch))
3✔
283
        if e := errors.Join(runErrors...); e != nil {
3✔
UNCOV
284
                return nil, e
×
UNCOV
285
        }
×
286

287
        result, postErr := p.Postprocess(batch)
3✔
288
        runErrors = append(runErrors, postErr)
3✔
289
        return result, errors.Join(runErrors...)
3✔
290
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc