• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 18652478786

20 Oct 2025 12:47PM UTC coverage: 55.25% (+0.2%) from 55.094%
18652478786

Pull #14674

github

web-flow
Merge b08968fc1 into b827818c7
Pull Request #14674: SPARKNLP-1293 Enhancements EntityRuler and DocumentNormalizer

114 of 149 new or added lines in 3 files covered. (76.51%)

40 existing lines in 36 files now uncovered.

11919 of 21573 relevant lines covered (55.25%)

0.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

59.22
/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeSpecialTokens.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.tokenizer.bpe
18

19
private[johnsnowlabs] class SpecialTokens(
20
    vocab: Map[String, Int],
21
    startTokenString: String,
22
    endTokenString: String,
23
    unkTokenString: String,
24
    maskTokenString: String,
25
    padTokenString: String,
26
    additionalStrings: Array[String] = Array()) {
27

28
  val allTokenStrings: Array[String] = Array(
1✔
29
    maskTokenString,
1✔
30
    startTokenString,
1✔
31
    endTokenString,
1✔
32
    unkTokenString,
1✔
33
    padTokenString) ++ additionalStrings
1✔
34

35
  for (specialTok <- allTokenStrings)
1✔
UNCOV
36
    require(vocab.contains(specialTok), s"Special Token '$specialTok' needs to be in vocabulary.")
×
37

38
  val sentenceStart: SpecialToken = SpecialToken(startTokenString, vocab(startTokenString))
1✔
39
  val sentenceEnd: SpecialToken = SpecialToken(endTokenString, vocab(endTokenString))
1✔
40
  val unk: SpecialToken = SpecialToken(unkTokenString, vocab(unkTokenString))
1✔
41
  val mask: SpecialToken = SpecialToken(
1✔
42
    maskTokenString,
1✔
43
    vocab(maskTokenString),
1✔
44
    lstrip = true // TODO: check if should be done for every model
1✔
45
  )
46
  val pad: SpecialToken = SpecialToken(padTokenString, vocab(padTokenString))
1✔
47

48
  val additionalTokens: Array[SpecialToken] =
49
    additionalStrings.map((tok: String) => SpecialToken(tok, vocab(tok)))
1✔
50

51
  // Put mask first, in case all special tokens are identical (so the stripping can be done first)
52
  val allTokens: Set[SpecialToken] =
53
    Set(mask, sentenceStart, sentenceEnd, unk, pad) ++ additionalTokens
1✔
54

55
  def contains(s: String): Boolean = allTokens.contains(SpecialToken(content = s, id = 0))
1✔
56
}
57

58
private[johnsnowlabs] object SpecialTokens {
59

60
  def apply(
61
      vocab: Map[String, Int],
62
      startTokenString: String,
63
      endTokenString: String,
64
      unkTokenString: String,
65
      maskTokenString: String,
66
      padTokenString: String,
67
      additionalStrings: Array[String] = Array()): SpecialTokens = new SpecialTokens(
1✔
68
    vocab,
69
    startTokenString,
70
    endTokenString,
71
    unkTokenString,
72
    maskTokenString,
73
    padTokenString,
74
    additionalStrings)
75

76
  def apply(
77
      vocab: Map[String, Int],
78
      startTokenId: Int,
79
      endTokenId: Int,
80
      unkTokenId: Int,
81
      maskTokenId: Int,
82
      padTokenId: Int,
83
      additionalTokenIds: Array[Int]): SpecialTokens = {
84
    val idToString = vocab.map { case (str, id) => (id, str) }
×
85

86
    new SpecialTokens(
×
87
      vocab,
88
      idToString(startTokenId),
×
89
      idToString(endTokenId),
×
90
      idToString(unkTokenId),
×
91
      idToString(maskTokenId),
×
92
      idToString(padTokenId),
×
93
      additionalTokenIds.map(idToString))
×
94
  }
95

96
  def getSpecialTokensForModel(modelType: String, vocab: Map[String, Int]): SpecialTokens =
97
    modelType match {
98
      case "roberta" =>
99
        SpecialTokens(
1✔
100
          vocab,
101
          startTokenString = "<s>",
1✔
102
          endTokenString = "</s>",
1✔
103
          unkTokenString = "<unk>",
1✔
104
          maskTokenString = "<mask>",
1✔
105
          padTokenString = "<pad>")
1✔
106
      case "gpt2" =>
107
        SpecialTokens(
1✔
108
          vocab,
109
          startTokenString = "<|endoftext|>",
1✔
110
          endTokenString = "<|endoftext|>",
1✔
111
          unkTokenString = "<|endoftext|>",
1✔
112
          maskTokenString = "<|endoftext|>",
1✔
113
          padTokenString = "<|endoftext|>")
1✔
114
      case "xlm" =>
115
        SpecialTokens(
1✔
116
          vocab,
117
          "<s>",
1✔
118
          "</s>",
1✔
119
          "<unk>",
1✔
120
          "<special1>",
1✔
121
          "<pad>",
1✔
122
          Array(
1✔
123
            "<special0>",
1✔
124
            "<special2>",
1✔
125
            "<special3>",
1✔
126
            "<special4>",
1✔
127
            "<special5>",
1✔
128
            "<special6>",
1✔
129
            "<special7>",
1✔
130
            "<special8>",
1✔
131
            "<special9>"))
1✔
132
      case "bart" =>
133
        SpecialTokens(
1✔
134
          vocab,
135
          startTokenString = "<s>",
1✔
136
          endTokenString = "</s>",
1✔
137
          unkTokenString = "<unk>",
1✔
138
          maskTokenString = "<mask>",
1✔
139
          padTokenString = "<pad>")
1✔
140
      case "olmo" =>
141
        SpecialTokens(
×
142
          vocab,
143
          startTokenString = "<|endoftext|>",
×
144
          endTokenString = "<|endoftext|>",
×
145
          unkTokenString = "<|endoftext|>",
×
146
          maskTokenString = "<|endoftext|>",
×
147
          padTokenString = "<|padding|>")
×
148
      case "clip" =>
149
        SpecialTokens(
1✔
150
          vocab,
151
          startTokenString = "<|startoftext|>",
1✔
152
          endTokenString = "<|endoftext|>",
1✔
153
          unkTokenString = "<|endoftext|>",
1✔
154
          maskTokenString = "<|endoftext|>",
1✔
155
          padTokenString = "<|endoftext|>")
1✔
156
      case "phi2" =>
157
        SpecialTokens(
×
158
          vocab,
159
          startTokenString = "<|endoftext|>",
×
160
          endTokenString = "<|endoftext|>",
×
161
          unkTokenString = "<|endoftext|>",
×
162
          maskTokenString = "<|endoftext|>",
×
163
          padTokenString = "<|endoftext|>")
×
164
      case "qwen" =>
165
        SpecialTokens(
×
166
          vocab,
167
          startTokenString = "<|im_start|>",
×
168
          endTokenString = "<|im_end|>",
×
169
          unkTokenString = "<|endoftext|>",
×
170
          maskTokenString = "<|endoftext|>",
×
171
          padTokenString = "<|endoftext|>")
×
172

173
      case "starcoder" =>
174
        SpecialTokens(
×
175
          vocab,
176
          startTokenString = "<|endoftext|>",
×
177
          endTokenString = "<|endoftext|>",
×
178
          unkTokenString = "<|endoftext|>",
×
179
          maskTokenString = "<|endoftext|>",
×
180
          padTokenString = "<|endoftext|>")
×
181
      case "bert" =>
182
        SpecialTokens(
×
183
          vocab,
184
          startTokenString = "[CLS]",
×
185
          endTokenString = "[SEP]",
×
186
          unkTokenString = "[UNK]",
×
187
          maskTokenString = "[MASK]",
×
188
          padTokenString = "[PAD]")
×
189
    }
190
}
191

192
case class SpecialToken(
193
    content: String,
194
    id: Int,
195
    singleWord: Boolean = false,
196
    lstrip: Boolean = false,
197
    rstrip: Boolean = false) {
198

199
  override def hashCode(): Int = content.hashCode
1✔
200

201
  override def canEqual(that: Any): Boolean = that.isInstanceOf[SpecialToken]
×
202

203
  override def equals(obj: Any): Boolean = obj match {
204
    case obj: SpecialToken => obj.content == content
1✔
205
    case _ => false
×
206
  }
207

208
  override def toString: String = content
×
209
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc