• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 15252839065

26 May 2025 11:30AM CUT coverage: 52.115% (-0.6%) from 52.715%
15252839065

Pull #14585

github

web-flow
Merge 625e5c10f into 56512b006
Pull Request #14585: SparkNLP 1131 - Introducing Florance-2

0 of 199 new or added lines in 4 files covered. (0.0%)

50 existing lines in 33 files now uncovered.

9931 of 19056 relevant lines covered (52.11%)

0.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

70.17
/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeTokenizer.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.tokenizer.bpe
18

19
import com.johnsnowlabs.nlp.annotators.common.{IndexedToken, Sentence, TokenPiece}
20
import org.apache.commons.lang3.StringUtils
21

22
import scala.collection.mutable
23
import scala.collection.mutable.ListBuffer
24

25
/** A BPE Tokenizer based on GPT2's tokenization scheme. The tokenization can then be used for
26
  * models based on this scheme (e.g. GPT2, roBERTa, DeBERTa)
27
  *
28
  * TODO: truncation assumed?
29
  *
30
  * @param merges
31
  *   Map of tokens that are mergeable
32
  * @param vocab
33
  *   Map of tokens to encoded representation
34
  * @param specialTokens
35
  *   Collection of special tokens
36
  * @param padWithSequenceTokens
37
  *   Whether to pad the sentence with sentence tokens at the start and end
38
  * @param addPrefixSpaceToSentence
39
  *   Whether to add a space to the first word of a sentence
40
  * @param alwaysAddPrefix
41
  *   Whether to always prefix token ids with `prefixForPieceId`
42
  */
43
private[nlp] abstract class BpeTokenizer(
44
    val merges: Map[(String, String), Int],
45
    val vocab: Map[String, Int],
46
    val specialTokens: SpecialTokens,
47
    val padWithSequenceTokens: Boolean,
48
    val addPrefixSpaceToSentence: Boolean,
49
    val alwaysAddPrefix: Boolean) {
50

51
  protected val bpeRanks: Map[(String, String), Int] = {
52
    merges
1✔
53
  }
54

55
  /** Rankings for the byte pairs. Derived from merges.txt */
56
  protected def getBpeRanking: ((String, String)) => Int =
57
    (bytePair: (String, String)) => bpeRanks.getOrElse(bytePair, Integer.MAX_VALUE)
1✔
58

59
  /** cache for already encoded tokens */
60
  protected val cache: mutable.Map[String, Array[String]] = mutable.Map()
1✔
61

62
  /** Create a sequence of byte-pairs of the word */
63
  protected def getBytePairs(word: Array[String]): Array[(String, String)] = {
64
    val createPairs = (i: Int) => (word(i), word(i + 1))
1✔
65
    (0 until (word.length - 1)).map(createPairs).toArray
1✔
66
  }
67

68
  // Can be overridden in inherited class
69
  protected val prefixForPieceId: Option[String] = None
1✔
70
  protected val suffixForPieceId: Option[String] = None
1✔
71

72
  protected def performMerges(
73
      wordChars: Array[String],
74
      charPairs: Array[(String, String)]): Array[String] = {
75
    var word = wordChars
76
    var pairs = charPairs
77
    // get highest priority byte-pair first
78
    var bytePair: (String, String) =
79
      pairs.sortWith(getBpeRanking(_) < getBpeRanking(_))(0)
1✔
80
    var done = false
1✔
81
    // while we still have byte-pairs from our vocabulary
82
    while (bpeRanks.contains(bytePair) && !done) {
1✔
83
      val (first, second) = bytePair
1✔
84
      val newWord: ListBuffer[String] = ListBuffer()
1✔
85
      var i = 0
1✔
86
      var j = 0
1✔
87
      // keep combining characters with the current byte-pair
88
      while ((i < word.length) && (j != -1)) {
1✔
89
        j = word.indexOf(first, i)
1✔
90
        if (j == -1) newWord ++= word.drop(i)
1✔
91
        else {
1✔
92
          newWord ++= word.slice(i, j)
1✔
93
          i = j
94
          val bpIsAtIndex =
95
            (word(i) == first) && (i < word.length - 1) && word(i + 1) == second
1✔
96
          if (bpIsAtIndex) {
1✔
97
            newWord += (first + second)
1✔
98
            i += 2
1✔
99
          } else {
1✔
100
            newWord += word(i)
1✔
101
            i += 1
1✔
102
          }
103
        }
104
      }
105
      word = newWord.toArray
1✔
106
      // if we were able to create a whole word that was in the vocabulary, we're done
107
      if (word.length == 1) {
1✔
108
        done = true
1✔
109
      } else {
1✔
110
        // do it again with the next byte-pair
111
        pairs = getBytePairs(word)
1✔
112
        bytePair = pairs.sortWith(getBpeRanking(_) < getBpeRanking(_))(0)
1✔
113
      }
114
    }
115
    word
116
  }
117

118
  protected def getTokenPieces(indToken: IndexedToken, word: Array[String]): Array[TokenPiece] = {
119
    var currentIndex = indToken.begin
1✔
120
    val wordIndexes = word.map((subWord: String) => {
1✔
121
      val startIndex = currentIndex
122
      currentIndex = startIndex + subWord.length
1✔
123
      (startIndex, startIndex + subWord.length - 1)
1✔
124
    })
125
    val result = word
126
      .zip(wordIndexes)
1✔
127
      .map { case (subWord: String, indexes: (Int, Int)) =>
1✔
128
        val isWordStart = indToken.begin == indexes._1
1✔
129
        val isDocumentStart = indToken.begin == 0
1✔
130
        var processedSubWord = subWord
131
        processedSubWord = if (isDocumentStart && !addPrefixSpaceToSentence) {
1✔
132
          processedSubWord
1✔
133
        } else
134
          prefixForPieceId match {
1✔
135
            case Some(prepend) if alwaysAddPrefix =>
1✔
136
              if (isWordStart && subWord.indexOf(prepend) < 0) prepend + processedSubWord
1✔
137
              else processedSubWord
1✔
138
            case _ => processedSubWord
139
          }
140
        processedSubWord = suffixForPieceId match {
1✔
141
          case None => processedSubWord
142
          case Some(append) =>
143
            val isWordEnd = indToken.end == indexes._2
1✔
144
            if (isWordEnd && subWord.indexOf(append) < 0) processedSubWord + append
1✔
145
            else processedSubWord
1✔
146
        }
147
        // Set unknown id if not found
148
        val subWordId: Int = vocab.getOrElse(processedSubWord, specialTokens.unk.id)
1✔
149

150
        TokenPiece(subWord, indToken.token.trim(), subWordId, isWordStart, indexes._1, indexes._2)
1✔
151

152
      }
153
    result
154
  }
155

156
  /** Do the BPE algorithm. Goal is to find the token as the largest words in the known
157
    * vocabulary. If not possible, the word is split into smaller subwords, until they are known.
158
    *
159
    * @return
160
    *   Array of TokenPieces, corresponding to encoded token
161
    */
162
  protected def bpe(indToken: IndexedToken): Array[TokenPiece] = {
163
    var processedToken = ""
1✔
164
    try {
1✔
165
      processedToken = preProcessTokenForBpe(indToken.token)
1✔
166
      // TODO: Caching
167
      var word: Array[String] = Array[String]()
1✔
168
      // split the word into characters, to be combined into subwords
169
      word = processedToken.map(_.toString).toArray
1✔
170
      val pairs: Array[(String, String)] = getBytePairs(word)
1✔
171

172
      if (pairs.isEmpty)
1✔
173
        word = Array(processedToken)
1✔
174
      else
175
        word = performMerges(word, pairs)
1✔
176

177
      getTokenPieces(indToken, word)
1✔
178
    } catch {
179
      case _: java.util.NoSuchElementException =>
180
        Array(
×
181
          TokenPiece(
×
182
            indToken.token,
×
183
            indToken.token,
×
184
            specialTokens.unk.id,
×
185
            isWordStart = true,
×
186
            indToken.begin,
×
187
            indToken.end))
×
188
    }
189
  }
190

191
  /** Split the the individual sub texts on special tokens, e.g. masking etc. */
192
  protected def splitOnSpecialToken(
193
      specialToken: SpecialToken,
194
      text: String): ListBuffer[String] = {
195
    val isControl = (c: Char) => {
196
      if (c == '\t' || c == '\n' || c == '\r') false // count as whitespace
×
197
      else c.isControl
×
198
    }
199
    val isPunctuation =
200
      (c: Char) => raw"""[^[:alnum:]]""".r.findFirstIn(c.toString).isDefined
×
201
    val isWordBorder =
202
      (c: Char) => isControl(c) || isPunctuation(c) || c.isWhitespace
×
203

204
    val isEndOfWord = (text: String) => isWordBorder(text.last)
×
205
    val isStartOfWord = (text: String) => isWordBorder(text.head)
×
206

207
    val result: ListBuffer[String] = ListBuffer()
1✔
208
    val tok = specialToken.content
1✔
209

210
    val splitText = StringUtils.splitByWholeSeparator(text, tok)
1✔
211
    var fullWord = ""
1✔
212

213
    for ((subText, i) <- splitText.zipWithIndex) {
1✔
214
      var done = false
1✔
215
      // Try to avoid splitting on token
216
      if (specialToken.singleWord) {
1✔
217
        if ((i < (splitText.length - 1)) && !isEndOfWord(subText) && !isStartOfWord(
×
218
            splitText(i + 1))) fullWord += subText + tok
×
219
        else if (fullWord.nonEmpty) {
×
220
          fullWord += subText
×
221
          result += fullWord
×
222
          fullWord = ""
×
223
          done = true
×
224
        }
225
      }
226
      if (!done) {
1✔
227
        // A bit counter-intuitive but we strip the left of the string
228
        // since rstrip means the special token is eating all white spaces on its right
229
        var subTextProcessed: String = subText
230
        if (specialToken.rstrip && i > 0)
1✔
231
          subTextProcessed = subText.stripPrefix(" ")
×
232
        if (specialToken.lstrip && i < (splitText.length - 1))
1✔
233
          subTextProcessed = subText.stripSuffix(" ")
1✔
234
        if (i == 0 && subTextProcessed.isEmpty)
1✔
235
          result += tok
×
236
        else if (i == (splitText.length - 1)) {
1✔
237
          if (subTextProcessed.nonEmpty) result += subTextProcessed
1✔
238
        } else {
1✔
239
          if (subTextProcessed.nonEmpty) result += subTextProcessed
1✔
240
          result += tok
1✔
241
        }
242
      }
243
    }
244
    result
245
  }
246

247
  /** Needs to be implemented */
248
  protected def tokenizeSubText(text: String, indexOffset: Int): Array[IndexedToken]
249

250
  /** Special tokens of the model for processing */
251
  val sentencePadding: (String, String) =
252
    (specialTokens.sentenceStart.content, specialTokens.sentenceEnd.content)
1✔
253

254
  /** Tokenize considering special tokens and split algorithm */
255
  def tokenize(sentence: Sentence): Array[IndexedToken] = {
256
    var text = sentence.content
1✔
257
    if (text.trim.isEmpty) Array[IndexedToken]()
1✔
258
    else {
1✔
259
      val splitTexts: ListBuffer[String] = ListBuffer()
1✔
260
      var textList: ListBuffer[String] = ListBuffer(text)
1✔
261

262
      for (transformations <- specialTokens.allTokens) {
1✔
263
        splitTexts.clear()
1✔
264
        for (subText <- textList) {
1✔
265
          if (!specialTokens.contains(subText))
1✔
266
            splitTexts ++= splitOnSpecialToken(transformations, subText)
1✔
267
          else
268
            splitTexts += subText
1✔
269
        }
270
        textList = splitTexts.clone()
1✔
271
      }
272

273
      if (padWithSequenceTokens) {
1✔
274
        text = sentencePadding._1 + text + sentencePadding._2
1✔
275
        splitTexts.prepend(sentencePadding._1)
1✔
276
        splitTexts.append(sentencePadding._2)
1✔
277
      }
278

279
      var currentIndex = 0
1✔
280
      val result = mutable.ArrayBuffer[IndexedToken]()
1✔
281
      for (subText <- splitTexts) {
1✔
282
        val subTextIndex = sentence.start + text.indexOf(subText, currentIndex)
1✔
283
        if (!specialTokens.contains(subText)) {
1✔
284
          val splitSubText: Array[IndexedToken] = tokenizeSubText(subText, subTextIndex)
1✔
285
          result.append(splitSubText: _*)
1✔
286
        } else // subtext is just the special token
287
          result.append(
1✔
288
            IndexedToken(subText, begin = subTextIndex, end = subTextIndex + subText.length - 1))
1✔
289
        currentIndex = subTextIndex + subText.length
1✔
290
      }
291
      result.toArray
1✔
292
    }
293
  }
294

295
  protected def preProcessTokenForBpe(token: String): String = token
296

297
  def encode(indToken: IndexedToken): Array[TokenPiece] = {
298
    if (!specialTokens.contains(indToken.token))
1✔
299
      bpe(indToken)
1✔
300
    else {
301
      Array(
1✔
302
        TokenPiece(
1✔
303
          indToken.token,
1✔
304
          indToken.token,
1✔
305
          vocab(indToken.token),
1✔
306
          isWordStart = false,
1✔
307
          indToken.begin,
1✔
308
          indToken.end))
1✔
309
    }
310
  }
311

312
  def encode(indTokens: Array[IndexedToken]): Array[TokenPiece] = indTokens.flatMap(encode(_))
1✔
313
}
314

315
object BpeTokenizer {
316
  def forModel(
317
      modelType: String,
318
      merges: Map[(String, String), Int],
319
      vocab: Map[String, Int],
320
      padWithSequenceTokens: Boolean = false,
321
      addPrefixSpaceToSentence: Boolean = false,
322
      specialTokens: Option[SpecialTokens] = None,
323
      alwaysAddPrefix: Boolean = true,
324
      prependString: String = ""): BpeTokenizer = {
325

326
    def modelSpecialTokens() = specialTokens match {
327
      case Some(specialTok) => specialTok
328
      case None => SpecialTokens.getSpecialTokensForModel(modelType, vocab)
1✔
329
    }
330

331
    val tokenizer = modelType match {
332
      case "roberta" =>
333
        new RobertaTokenizer(
1✔
334
          merges,
335
          vocab,
336
          modelSpecialTokens(),
1✔
337
          padWithSequenceTokens,
338
          addPrefixSpaceToSentence = addPrefixSpaceToSentence,
339
          alwaysAddPrefix = alwaysAddPrefix)
340
      case "xlm" =>
341
        new XlmTokenizer(merges, vocab, modelSpecialTokens(), padWithSequenceTokens)
1✔
342
      case "gpt2" =>
343
        new Gpt2Tokenizer(
1✔
344
          merges,
345
          vocab,
346
          modelSpecialTokens(),
1✔
347
          padWithSequenceTokens,
348
          addPrefixSpaceToSentence = addPrefixSpaceToSentence,
349
          alwaysAddPrefix = alwaysAddPrefix)
350
      case "bart" =>
351
        new BartTokenizer(
1✔
352
          merges,
353
          vocab,
354
          modelSpecialTokens(),
1✔
355
          padWithSequenceTokens,
356
          addPrefixSpaceToSentence = addPrefixSpaceToSentence)
357
      case "olmo" =>
358
        new OLMoTokenizer(
×
359
          merges,
360
          vocab,
361
          modelSpecialTokens(),
×
362
          padWithSequenceTokens,
363
          addPrefixSpaceToSentence = addPrefixSpaceToSentence)
364
      case "clip" =>
365
        new CLIPTokenizer(merges, vocab, modelSpecialTokens())
1✔
366
      case "phi2" =>
367
        new Phi2Tokenizer(
×
368
          merges,
369
          vocab,
370
          modelSpecialTokens(),
×
371
          padWithSequenceTokens,
372
          addPrefixSpaceToSentence = addPrefixSpaceToSentence)
373
      case "qwen" =>
374
        new QwenTokenizer(
×
375
          merges,
376
          vocab,
377
          modelSpecialTokens(),
×
378
          padWithSequenceTokens,
379
          addPrefixSpaceToSentence = addPrefixSpaceToSentence)
380
      case "starcoder" =>
381
        new StarCoderTokenizer(
×
382
          merges,
383
          vocab,
384
          modelSpecialTokens(),
×
385
          padWithSequenceTokens,
386
          addPrefixSpaceToSentence = addPrefixSpaceToSentence)
387
      case "llama3" =>
388
        new LLAMA3Tokenizer(
×
389
          merges,
390
          vocab,
391
          modelSpecialTokens(),
×
392
          padWithSequenceTokens,
393
          addPrefixSpaceToSentence = addPrefixSpaceToSentence)
394
      case "Janus" =>
395
        new JanusTokenizer(
×
396
          merges,
397
          vocab,
398
          modelSpecialTokens(),
×
399
          padWithSequenceTokens,
400
          addPrefixSpaceToSentence = addPrefixSpaceToSentence,
401
          alwaysAddPrefix = alwaysAddPrefix,
402
          prependString = prependString)
403
      case "mllama" =>
404
        new MLLamaTokenizer(
×
405
          merges,
406
          vocab,
407
          modelSpecialTokens(),
×
408
          padWithSequenceTokens,
409
          addPrefixSpaceToSentence = addPrefixSpaceToSentence)
410
      case "qwen2vl" =>
411
        new Qwen2VLTokenizer(
×
412
          merges,
413
          vocab,
414
          modelSpecialTokens(),
×
415
          padWithSequenceTokens,
416
          addPrefixSpaceToSentence = addPrefixSpaceToSentence,
417
          prependString = prependString)
418
      case "llava" =>
419
        new LLAVATokenizer(
×
420
          merges,
421
          vocab,
422
          modelSpecialTokens(),
×
423
          padWithSequenceTokens,
424
          addPrefixSpaceToSentence = addPrefixSpaceToSentence,
425
          prependString = prependString)
426
      case "phi3v" =>
427
        new Phi3VisionTokenizer(
×
428
          merges,
429
          vocab,
430
          modelSpecialTokens(),
×
431
          padWithSequenceTokens,
432
          addPrefixSpaceToSentence = addPrefixSpaceToSentence,
433
          alwaysAddPrefix = alwaysAddPrefix,
434
          prependString = prependString)
435
      case "smolvlm" =>
436
        new SmolVLMTokenizer(
×
437
          merges,
438
          vocab,
439
          modelSpecialTokens(),
×
440
          padWithSequenceTokens,
441
          addPrefixSpaceToSentence = addPrefixSpaceToSentence,
442
          prependString = prependString)
443
      case "paligemma" =>
444
        new PaliGemmaTokenizer(
×
445
          merges,
446
          vocab,
447
          modelSpecialTokens(),
×
448
          padWithSequenceTokens,
449
          addPrefixSpaceToSentence = addPrefixSpaceToSentence,
450
          alwaysAddPrefix = alwaysAddPrefix,
451
          prependString = prependString)
452
      case "gemma3" =>
453
        new Gemma3Tokenizer(
×
454
          merges,
455
          vocab,
456
          modelSpecialTokens(),
×
457
          padWithSequenceTokens,
458
          addPrefixSpaceToSentence = addPrefixSpaceToSentence,
459
          alwaysAddPrefix = alwaysAddPrefix)
460
      case "internvl" =>
461
        new InternVLTokenizer(
×
462
          merges,
463
          vocab,
464
          modelSpecialTokens(),
×
465
          padWithSequenceTokens,
466
          addPrefixSpaceToSentence = addPrefixSpaceToSentence)
467
      case "florence2" =>
NEW
468
        new Florence2Tokenizer(
×
469
          merges,
470
          vocab,
NEW
471
          modelSpecialTokens(),
×
472
          padWithSequenceTokens,
473
          addPrefixSpaceToSentence = addPrefixSpaceToSentence,
474
          alwaysAddPrefix = alwaysAddPrefix,
475
          prependString = prependString)
476
      case _ =>
477
        throw new IllegalArgumentException("Model type \"" + modelType + "\" not supported yet.")
×
478
    }
479

480
    tokenizer
481
  }
482
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc