• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 18652478786

20 Oct 2025 12:47PM UTC coverage: 55.25% (+0.2%) from 55.094%
18652478786

Pull #14674

github

web-flow
Merge b08968fc1 into b827818c7
Pull Request #14674: SPARKNLP-1293 Enhancements EntityRuler and DocumentNormalizer

114 of 149 new or added lines in 3 files covered. (76.51%)

40 existing lines in 36 files now uncovered.

11919 of 21573 relevant lines covered (55.25%)

0.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.03
/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLApproach.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.ner.dl
18

19
import com.johnsnowlabs.client.CloudResources
20
import com.johnsnowlabs.client.util.CloudHelper
21
import com.johnsnowlabs.ml.crf.TextSentenceLabels
22
import com.johnsnowlabs.ml.tensorflow._
23
import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, NAMED_ENTITY, TOKEN, WORD_EMBEDDINGS}
24
import com.johnsnowlabs.nlp.annotators.common.{NerTagged, WordpieceEmbeddingsSentence}
25
import com.johnsnowlabs.nlp.annotators.ner.{ModelMetrics, NerApproach, Verbose}
26
import com.johnsnowlabs.nlp.annotators.param.EvaluationDLParams
27
import com.johnsnowlabs.nlp.util.io.{OutputHelper, ResourceHelper}
28
import com.johnsnowlabs.nlp.{AnnotatorApproach, AnnotatorType, ParamsAndFeaturesWritable}
29
import com.johnsnowlabs.storage.HasStorageRef
30
import org.apache.commons.io.IOUtils
31
import org.apache.commons.lang3.SystemUtils
32
import org.apache.spark.ml.PipelineModel
33
import org.apache.spark.ml.param._
34
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
35
import org.apache.spark.sql.{Dataset, Row, SparkSession}
36
import org.tensorflow.Graph
37
import org.tensorflow.proto.framework.GraphDef
38

39
import java.io.File
40
import scala.collection.mutable
41
import scala.util.Random
42

43
/** This Named Entity recognition annotator allows to train generic NER model based on Neural
44
  * Networks.
45
  *
46
  * The architecture of the neural network is a Char CNNs - BiLSTM - CRF that achieves
47
  * state-of-the-art in most datasets.
48
  *
49
  * For instantiated/pretrained models, see [[NerDLModel]].
50
  *
51
  * The training data should be a labeled Spark Dataset, in the format of
52
  * [[com.johnsnowlabs.nlp.training.CoNLL CoNLL]] 2003 IOB with `Annotation` type columns. The
53
  * data should have columns of type `DOCUMENT, TOKEN, WORD_EMBEDDINGS` and an additional label
54
  * column of annotator type `NAMED_ENTITY`. Excluding the label, this can be done with for
55
  * example
56
  *   - a [[com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector SentenceDetector]],
57
  *   - a [[com.johnsnowlabs.nlp.annotators.Tokenizer Tokenizer]] and
58
  *   - a [[com.johnsnowlabs.nlp.embeddings.WordEmbeddingsModel WordEmbeddingsModel]] (any
59
  *     embeddings can be chosen, e.g.
60
  *     [[com.johnsnowlabs.nlp.embeddings.BertEmbeddings BertEmbeddings]] for BERT based
61
  *     embeddings).
62
  *
63
  * Setting a test dataset to monitor model metrics can be done with `.setTestDataset`. The method
64
  * expects a path to a parquet file containing a dataframe that has the same required columns as
65
  * the training dataframe. The pre-processing steps for the training dataframe should also be
66
  * applied to the test dataframe. The following example will show how to create the test dataset
67
  * with a CoNLL dataset:
68
  *
69
  * {{{
70
  * val documentAssembler = new DocumentAssembler()
71
  *   .setInputCol("text")
72
  *   .setOutputCol("document")
73
  *
74
  * val embeddings = WordEmbeddingsModel
75
  *   .pretrained()
76
  *   .setInputCols("document", "token")
77
  *   .setOutputCol("embeddings")
78
  *
79
  * val preProcessingPipeline = new Pipeline().setStages(Array(documentAssembler, embeddings))
80
  *
81
  * val conll = CoNLL()
82
  * val Array(train, test) = conll
83
  *   .readDataset(spark, "src/test/resources/conll2003/eng.train")
84
  *   .randomSplit(Array(0.8, 0.2))
85
  *
86
  * preProcessingPipeline
87
  *   .fit(test)
88
  *   .transform(test)
89
  *   .write
90
  *   .mode("overwrite")
91
  *   .parquet("test_data")
92
  *
93
  * val nerTagger = new NerDLApproach()
94
  *   .setInputCols("document", "token", "embeddings")
95
  *   .setLabelColumn("label")
96
  *   .setOutputCol("ner")
97
  *   .setTestDataset("test_data")
98
  * }}}
99
  *
100
  * For extended examples of usage, see the
101
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/english/dl-ner Examples]]
102
  * and the
103
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLSpec.scala NerDLSpec]].
104
  *
105
  * ==Example==
106
  * {{{
107
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
108
  * import com.johnsnowlabs.nlp.annotators.Tokenizer
109
  * import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
110
  * import com.johnsnowlabs.nlp.embeddings.BertEmbeddings
111
  * import com.johnsnowlabs.nlp.annotators.ner.dl.NerDLApproach
112
  * import com.johnsnowlabs.nlp.training.CoNLL
113
  * import org.apache.spark.ml.Pipeline
114
  *
115
  * // This CoNLL dataset already includes a sentence, token and label
116
  * // column with their respective annotator types. If a custom dataset is used,
117
  * // these need to be defined with for example:
118
  *
119
  * val documentAssembler = new DocumentAssembler()
120
  *   .setInputCol("text")
121
  *   .setOutputCol("document")
122
  *
123
  * val sentence = new SentenceDetector()
124
  *   .setInputCols("document")
125
  *   .setOutputCol("sentence")
126
  *
127
  * val tokenizer = new Tokenizer()
128
  *   .setInputCols("sentence")
129
  *   .setOutputCol("token")
130
  *
131
  * // Then the training can start
132
  * val embeddings = BertEmbeddings.pretrained()
133
  *   .setInputCols("sentence", "token")
134
  *   .setOutputCol("embeddings")
135
  *
136
  * val nerTagger = new NerDLApproach()
137
  *   .setInputCols("sentence", "token", "embeddings")
138
  *   .setLabelColumn("label")
139
  *   .setOutputCol("ner")
140
  *   .setMaxEpochs(1)
141
  *   .setRandomSeed(0)
142
  *   .setVerbose(0)
143
  *
144
  * val pipeline = new Pipeline().setStages(Array(
145
  *   embeddings,
146
  *   nerTagger
147
  * ))
148
  *
149
  * // We use the sentences, tokens and labels from the CoNLL dataset
150
  * val conll = CoNLL()
151
  * val trainingData = conll.readDataset(spark, "src/test/resources/conll2003/eng.train")
152
  *
153
  * val pipelineModel = pipeline.fit(trainingData)
154
  * }}}
155
  *
156
  * @see
157
  *   [[com.johnsnowlabs.nlp.annotators.ner.crf.NerCrfApproach NerCrfApproach]] for a generic CRF
158
  *   approach
159
  * @see
160
  *   [[com.johnsnowlabs.nlp.annotators.ner.NerConverter NerConverter]] to further process the
161
  *   results
162
  * @param uid
163
  *   required uid for storing annotator to disk
164
  * @groupname anno Annotator types
165
  * @groupdesc anno
166
  *   Required input and expected output annotator types
167
  * @groupname Ungrouped Members
168
  * @groupname param Parameters
169
  * @groupname setParam Parameter setters
170
  * @groupname getParam Parameter getters
171
  * @groupname Ungrouped Members
172
  * @groupprio param  1
173
  * @groupprio anno  2
174
  * @groupprio Ungrouped 3
175
  * @groupprio setParam  4
176
  * @groupprio getParam  5
177
  * @groupdesc param
178
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
179
  *   parameter values through setters and getters, respectively.
180
  */
181
class NerDLApproach(override val uid: String)
182
    extends AnnotatorApproach[NerDLModel]
183
    with NerApproach[NerDLApproach]
184
    with Logging
185
    with ParamsAndFeaturesWritable
186
    with EvaluationDLParams {
187

188
  def this() = this(Identifiable.randomUID("NerDL"))
1✔
189

190
  override def getLogName: String = "NerDL"
1✔
191

192
  /** Trains Tensorflow based Char-CNN-BLSTM model */
193
  override val description = "Trains Tensorflow based Char-CNN-BLSTM model"
1✔
194

195
  /** Input annotator types: DOCUMENT, TOKEN, WORD_EMBEDDINGS
196
    *
197
    * @group anno
198
    */
199
  override val inputAnnotatorTypes: Array[String] = Array(DOCUMENT, TOKEN, WORD_EMBEDDINGS)
1✔
200

201
  /** Output annotator types: NAMED_ENTITY
202
    *
203
    * @group anno
204
    */
205
  override val outputAnnotatorType: String = NAMED_ENTITY
1✔
206

207
  /** Learning Rate (Default: `1e-3f`)
208
    *
209
    * @group param
210
    */
211
  val lr = new FloatParam(this, "lr", "Learning Rate")
1✔
212

213
  /** Learning rate decay coefficient (Default: `0.005f`). Real Learning Rate calculates to `lr /
214
    * (1 + po * epoch)`
215
    *
216
    * @group param
217
    */
218
  val po = new FloatParam(
1✔
219
    this,
220
    "po",
1✔
221
    "Learning rate decay coefficient. Real Learning Rage = lr / (1 + po * epoch)")
1✔
222

223
  /** Batch size (Default: `8`)
224
    *
225
    * @group param
226
    */
227
  val batchSize = new IntParam(this, "batchSize", "Batch size")
1✔
228

229
  /** Dropout coefficient (Default: `0.5f`)
230
    *
231
    * @group param
232
    */
233
  val dropout = new FloatParam(this, "dropout", "Dropout coefficient")
1✔
234

235
  /** Folder path that contain external graph files
236
    *
237
    * @group param
238
    */
239
  val graphFolder =
240
    new Param[String](this, "graphFolder", "Folder path that contain external graph files")
1✔
241

242
  /** ConfigProto from tensorflow, serialized into byte array. Get with
243
    * config_proto.SerializeToString()
244
    *
245
    * @group param
246
    */
247
  val configProtoBytes = new IntArrayParam(
1✔
248
    this,
249
    "configProtoBytes",
1✔
250
    "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()")
1✔
251

252
  /** Whether to use contrib LSTM Cells (Default: `true`). Not compatible with Windows. Might
253
    * slightly improve accuracy. This param is deprecated and only exists for backward
254
    * compatibility
255
    *
256
    * @group param
257
    */
258
  val useContrib =
259
    new BooleanParam(this, "useContrib", "deprecated param - the value won't have any effect")
1✔
260

261
  /** Whether to include confidence scores in annotation metadata (Default: `false`)
262
    *
263
    * @group param
264
    */
265
  val includeConfidence = new BooleanParam(
1✔
266
    this,
267
    "includeConfidence",
1✔
268
    "Whether to include confidence scores in annotation metadata")
1✔
269

270
  /** whether to include all confidence scores in annotation metadata or just score of the
271
    * predicted tag
272
    *
273
    * @group param
274
    */
275
  val includeAllConfidenceScores = new BooleanParam(
1✔
276
    this,
277
    "includeAllConfidenceScores",
1✔
278
    "whether to include all confidence scores in annotation metadata")
1✔
279

280
  /** Whether to optimize for large datasets or not (Default: `false`). Enabling this option can
281
    * slow down training.
282
    *
283
    * @group param
284
    */
285
  val enableMemoryOptimizer = new BooleanParam(
1✔
286
    this,
287
    "enableMemoryOptimizer",
1✔
288
    "Whether to optimize for large datasets or not. Enabling this option can slow down training.")
1✔
289

290
  /** Whether to restore and use the model that has achieved the best performance at the end of
291
    * the training. The metric that is being monitored is F1 for testDataset and if it's not set
292
    * it will be validationSplit, and if it's not set finally looks for loss.
293
    *
294
    * @group param
295
    */
296
  val useBestModel = new BooleanParam(
1✔
297
    this,
298
    "useBestModel",
1✔
299
    "Whether to restore and use the model that has achieved the best performance at the end of the training.")
1✔
300

301
  /** Whether to check F1 Micro-average or F1 Macro-average as a final metric for the best model
302
    * This will fall back to loss if there is no validation or test dataset
303
    *
304
    * @group param
305
    */
306
  val bestModelMetric = new Param[String](
1✔
307
    this,
308
    "bestModelMetric",
1✔
309
    "Whether to check F1 Micro-average or F1 Macro-average as a final metric for the best model.")
1✔
310

311
  /** Learning Rate
312
    *
313
    * @group getParam
314
    */
315
  def getLr: Float = $(this.lr)
×
316

317
  /** Learning rate decay coefficient. Real Learning Rage = lr / (1 + po * epoch)
318
    *
319
    * @group getParam
320
    */
321
  def getPo: Float = $(this.po)
×
322

323
  /** Batch size
324
    *
325
    * @group getParam
326
    */
327
  def getBatchSize: Int = $(this.batchSize)
×
328

329
  /** Dropout coefficient
330
    *
331
    * @group getParam
332
    */
333
  def getDropout: Float = $(this.dropout)
×
334

335
  /** ConfigProto from tensorflow, serialized into byte array. Get with
336
    * config_proto.SerializeToString()
337
    *
338
    * @group getParam
339
    */
340
  def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte))
1✔
341

342
  /** Whether to use contrib LSTM Cells. Not compatible with Windows. Might slightly improve
343
    * accuracy.
344
    *
345
    * @group getParam
346
    */
347
  def getUseContrib: Boolean = $(this.useContrib)
×
348

349
  /** Memory Optimizer
350
    *
351
    * @group getParam
352
    */
353
  def getEnableMemoryOptimizer: Boolean = $(this.enableMemoryOptimizer)
×
354

355
  /** useBestModel
356
    *
357
    * @group getParam
358
    */
359
  def getUseBestModel: Boolean = $(this.useBestModel)
×
360

361
  /** @group getParam */
362
  def getBestModelMetric: String = $(bestModelMetric)
×
363

364
  /** Learning Rate
365
    *
366
    * @group setParam
367
    */
368
  def setLr(lr: Float): NerDLApproach.this.type = set(this.lr, lr)
1✔
369

370
  /** Learning rate decay coefficient. Real Learning Rage = lr / (1 + po * epoch)
371
    *
372
    * @group setParam
373
    */
374
  def setPo(po: Float): NerDLApproach.this.type = set(this.po, po)
1✔
375

376
  /** Batch size
377
    *
378
    * @group setParam
379
    */
380
  def setBatchSize(batch: Int): NerDLApproach.this.type = set(this.batchSize, batch)
1✔
381

382
  /** Dropout coefficient
383
    *
384
    * @group setParam
385
    */
386
  def setDropout(dropout: Float): NerDLApproach.this.type = set(this.dropout, dropout)
1✔
387

388
  /** Folder path that contain external graph files
389
    *
390
    * @group setParam
391
    */
392
  def setGraphFolder(path: String): NerDLApproach.this.type = set(this.graphFolder, path)
1✔
393

394
  /** ConfigProto from tensorflow, serialized into byte array. Get with
395
    * config_proto.SerializeToString()
396
    *
397
    * @group setParam
398
    */
399
  def setConfigProtoBytes(bytes: Array[Int]): NerDLApproach.this.type =
400
    set(this.configProtoBytes, bytes)
×
401

402
  /** Whether to use contrib LSTM Cells. Not compatible with Windows. Might slightly improve
403
    * accuracy.
404
    *
405
    * @group setParam
406
    */
407
  def setUseContrib(value: Boolean): NerDLApproach.this.type =
408
    if (value && SystemUtils.IS_OS_WINDOWS)
×
409
      throw new UnsupportedOperationException("Cannot set contrib in Windows")
×
410
    else set(useContrib, value)
×
411

412
  /** Whether to optimize for large datasets or not. Enabling this option can slow down training.
413
    *
414
    * @group setParam
415
    */
416
  def setEnableMemoryOptimizer(value: Boolean): NerDLApproach.this.type =
417
    set(this.enableMemoryOptimizer, value)
×
418

419
  /** Whether to include confidence scores in annotation metadata
420
    *
421
    * @group setParam
422
    */
423
  def setIncludeConfidence(value: Boolean): NerDLApproach.this.type =
424
    set(this.includeConfidence, value)
×
425

426
  /** whether to include confidence scores for all tags rather than just for the predicted one
427
    *
428
    * @group setParam
429
    */
430
  def setIncludeAllConfidenceScores(value: Boolean): this.type =
431
    set(this.includeAllConfidenceScores, value)
×
432

433
  /** @group setParam */
434
  def setUseBestModel(value: Boolean): NerDLApproach.this.type = set(this.useBestModel, value)
1✔
435

436
  /** @group setParam */
437
  def setBestModelMetric(value: String): NerDLApproach.this.type = {
438
    require(
×
439
      ModelMetrics.values.contains(value),
×
440
      s"Invalid metric: $value. Allowed metrics are: ${ModelMetrics.values.mkString(", ")}")
×
441

442
    set(this.bestModelMetric, value)
×
443

444
  }
445

446
  setDefault(
1✔
447
    minEpochs -> 0,
1✔
448
    maxEpochs -> 70,
1✔
449
    lr -> 1e-3f,
1✔
450
    po -> 0.005f,
1✔
451
    batchSize -> 8,
1✔
452
    dropout -> 0.5f,
1✔
453
    useContrib -> true,
1✔
454
    includeConfidence -> false,
1✔
455
    includeAllConfidenceScores -> false,
1✔
456
    enableMemoryOptimizer -> false,
1✔
457
    useBestModel -> false,
1✔
458
    bestModelMetric -> ModelMetrics.loss)
1✔
459

460
  override val verboseLevel: Verbose.Level = Verbose($(verbose))
1✔
461

462
  def calculateEmbeddingsDim(sentences: Seq[WordpieceEmbeddingsSentence]): Int = {
463
    sentences
464
      .find(s => s.tokens.nonEmpty)
465
      .map(s => s.tokens.head.embeddings.length)
466
      .getOrElse(1)
×
467
  }
468

469
  override def beforeTraining(spark: SparkSession): Unit = {
470
    LoadsContrib.loadContribToCluster(spark)
1✔
471
    LoadsContrib.loadContribToTensorflow()
1✔
472
  }
473

474
  override def train(
475
      dataset: Dataset[_],
476
      recursivePipeline: Option[PipelineModel]): NerDLModel = {
477

478
    require(
1✔
479
      $(validationSplit) <= 1f | $(validationSplit) >= 0f,
1✔
480
      "The validationSplit must be between 0f and 1f")
×
481

482
    val train = dataset.toDF()
1✔
483

484
    val test = if (!isDefined(testDataset)) {
1✔
485
      train.limit(0) // keep the schema only
1✔
486
    } else {
487
      ResourceHelper.readSparkDataFrame($(testDataset))
×
488
    }
489

490
    val embeddingsRef =
491
      HasStorageRef.getStorageRefFromInput(dataset, $(inputCols), AnnotatorType.WORD_EMBEDDINGS)
1✔
492

493
    val Array(validSplit, trainSplit) =
1✔
494
      train.randomSplit(Array($(validationSplit), 1.0f - $(validationSplit)))
495

496
    val trainIteratorFunc = NerDLApproach.getIteratorFunc(
1✔
497
      trainSplit,
498
      inputColumns = getInputCols,
1✔
499
      labelColumn = $(labelColumn),
1✔
500
      batchSize = $(batchSize),
1✔
501
      enableMemoryOptimizer = $(enableMemoryOptimizer))
1✔
502

503
    val validIteratorFunc = NerDLApproach.getIteratorFunc(
1✔
504
      validSplit,
505
      inputColumns = getInputCols,
1✔
506
      labelColumn = $(labelColumn),
1✔
507
      batchSize = $(batchSize),
1✔
508
      enableMemoryOptimizer = $(enableMemoryOptimizer))
1✔
509

510
    val testIteratorFunc = NerDLApproach.getIteratorFunc(
1✔
511
      test,
512
      inputColumns = getInputCols,
1✔
513
      labelColumn = $(labelColumn),
1✔
514
      batchSize = $(batchSize),
1✔
515
      enableMemoryOptimizer = $(enableMemoryOptimizer))
1✔
516

517
    val (labels, chars, embeddingsDim, dsLen) =
1✔
518
      NerDLApproach.getDataSetParams(trainIteratorFunc())
519

520
    val settings = DatasetEncoderParams(
1✔
521
      labels.toList,
1✔
522
      chars.toList,
1✔
523
      Array.fill(embeddingsDim)(0f).toList,
1✔
524
      embeddingsDim)
525
    val encoder = new NerDatasetEncoder(settings)
1✔
526

527
    val graphFile = NerDLApproach.searchForSuitableGraph(
1✔
528
      labels.size,
1✔
529
      embeddingsDim,
530
      chars.size + 1,
1✔
531
      get(graphFolder))
1✔
532

533
    val graph = new Graph()
1✔
534
    val graphStream = ResourceHelper.getResourceStream(graphFile)
1✔
535
    val graphBytesDef = IOUtils.toByteArray(graphStream)
1✔
536
    graph.importGraphDef(GraphDef.parseFrom(graphBytesDef))
1✔
537

538
    val tfWrapper = new TensorflowWrapper(
1✔
539
      Variables(Array.empty[Array[Byte]], Array.empty[Byte]),
1✔
540
      graph.toGraphDef.toByteArray)
1✔
541

542
    val (ner, trainedTf) =
1✔
543
      try {
544
        val model = new TensorflowNer(tfWrapper, encoder, Verbose($(verbose)))
545
        if (isDefined(randomSeed)) {
546
          Random.setSeed($(randomSeed))
547
        }
548

549
        // start the iterator here once again
550
        val trainedTf = model.train(
551
          trainIteratorFunc(),
552
          dsLen,
553
          validIteratorFunc(),
554
          (dsLen * $(validationSplit)).toLong,
555
          $(lr),
556
          $(po),
557
          $(dropout),
558
          $(batchSize),
559
          $(useBestModel),
560
          $(bestModelMetric),
561
          graphFileName = graphFile,
562
          test = testIteratorFunc(),
563
          startEpoch = 0,
564
          endEpoch = $(maxEpochs),
565
          configProtoBytes = getConfigProtoBytes,
566
          validationSplit = $(validationSplit),
567
          evaluationLogExtended = $(evaluationLogExtended),
568
          enableOutputLogs = $(enableOutputLogs),
569
          outputLogsPath = $(outputLogsPath),
570
          uuid = this.uid)
571
        (model, trainedTf)
572
      } catch {
573
        case e: Exception =>
574
          graph.close()
575
          throw e
576
      }
577

578
    val newWrapper =
579
      new TensorflowWrapper(
1✔
580
        TensorflowWrapper.extractVariablesSavedModel(trainedTf),
1✔
581
        tfWrapper.graph)
1✔
582

583
    val model = new NerDLModel()
584
      .setDatasetParams(ner.encoder.params)
585
      .setModelIfNotSet(dataset.sparkSession, newWrapper)
586
      .setIncludeConfidence($(includeConfidence))
587
      .setIncludeAllConfidenceScores($(includeAllConfidenceScores))
588
      .setStorageRef(embeddingsRef)
1✔
589

590
    if (get(configProtoBytes).isDefined)
1✔
591
      model.setConfigProtoBytes($(configProtoBytes))
×
592

593
    model
594

595
  }
596
}
597

598
trait WithGraphResolver {
599

600
  def searchForSuitableGraph(
601
      tags: Int,
602
      embeddingsNDims: Int,
603
      nChars: Int,
604
      localGraphPath: Option[String] = None): String = {
605

606
    val files: Seq[String] = getFiles(localGraphPath)
1✔
607

608
    // 1. Filter Graphs by embeddings
609
    val embeddingsFiltered = files.map { filePath =>
1✔
610
      val file = new File(filePath)
1✔
611
      val name = file.getName
1✔
612
      val graphPrefix = "blstm_"
1✔
613

614
      if (name.startsWith(graphPrefix)) {
1✔
615
        val clean = name.replace(graphPrefix, "").replace(".pb", "")
1✔
616
        val graphParams = clean.split("_").take(4).map(s => s.toInt)
1✔
617
        val Array(fileTags, fileEmbeddingsNDims, _, fileNChars) = graphParams
1✔
618

619
        if (embeddingsNDims == fileEmbeddingsNDims)
1✔
620
          Some((fileTags, fileEmbeddingsNDims, fileNChars))
1✔
621
        else
622
          None
1✔
623
      } else {
624
        None
1✔
625
      }
626
    }
627

628
    require(
1✔
629
      embeddingsFiltered.exists(_.nonEmpty),
1✔
630
      s"Graph dimensions should be $embeddingsNDims: Could not find a suitable tensorflow graph for embeddings dim: $embeddingsNDims tags: $tags nChars: $nChars. " +
1✔
631
        s"Check https://sparknlp.org/docs/en/graph for instructions to generate the required graph.")
1✔
632

633
    // 2. Filter by labels and nChars
634
    val tagsFiltered = embeddingsFiltered.map {
1✔
635
      case Some((fileTags, fileEmbeddingsNDims, fileNChars)) =>
636
        if (tags > fileTags)
1✔
637
          None
1✔
638
        else
639
          Some((fileTags, fileEmbeddingsNDims, fileNChars))
1✔
640
      case _ => None
1✔
641
    }
642

643
    require(
1✔
644
      tagsFiltered.exists(_.nonEmpty),
1✔
645
      s"Graph tags size should be $tags: Could not find a suitable tensorflow graph for embeddings dim: $embeddingsNDims tags: $tags nChars: $nChars. " +
1✔
646
        s"Check https://sparknlp.org/docs/en/graph for instructions to generate the required graph.")
1✔
647

648
    // 3. Filter by labels and nChars
649
    val charsFiltered = tagsFiltered.map {
1✔
650
      case Some((fileTags, fileEmbeddingsNDims, fileNChars)) =>
651
        if (nChars > fileNChars)
1✔
652
          None
×
653
        else
654
          Some((fileTags, fileEmbeddingsNDims, fileNChars))
1✔
655
      case _ => None
1✔
656
    }
657

658
    require(
1✔
659
      charsFiltered.exists(_.nonEmpty),
1✔
660
      s"Graph chars size should be $nChars: Could not find a suitable tensorflow graph for embeddings dim: $embeddingsNDims tags: $tags nChars: $nChars. " +
×
661
        s"Check https://sparknlp.org/docs/en/graph for instructions to generate the required graph")
×
662

663
    for (i <- files.indices) {
1✔
664
      if (charsFiltered(i).nonEmpty)
1✔
665
        return files(i)
1✔
666
    }
667

668
    throw new IllegalStateException("Code shouldn't pass here")
×
669
  }
670

671
  private def getFiles(localGraphPath: Option[String]): Seq[String] = {
672
    var files: Seq[String] = List()
1✔
673

UNCOV
674
    if (localGraphPath.isDefined && CloudHelper.isCloudPath(localGraphPath.get)) {
×
675
      val tmpDirectory = CloudResources.downloadBucketToLocalTmp(localGraphPath.get).getPath
×
676
      files = ResourceHelper.listLocalFiles(tmpDirectory).map(_.getAbsolutePath)
×
677
    } else {
678

679
      if (localGraphPath.isDefined && OutputHelper
1✔
680
          .getFileSystem(localGraphPath.get)
681
          .getScheme == "dbfs") {
1✔
682
        files =
×
683
          ResourceHelper.listLocalFiles(localGraphPath.get).map(file => file.getAbsolutePath)
×
684
      } else {
685
        files = localGraphPath
1✔
686
          .map(path =>
687
            ResourceHelper
688
              .listLocalFiles(ResourceHelper.copyToLocal(path))
1✔
689
              .map(_.getAbsolutePath))
1✔
690
          .getOrElse(ResourceHelper.listResourceDirectory("/ner-dl"))
1✔
691
      }
692

693
    }
694
    files
695
  }
696

697
}
698

699
/** This is the companion object of [[NerDLApproach]]. Please refer to that class for the
700
  * documentation.
701
  */
702
object NerDLApproach extends DefaultParamsReadable[NerDLApproach] with WithGraphResolver {
703

704
  def getIteratorFunc(
705
      dataset: Dataset[Row],
706
      inputColumns: Array[String],
707
      labelColumn: String,
708
      batchSize: Int,
709
      enableMemoryOptimizer: Boolean)
710
      : () => Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = {
711

712
    if (enableMemoryOptimizer) { () =>
×
713
      NerTagged.iterateOnDataframe(dataset, inputColumns, labelColumn, batchSize)
×
714

715
    } else {
1✔
716
      val inMemory = dataset
717
        .select(labelColumn, inputColumns.toSeq: _*)
718
        .collect()
1✔
719

720
      () => NerTagged.iterateOnArray(inMemory, inputColumns, batchSize)
1✔
721
    }
722
  }
723

724
  def getDataSetParams(dsIt: Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]])
725
      : (mutable.Set[String], mutable.Set[Char], Int, Long) = {
726

727
    val labels = scala.collection.mutable.Set[String]()
1✔
728
    val chars = scala.collection.mutable.Set[Char]()
1✔
729
    var embeddingsDim = 1
1✔
730
    var dsLen = 0L
1✔
731

732
    // try to be frugal with memory and with number of passes thru the iterator
733
    for (batch <- dsIt) {
1✔
734
      dsLen += batch.length
1✔
735
      for (datapoint <- batch) {
1✔
736

737
        for (label <- datapoint._1.labels)
1✔
738
          labels += label
1✔
739

740
        for (token <- datapoint._2.tokens; char <- token.token.toCharArray)
1✔
741
          chars += char
1✔
742

743
        if (datapoint._2.tokens.nonEmpty)
1✔
744
          embeddingsDim = datapoint._2.tokens.head.embeddings.length
1✔
745
      }
746
    }
747

748
    (labels, chars, embeddingsDim, dsLen)
1✔
749
  }
750

751
  def getGraphParams(
752
      dataset: Dataset[_],
753
      inputColumns: java.util.ArrayList[java.lang.String],
754
      labelColumn: String): (Int, Int, Int) = {
755

756
    val trainIteratorFunc =
757
      getIteratorFunc(dataset.toDF(), inputColumns.toArray.map(_.toString), labelColumn, 0, false)
×
758

759
    val (labels, chars, embeddingsDim, _) = getDataSetParams(trainIteratorFunc())
×
760

761
    (labels.size, embeddingsDim, chars.size + 1)
×
762
  }
763
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc