• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4951808959

pending completion
4951808959

Pull #13792

github

GitHub
Merge efe6b42df into ef7906c5e
Pull Request #13792: SPARKNLP-825 Adding multilabel param

7 of 7 new or added lines in 1 file covered. (100.0%)

8637 of 13128 relevant lines covered (65.79%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.51
/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLApproach.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.ner.dl
18

19
import com.johnsnowlabs.ml.crf.TextSentenceLabels
20
import com.johnsnowlabs.ml.tensorflow._
21
import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, NAMED_ENTITY, TOKEN, WORD_EMBEDDINGS}
22
import com.johnsnowlabs.nlp.annotators.common.{NerTagged, WordpieceEmbeddingsSentence}
23
import com.johnsnowlabs.nlp.annotators.ner.{ModelMetrics, NerApproach, Verbose}
24
import com.johnsnowlabs.nlp.annotators.param.EvaluationDLParams
25
import com.johnsnowlabs.nlp.pretrained.ResourceDownloader
26
import com.johnsnowlabs.nlp.util.io.{OutputHelper, ResourceHelper}
27
import com.johnsnowlabs.nlp.{AnnotatorApproach, AnnotatorType, ParamsAndFeaturesWritable}
28
import com.johnsnowlabs.storage.HasStorageRef
29
import org.apache.commons.io.IOUtils
30
import org.apache.commons.lang3.SystemUtils
31
import org.apache.spark.ml.PipelineModel
32
import org.apache.spark.ml.param._
33
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
34
import org.apache.spark.sql.{Dataset, Row, SparkSession}
35
import org.tensorflow.Graph
36
import org.tensorflow.proto.framework.GraphDef
37

38
import java.io.File
39
import scala.collection.mutable
40
import scala.util.Random
41

42
/** This Named Entity recognition annotator allows to train generic NER model based on Neural
43
  * Networks.
44
  *
45
  * The architecture of the neural network is a Char CNNs - BiLSTM - CRF that achieves
46
  * state-of-the-art in most datasets.
47
  *
48
  * For instantiated/pretrained models, see [[NerDLModel]].
49
  *
50
  * The training data should be a labeled Spark Dataset, in the format of
51
  * [[com.johnsnowlabs.nlp.training.CoNLL CoNLL]] 2003 IOB with `Annotation` type columns. The
52
  * data should have columns of type `DOCUMENT, TOKEN, WORD_EMBEDDINGS` and an additional label
53
  * column of annotator type `NAMED_ENTITY`. Excluding the label, this can be done with for
54
  * example
55
  *   - a [[com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector SentenceDetector]],
56
  *   - a [[com.johnsnowlabs.nlp.annotators.Tokenizer Tokenizer]] and
57
  *   - a [[com.johnsnowlabs.nlp.embeddings.WordEmbeddingsModel WordEmbeddingsModel]] (any
58
  *     embeddings can be chosen, e.g.
59
  *     [[com.johnsnowlabs.nlp.embeddings.BertEmbeddings BertEmbeddings]] for BERT based
60
  *     embeddings).
61
  *
62
  * Setting a test dataset to monitor model metrics can be done with `.setTestDataset`. The method
63
  * expects a path to a parquet file containing a dataframe that has the same required columns as
64
  * the training dataframe. The pre-processing steps for the training dataframe should also be
65
  * applied to the test dataframe. The following example will show how to create the test dataset
66
  * with a CoNLL dataset:
67
  *
68
  * {{{
69
  * val documentAssembler = new DocumentAssembler()
70
  *   .setInputCol("text")
71
  *   .setOutputCol("document")
72
  *
73
  * val embeddings = WordEmbeddingsModel
74
  *   .pretrained()
75
  *   .setInputCols("document", "token")
76
  *   .setOutputCol("embeddings")
77
  *
78
  * val preProcessingPipeline = new Pipeline().setStages(Array(documentAssembler, embeddings))
79
  *
80
  * val conll = CoNLL()
81
  * val Array(train, test) = conll
82
  *   .readDataset(spark, "src/test/resources/conll2003/eng.train")
83
  *   .randomSplit(Array(0.8, 0.2))
84
  *
85
  * preProcessingPipeline
86
  *   .fit(test)
87
  *   .transform(test)
88
  *   .write
89
  *   .mode("overwrite")
90
  *   .parquet("test_data")
91
  *
92
  * val nerTagger = new NerDLApproach()
93
  *   .setInputCols("document", "token", "embeddings")
94
  *   .setLabelColumn("label")
95
  *   .setOutputCol("ner")
96
  *   .setTestDataset("test_data")
97
  * }}}
98
  *
99
  * For extended examples of usage, see the
100
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/english/dl-ner Examples]]
101
  * and the
102
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLSpec.scala NerDLSpec]].
103
  *
104
  * ==Example==
105
  * {{{
106
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
107
  * import com.johnsnowlabs.nlp.annotators.Tokenizer
108
  * import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
109
  * import com.johnsnowlabs.nlp.embeddings.BertEmbeddings
110
  * import com.johnsnowlabs.nlp.annotators.ner.dl.NerDLApproach
111
  * import com.johnsnowlabs.nlp.training.CoNLL
112
  * import org.apache.spark.ml.Pipeline
113
  *
114
  * // This CoNLL dataset already includes a sentence, token and label
115
  * // column with their respective annotator types. If a custom dataset is used,
116
  * // these need to be defined with for example:
117
  *
118
  * val documentAssembler = new DocumentAssembler()
119
  *   .setInputCol("text")
120
  *   .setOutputCol("document")
121
  *
122
  * val sentence = new SentenceDetector()
123
  *   .setInputCols("document")
124
  *   .setOutputCol("sentence")
125
  *
126
  * val tokenizer = new Tokenizer()
127
  *   .setInputCols("sentence")
128
  *   .setOutputCol("token")
129
  *
130
  * // Then the training can start
131
  * val embeddings = BertEmbeddings.pretrained()
132
  *   .setInputCols("sentence", "token")
133
  *   .setOutputCol("embeddings")
134
  *
135
  * val nerTagger = new NerDLApproach()
136
  *   .setInputCols("sentence", "token", "embeddings")
137
  *   .setLabelColumn("label")
138
  *   .setOutputCol("ner")
139
  *   .setMaxEpochs(1)
140
  *   .setRandomSeed(0)
141
  *   .setVerbose(0)
142
  *
143
  * val pipeline = new Pipeline().setStages(Array(
144
  *   embeddings,
145
  *   nerTagger
146
  * ))
147
  *
148
  * // We use the sentences, tokens and labels from the CoNLL dataset
149
  * val conll = CoNLL()
150
  * val trainingData = conll.readDataset(spark, "src/test/resources/conll2003/eng.train")
151
  *
152
  * val pipelineModel = pipeline.fit(trainingData)
153
  * }}}
154
  *
155
  * @see
156
  *   [[com.johnsnowlabs.nlp.annotators.ner.crf.NerCrfApproach NerCrfApproach]] for a generic CRF
157
  *   approach
158
  * @see
159
  *   [[com.johnsnowlabs.nlp.annotators.ner.NerConverter NerConverter]] to further process the
160
  *   results
161
  * @param uid
162
  *   required uid for storing annotator to disk
163
  * @groupname anno Annotator types
164
  * @groupdesc anno
165
  *   Required input and expected output annotator types
166
  * @groupname Ungrouped Members
167
  * @groupname param Parameters
168
  * @groupname setParam Parameter setters
169
  * @groupname getParam Parameter getters
170
  * @groupname Ungrouped Members
171
  * @groupprio param  1
172
  * @groupprio anno  2
173
  * @groupprio Ungrouped 3
174
  * @groupprio setParam  4
175
  * @groupprio getParam  5
176
  * @groupdesc param
177
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
178
  *   parameter values through setters and getters, respectively.
179
  */
180
class NerDLApproach(override val uid: String)
181
    extends AnnotatorApproach[NerDLModel]
182
    with NerApproach[NerDLApproach]
183
    with Logging
184
    with ParamsAndFeaturesWritable
185
    with EvaluationDLParams {
186

187
  def this() = this(Identifiable.randomUID("NerDL"))
1✔
188

189
  override def getLogName: String = "NerDL"
1✔
190

191
  /** Trains Tensorflow based Char-CNN-BLSTM model */
192
  override val description = "Trains Tensorflow based Char-CNN-BLSTM model"
1✔
193

194
  /** Input annotator types: DOCUMENT, TOKEN, WORD_EMBEDDINGS
195
    *
196
    * @group anno
197
    */
198
  override val inputAnnotatorTypes: Array[String] = Array(DOCUMENT, TOKEN, WORD_EMBEDDINGS)
1✔
199

200
  /** Output annotator types: NAMED_ENTITY
201
    *
202
    * @group anno
203
    */
204
  override val outputAnnotatorType: String = NAMED_ENTITY
1✔
205

206
  /** Learning Rate (Default: `1e-3f`)
207
    *
208
    * @group param
209
    */
210
  val lr = new FloatParam(this, "lr", "Learning Rate")
1✔
211

212
  /** Learning rate decay coefficient (Default: `0.005f`). Real Learning Rate calculates to `lr /
213
    * (1 + po * epoch)`
214
    *
215
    * @group param
216
    */
217
  val po = new FloatParam(
1✔
218
    this,
219
    "po",
1✔
220
    "Learning rate decay coefficient. Real Learning Rage = lr / (1 + po * epoch)")
1✔
221

222
  /** Batch size (Default: `8`)
223
    *
224
    * @group param
225
    */
226
  val batchSize = new IntParam(this, "batchSize", "Batch size")
1✔
227

228
  /** Dropout coefficient (Default: `0.5f`)
229
    *
230
    * @group param
231
    */
232
  val dropout = new FloatParam(this, "dropout", "Dropout coefficient")
1✔
233

234
  /** Folder path that contain external graph files
235
    *
236
    * @group param
237
    */
238
  val graphFolder =
239
    new Param[String](this, "graphFolder", "Folder path that contain external graph files")
1✔
240

241
  /** ConfigProto from tensorflow, serialized into byte array. Get with
242
    * config_proto.SerializeToString()
243
    *
244
    * @group param
245
    */
246
  val configProtoBytes = new IntArrayParam(
1✔
247
    this,
248
    "configProtoBytes",
1✔
249
    "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()")
1✔
250

251
  /** Whether to use contrib LSTM Cells (Default: `true`). Not compatible with Windows. Might
252
    * slightly improve accuracy. This param is deprecated and only exists for backward
253
    * compatibility
254
    *
255
    * @group param
256
    */
257
  val useContrib =
258
    new BooleanParam(this, "useContrib", "deprecated param - the value won't have any effect")
1✔
259

260
  /** Whether to include confidence scores in annotation metadata (Default: `false`)
261
    *
262
    * @group param
263
    */
264
  val includeConfidence = new BooleanParam(
1✔
265
    this,
266
    "includeConfidence",
1✔
267
    "Whether to include confidence scores in annotation metadata")
1✔
268

269
  /** whether to include all confidence scores in annotation metadata or just score of the
270
    * predicted tag
271
    *
272
    * @group param
273
    */
274
  val includeAllConfidenceScores = new BooleanParam(
1✔
275
    this,
276
    "includeAllConfidenceScores",
1✔
277
    "whether to include all confidence scores in annotation metadata")
1✔
278

279
  /** Whether to optimize for large datasets or not (Default: `false`). Enabling this option can
280
    * slow down training.
281
    *
282
    * @group param
283
    */
284
  val enableMemoryOptimizer = new BooleanParam(
1✔
285
    this,
286
    "enableMemoryOptimizer",
1✔
287
    "Whether to optimize for large datasets or not. Enabling this option can slow down training.")
1✔
288

289
  /** Whether to restore and use the model that has achieved the best performance at the end of
290
    * the training. The metric that is being monitored is F1 for testDataset and if it's not set
291
    * it will be validationSplit, and if it's not set finally looks for loss.
292
    *
293
    * @group param
294
    */
295
  val useBestModel = new BooleanParam(
1✔
296
    this,
297
    "useBestModel",
1✔
298
    "Whether to restore and use the model that has achieved the best performance at the end of the training.")
1✔
299

300
  /** Whether to check F1 Micro-average or F1 Macro-average as a final metric for the best model
301
    * This will fall back to loss if there is no validation or test dataset
302
    *
303
    * @group param
304
    */
305
  val bestModelMetric = new Param[String](
1✔
306
    this,
307
    "bestModelMetric",
1✔
308
    "Whether to check F1 Micro-average or F1 Macro-average as a final metric for the best model.")
1✔
309

310
  /** Learning Rate
311
    *
312
    * @group getParam
313
    */
314
  def getLr: Float = $(this.lr)
×
315

316
  /** Learning rate decay coefficient. Real Learning Rage = lr / (1 + po * epoch)
317
    *
318
    * @group getParam
319
    */
320
  def getPo: Float = $(this.po)
×
321

322
  /** Batch size
323
    *
324
    * @group getParam
325
    */
326
  def getBatchSize: Int = $(this.batchSize)
×
327

328
  /** Dropout coefficient
329
    *
330
    * @group getParam
331
    */
332
  def getDropout: Float = $(this.dropout)
×
333

334
  /** ConfigProto from tensorflow, serialized into byte array. Get with
335
    * config_proto.SerializeToString()
336
    *
337
    * @group getParam
338
    */
339
  def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte))
×
340

341
  /** Whether to use contrib LSTM Cells. Not compatible with Windows. Might slightly improve
342
    * accuracy.
343
    *
344
    * @group getParam
345
    */
346
  def getUseContrib: Boolean = $(this.useContrib)
×
347

348
  /** Memory Optimizer
349
    *
350
    * @group getParam
351
    */
352
  def getEnableMemoryOptimizer: Boolean = $(this.enableMemoryOptimizer)
×
353

354
  /** useBestModel
355
    *
356
    * @group getParam
357
    */
358
  def getUseBestModel: Boolean = $(this.useBestModel)
×
359

360
  /** @group getParam */
361
  def getBestModelMetric: String = $(bestModelMetric)
×
362

363
  /** Learning Rate
364
    *
365
    * @group setParam
366
    */
367
  def setLr(lr: Float): NerDLApproach.this.type = set(this.lr, lr)
1✔
368

369
  /** Learning rate decay coefficient. Real Learning Rage = lr / (1 + po * epoch)
370
    *
371
    * @group setParam
372
    */
373
  def setPo(po: Float): NerDLApproach.this.type = set(this.po, po)
1✔
374

375
  /** Batch size
376
    *
377
    * @group setParam
378
    */
379
  def setBatchSize(batch: Int): NerDLApproach.this.type = set(this.batchSize, batch)
×
380

381
  /** Dropout coefficient
382
    *
383
    * @group setParam
384
    */
385
  def setDropout(dropout: Float): NerDLApproach.this.type = set(this.dropout, dropout)
1✔
386

387
  /** Folder path that contain external graph files
388
    *
389
    * @group setParam
390
    */
391
  def setGraphFolder(path: String): NerDLApproach.this.type = set(this.graphFolder, path)
1✔
392

393
  /** ConfigProto from tensorflow, serialized into byte array. Get with
394
    * config_proto.SerializeToString()
395
    *
396
    * @group setParam
397
    */
398
  def setConfigProtoBytes(bytes: Array[Int]): NerDLApproach.this.type =
399
    set(this.configProtoBytes, bytes)
×
400

401
  /** Whether to use contrib LSTM Cells. Not compatible with Windows. Might slightly improve
402
    * accuracy.
403
    *
404
    * @group setParam
405
    */
406
  def setUseContrib(value: Boolean): NerDLApproach.this.type =
407
    if (value && SystemUtils.IS_OS_WINDOWS)
×
408
      throw new UnsupportedOperationException("Cannot set contrib in Windows")
×
409
    else set(useContrib, value)
×
410

411
  /** Whether to optimize for large datasets or not. Enabling this option can slow down training.
412
    *
413
    * @group setParam
414
    */
415
  def setEnableMemoryOptimizer(value: Boolean): NerDLApproach.this.type =
416
    set(this.enableMemoryOptimizer, value)
×
417

418
  /** Whether to include confidence scores in annotation metadata
419
    *
420
    * @group setParam
421
    */
422
  def setIncludeConfidence(value: Boolean): NerDLApproach.this.type =
423
    set(this.includeConfidence, value)
×
424

425
  /** whether to include confidence scores for all tags rather than just for the predicted one
426
    *
427
    * @group setParam
428
    */
429
  def setIncludeAllConfidenceScores(value: Boolean): this.type =
430
    set(this.includeAllConfidenceScores, value)
×
431

432
  /** @group setParam */
433
  def setUseBestModel(value: Boolean): NerDLApproach.this.type = set(this.useBestModel, value)
1✔
434

435
  /** @group setParam */
436
  def setBestModelMetric(value: String): NerDLApproach.this.type = {
437

438
    if (value == ModelMetrics.macroF1)
×
439
      set(bestModelMetric, value)
×
440
    else
441
      set(bestModelMetric, ModelMetrics.microF1)
×
442
    this
443
  }
444

445
  setDefault(
1✔
446
    minEpochs -> 0,
1✔
447
    maxEpochs -> 70,
1✔
448
    lr -> 1e-3f,
1✔
449
    po -> 0.005f,
1✔
450
    batchSize -> 8,
1✔
451
    dropout -> 0.5f,
1✔
452
    useContrib -> true,
1✔
453
    includeConfidence -> false,
1✔
454
    includeAllConfidenceScores -> false,
1✔
455
    enableMemoryOptimizer -> false,
1✔
456
    useBestModel -> false,
1✔
457
    bestModelMetric -> ModelMetrics.microF1)
1✔
458

459
  override val verboseLevel: Verbose.Level = Verbose($(verbose))
1✔
460

461
  def calculateEmbeddingsDim(sentences: Seq[WordpieceEmbeddingsSentence]): Int = {
462
    sentences
463
      .find(s => s.tokens.nonEmpty)
464
      .map(s => s.tokens.head.embeddings.length)
465
      .getOrElse(1)
×
466
  }
467

468
  override def beforeTraining(spark: SparkSession): Unit = {
469
    LoadsContrib.loadContribToCluster(spark)
1✔
470
    LoadsContrib.loadContribToTensorflow()
1✔
471
  }
472

473
  override def train(
474
      dataset: Dataset[_],
475
      recursivePipeline: Option[PipelineModel]): NerDLModel = {
476

477
    require(
1✔
478
      $(validationSplit) <= 1f | $(validationSplit) >= 0f,
1✔
479
      "The validationSplit must be between 0f and 1f")
×
480

481
    val train = dataset.toDF()
1✔
482

483
    val test = if (!isDefined(testDataset)) {
1✔
484
      train.limit(0) // keep the schema only
1✔
485
    } else {
486
      ResourceHelper.readSparkDataFrame($(testDataset))
×
487
    }
488

489
    val embeddingsRef =
490
      HasStorageRef.getStorageRefFromInput(dataset, $(inputCols), AnnotatorType.WORD_EMBEDDINGS)
1✔
491

492
    val Array(validSplit, trainSplit) =
1✔
493
      train.randomSplit(Array($(validationSplit), 1.0f - $(validationSplit)))
494

495
    val trainIteratorFunc = NerDLApproach.getIteratorFunc(
1✔
496
      trainSplit,
497
      inputColumns = getInputCols,
1✔
498
      labelColumn = $(labelColumn),
1✔
499
      batchSize = $(batchSize),
1✔
500
      enableMemoryOptimizer = $(enableMemoryOptimizer))
1✔
501

502
    val validIteratorFunc = NerDLApproach.getIteratorFunc(
1✔
503
      validSplit,
504
      inputColumns = getInputCols,
1✔
505
      labelColumn = $(labelColumn),
1✔
506
      batchSize = $(batchSize),
1✔
507
      enableMemoryOptimizer = $(enableMemoryOptimizer))
1✔
508

509
    val testIteratorFunc = NerDLApproach.getIteratorFunc(
1✔
510
      test,
511
      inputColumns = getInputCols,
1✔
512
      labelColumn = $(labelColumn),
1✔
513
      batchSize = $(batchSize),
1✔
514
      enableMemoryOptimizer = $(enableMemoryOptimizer))
1✔
515

516
    val (labels, chars, embeddingsDim, dsLen) =
1✔
517
      NerDLApproach.getDataSetParams(trainIteratorFunc())
518

519
    val settings = DatasetEncoderParams(
1✔
520
      labels.toList,
1✔
521
      chars.toList,
1✔
522
      Array.fill(embeddingsDim)(0f).toList,
1✔
523
      embeddingsDim)
524
    val encoder = new NerDatasetEncoder(settings)
1✔
525

526
    val graphFile = NerDLApproach.searchForSuitableGraph(
1✔
527
      labels.size,
1✔
528
      embeddingsDim,
529
      chars.size + 1,
1✔
530
      get(graphFolder))
1✔
531

532
    val graph = new Graph()
1✔
533
    val graphStream = ResourceHelper.getResourceStream(graphFile)
1✔
534
    val graphBytesDef = IOUtils.toByteArray(graphStream)
1✔
535
    graph.importGraphDef(GraphDef.parseFrom(graphBytesDef))
1✔
536

537
    val tfWrapper = new TensorflowWrapper(
1✔
538
      Variables(Array.empty[Array[Byte]], Array.empty[Byte]),
1✔
539
      graph.toGraphDef.toByteArray)
1✔
540

541
    val (ner, trainedTf) =
1✔
542
      try {
543
        val model = new TensorflowNer(tfWrapper, encoder, Verbose($(verbose)))
544
        if (isDefined(randomSeed)) {
545
          Random.setSeed($(randomSeed))
546
        }
547

548
        // start the iterator here once again
549
        val trainedTf = model.train(
550
          trainIteratorFunc(),
551
          dsLen,
552
          validIteratorFunc(),
553
          (dsLen * $(validationSplit)).toLong,
554
          $(lr),
555
          $(po),
556
          $(dropout),
557
          $(batchSize),
558
          $(useBestModel),
559
          $(bestModelMetric),
560
          graphFileName = graphFile,
561
          test = testIteratorFunc(),
562
          startEpoch = 0,
563
          endEpoch = $(maxEpochs),
564
          configProtoBytes = getConfigProtoBytes,
565
          validationSplit = $(validationSplit),
566
          evaluationLogExtended = $(evaluationLogExtended),
567
          enableOutputLogs = $(enableOutputLogs),
568
          outputLogsPath = $(outputLogsPath),
569
          uuid = this.uid)
570
        (model, trainedTf)
571
      } catch {
572
        case e: Exception =>
573
          graph.close()
574
          throw e
575
      }
576

577
    val newWrapper =
578
      new TensorflowWrapper(
1✔
579
        TensorflowWrapper.extractVariablesSavedModel(trainedTf),
1✔
580
        tfWrapper.graph)
1✔
581

582
    val model = new NerDLModel()
583
      .setDatasetParams(ner.encoder.params)
584
      .setModelIfNotSet(dataset.sparkSession, newWrapper)
585
      .setIncludeConfidence($(includeConfidence))
586
      .setIncludeAllConfidenceScores($(includeAllConfidenceScores))
587
      .setStorageRef(embeddingsRef)
1✔
588

589
    if (get(configProtoBytes).isDefined)
1✔
590
      model.setConfigProtoBytes($(configProtoBytes))
×
591

592
    model
593

594
  }
595
}
596

597
trait WithGraphResolver {
598

599
  def searchForSuitableGraph(
600
      tags: Int,
601
      embeddingsNDims: Int,
602
      nChars: Int,
603
      localGraphPath: Option[String] = None): String = {
604

605
    val files: Seq[String] = getFiles(localGraphPath)
1✔
606

607
    // 1. Filter Graphs by embeddings
608
    val embeddingsFiltered = files.map { filePath =>
1✔
609
      val file = new File(filePath)
1✔
610
      val name = file.getName
1✔
611
      val graphPrefix = "blstm_"
1✔
612

613
      if (name.startsWith(graphPrefix)) {
1✔
614
        val clean = name.replace(graphPrefix, "").replace(".pb", "")
1✔
615
        val graphParams = clean.split("_").take(4).map(s => s.toInt)
1✔
616
        val Array(fileTags, fileEmbeddingsNDims, _, fileNChars) = graphParams
1✔
617

618
        if (embeddingsNDims == fileEmbeddingsNDims)
1✔
619
          Some((fileTags, fileEmbeddingsNDims, fileNChars))
1✔
620
        else
621
          None
1✔
622
      } else {
623
        None
1✔
624
      }
625
    }
626

627
    require(
1✔
628
      embeddingsFiltered.exists(_.nonEmpty),
1✔
629
      s"Graph dimensions should be $embeddingsNDims: Could not find a suitable tensorflow graph for embeddings dim: $embeddingsNDims tags: $tags nChars: $nChars. " +
1✔
630
        s"Check https://sparknlp.org/docs/en/graph for instructions to generate the required graph.")
1✔
631

632
    // 2. Filter by labels and nChars
633
    val tagsFiltered = embeddingsFiltered.map {
1✔
634
      case Some((fileTags, fileEmbeddingsNDims, fileNChars)) =>
635
        if (tags > fileTags)
1✔
636
          None
1✔
637
        else
638
          Some((fileTags, fileEmbeddingsNDims, fileNChars))
1✔
639
      case _ => None
1✔
640
    }
641

642
    require(
1✔
643
      tagsFiltered.exists(_.nonEmpty),
1✔
644
      s"Graph tags size should be $tags: Could not find a suitable tensorflow graph for embeddings dim: $embeddingsNDims tags: $tags nChars: $nChars. " +
1✔
645
        s"Check https://sparknlp.org/docs/en/graph for instructions to generate the required graph.")
1✔
646

647
    // 3. Filter by labels and nChars
648
    val charsFiltered = tagsFiltered.map {
1✔
649
      case Some((fileTags, fileEmbeddingsNDims, fileNChars)) =>
650
        if (nChars > fileNChars)
1✔
651
          None
×
652
        else
653
          Some((fileTags, fileEmbeddingsNDims, fileNChars))
1✔
654
      case _ => None
1✔
655
    }
656

657
    require(
1✔
658
      charsFiltered.exists(_.nonEmpty),
1✔
659
      s"Graph chars size should be $nChars: Could not find a suitable tensorflow graph for embeddings dim: $embeddingsNDims tags: $tags nChars: $nChars. " +
×
660
        s"Check https://sparknlp.org/docs/en/graph for instructions to generate the required graph")
×
661

662
    for (i <- files.indices) {
1✔
663
      if (charsFiltered(i).nonEmpty)
1✔
664
        return files(i)
1✔
665
    }
666

667
    throw new IllegalStateException("Code shouldn't pass here")
×
668
  }
669

670
  private def getFiles(localGraphPath: Option[String]): Seq[String] = {
671
    var files: Seq[String] = List()
1✔
672

673
    if (localGraphPath.isDefined && localGraphPath.get
1✔
674
        .startsWith("s3://")) { // TODO: Might be able to remove this condition
×
675

676
      val tmpDirectory = ResourceDownloader.downloadS3Directory(localGraphPath.get).getPath
×
677

678
      files = ResourceHelper.listLocalFiles(tmpDirectory).map(_.getAbsolutePath)
×
679
    } else {
680

681
      if (localGraphPath.isDefined && OutputHelper
1✔
682
          .getFileSystem(localGraphPath.get)
683
          .getScheme == "dbfs") {
1✔
684
        files =
×
685
          ResourceHelper.listLocalFiles(localGraphPath.get).map(file => file.getAbsolutePath)
×
686
      } else {
687
        files = localGraphPath
1✔
688
          .map(path =>
689
            ResourceHelper
690
              .listLocalFiles(ResourceHelper.copyToLocal(path))
1✔
691
              .map(_.getAbsolutePath))
1✔
692
          .getOrElse(ResourceHelper.listResourceDirectory("/ner-dl"))
1✔
693
      }
694

695
    }
696
    files
697
  }
698

699
}
700

701
/** This is the companion object of [[NerDLApproach]]. Please refer to that class for the
702
  * documentation.
703
  */
704
object NerDLApproach extends DefaultParamsReadable[NerDLApproach] with WithGraphResolver {
705

706
  def getIteratorFunc(
707
      dataset: Dataset[Row],
708
      inputColumns: Array[String],
709
      labelColumn: String,
710
      batchSize: Int,
711
      enableMemoryOptimizer: Boolean)
712
      : () => Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = {
713

714
    if (enableMemoryOptimizer) { () =>
×
715
      NerTagged.iterateOnDataframe(dataset, inputColumns, labelColumn, batchSize)
×
716

717
    } else {
1✔
718
      val inMemory = dataset
719
        .select(labelColumn, inputColumns.toSeq: _*)
720
        .collect()
1✔
721

722
      () => NerTagged.iterateOnArray(inMemory, inputColumns, batchSize)
1✔
723
    }
724
  }
725

726
  def getDataSetParams(dsIt: Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]])
727
      : (mutable.Set[String], mutable.Set[Char], Int, Long) = {
728

729
    val labels = scala.collection.mutable.Set[String]()
1✔
730
    val chars = scala.collection.mutable.Set[Char]()
1✔
731
    var embeddingsDim = 1
1✔
732
    var dsLen = 0L
1✔
733

734
    // try to be frugal with memory and with number of passes thru the iterator
735
    for (batch <- dsIt) {
1✔
736
      dsLen += batch.length
1✔
737
      for (datapoint <- batch) {
1✔
738

739
        for (label <- datapoint._1.labels)
1✔
740
          labels += label
1✔
741

742
        for (token <- datapoint._2.tokens; char <- token.token.toCharArray)
1✔
743
          chars += char
1✔
744

745
        if (datapoint._2.tokens.nonEmpty)
1✔
746
          embeddingsDim = datapoint._2.tokens.head.embeddings.length
1✔
747
      }
748
    }
749

750
    (labels, chars, embeddingsDim, dsLen)
1✔
751
  }
752

753
  def getGraphParams(
754
      dataset: Dataset[_],
755
      inputColumns: java.util.ArrayList[java.lang.String],
756
      labelColumn: String): (Int, Int, Int) = {
757

758
    val trainIteratorFunc =
759
      getIteratorFunc(dataset.toDF(), inputColumns.toArray.map(_.toString), labelColumn, 0, false)
×
760

761
    val (labels, chars, embeddingsDim, _) = getDataSetParams(trainIteratorFunc())
×
762

763
    (labels.size, embeddingsDim, chars.size + 1)
×
764
  }
765
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc