• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

knights-analytics / hugot / 20023537116

08 Dec 2025 09:38AM UTC coverage: 55.81% (+0.6%) from 55.172%
20023537116

push

github

RJKeevil
add pre tokenization option to token classification pipeline

76 of 103 new or added lines in 1 file covered. (73.79%)

455 existing lines in 14 files now uncovered.

2584 of 4630 relevant lines covered (55.81%)

12236.29 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

72.37
/pipelines/zeroShotClassification.go
1
package pipelines
2

3
import (
4
        "errors"
5
        "fmt"
6
        "math"
7
        "sort"
8
        "strings"
9
        "sync/atomic"
10
        "time"
11

12
        "github.com/knights-analytics/hugot/backends"
13
        "github.com/knights-analytics/hugot/options"
14
        "github.com/knights-analytics/hugot/util/safeconv"
15
)
16

17
type ZeroShotClassificationPipeline struct {
18
        *backends.BasePipeline
19
        HypothesisTemplate string
20
        Sequences          []string
21
        Labels             []string
22
        EntailmentID       int
23
        Multilabel         bool
24
}
25
type ZeroShotClassificationOutput struct {
26
        Sequence     string
27
        SortedValues []struct {
28
                Key   string
29
                Value float64
30
        }
31
}
32
type ZeroShotOutput struct {
33
        ClassificationOutputs []ZeroShotClassificationOutput
34
}
35

36
// options
37

38
// WithMultilabel can be used to set whether the pipeline is multilabel.
39
func WithMultilabel(multilabel bool) backends.PipelineOption[*ZeroShotClassificationPipeline] {
1✔
40
        return func(pipeline *ZeroShotClassificationPipeline) error {
2✔
41
                pipeline.Multilabel = multilabel
1✔
42
                return nil
1✔
43
        }
1✔
44
}
45

46
// WithLabels can be used to set the labels to classify the examples.
47
func WithLabels(labels []string) backends.PipelineOption[*ZeroShotClassificationPipeline] {
1✔
48
        return func(pipeline *ZeroShotClassificationPipeline) error {
2✔
49
                pipeline.Labels = labels
1✔
50
                return nil
1✔
51
        }
1✔
52
}
53

54
// WithHypothesisTemplate can be used to set the hypothesis template for classification.
55
func WithHypothesisTemplate(hypothesisTemplate string) backends.PipelineOption[*ZeroShotClassificationPipeline] {
1✔
56
        return func(pipeline *ZeroShotClassificationPipeline) error {
2✔
57
                pipeline.HypothesisTemplate = hypothesisTemplate
1✔
58
                return nil
1✔
59
        }
1✔
60
}
61

62
// GetOutput converts raw output to readable output.
63
func (t *ZeroShotOutput) GetOutput() []any {
6✔
64
        out := make([]any, len(t.ClassificationOutputs))
6✔
65
        for i, o := range t.ClassificationOutputs {
15✔
66
                out[i] = any(o)
9✔
67
        }
9✔
68
        return out
6✔
69
}
70

71
// create all pairs between input sequences and labels.
72
func createSequencePairs(sequences interface{}, labels []string, hypothesisTemplate string) ([][][]string, []string, error) {
6✔
73
        // Check if labels or sequences are empty
6✔
74
        if len(labels) == 0 || sequences == nil {
6✔
UNCOV
75
                return nil, nil, errors.New("you must include at least one label and at least one sequence")
×
76
        }
×
77
        // Check if hypothesisTemplate can be formatted with labels
78
        if fmt.Sprintf(hypothesisTemplate, labels[0]) == hypothesisTemplate {
6✔
UNCOV
79
                return nil, nil, fmt.Errorf(`the provided hypothesis_template "%s" was not able to be formatted with the target labels. Make sure the passed template includes formatting syntax such as {{}} where the label should go`, hypothesisTemplate)
×
UNCOV
80
        }
×
81
        // Convert sequences to []string if it's a single string
82
        var seqs []string
6✔
83
        switch v := sequences.(type) {
6✔
UNCOV
84
        case string:
×
UNCOV
85
                seqs = []string{v}
×
86
        case []string:
6✔
87
                seqs = v
6✔
88
        default:
×
UNCOV
89
                return nil, nil, errors.New("sequences must be either a string or a []string")
×
90
        }
91
        // Create sequence_pairs
92
        var sequencePairs [][][]string
6✔
93
        for _, sequence := range seqs {
15✔
94
                var temp [][]string
9✔
95
                for _, label := range labels {
24✔
96
                        hypothesis := strings.Replace(hypothesisTemplate, "{}", label, 1)
15✔
97
                        temp = append(temp, []string{sequence, hypothesis})
15✔
98
                }
15✔
99
                sequencePairs = append(sequencePairs, temp)
9✔
100
        }
101
        return sequencePairs, seqs, nil
6✔
102
}
103

104
// NewZeroShotClassificationPipeline create new Zero Shot Classification Pipeline.
105
func NewZeroShotClassificationPipeline(config backends.PipelineConfig[*ZeroShotClassificationPipeline], s *options.Options, model *backends.Model) (*ZeroShotClassificationPipeline, error) {
1✔
106
        defaultPipeline, err := backends.NewBasePipeline(config, s, model)
1✔
107
        if err != nil {
1✔
UNCOV
108
                return nil, err
×
UNCOV
109
        }
×
110
        pipeline := &ZeroShotClassificationPipeline{BasePipeline: defaultPipeline}
1✔
111
        for _, o := range config.Options {
5✔
112
                err = o(pipeline)
4✔
113
                if err != nil {
4✔
114
                        return nil, err
×
UNCOV
115
                }
×
116
        }
117
        pipeline.EntailmentID = -1 // Default value
1✔
118
        pipeline.HypothesisTemplate = "This example is {}."
1✔
119
        if len(pipeline.Labels) == 0 {
1✔
120
                return nil, fmt.Errorf("no labels provided, please provide labels using the WithLabels() option")
×
121
        }
×
122
        // Find entailment ID
123
        for id, label := range model.IDLabelMap {
2✔
124
                if strings.HasPrefix(strings.ToLower(label), "entail") {
2✔
125
                        pipeline.EntailmentID = id
1✔
126
                        break
1✔
127
                }
128
        }
129
        return pipeline, err
1✔
130
}
131

132
func (p *ZeroShotClassificationPipeline) Preprocess(batch *backends.PipelineBatch, inputs []string) error {
15✔
133
        start := time.Now()
15✔
134
        backends.TokenizeInputs(batch, p.Model.Tokenizer, inputs)
15✔
135
        atomic.AddUint64(&p.Model.Tokenizer.TokenizerTimings.NumCalls, 1)
15✔
136
        atomic.AddUint64(&p.Model.Tokenizer.TokenizerTimings.TotalNS, safeconv.DurationToU64(time.Since(start)))
15✔
137
        err := backends.CreateInputTensors(batch, p.Model, p.Runtime)
15✔
138
        return err
15✔
139
}
15✔
140

141
func (p *ZeroShotClassificationPipeline) Forward(batch *backends.PipelineBatch) error {
15✔
142
        start := time.Now()
15✔
143
        err := backends.RunSessionOnBatch(batch, p.BasePipeline)
15✔
144
        if err != nil {
15✔
UNCOV
145
                return err
×
UNCOV
146
        }
×
147
        atomic.AddUint64(&p.PipelineTimings.NumCalls, 1)
15✔
148
        atomic.AddUint64(&p.PipelineTimings.TotalNS, safeconv.DurationToU64(time.Since(start)))
15✔
149
        return nil
15✔
150
}
151

152
func (p *ZeroShotClassificationPipeline) Postprocess(outputTensors [][][]float32, labels []string, sequences []string) (*ZeroShotOutput, error) {
6✔
153
        classificationOutputs := make([]ZeroShotClassificationOutput, 0, len(sequences))
6✔
154
        LabelLikelihood := make(map[string]float64)
6✔
155
        if p.Multilabel || len(p.Labels) == 1 {
10✔
156
                for ind, sequence := range outputTensors {
10✔
157
                        output := ZeroShotClassificationOutput{
6✔
158
                                Sequence: sequences[ind],
6✔
159
                        }
6✔
160
                        var entailmentLogits []float32
6✔
161
                        var contradictionLogits []float32
6✔
162
                        var entailmentID int
6✔
163
                        var contradictionID int
6✔
164
                        switch p.EntailmentID {
6✔
UNCOV
165
                        case -1:
×
UNCOV
166
                                entailmentID = len(sequence[0]) - 1
×
UNCOV
167
                                contradictionID = 0
×
168
                        default:
6✔
169
                                entailmentID = p.EntailmentID
6✔
170
                                contradictionID = 0
6✔
171
                                if entailmentID == 0 {
12✔
172
                                        contradictionID = len(sequence[0]) - 1
6✔
173
                                }
6✔
174
                        }
175
                        for _, tensor := range sequence {
14✔
176
                                entailmentLogits = append(entailmentLogits, tensor[entailmentID])
8✔
177
                                contradictionLogits = append(contradictionLogits, tensor[contradictionID])
8✔
178
                        }
8✔
179
                        for i := range entailmentLogits {
14✔
180
                                logits := []float64{float64(contradictionLogits[i]), float64(entailmentLogits[i])}
8✔
181
                                expLogits := []float64{math.Exp(logits[0]), math.Exp(logits[1])}
8✔
182
                                sumExpLogits := expLogits[0] + expLogits[1]
8✔
183
                                score := expLogits[1] / sumExpLogits
8✔
184
                                LabelLikelihood[labels[i]] = score
8✔
185
                        }
8✔
186
                        // Define ss as a slice of anonymous structs
187
                        var ss []struct {
6✔
188
                                Key   string
6✔
189
                                Value float64
6✔
190
                        }
6✔
191
                        for k, v := range LabelLikelihood {
14✔
192
                                ss = append(ss, struct {
8✔
193
                                        Key   string
8✔
194
                                        Value float64
8✔
195
                                }{k, v})
8✔
196
                        }
8✔
197
                        // Sort the slice by the value field
198
                        sort.Slice(ss, func(i, j int) bool {
9✔
199
                                return ss[i].Value > ss[j].Value
3✔
200
                        })
3✔
201
                        output.SortedValues = ss
6✔
202
                        classificationOutputs = append(classificationOutputs, output)
6✔
203
                }
204
                return &ZeroShotOutput{
4✔
205
                        ClassificationOutputs: classificationOutputs,
4✔
206
                }, nil
4✔
207
        }
208
        for ind, sequence := range outputTensors {
7✔
209
                output := ZeroShotClassificationOutput{}
4✔
210
                var entailmentLogits []float32
4✔
211
                var entailmentID int
4✔
212
                switch p.EntailmentID {
4✔
UNCOV
213
                case -1:
×
UNCOV
214
                        entailmentID = len(sequence[0]) - 1
×
215
                default:
4✔
216
                        entailmentID = p.EntailmentID
4✔
217
                }
218
                for _, tensor := range sequence {
12✔
219
                        entailmentLogits = append(entailmentLogits, tensor[entailmentID])
8✔
220
                }
8✔
221
                var numerator []float64
4✔
222
                var logitSum float64
4✔
223
                for _, logit := range entailmentLogits {
12✔
224
                        exp := math.Exp(float64(logit))
8✔
225
                        numerator = append(numerator, exp)
8✔
226
                        logitSum += exp
8✔
227
                }
8✔
228
                var quotient []float64
4✔
229
                for ind, i := range numerator {
12✔
230
                        quotient = append(quotient, i/logitSum)
8✔
231
                        LabelLikelihood[labels[ind]] = quotient[ind]
8✔
232
                }
8✔
233
                output.Sequence = sequences[ind]
4✔
234
                // Define ss as a slice of anonymous structs
4✔
235
                var ss []struct {
4✔
236
                        Key   string
4✔
237
                        Value float64
4✔
238
                }
4✔
239
                for k, v := range LabelLikelihood {
12✔
240
                        ss = append(ss, struct {
8✔
241
                                Key   string
8✔
242
                                Value float64
8✔
243
                        }{k, v})
8✔
244
                }
8✔
245
                // Sort the slice by the value field
246
                sort.Slice(ss, func(i, j int) bool {
10✔
247
                        return ss[i].Value > ss[j].Value
6✔
248
                })
6✔
249
                output.SortedValues = ss
4✔
250
                classificationOutputs = append(classificationOutputs, output)
4✔
251
        }
252
        return &ZeroShotOutput{
3✔
253
                ClassificationOutputs: classificationOutputs,
3✔
254
        }, nil
3✔
255
}
256

257
func (p *ZeroShotClassificationPipeline) RunPipeline(inputs []string) (*ZeroShotOutput, error) {
6✔
258
        var outputTensors [][][]float32
6✔
259
        var runErrors []error
6✔
260
        sequencePairs, _, err := createSequencePairs(inputs, p.Labels, p.HypothesisTemplate)
6✔
261
        if err != nil {
6✔
UNCOV
262
                return nil, err
×
UNCOV
263
        }
×
264
        for _, sequence := range sequencePairs {
15✔
265
                var sequenceTensors [][]float32
9✔
266
                for _, pair := range sequence {
24✔
267
                        batch := backends.NewBatch(len(inputs))
15✔
268
                        // have to do this because python inserts a separator token in between the two clauses when tokenizing
15✔
269
                        // separator token isn't universal and depends on its value in special_tokens_map.json of model
15✔
270
                        // still isn't perfect because some models (protectai/MoritzLaurer-roberta-base-zeroshot-v2.0-c-onnx for example)
15✔
271
                        // insert two separator tokens while others (protectai/deberta-v3-base-zeroshot-v1-onnx and others) only insert one
15✔
272
                        // need to find a way to determine how many to insert or find a better way to tokenize inputs
15✔
273
                        // The difference in outputs for one separator vs two is very small (differences in the thousandths place), but they
15✔
274
                        // definitely are different
15✔
275
                        concatenatedString := pair[0] + p.Model.SeparatorToken + pair[1]
15✔
276
                        runErrors = append(runErrors, p.Preprocess(batch, []string{concatenatedString}))
15✔
277
                        if e := errors.Join(runErrors...); e != nil {
15✔
UNCOV
278
                                return nil, errors.Join(e, batch.Destroy())
×
UNCOV
279
                        }
×
280
                        runErrors = append(runErrors, p.Forward(batch))
15✔
281
                        if e := errors.Join(runErrors...); e != nil {
15✔
UNCOV
282
                                return nil, errors.Join(e, batch.Destroy())
×
UNCOV
283
                        }
×
284
                        sequenceTensors = append(sequenceTensors, batch.OutputValues[0].([][]float32)[0])
15✔
285
                        runErrors = append(runErrors, batch.Destroy())
15✔
286
                        if e := errors.Join(runErrors...); e != nil {
15✔
UNCOV
287
                                return nil, e
×
UNCOV
288
                        }
×
289
                }
290
                outputTensors = append(outputTensors, sequenceTensors)
9✔
291
        }
292
        outputs, err := p.Postprocess(outputTensors, p.Labels, inputs)
6✔
293
        runErrors = append(runErrors, err)
6✔
294
        return outputs, errors.Join(runErrors...)
6✔
295
}
296

297
// PIPELINE INTERFACE IMPLEMENTATION
298

UNCOV
299
func (p *ZeroShotClassificationPipeline) GetModel() *backends.Model {
×
UNCOV
300
        return p.Model
×
UNCOV
301
}
×
302

UNCOV
303
func (p *ZeroShotClassificationPipeline) GetMetadata() backends.PipelineMetadata {
×
UNCOV
304
        return backends.PipelineMetadata{
×
UNCOV
305
                OutputsInfo: []backends.OutputInfo{
×
UNCOV
306
                        {
×
UNCOV
307
                                Name:       p.Model.OutputsMeta[0].Name,
×
308
                                Dimensions: p.Model.OutputsMeta[0].Dimensions,
×
309
                        },
×
UNCOV
310
                },
×
UNCOV
311
        }
×
312
}
×
313

314
// GetStatistics returns the runtime statistics for the pipeline.
UNCOV
315
func (p *ZeroShotClassificationPipeline) GetStatistics() backends.PipelineStatistics {
×
UNCOV
316
        statistics := backends.PipelineStatistics{}
×
UNCOV
317
        statistics.ComputeTokenizerStatistics(p.Model.Tokenizer.TokenizerTimings)
×
318
        statistics.ComputeOnnxStatistics(p.PipelineTimings)
×
319
        return statistics
×
UNCOV
320
}
×
321

UNCOV
322
func (p *ZeroShotClassificationPipeline) Run(inputs []string) (backends.PipelineBatchOutput, error) {
×
UNCOV
323
        return p.RunPipeline(inputs)
×
UNCOV
324
}
×
325

UNCOV
326
func (p *ZeroShotClassificationPipeline) Validate() error {
×
UNCOV
327
        var validationErrors []error
×
UNCOV
328
        if p.Model.Tokenizer == nil {
×
UNCOV
329
                validationErrors = append(validationErrors, fmt.Errorf("zero shot classification pipeline requires a tokenizer"))
×
UNCOV
330
        }
×
331
        if len(p.Model.IDLabelMap) <= 0 {
×
332
                validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: length of id2label map for zero shot classification pipeline must be greater than zero"))
×
333
        }
×
UNCOV
334
        outDims := p.Model.OutputsMeta[0].Dimensions
×
335
        if len(outDims) != 2 {
×
336
                validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: zero shot classification must have 2 dimensional output"))
×
337
        }
×
338
        dynamicBatch := false
×
339
        for _, d := range outDims {
×
340
                if d == -1 {
×
341
                        if dynamicBatch {
×
342
                                validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: text classification must have max one dynamic dimensions (input)"))
×
343
                                break
×
344
                        }
UNCOV
345
                        dynamicBatch = true
×
346
                }
347
        }
348
        return errors.Join(validationErrors...)
×
349
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc