• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4992350528

pending completion
4992350528

Pull #13797

github

GitHub
Merge 424c7ff18 into ef7906c5e
Pull Request #13797: SPARKNLP-835: ProtectedParam and ProtectedFeature

24 of 24 new or added lines in 6 files covered. (100.0%)

8643 of 13129 relevant lines covered (65.83%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

66.27
/src/main/scala/com/johnsnowlabs/nlp/annotators/ld/dl/LanguageDetectorDL.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.ld.dl
18

19
import com.johnsnowlabs.ml.tensorflow._
20
import com.johnsnowlabs.ml.util.LoadExternalModel.{
21
  loadTextAsset,
22
  modelSanityCheck,
23
  notSupportedEngineError
24
}
25
import com.johnsnowlabs.ml.util.ModelEngine
26
import com.johnsnowlabs.nlp._
27
import com.johnsnowlabs.nlp.annotators.common._
28
import com.johnsnowlabs.nlp.serialization.MapFeature
29
import org.apache.spark.broadcast.Broadcast
30
import org.apache.spark.ml.param._
31
import org.apache.spark.ml.util.Identifiable
32
import org.apache.spark.sql.SparkSession
33

34
import scala.collection.immutable.ListMap
35

36
/** Language Identification and Detection by using CNN and RNN architectures in TensorFlow.
37
  *
38
  * `LanguageDetectorDL` is an annotator that detects the language of documents or sentences
39
  * depending on the inputCols. The models are trained on large datasets such as Wikipedia and
40
  * Tatoeba. Depending on the language (how similar the characters are), the LanguageDetectorDL
41
  * works best with text longer than 140 characters. The output is a language code in
42
  * [[https://en.wikipedia.org/wiki/List_of_Wikipedias Wiki Code style]].
43
  *
44
  * Pretrained models can be loaded with `pretrained` of the companion object:
45
  * {{{
46
  * Val languageDetector = LanguageDetectorDL.pretrained()
47
  *   .setInputCols("sentence")
48
  *   .setOutputCol("language")
49
  * }}}
50
  * The default model is `"ld_wiki_tatoeba_cnn_21"`, default language is `"xx"` (meaning
51
  * multi-lingual), if no values are provided. For available pretrained models please see the
52
  * [[https://sparknlp.org/models?task=Language+Detection Models Hub]].
53
  *
54
  * For extended examples of usage, see the
55
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/text/english/language-detection/Language_Detection_and_Indentification.ipynb Examples]]
56
  * And the
57
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/ld/dl/LanguageDetectorDLTestSpec.scala LanguageDetectorDLTestSpec]].
58
  *
59
  * ==Example==
60
  * {{{
61
  * import spark.implicits._
62
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
63
  * import com.johnsnowlabs.nlp.annotators.ld.dl.LanguageDetectorDL
64
  * import org.apache.spark.ml.Pipeline
65
  *
66
  * val documentAssembler = new DocumentAssembler()
67
  *   .setInputCol("text")
68
  *   .setOutputCol("document")
69
  *
70
  * val languageDetector = LanguageDetectorDL.pretrained()
71
  *   .setInputCols("document")
72
  *   .setOutputCol("language")
73
  *
74
  * val pipeline = new Pipeline()
75
  *   .setStages(Array(
76
  *     documentAssembler,
77
  *     languageDetector
78
  *   ))
79
  *
80
  * val data = Seq(
81
  *   "Spark NLP is an open-source text processing library for advanced natural language processing for the Python, Java and Scala programming languages.",
82
  *   "Spark NLP est une bibliothèque de traitement de texte open source pour le traitement avancé du langage naturel pour les langages de programmation Python, Java et Scala.",
83
  *   "Spark NLP ist eine Open-Source-Textverarbeitungsbibliothek für fortgeschrittene natürliche Sprachverarbeitung für die Programmiersprachen Python, Java und Scala."
84
  * ).toDF("text")
85
  * val result = pipeline.fit(data).transform(data)
86
  *
87
  * result.select("language.result").show(false)
88
  * +------+
89
  * |result|
90
  * +------+
91
  * |[en]  |
92
  * |[fr]  |
93
  * |[de]  |
94
  * +------+
95
  * }}}
96
  *
97
  * @groupname anno Annotator types
98
  * @groupdesc anno
99
  *   Required input and expected output annotator types
100
  * @groupname Ungrouped Members
101
  * @groupname param Parameters
102
  * @groupname setParam Parameter setters
103
  * @groupname getParam Parameter getters
104
  * @groupname Ungrouped Members
105
  * @groupprio param  1
106
  * @groupprio anno  2
107
  * @groupprio Ungrouped 3
108
  * @groupprio setParam  4
109
  * @groupprio getParam  5
110
  * @groupdesc param
111
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
112
  *   parameter values through setters and getters, respectively.
113
  */
114
class LanguageDetectorDL(override val uid: String)
115
    extends AnnotatorModel[LanguageDetectorDL]
116
    with HasSimpleAnnotate[LanguageDetectorDL]
117
    with WriteTensorflowModel
118
    with HasEngine {
119

120
  def this() = this(Identifiable.randomUID("LANGUAGE_DETECTOR_DL"))
×
121

122
  /** Alphabet used to feed the TensorFlow model for prediction
123
    *
124
    * @group param
125
    */
126
  val alphabet: MapFeature[String, Int] = new MapFeature(this, "alphabet").setProtected()
1✔
127

128
  /** Language used to map prediction to ISO 639-1 language codes
129
    *
130
    * @group param
131
    */
132
  val language: MapFeature[String, Int] = new MapFeature(this, "language").setProtected()
1✔
133

134
  /** The minimum threshold for the final result, otherwise it will be either `"unk"` or the value
135
    * set in `thresholdLabel` (Default: `0.1f`). Value is between 0.0 to 1.0. Try to set this
136
    * lower if your text is hard to predict
137
    *
138
    * @group param
139
    */
140
  val threshold = new FloatParam(
1✔
141
    this,
142
    "threshold",
1✔
143
    "The minimum threshold for the final result otherwise it will be either Unknown or the value set in thresholdLabel.")
1✔
144

145
  /** Value for the classification, if confidence is less than `threshold` (Default: `"unk"`).
146
    *
147
    * @group param
148
    */
149
  val thresholdLabel = new Param[String](
1✔
150
    this,
151
    "thresholdLabel",
1✔
152
    "In case the score is less than threshold, what should be the label. Default is Unknown.")
1✔
153

154
  /** Output average of sentences instead of one output per sentence (Default: `true`).
155
    *
156
    * @group param
157
    */
158
  val coalesceSentences = new BooleanParam(
1✔
159
    this,
160
    "coalesceSentences",
1✔
161
    "If sets to true the output of all sentences will be averaged to one output instead of one output per sentence. Default to true.")
1✔
162

163
  /** ConfigProto from tensorflow, serialized into byte array. Get with
164
    * config_proto.SerializeToString()
165
    *
166
    * @group param
167
    */
168
  val configProtoBytes = new IntArrayParam(
1✔
169
    this,
170
    "configProtoBytes",
1✔
171
    "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()")
1✔
172

173
  /** Languages the model was trained with.
174
    *
175
    * @group param
176
    */
177
  val languages =
178
    new StringArrayParam(this, "languages", "keep an internal copy of languages for Python")
1✔
179

180
  /** @group setParam */
181
  def setLanguage(value: Map[String, Int]): this.type = {
182
    set(this.language, value)
×
183
    this
184
  }
185

186
  /** @group setParam */
187
  def setAlphabet(value: Map[String, Int]): this.type = {
188
    set(alphabet, value)
×
189
    this
190
  }
191

192
  /** @group setParam */
193
  def setThreshold(threshold: Float): this.type = set(this.threshold, threshold)
×
194

195
  /** @group setParam */
196
  def setThresholdLabel(label: String): this.type = set(this.thresholdLabel, label)
×
197

198
  /** @group setParam */
199
  def setCoalesceSentences(value: Boolean): this.type = set(coalesceSentences, value)
×
200

201
  /** @group setParam */
202
  def setConfigProtoBytes(bytes: Array[Int]): LanguageDetectorDL.this.type =
203
    set(this.configProtoBytes, bytes)
×
204

205
  /** @group getParam */
206
  def getLanguage: Array[String] = {
207
    val langs = $$(language).keys.toArray
×
208
    set(languages, langs)
×
209
    langs
210
  }
211

212
  /** @group getParam */
213
  def getThreshold: Float = $(this.threshold)
×
214

215
  /** @group getParam */
216
  def getThresholdLabel: String = $(this.thresholdLabel)
×
217

218
  /** @group getParam */
219
  def getCoalesceSentences: Boolean = $(coalesceSentences)
×
220

221
  /** @group getParam */
222
  def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte))
×
223

224
  setDefault(
1✔
225
    inputCols -> Array("document"),
1✔
226
    outputCol -> "language",
1✔
227
    threshold -> 0.1f,
1✔
228
    thresholdLabel -> "unk",
1✔
229
    coalesceSentences -> true)
1✔
230

231
  private var _model: Option[Broadcast[TensorflowLD]] = None
1✔
232

233
  /** @group getParam */
234
  def getModelIfNotSet: TensorflowLD = _model.get.value
1✔
235

236
  /** @group setParam */
237
  def setModelIfNotSet(spark: SparkSession, tensorflow: TensorflowWrapper): this.type = {
238
    if (_model.isEmpty) {
×
239

240
      _model = Some(
1✔
241
        spark.sparkContext.broadcast(
1✔
242
          new TensorflowLD(
1✔
243
            tensorflow,
244
            configProtoBytes = getConfigProtoBytes,
1✔
245
            ListMap($$(language).toSeq.sortBy(_._2): _*),
1✔
246
            ListMap($$(alphabet).toSeq.sortBy(_._2): _*))))
1✔
247
    }
248

249
    this
250
  }
251

252
  /** Takes a document and annotations and produces new annotations of this annotator's annotation
253
    * type
254
    *
255
    * @param annotations
256
    *   Annotations that correspond to inputAnnotationCols generated by previous annotators if any
257
    * @return
258
    *   any number of annotations processed for every input annotation. Not necessary one to one
259
    *   relationship
260
    */
261
  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
262
    val sentences = SentenceSplit.unpack(annotations)
1✔
263
    val nonEmptySentences = sentences.filter(_.content.nonEmpty)
1✔
264
    if (nonEmptySentences.nonEmpty) {
1✔
265
      getModelIfNotSet.predict(
1✔
266
        nonEmptySentences,
267
        $(threshold),
1✔
268
        $(thresholdLabel),
1✔
269
        $(coalesceSentences))
1✔
270
    } else {
271
      Seq.empty[Annotation]
×
272
    }
273
  }
274

275
  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
276
    * type
277
    */
278
  override val inputAnnotatorTypes: Array[String] = Array(AnnotatorType.DOCUMENT)
1✔
279
  override val outputAnnotatorType: AnnotatorType = AnnotatorType.LANGUAGE
1✔
280

281
  override def onWrite(path: String, spark: SparkSession): Unit = {
282
    super.onWrite(path, spark)
1✔
283
    writeTensorflowModelV2(
1✔
284
      path,
285
      spark,
286
      getModelIfNotSet.tensorflow,
1✔
287
      "_ld",
1✔
288
      LanguageDetectorDL.tfFile,
1✔
289
      configProtoBytes = getConfigProtoBytes)
1✔
290
  }
291

292
}
293

294
trait ReadablePretrainedLanguageDetectorDLModel
295
    extends ParamsAndFeaturesReadable[LanguageDetectorDL]
296
    with HasPretrained[LanguageDetectorDL] {
297
  override val defaultModelName: Some[String] = Some("ld_wiki_tatoeba_cnn_21")
1✔
298
  override val defaultLang: String = "xx"
1✔
299

300
  /** Java compliant-overrides */
301
  override def pretrained(): LanguageDetectorDL = super.pretrained()
1✔
302

303
  override def pretrained(name: String): LanguageDetectorDL = super.pretrained(name)
×
304

305
  override def pretrained(name: String, lang: String): LanguageDetectorDL =
306
    super.pretrained(name, lang)
×
307

308
  override def pretrained(name: String, lang: String, remoteLoc: String): LanguageDetectorDL =
309
    super.pretrained(name, lang, remoteLoc)
1✔
310
}
311

312
trait ReadLanguageDetectorDLTensorflowModel extends ReadTensorflowModel {
313
  this: ParamsAndFeaturesReadable[LanguageDetectorDL] =>
314

315
  override val tfFile: String = "ld_tensorflow"
1✔
316

317
  def readModel(instance: LanguageDetectorDL, path: String, spark: SparkSession): Unit = {
318

319
    val tf = readTensorflowModel(path, spark, "_ld_tf")
1✔
320
    instance.setModelIfNotSet(spark, tf)
1✔
321
    // This allows for Python to access getLanguages function
322
    val t = instance.language.get.toArray
1✔
323
    val r = t(0).keys.toArray
1✔
324
    instance.set(instance.languages, r)
1✔
325
  }
326

327
  addReader(readModel)
1✔
328

329
  def loadSavedModel(modelPath: String, spark: SparkSession): LanguageDetectorDL = {
330

331
    val (localModelPath, detectedEngine) = modelSanityCheck(modelPath)
×
332

333
    val alphabets = loadTextAsset(localModelPath, "alphabet.txt").zipWithIndex.toMap
×
334
    val languages = loadTextAsset(localModelPath, "language.txt").zipWithIndex.toMap
×
335

336
    /*Universal parameters for all engines*/
337
    val annotatorModel = new LanguageDetectorDL()
338
      .setAlphabet(alphabets)
339
      .setLanguage(languages)
×
340

341
    annotatorModel.set(annotatorModel.engine, detectedEngine)
×
342

343
    detectedEngine match {
344
      case ModelEngine.tensorflow =>
345
        val (wrapper, _) =
346
          TensorflowWrapper.read(
×
347
            localModelPath,
348
            zipped = false,
×
349
            useBundle = true,
×
350
            tags = Array("serve"))
×
351

352
        /** the order of setSignatures is important if we use getSignatures inside
353
          * setModelIfNotSet
354
          */
355
        annotatorModel
356
          .setModelIfNotSet(spark, wrapper)
×
357

358
      case _ =>
359
        throw new Exception(notSupportedEngineError)
×
360
    }
361

362
    annotatorModel
363
  }
364
}
365

366
/** This is the companion object of [[LanguageDetectorDL]]. Please refer to that class for the
367
  * documentation.
368
  */
369
object LanguageDetectorDL
370
    extends ReadablePretrainedLanguageDetectorDLModel
371
    with ReadLanguageDetectorDLTensorflowModel
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc