• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 7861513225

11 Feb 2024 11:05AM UTC coverage: 62.678% (-0.05%) from 62.731%
7861513225

Pull #14169

github

web-flow
Merge 13f2acde4 into 6010244ba
Pull Request #14169: Fixed a bug with models that has 'onnx_data' file not working in dbfs/hdfs

8951 of 14281 relevant lines covered (62.68%)

0.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.12
/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeTokenizer.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.tokenizer.bpe
18

19
import com.johnsnowlabs.nlp.annotators.common.{IndexedToken, Sentence, TokenPiece}
20
import org.apache.commons.lang3.StringUtils
21

22
import scala.collection.mutable
23
import scala.collection.mutable.ListBuffer
24

25
/** A BPE Tokenizer based on GPT2's tokenization scheme. The tokenization can then be used for
26
  * models based on this scheme (e.g. GPT2, roBERTa, DeBERTa)
27
  *
28
  * TODO: truncation assumed?
29
  *
30
  * @param merges
31
  *   Map of tokens that are mergeable
32
  * @param vocab
33
  *   Map of tokens to encoded representation
34
  * @param specialTokens
35
  *   Collection of special tokens
36
  * @param padWithSequenceTokens
37
  *   Whether to pad the sentence with sentence tokens at the start and end
38
  * @param addPrefixSpaceToSentence
39
  *   Whether to add a space to the first word of a sentence
40
  * @param alwaysAddPrefix
41
  *   Whether to always prefix token ids with `prefixForPieceId`
42
  */
43
private[nlp] abstract class BpeTokenizer(
44
    val merges: Map[(String, String), Int],
45
    val vocab: Map[String, Int],
46
    val specialTokens: SpecialTokens,
47
    val padWithSequenceTokens: Boolean,
48
    val addPrefixSpaceToSentence: Boolean,
49
    val alwaysAddPrefix: Boolean) {
50

51
  protected val bpeRanks: Map[(String, String), Int] = {
52
    merges
1✔
53
  }
54

55
  /** Rankings for the byte pairs. Derived from merges.txt */
56
  protected def getBpeRanking: ((String, String)) => Int =
57
    (bytePair: (String, String)) => bpeRanks.getOrElse(bytePair, Integer.MAX_VALUE)
1✔
58

59
  /** cache for already encoded tokens */
60
  protected val cache: mutable.Map[String, Array[String]] = mutable.Map()
1✔
61

62
  /** Create a sequence of byte-pairs of the word */
63
  protected def getBytePairs(word: Array[String]): Array[(String, String)] = {
64
    val createPairs = (i: Int) => (word(i), word(i + 1))
1✔
65
    (0 until (word.length - 1)).map(createPairs).toArray
1✔
66
  }
67

68
  // Can be overridden in inherited class
69
  protected val prefixForPieceId: Option[String] = None
1✔
70
  protected val suffixForPieceId: Option[String] = None
1✔
71

72
  protected def performMerges(
73
      wordChars: Array[String],
74
      charPairs: Array[(String, String)]): Array[String] = {
75
    var word = wordChars
76
    var pairs = charPairs
77
    // get highest priority byte-pair first
78
    var bytePair: (String, String) =
79
      pairs.sortWith(getBpeRanking(_) < getBpeRanking(_))(0)
1✔
80
    var done = false
1✔
81
    // while we still have byte-pairs from our vocabulary
82
    while (bpeRanks.contains(bytePair) && !done) {
1✔
83
      val (first, second) = bytePair
1✔
84
      val newWord: ListBuffer[String] = ListBuffer()
1✔
85
      var i = 0
1✔
86
      var j = 0
1✔
87
      // keep combining characters with the current byte-pair
88
      while ((i < word.length) && (j != -1)) {
1✔
89
        j = word.indexOf(first, i)
1✔
90
        if (j == -1) newWord ++= word.drop(i)
1✔
91
        else {
1✔
92
          newWord ++= word.slice(i, j)
1✔
93
          i = j
94
          val bpIsAtIndex =
95
            (word(i) == first) && (i < word.length - 1) && word(i + 1) == second
1✔
96
          if (bpIsAtIndex) {
1✔
97
            newWord += (first + second)
1✔
98
            i += 2
1✔
99
          } else {
1✔
100
            newWord += word(i)
1✔
101
            i += 1
1✔
102
          }
103
        }
104
      }
105
      word = newWord.toArray
1✔
106
      // if we were able to create a whole word that was in the vocabulary, we're done
107
      if (word.length == 1) {
1✔
108
        done = true
1✔
109
      } else {
1✔
110
        // do it again with the next byte-pair
111
        pairs = getBytePairs(word)
1✔
112
        bytePair = pairs.sortWith(getBpeRanking(_) < getBpeRanking(_))(0)
1✔
113
      }
114
    }
115
    word
116
  }
117

118
  protected def getTokenPieces(indToken: IndexedToken, word: Array[String]): Array[TokenPiece] = {
119
    var currentIndex = indToken.begin
1✔
120
    val wordIndexes = word.map((subWord: String) => {
1✔
121
      val startIndex = currentIndex
122
      currentIndex = startIndex + subWord.length
1✔
123
      (startIndex, startIndex + subWord.length - 1)
1✔
124
    })
125
    val result = word
126
      .zip(wordIndexes)
1✔
127
      .map { case (subWord: String, indexes: (Int, Int)) =>
1✔
128
        val isWordStart = indToken.begin == indexes._1
1✔
129
        val isDocumentStart = indToken.begin == 0
1✔
130
        var processedSubWord = subWord
131
        processedSubWord = if (isDocumentStart && !addPrefixSpaceToSentence) {
1✔
132
          processedSubWord
1✔
133
        } else
134
          prefixForPieceId match {
1✔
135
            case Some(prepend) if alwaysAddPrefix =>
1✔
136
              if (isWordStart && subWord.indexOf(prepend) < 0) prepend + processedSubWord
1✔
137
              else processedSubWord
1✔
138
            case _ => processedSubWord
139
          }
140
        processedSubWord = suffixForPieceId match {
1✔
141
          case None => processedSubWord
142
          case Some(append) =>
143
            val isWordEnd = indToken.end == indexes._2
1✔
144
            if (isWordEnd && subWord.indexOf(append) < 0) processedSubWord + append
1✔
145
            else processedSubWord
1✔
146
        }
147
        // Set unknown id if not found
148
        val subWordId: Int = vocab.getOrElse(processedSubWord, specialTokens.unk.id)
1✔
149

150
        TokenPiece(subWord, indToken.token.trim(), subWordId, isWordStart, indexes._1, indexes._2)
1✔
151

152
      }
153
    result
154
  }
155

156
  /** Do the BPE algorithm. Goal is to find the token as the largest words in the known
157
    * vocabulary. If not possible, the word is split into smaller subwords, until they are known.
158
    *
159
    * @return
160
    *   Array of TokenPieces, corresponding to encoded token
161
    */
162
  protected def bpe(indToken: IndexedToken): Array[TokenPiece] = {
163
    var processedToken = ""
1✔
164
    try {
1✔
165
      processedToken = preProcessTokenForBpe(indToken.token)
1✔
166
      // TODO: Caching
167
      var word: Array[String] = Array[String]()
1✔
168
      // split the word into characters, to be combined into subwords
169
      word = processedToken.map(_.toString).toArray
1✔
170
      val pairs: Array[(String, String)] = getBytePairs(word)
1✔
171

172
      if (pairs.isEmpty)
1✔
173
        word = Array(processedToken)
1✔
174
      else
175
        word = performMerges(word, pairs)
1✔
176

177
      getTokenPieces(indToken, word)
1✔
178
    } catch {
179
      case _: java.util.NoSuchElementException =>
180
        Array(
×
181
          TokenPiece(
×
182
            indToken.token,
×
183
            indToken.token,
×
184
            specialTokens.unk.id,
×
185
            isWordStart = true,
×
186
            indToken.begin,
×
187
            indToken.end))
×
188
    }
189
  }
190

191
  /** Split the the individual sub texts on special tokens, e.g. masking etc. */
192
  protected def splitOnSpecialToken(
193
      specialToken: SpecialToken,
194
      text: String): ListBuffer[String] = {
195
    val isControl = (c: Char) => {
196
      if (c == '\t' || c == '\n' || c == '\r') false // count as whitespace
×
197
      else c.isControl
×
198
    }
199
    val isPunctuation =
200
      (c: Char) => raw"""[^[:alnum:]]""".r.findFirstIn(c.toString).isDefined
×
201
    val isWordBorder =
202
      (c: Char) => isControl(c) || isPunctuation(c) || c.isWhitespace
×
203

204
    val isEndOfWord = (text: String) => isWordBorder(text.last)
×
205
    val isStartOfWord = (text: String) => isWordBorder(text.head)
×
206

207
    val result: ListBuffer[String] = ListBuffer()
1✔
208
    val tok = specialToken.content
1✔
209

210
    val splitText = StringUtils.splitByWholeSeparator(text, tok)
1✔
211
    var fullWord = ""
1✔
212

213
    for ((subText, i) <- splitText.zipWithIndex) {
1✔
214
      var done = false
1✔
215
      // Try to avoid splitting on token
216
      if (specialToken.singleWord) {
1✔
217
        if ((i < (splitText.length - 1)) && !isEndOfWord(subText) && !isStartOfWord(
×
218
            splitText(i + 1))) fullWord += subText + tok
×
219
        else if (fullWord.nonEmpty) {
×
220
          fullWord += subText
×
221
          result += fullWord
×
222
          fullWord = ""
×
223
          done = true
×
224
        }
225
      }
226
      if (!done) {
×
227
        // A bit counter-intuitive but we strip the left of the string
228
        // since rstrip means the special token is eating all white spaces on its right
229
        var subTextProcessed: String = subText
230
        if (specialToken.rstrip && i > 0)
×
231
          subTextProcessed = subText.stripPrefix(" ")
×
232
        if (specialToken.lstrip && i < (splitText.length - 1))
1✔
233
          subTextProcessed = subText.stripSuffix(" ")
1✔
234
        if (i == 0 && subTextProcessed.isEmpty)
1✔
235
          result += tok
×
236
        else if (i == (splitText.length - 1)) {
1✔
237
          if (subTextProcessed.nonEmpty) result += subTextProcessed
1✔
238
        } else {
1✔
239
          if (subTextProcessed.nonEmpty) result += subTextProcessed
×
240
          result += tok
1✔
241
        }
242
      }
243
    }
244
    result
245
  }
246

247
  /** Needs to be implemented */
248
  protected def tokenizeSubText(text: String, indexOffset: Int): Array[IndexedToken]
249

250
  /** Special tokens of the model for processing */
251
  val sentencePadding: (String, String) =
252
    (specialTokens.sentenceStart.content, specialTokens.sentenceEnd.content)
1✔
253

254
  /** Tokenize considering special tokens and split algorithm */
255
  def tokenize(sentence: Sentence): Array[IndexedToken] = {
256
    var text = sentence.content
1✔
257
    if (text.trim.isEmpty) Array[IndexedToken]()
1✔
258
    else {
1✔
259
      val splitTexts: ListBuffer[String] = ListBuffer()
1✔
260
      var textList: ListBuffer[String] = ListBuffer(text)
1✔
261

262
      for (transformations <- specialTokens.allTokens) {
1✔
263
        splitTexts.clear()
1✔
264
        for (subText <- textList) {
1✔
265
          if (!specialTokens.contains(subText))
1✔
266
            splitTexts ++= splitOnSpecialToken(transformations, subText)
1✔
267
          else
268
            splitTexts += subText
1✔
269
        }
270
        textList = splitTexts.clone()
1✔
271
      }
272

273
      if (padWithSequenceTokens) {
1✔
274
        text = sentencePadding._1 + text + sentencePadding._2
1✔
275
        splitTexts.prepend(sentencePadding._1)
1✔
276
        splitTexts.append(sentencePadding._2)
1✔
277
      }
278

279
      var currentIndex = 0
1✔
280
      val result = mutable.ArrayBuffer[IndexedToken]()
1✔
281
      for (subText <- splitTexts) {
1✔
282
        val subTextIndex = sentence.start + text.indexOf(subText, currentIndex)
1✔
283
        if (!specialTokens.contains(subText)) {
1✔
284
          val splitSubText: Array[IndexedToken] = tokenizeSubText(subText, subTextIndex)
1✔
285
          result.append(splitSubText: _*)
1✔
286
        } else // subtext is just the special token
287
          result.append(
1✔
288
            IndexedToken(subText, begin = subTextIndex, end = subTextIndex + subText.length - 1))
1✔
289
        currentIndex = subTextIndex + subText.length
1✔
290
      }
291
      result.toArray
1✔
292
    }
293
  }
294

295
  protected def preProcessTokenForBpe(token: String): String = token
296

297
  def encode(indToken: IndexedToken): Array[TokenPiece] = {
298
    if (!specialTokens.contains(indToken.token))
1✔
299
      bpe(indToken)
1✔
300
    else
301
      Array(
1✔
302
        TokenPiece(
1✔
303
          indToken.token,
1✔
304
          indToken.token,
1✔
305
          vocab(indToken.token),
1✔
306
          isWordStart = true,
1✔
307
          indToken.begin,
1✔
308
          indToken.end))
1✔
309
  }
310

311
  def encode(indTokens: Array[IndexedToken]): Array[TokenPiece] = indTokens.flatMap(encode(_))
1✔
312
}
313

314
object BpeTokenizer {
315
  def forModel(
316
      modelType: String,
317
      merges: Map[(String, String), Int],
318
      vocab: Map[String, Int],
319
      padWithSequenceTokens: Boolean = false,
320
      addPrefixSpaceToSentence: Boolean = false,
321
      specialTokens: Option[SpecialTokens] = None,
322
      alwaysAddPrefix: Boolean = true): BpeTokenizer = {
323

324
    def modelSpecialTokens() = specialTokens match {
325
      case Some(specialTok) => specialTok
326
      case None => SpecialTokens.getSpecialTokensForModel(modelType, vocab)
1✔
327
    }
328

329
    val tokenizer = modelType match {
330
      case "roberta" =>
331
        new RobertaTokenizer(
1✔
332
          merges,
333
          vocab,
334
          modelSpecialTokens(),
1✔
335
          padWithSequenceTokens,
336
          addPrefixSpaceToSentence = addPrefixSpaceToSentence,
337
          alwaysAddPrefix = alwaysAddPrefix)
338
      case "xlm" =>
339
        new XlmTokenizer(merges, vocab, modelSpecialTokens(), padWithSequenceTokens)
1✔
340
      case "gpt2" =>
341
        new Gpt2Tokenizer(
1✔
342
          merges,
343
          vocab,
344
          modelSpecialTokens(),
1✔
345
          padWithSequenceTokens,
346
          addPrefixSpaceToSentence = addPrefixSpaceToSentence,
347
          alwaysAddPrefix = alwaysAddPrefix)
348
      case "bart" =>
349
        new BartTokenizer(
1✔
350
          merges,
351
          vocab,
352
          modelSpecialTokens(),
1✔
353
          padWithSequenceTokens,
354
          addPrefixSpaceToSentence = addPrefixSpaceToSentence)
355
      case "clip" =>
356
        new CLIPTokenizer(merges, vocab, modelSpecialTokens())
1✔
357
      case _ =>
358
        throw new IllegalArgumentException("Model type \"" + modelType + "\" not supported yet.")
×
359
    }
360

361
    tokenizer
362
  }
363
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc