• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 9929262317

14 Jul 2024 04:27PM UTC coverage: 62.618% (+0.008%) from 62.61%
9929262317

push

github

maziyarpanahi
Bump version to 5.4.1 [run doc]

2 of 2 new or added lines in 2 files covered. (100.0%)

52 existing lines in 36 files now uncovered.

8970 of 14325 relevant lines covered (62.62%)

0.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.17
/src/main/scala/com/johnsnowlabs/nlp/annotators/spell/norvig/NorvigSweetingModel.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.spell.norvig
18

19
import com.johnsnowlabs.nlp.annotators.spell.util.Utilities
20
import com.johnsnowlabs.nlp.serialization.MapFeature
21
import com.johnsnowlabs.nlp._
22
import org.apache.spark.ml.util.Identifiable
23
import org.slf4j.LoggerFactory
24

25
import scala.collection.immutable.HashSet
26

27
/** This annotator retrieves tokens and makes corrections automatically if not found in an English
28
  * dictionary. Inspired by Norvig model and [[https://github.com/wolfgarbe/SymSpell SymSpell]].
29
  *
30
  * The Symmetric Delete spelling correction algorithm reduces the complexity of edit candidate
31
  * generation and dictionary lookup for a given Damerau-Levenshtein distance. It is six orders of
32
  * magnitude faster (than the standard approach with deletes + transposes + replaces + inserts)
33
  * and language independent.
34
  *
35
  * This is the instantiated model of the [[NorvigSweetingApproach]]. For training your own model,
36
  * please see the documentation of that class.
37
  *
38
  * Pretrained models can be loaded with `pretrained` of the companion object:
39
  * {{{
40
  * val spellChecker = NorvigSweetingModel.pretrained()
41
  *   .setInputCols("token")
42
  *   .setOutputCol("spell")
43
  *   .setDoubleVariants(true)
44
  * }}}
45
  * The default model is `"spellcheck_norvig"`, if no name is provided. For available pretrained
46
  * models please see the [[https://sparknlp.org/models?task=Spell+Check Models Hub]].
47
  *
48
  * For extended examples of usage, see the
49
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/spell/norvig/NorvigSweetingTestSpec.scala NorvigSweetingTestSpec]].
50
  *
51
  * ==Example==
52
  * {{{
53
  * import spark.implicits._
54
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
55
  * import com.johnsnowlabs.nlp.annotators.Tokenizer
56
  * import com.johnsnowlabs.nlp.annotators.spell.norvig.NorvigSweetingModel
57
  *
58
  * import org.apache.spark.ml.Pipeline
59
  *
60
  * val documentAssembler = new DocumentAssembler()
61
  *   .setInputCol("text")
62
  *   .setOutputCol("document")
63
  *
64
  * val tokenizer = new Tokenizer()
65
  *   .setInputCols("document")
66
  *   .setOutputCol("token")
67
  *
68
  * val spellChecker = NorvigSweetingModel.pretrained()
69
  *   .setInputCols("token")
70
  *   .setOutputCol("spell")
71
  *
72
  * val pipeline = new Pipeline().setStages(Array(
73
  *   documentAssembler,
74
  *   tokenizer,
75
  *   spellChecker
76
  * ))
77
  *
78
  * val data = Seq("somtimes i wrrite wordz erong.").toDF("text")
79
  * val result = pipeline.fit(data).transform(data)
80
  * result.select("spell.result").show(false)
81
  * +--------------------------------------+
82
  * |result                                |
83
  * +--------------------------------------+
84
  * |[sometimes, i, write, words, wrong, .]|
85
  * +--------------------------------------+
86
  * }}}
87
  *
88
  * @see
89
  *   [[com.johnsnowlabs.nlp.annotators.spell.symmetric.SymmetricDeleteModel SymmetricDeleteModel]]
90
  *   for an alternative approach to spell checking
91
  * @see
92
  *   [[com.johnsnowlabs.nlp.annotators.spell.context.ContextSpellCheckerModel ContextSpellCheckerModel]]
93
  *   for a DL based approach
94
  * @groupname anno Annotator types
95
  * @groupdesc anno
96
  *   Required input and expected output annotator types
97
  * @groupname Ungrouped Members
98
  * @groupname param Parameters
99
  * @groupname setParam Parameter setters
100
  * @groupname getParam Parameter getters
101
  * @groupname Ungrouped Members
102
  * @groupprio param  1
103
  * @groupprio anno  2
104
  * @groupprio Ungrouped 3
105
  * @groupprio setParam  4
106
  * @groupprio getParam  5
107
  * @groupdesc param
108
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
109
  *   parameter values through setters and getters, respectively.
110
  */
111
class NorvigSweetingModel(override val uid: String)
112
    extends AnnotatorModel[NorvigSweetingModel]
113
    with HasSimpleAnnotate[NorvigSweetingModel]
114
    with NorvigSweetingParams {
115

116
  import com.johnsnowlabs.nlp.AnnotatorType._
117

118
  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
119
    * type
120
    */
121
  def this() = this(Identifiable.randomUID("SPELL"))
1✔
122

123
  private val logger = LoggerFactory.getLogger("NorvigApproach")
1✔
124

125
  /** Output annotator type : TOKEN
126
    *
127
    * @group anno
128
    */
129
  override val outputAnnotatorType: AnnotatorType = TOKEN
1✔
130

131
  /** Input annotator type : TOKEN
132
    *
133
    * @group anno
134
    */
135
  override val inputAnnotatorTypes: Array[AnnotatorType] = Array(TOKEN)
1✔
136

137
  /** Number of words in the dictionary
138
    *
139
    * @group param
140
    */
141
  protected val wordCount: MapFeature[String, Long] = new MapFeature(this, "wordCount")
1✔
142

143
  /** @group getParam */
144
  protected def getWordCount: Map[String, Long] = $$(wordCount)
×
145

146
  /** @group setParam */
147
  def setWordCount(value: Map[String, Long]): this.type = set(wordCount, value)
1✔
148

149
  private lazy val allWords: HashSet[String] = {
150
    if ($(caseSensitive)) HashSet($$(wordCount).keys.toSeq: _*)
151
    else HashSet($$(wordCount).keys.toSeq.map(_.toLowerCase): _*)
152
  }
153

154
  private lazy val frequencyBoundaryValues: (Long, Long) = {
155
    val min: Long = $$(wordCount).filter(_._1.length > $(wordSizeIgnore)).minBy(_._2)._2
156
    val max = $$(wordCount).filter(_._1.length > $(wordSizeIgnore)).maxBy(_._2)._2
157
    (min, max)
158
  }
159

160
  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
161
    annotations.map { token =>
1✔
162
      val verifiedWord = checkSpellWord(token.result)
1✔
163
      Annotation(
1✔
164
        outputAnnotatorType,
1✔
165
        token.begin,
1✔
166
        token.end,
1✔
167
        verifiedWord._1,
1✔
168
        Map("confidence" -> verifiedWord._2.toString, "sentence" -> token.metadata("sentence")))
1✔
169
    }
170
  }
171

172
  def checkSpellWord(raw: String): (String, Double) = {
173
    val input = Utilities.limitDuplicates($(dupsLimit), raw)
1✔
174
    logger.debug(s"spell checker target word: $input")
1✔
175
    val possibility = getBestSpellingSuggestion(input)
1✔
176
    if (possibility._1.isDefined) return (possibility._1.get, possibility._2)
1✔
177
    val listedSuggestions = suggestions(input)
1✔
178
    val sortedFrequencies = getSortedWordsByFrequency(listedSuggestions, input)
1✔
179
    val sortedHamming = getSortedWordsByHamming(listedSuggestions, input)
1✔
180
    getResult(sortedFrequencies, sortedHamming, input)
1✔
181
  }
182

183
  private def getBestSpellingSuggestion(word: String): (Option[String], Double) = {
184
    var score: Double = 0
1✔
UNCOV
185
    if ($(shortCircuit)) {
×
186
      val suggestedWord = getShortCircuitSuggestion(word).getOrElse(word)
×
187
      score = getScoreFrequency(suggestedWord)
×
188
      (Some(suggestedWord), score)
×
189
    } else {
1✔
190
      val suggestions = getSuggestion(word: String)
1✔
191
      (suggestions._1, suggestions._2)
1✔
192
    }
193
  }
194

195
  private def getShortCircuitSuggestion(word: String): Option[String] = {
196
    if (Utilities.reductions(word, $(reductLimit)).exists(allWords.contains)) Some(word)
×
197
    else if (Utilities.getVowelSwaps(word, $(vowelSwapLimit)).exists(allWords.contains))
×
198
      Some(word)
×
199
    else if (Utilities.variants(word).exists(allWords.contains)) Some(word)
×
200
    else if (both(word).exists(allWords.contains)) Some(word)
×
201
    else if ($(doubleVariants) && computeDoubleVariants(word).exists(allWords.contains))
×
202
      Some(word)
×
203
    else None
×
204
  }
205

206
  /** variants of variants of a word */
207
  def computeDoubleVariants(word: String): List[String] =
208
    Utilities.variants(word).flatMap(variant => Utilities.variants(variant))
×
209

210
  private def getSuggestion(word: String): (Option[String], Double) = {
211
    if (allWords.contains(word)) {
1✔
212
      logger.debug("Word found in dictionary. No spell change")
1✔
213
      (Some(word), 1)
1✔
214
    } else if (word.length <= $(wordSizeIgnore)) {
1✔
215
      logger.debug("word ignored because length is less than wordSizeIgnore")
1✔
216
      (Some(word), 0)
1✔
217
    } else if (allWords.contains(word.distinct)) {
1✔
218
      logger.debug("Word as distinct found in dictionary")
1✔
219
      val score = getScoreFrequency(word.distinct)
1✔
220
      (Some(word.distinct), score)
1✔
221
    } else (None, -1)
1✔
222
  }
223

224
  def getScoreFrequency(word: String): Double = {
225
    val frequency = Utilities.getFrequency(word, $$(wordCount))
1✔
226
    normalizeFrequencyValue(frequency)
1✔
227
  }
228

229
  def normalizeFrequencyValue(value: Long): Double = {
230
    if (value > frequencyBoundaryValues._2) {
1✔
231
      return 1
1✔
232
    }
233
    if (value < frequencyBoundaryValues._1) {
1✔
234
      return 0
×
235
    }
236
    val normalizedValue =
237
      (value - frequencyBoundaryValues._1).toDouble / (frequencyBoundaryValues._2 - frequencyBoundaryValues._1).toDouble
1✔
238
    BigDecimal(normalizedValue).setScale(4, BigDecimal.RoundingMode.HALF_UP).toDouble
1✔
239
  }
240

241
  private def suggestions(word: String): List[String] = {
242
    val intersectedPossibilities = allWords.intersect({
1✔
243
      val base =
244
        Utilities.reductions(word, $(reductLimit)) ++
1✔
245
          Utilities.getVowelSwaps(word, $(vowelSwapLimit)) ++
1✔
246
          Utilities.variants(word) ++
1✔
247
          both(word)
1✔
248
      if ($(doubleVariants)) base ++ computeDoubleVariants(word) else base
1✔
249
    }.toSet)
1✔
250
    if (intersectedPossibilities.nonEmpty) intersectedPossibilities.toList
1✔
251
    else List.empty[String]
1✔
252
  }
253

254
  private def both(word: String): List[String] = {
255
    Utilities
256
      .reductions(word, $(reductLimit))
1✔
257
      .flatMap(reduction => Utilities.getVowelSwaps(reduction, $(vowelSwapLimit)))
1✔
258
  }
259

260
  def getSortedWordsByFrequency(words: List[String], input: String): List[(String, Long)] = {
261
    val filteredWords = words.withFilter(_.length >= input.length)
1✔
262
    val sortedWordsByFrequency = filteredWords
263
      .map(word => (word, compareFrequencies(word)))
1✔
264
      .sortWith(_._2 > _._2)
1✔
265
      .take($(intersections))
1✔
266
    logger.debug(s"recommended by frequency: ${sortedWordsByFrequency.mkString(", ")}")
1✔
267
    sortedWordsByFrequency
268
  }
269

270
  private def compareFrequencies(value: String): Long =
271
    Utilities.getFrequency(value, $$(wordCount))
1✔
272

273
  def getSortedWordsByHamming(words: List[String], input: String): List[(String, Long)] = {
274
    val sortedWordByHamming = words
275
      .map(word => (word, compareHammers(input)(word)))
1✔
276
      .sortBy(_._2)
1✔
277
      .takeRight($(intersections))
1✔
278
    logger.debug(s"recommended by hamming: ${sortedWordByHamming.mkString(", ")}")
1✔
279
    sortedWordByHamming
280
  }
281

282
  private def compareHammers(input: String)(value: String): Long =
283
    Utilities.computeHammingDistance(input, value)
1✔
284

285
  def getResult(
286
      wordsByFrequency: List[(String, Long)],
287
      wordsByHamming: List[(String, Long)],
288
      input: String): (String, Double) = {
289
    var recommendation: (Option[String], Double) = (None, 0)
1✔
290
    val intersectWords =
291
      wordsByFrequency.map(word => word._1).intersect(wordsByHamming.map(word => word._1))
1✔
292
    if (wordsByFrequency.isEmpty && wordsByHamming.isEmpty) {
1✔
293
      logger.debug("no intersection or frequent words found")
1✔
294
      recommendation = (Some(input), 0)
1✔
295
    } else if (wordsByFrequency.isEmpty || wordsByHamming.isEmpty) {
1✔
296
      logger.debug("no intersection but one recommendation found")
1✔
297
      recommendation = getRecommendation(wordsByFrequency, wordsByHamming)
1✔
298
    } else if (intersectWords.nonEmpty) {
1✔
299
      logger.debug("hammer and frequency recommendations found")
1✔
300
      val frequencyAndHammingRecommendation =
301
        getFrequencyAndHammingRecommendation(wordsByFrequency, wordsByHamming, intersectWords)
1✔
302
      recommendation =
303
        (frequencyAndHammingRecommendation._1, frequencyAndHammingRecommendation._2)
1✔
304
    } else {
1✔
305
      logger.debug("no intersection of hammer and frequency")
1✔
306
      recommendation =
307
        getFrequencyOrHammingRecommendation(wordsByFrequency, wordsByHamming, input)
1✔
308
    }
309
    (recommendation._1.getOrElse(input), recommendation._2)
1✔
310
  }
311

312
  private def getRecommendation(
313
      wordsByFrequency: List[(String, Long)],
314
      wordsByHamming: List[(String, Long)]) = {
315
    if (wordsByFrequency.nonEmpty) {
1✔
316
      getResultByFrequency(wordsByFrequency)
×
317
    } else {
318
      getResultByHamming(wordsByHamming)
1✔
319
    }
320
  }
321

322
  private def getFrequencyAndHammingRecommendation(
323
      wordsByFrequency: List[(String, Long)],
324
      wordsByHamming: List[(String, Long)],
325
      intersectWords: List[String]): (Option[String], Double) = {
326
    val wordsByFrequencyAndHamming = intersectWords.map { word =>
1✔
327
      val frequency = wordsByFrequency.find(_._1 == word).get._2
1✔
328
      val hamming = wordsByHamming.find(_._1 == word).get._2
1✔
329
      (word, frequency, hamming)
1✔
330
    }
331
    val bestFrequencyValue = wordsByFrequencyAndHamming.maxBy(_._2)._2
1✔
332
    val bestHammingValue = wordsByFrequencyAndHamming.minBy(_._3)._3
1✔
333
    val bestRecommendations = wordsByFrequencyAndHamming.filter(word =>
1✔
334
      word._2 == bestFrequencyValue && word._3 == bestHammingValue)
1✔
335
    if (bestRecommendations.nonEmpty) {
1✔
336
      val result = (
1✔
337
        Utilities.getRandomValueFromList(bestRecommendations),
1✔
338
        Utilities.computeConfidenceValue(bestRecommendations))
1✔
339
      (Some(result._1.get._1), result._2)
1✔
340
    } else {
341
      if ($(frequencyPriority)) {
1✔
342
        (Some(wordsByFrequencyAndHamming.sortBy(_._3).maxBy(_._2)._1), 1.toDouble)
1✔
343
      } else {
344
        (Some(wordsByFrequencyAndHamming.sortBy(_._2).reverse.minBy(_._3)._1), 1.toDouble)
1✔
345
      }
346
    }
347
  }
348

349
  def getResultByFrequency(wordsByFrequency: List[(String, Long)]): (Option[String], Double) = {
350
    val bestFrequencyValue = wordsByFrequency.maxBy(_._2)._2
1✔
351
    val bestRecommendations = wordsByFrequency.filter(_._2 == bestFrequencyValue).map(_._1)
1✔
352
    (
1✔
353
      Utilities.getRandomValueFromList(bestRecommendations),
1✔
354
      Utilities.computeConfidenceValue(bestRecommendations))
1✔
355
  }
356

357
  def getResultByHamming(wordsByHamming: List[(String, Long)]): (Option[String], Double) = {
358
    val bestHammingValue = wordsByHamming.minBy(_._2)._2
1✔
359
    val bestRecommendations = wordsByHamming.filter(_._2 == bestHammingValue).map(_._1)
1✔
360
    (
1✔
361
      Utilities.getRandomValueFromList(bestRecommendations),
1✔
362
      Utilities.computeConfidenceValue(bestRecommendations))
1✔
363
  }
364

365
  def getFrequencyOrHammingRecommendation(
366
      wordsByFrequency: List[(String, Long)],
367
      wordsByHamming: List[(String, Long)],
368
      input: String): (Option[String], Double) = {
369
    val frequencyResult: String = getResultByFrequency(wordsByFrequency)._1.getOrElse(input)
1✔
370
    val hammingResult: String = getResultByHamming(wordsByHamming)._1.getOrElse(input)
1✔
371
    var result = List(frequencyResult, hammingResult)
1✔
372
    if (frequencyResult == input) {
1✔
373
      result = List(hammingResult)
×
374
    } else if (hammingResult == input) {
1✔
375
      result = List(frequencyResult)
×
376
    }
377

378
    (Utilities.getRandomValueFromList(result), Utilities.computeConfidenceValue(result))
1✔
379
  }
380

381
}
382

383
trait ReadablePretrainedNorvig
384
    extends ParamsAndFeaturesReadable[NorvigSweetingModel]
385
    with HasPretrained[NorvigSweetingModel] {
386
  override val defaultModelName = Some("spellcheck_norvig")
1✔
387

388
  /** Java compliant-overrides */
389
  override def pretrained(): NorvigSweetingModel = super.pretrained()
×
390

391
  override def pretrained(name: String): NorvigSweetingModel = super.pretrained(name)
×
392

393
  override def pretrained(name: String, lang: String): NorvigSweetingModel =
394
    super.pretrained(name, lang)
×
395

396
  override def pretrained(name: String, lang: String, remoteLoc: String): NorvigSweetingModel =
397
    super.pretrained(name, lang, remoteLoc)
×
398
}
399

400
/** This is the companion object of [[NorvigSweetingModel]]. Please refer to that class for the
401
  * documentation.
402
  */
403
object NorvigSweetingModel extends ReadablePretrainedNorvig
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc