• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4947838414

pending completion
4947838414

Pull #13796

github

GitHub
Merge 30bdeef19 into ef7906c5e
Pull Request #13796: Add unzip param to downloadModelDirectly in ResourceDownloader

39 of 39 new or added lines in 2 files covered. (100.0%)

8632 of 13111 relevant lines covered (65.84%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.72
/src/main/scala/com/johnsnowlabs/nlp/annotators/sentence_detector_dl/SentenceDetectorDLModel.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.sentence_detector_dl
18

19
import com.johnsnowlabs.ml.ai.SentenceDetectorDL
20
import com.johnsnowlabs.ml.tensorflow.{
21
  ReadTensorflowModel,
22
  TensorflowWrapper,
23
  WriteTensorflowModel
24
}
25
import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT
26
import com.johnsnowlabs.nlp._
27
import com.johnsnowlabs.storage.HasStorageRef
28
import org.apache.spark.broadcast.Broadcast
29
import org.apache.spark.ml.param.{BooleanParam, IntParam, Param, StringArrayParam}
30
import org.apache.spark.ml.util.Identifiable
31
import org.apache.spark.sql.{DataFrame, SparkSession}
32

33
import scala.collection.mutable
34
import scala.collection.mutable.ArrayBuffer
35
import scala.util.Random
36

37
case class Metrics(accuracy: Double, recall: Double, precision: Double, f1: Double)
38

39
/** Annotator that detects sentence boundaries using a deep learning approach.
40
  *
41
  * Instantiated Model of the
42
  * [[com.johnsnowlabs.nlp.annotators.sentence_detector_dl.SentenceDetectorDLApproach SentenceDetectorDLApproach]].
43
  * Detects sentence boundaries using a deep learning approach.
44
  *
45
  * Pretrained models can be loaded with `pretrained` of the companion object:
46
  * {{{
47
  * val sentenceDL = SentenceDetectorDLModel.pretrained()
48
  *   .setInputCols("document")
49
  *   .setOutputCol("sentencesDL")
50
  * }}}
51
  * The default model is `"sentence_detector_dl"`, if no name is provided. For available
52
  * pretrained models please see the
53
  * [[https://sparknlp.org/models?task=Sentence+Detection Models Hub]].
54
  *
55
  * Each extracted sentence can be returned in an Array or exploded to separate rows, if
56
  * `explodeSentences` is set to `true`.
57
  *
58
  * For extended examples of usage, see the
59
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/text/multilingual/SentenceDetectorDL.ipynb Examples]]
60
  * and the
61
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/sentence_detector_dl/SentenceDetectorDLSpec.scala SentenceDetectorDLSpec]].
62
  *
63
  * ==Example==
64
  * In this example, the normal `SentenceDetector` is compared to the `SentenceDetectorDLModel`.
65
  * In a pipeline, `SentenceDetectorDLModel` can be used as a replacement for the
66
  * `SentenceDetector`.
67
  * {{{
68
  * import spark.implicits._
69
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
70
  * import com.johnsnowlabs.nlp.annotator.SentenceDetector
71
  * import com.johnsnowlabs.nlp.annotators.sentence_detector_dl.SentenceDetectorDLModel
72
  * import org.apache.spark.ml.Pipeline
73
  *
74
  * val documentAssembler = new DocumentAssembler()
75
  *   .setInputCol("text")
76
  *   .setOutputCol("document")
77
  *
78
  * val sentence = new SentenceDetector()
79
  *   .setInputCols("document")
80
  *   .setOutputCol("sentences")
81
  *
82
  * val sentenceDL = SentenceDetectorDLModel
83
  *   .pretrained("sentence_detector_dl", "en")
84
  *   .setInputCols("document")
85
  *   .setOutputCol("sentencesDL")
86
  *
87
  * val pipeline = new Pipeline().setStages(Array(
88
  *   documentAssembler,
89
  *   sentence,
90
  *   sentenceDL
91
  * ))
92
  *
93
  * val data = Seq("""John loves Mary.Mary loves Peter
94
  *   Peter loves Helen .Helen loves John;
95
  *   Total: four people involved.""").toDF("text")
96
  * val result = pipeline.fit(data).transform(data)
97
  *
98
  * result.selectExpr("explode(sentences.result) as sentences").show(false)
99
  * +----------------------------------------------------------+
100
  * |sentences                                                 |
101
  * +----------------------------------------------------------+
102
  * |John loves Mary.Mary loves Peter\n     Peter loves Helen .|
103
  * |Helen loves John;                                         |
104
  * |Total: four people involved.                              |
105
  * +----------------------------------------------------------+
106
  *
107
  * result.selectExpr("explode(sentencesDL.result) as sentencesDL").show(false)
108
  * +----------------------------+
109
  * |sentencesDL                 |
110
  * +----------------------------+
111
  * |John loves Mary.            |
112
  * |Mary loves Peter            |
113
  * |Peter loves Helen .         |
114
  * |Helen loves John;           |
115
  * |Total: four people involved.|
116
  * +----------------------------+
117
  * }}}
118
  *
119
  * @see
120
  *   [[com.johnsnowlabs.nlp.annotators.sentence_detector_dl.SentenceDetectorDLApproach SentenceDetectorDLApproach]]
121
  *   for training a model yourself
122
  * @see
123
  *   [[com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector SentenceDetector]] for non
124
  *   deep learning extraction
125
  * @param uid
126
  *   required uid for storing annotator to disk
127
  * @groupname anno Annotator types
128
  * @groupdesc anno
129
  *   Required input and expected output annotator types
130
  * @groupname Ungrouped Members
131
  * @groupname param Parameters
132
  * @groupname setParam Parameter setters
133
  * @groupname getParam Parameter getters
134
  * @groupname Ungrouped Members
135
  * @groupprio param  1
136
  * @groupprio anno  2
137
  * @groupprio Ungrouped 3
138
  * @groupprio setParam  4
139
  * @groupprio getParam  5
140
  * @groupdesc param
141
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
142
  *   parameter values through setters and getters, respectively.
143
  */
144
class SentenceDetectorDLModel(override val uid: String)
145
    extends AnnotatorModel[SentenceDetectorDLModel]
146
    with HasSimpleAnnotate[SentenceDetectorDLModel]
147
    with HasStorageRef
148
    with ParamsAndFeaturesWritable
149
    with WriteTensorflowModel
150
    with HasEngine {
151

152
  def this() = this(Identifiable.randomUID("SentenceDetectorDLModel"))
1✔
153

154
  /** Output annotator type : DOCUMENT
155
    *
156
    * @group anno
157
    */
158
  override val inputAnnotatorTypes: Array[AnnotatorType] = Array(DOCUMENT)
1✔
159

160
  /** Output annotator type : DOCUMENT
161
    *
162
    * @group anno
163
    */
164
  override val outputAnnotatorType: String = DOCUMENT
1✔
165

166
  var encoder = new SentenceDetectorDLEncoderParam(this, "Encoder", "Data encoder")
1✔
167

168
  def setEncoder(encoder: SentenceDetectorDLEncoder): SentenceDetectorDLModel.this.type =
169
    set(this.encoder, encoder)
1✔
170

171
  def getEncoder: SentenceDetectorDLEncoder = $(this.encoder)
1✔
172

173
  /** Model architecture (Default: `"cnn"`)
174
    *
175
    * @group param
176
    */
177
  var modelArchitecture =
178
    new Param[String](this, "modelArchitecture", "Model Architecture: one of (CNN)")
1✔
179

180
  /** Set architecture
181
    *
182
    * @group setParam
183
    */
184
  def setModel(modelArchitecture: String): SentenceDetectorDLModel.this.type =
185
    set(this.modelArchitecture, modelArchitecture)
1✔
186

187
  /** Get model architecture
188
    *
189
    * @group getParam
190
    */
191
  def getModel: String = $(this.modelArchitecture)
×
192

193
  /** Impossible penultimates (Default: `Array()`)
194
    *
195
    * @group param
196
    */
197
  val impossiblePenultimates =
198
    new StringArrayParam(this, "impossiblePenultimates", "Impossible penultimates")
1✔
199

200
  /** Length at which sentences will be forcibly split (Ignored if not set)
201
    *
202
    * @group param
203
    */
204

205
  val splitLength: IntParam =
206
    new IntParam(this, "splitLength", "length at which sentences will be forcibly split.")
1✔
207

208
  /** Set the minimum allowed length for each sentence (Default: `0`)
209
    *
210
    * @group param
211
    */
212

213
  val minLength =
214
    new IntParam(this, "minLength", "Set the minimum allowed length for each sentence")
1✔
215

216
  /** Set the maximum allowed length for each sentence (Ignored if not set)
217
    *
218
    * @group param
219
    */
220
  val maxLength =
221
    new IntParam(this, "maxLength", "Set the maximum allowed length for each sentence")
1✔
222

223
  /** A flag indicating whether to split sentences into different Dataset rows. Useful for higher
224
    * parallelism in fat rows (Default: `false`)
225
    *
226
    * @group getParam
227
    */
228
  val explodeSentences =
229
    new BooleanParam(this, "explodeSentences", "Split sentences in separate rows")
1✔
230

231
  /** Whether to only utilize custom bounds for sentence detection (Default: `false`)
232
    *
233
    * @group param
234
    */
235
  val useCustomBoundsOnly = new BooleanParam(
1✔
236
    this,
237
    "useCustomBoundsOnly",
1✔
238
    "whether to only utilize custom bounds for sentence detection")
1✔
239

240
  /** Characters used to explicitly mark sentence bounds (Default: None)
241
    *
242
    * @group param
243
    */
244
  val customBounds: StringArrayParam = new StringArrayParam(
1✔
245
    this,
246
    "customBounds",
1✔
247
    "characters used to explicitly mark sentence bounds")
1✔
248

249
  /** Length at which sentences will be forcibly split
250
    * @group setParam
251
    */
252
  def setSplitLength(value: Int): this.type = set(splitLength, value)
1✔
253

254
  /** Length at which sentences will be forcibly split
255
    * @group getParam
256
    */
257
  def getSplitLength: Int = $(splitLength)
1✔
258

259
  /** Set the minimum allowed length for each sentence
260
    * @group setParam
261
    */
262
  def setMinLength(value: Int): this.type = {
263
    require(value >= 0, "minLength must be greater equal than 0")
1✔
264
    require(value.isValidInt, "minLength must be Int")
1✔
265
    set(minLength, value)
1✔
266
  }
267

268
  /** Get the minimum allowed length for each sentence
269
    * @group getParam
270
    */
271
  def getMinLength: Int = $(minLength)
1✔
272

273
  /** Set the maximum allowed length for each sentence
274
    * @group setParam
275
    */
276
  def setMaxLength(value: Int): this.type = {
277
    require(
1✔
278
      value >= $ {
1✔
279
        minLength
1✔
280
      },
281
      "maxLength must be greater equal than minLength")
×
282
    require(value.isValidInt, "minLength must be Int")
1✔
283
    set(maxLength, value)
1✔
284
  }
285

286
  /** Get the maximum allowed length for each sentence
287
    * @group getParam
288
    */
289
  def getMaxLength: Int = $(maxLength)
1✔
290

291
  /** Set impossible penultimates
292
    *
293
    * @group setParam
294
    */
295
  def setImpossiblePenultimates(
296
      impossiblePenultimates: Array[String]): SentenceDetectorDLModel.this.type =
297
    set(this.impossiblePenultimates, impossiblePenultimates)
1✔
298

299
  /** Get impossible penultimates
300
    *
301
    * @group getParam
302
    */
303
  def getImpossiblePenultimates: Array[String] = $(this.impossiblePenultimates)
1✔
304

305
  /** Whether to split sentences into different Dataset rows. Useful for higher parallelism in fat
306
    * rows. Defaults to false.
307
    *
308
    * @group setParam
309
    */
310
  def setExplodeSentences(value: Boolean): SentenceDetectorDLModel.this.type =
311
    set(this.explodeSentences, value)
1✔
312

313
  /** Whether to split sentences into different Dataset rows. Useful for higher parallelism in fat
314
    * rows. Defaults to false.
315
    *
316
    * @group getParam
317
    */
318
  def getExplodeSentences: Boolean = $(this.explodeSentences)
×
319

320
  /** Custom sentence separator text
321
    * @group setParam
322
    */
323
  def setCustomBounds(value: Array[String]): this.type = set(customBounds, value)
1✔
324

325
  /** Custom sentence separator text
326
    * @group getParam
327
    */
328
  def getCustomBounds: Array[String] = $(customBounds)
×
329

330
  /** Use only custom bounds without considering those of Pragmatic Segmenter. Defaults to false.
331
    * Needs customBounds.
332
    * @group setParam
333
    */
334
  def setUseCustomBoundsOnly(value: Boolean): this.type = set(useCustomBoundsOnly, value)
1✔
335

336
  /** Use only custom bounds without considering those of Pragmatic Segmenter. Defaults to false.
337
    * Needs customBounds.
338
    * @group getParam
339
    */
340
  def getUseCustomBoundsOnly: Boolean = $(useCustomBoundsOnly)
×
341

342
  setDefault(
1✔
343
    modelArchitecture -> "cnn",
1✔
344
    impossiblePenultimates -> Array(),
1✔
345
    explodeSentences -> false,
1✔
346
    minLength -> 0,
1✔
347
    maxLength -> Int.MaxValue,
1✔
348
    splitLength -> Int.MaxValue,
1✔
349
    useCustomBoundsOnly -> false,
1✔
350
    customBounds -> Array.empty[String])
1✔
351

352
  private var _tfClassifier: Option[Broadcast[SentenceDetectorDL]] = None
1✔
353

354
  def setupTFClassifier(spark: SparkSession, tfWrapper: TensorflowWrapper): this.type = {
355
    if (_tfClassifier.isEmpty) {
×
356
      _tfClassifier = Some(spark.sparkContext.broadcast(new SentenceDetectorDL(tfWrapper)))
1✔
357
    }
358
    this
359
  }
360

361
  def setupNew(spark: SparkSession, modelPath: String, vocabularyPath: String): this.type = {
362
    val encoder = new SentenceDetectorDLEncoder()
×
363
    encoder.loadVocabulary(vocabularyPath)
×
364
    setEncoder(encoder)
×
365

366
    val (wrapper, _) = TensorflowWrapper.read(modelPath)
×
367
    setupTFClassifier(spark, wrapper)
×
368
  }
369

370
  def getTFClassifier: SentenceDetectorDL = {
371
    require(_tfClassifier.isDefined, "TF model not setup.")
×
372
    _tfClassifier.get.value
1✔
373
  }
374

375
  def getMetrics(text: String, injectNewLines: Boolean = false): Metrics = {
376

377
    var nExamples = 0.0
1✔
378
    var nRecall = 0.0
1✔
379
    var nPrecision = 0.0
1✔
380

381
    var accuracy = 0.0
1✔
382
    var recall = 0.0
1✔
383
    var precision = 0.0
1✔
384

385
    var pText = text
386

387
    if (injectNewLines) {
×
388
      val nlShare = (text.split("\n").length / 10).toInt
×
389
      Array
390
        .fill(nlShare)(Random.nextInt(text.length - 10))
×
391
        .foreach(pos => {
×
392
          if (text(pos) != '\n' && text(pos + 1) != '\n' && text(pos - 1) != '\n') {
×
393
            pText = pText.slice(0, pos) + "\n" + pText.slice(pos + 1, pText.length - 1)
×
394
          }
395
        })
396
    } else {
397
      pText = text
1✔
398
    }
399

400
    getEncoder
1✔
401
      .getEOSPositions(pText)
1✔
402
      .foreach(ex => {
1✔
403
        val (pos, vector) = ex
1✔
404
        val output = getTFClassifier.predict(Array(vector))
1✔
405
        val posPrediction = output._1(0)
1✔
406
        val posActivation = output._2(0)
1✔
407

408
        val groundTruth = (
409
          (pos < (text.length - 1) && text(pos + 1) == '\n')
1✔
410
            || (text(pos) == '\n' && pos > 0 && (!Array('.', ':', '?', '!', ';').contains(
1✔
411
              text(pos - 1))))
1✔
412
        )
413

414
        val prediction = (posActivation > 0.5f)
1✔
415

416
        accuracy += (if (groundTruth == prediction) 1.0 else 0.0)
1✔
417
        nExamples += 1.0
1✔
418

419
        if (groundTruth) {
1✔
420
          recall += (if (groundTruth == prediction) 1.0 else 0.0)
×
421

422
          nRecall += 1.0
1✔
423
        }
424

425
        if (prediction) {
×
426
          precision += (if (groundTruth == prediction) 1.0 else 0.0)
1✔
427
          nPrecision += 1.0
1✔
428
        }
429
      })
430

431
    accuracy = (if (nExamples > 0) (accuracy / nExamples) else 1)
1✔
432
    recall = (if (nRecall > 0) (recall / nRecall) else 1)
×
433
    precision = (if (nPrecision > 0) (precision / nPrecision) else 1)
1✔
434

435
    Metrics(
1✔
436
      accuracy,
437
      recall,
438
      precision,
439
      2.0 * (if ((recall + precision) > 0.0) (recall * precision / (recall + precision))
1✔
440
             else 0.0))
×
441
  }
442

443
  def processText(
444
      text: String,
445
      processCustomBounds: Boolean = true): Iterator[(Int, Int, String)] = {
446

447
    if (processCustomBounds) {
1✔
448
      var sentences = Array("")
1✔
449
      var sentenceStarts = Array(0)
1✔
450
      var currentPos = 0
1✔
451
      text.zipWithIndex.foreach(x => {
1✔
452
        val boundary = $(customBounds).find(b => sentences(currentPos).matches(".*" + b + "$"))
1✔
453
        if (boundary.isDefined) {
1✔
454
//          sentences(currentPos) = sentences(currentPos).dropRight(boundary.get.length)
455
          sentences = sentences ++ Array("")
1✔
456
          sentenceStarts = sentenceStarts ++ Array(x._2)
1✔
457
          currentPos += 1
1✔
458
        }
459
        if (!(sentences(currentPos).isEmpty && getEncoder.getSkipChars.contains(x._1)))
1✔
460
          sentences(currentPos) = sentences(currentPos) + x._1
1✔
461
        else
462
          sentenceStarts(currentPos) = sentenceStarts(currentPos) + 1
1✔
463
      })
464
      return if ($(useCustomBoundsOnly)) {
1✔
465
        sentences.zip(sentenceStarts).map(x => (x._2, x._2 + x._1.length, x._1)).toIterator
×
466
      } else {
467
        sentences
468
          .zip(sentenceStarts)
1✔
469
          .flatMap(x => {
1✔
470
            processText(x._1, false).map(s => (s._1 + x._2, s._2 + x._2, s._3))
1✔
471
          })
472
          .toIterator
1✔
473
      }
474

475
    }
476

477
    var startPos = 0
1✔
478
    val skipChars = getEncoder.getSkipChars
1✔
479

480
    val sentences = getEncoder
481
      .getEOSPositions(text, getImpossiblePenultimates)
1✔
482
      .map(ex => {
483
        val (pos, vector) = ex
1✔
484
        val output = getTFClassifier.predict(Array(vector))
1✔
485
        val posActivation = output._2(0)
1✔
486
        (pos, posActivation)
1✔
487
      })
488
      .filter(ex => ex._2 > 0.5f)
1✔
489
      .map(_._1)
1✔
490
      .map(eos => {
1✔
491

492
        while ((startPos < eos) && skipChars.contains(text(startPos))) {
1✔
493
          startPos += 1
1✔
494
        }
495

496
        val endPos = if (skipChars.contains(text(eos))) eos else eos + 1
1✔
497
        val s = (startPos, eos, text.slice(startPos, endPos))
1✔
498

499
        startPos = eos + 1
1✔
500

501
        s
502

503
      })
504

505
    sentences ++ (if (startPos < text.length)
1✔
506
                    Array((startPos, text.length, text.slice(startPos, text.length))).toIterator
1✔
507
                  else
508
                    Array().toIterator)
1✔
509
  }
510

511
  private def truncateSentence(sentence: String, maxLength: Int): Array[String] = {
512
    var currentLength = 0
1✔
513
    val allSentences = ArrayBuffer.empty[String]
1✔
514
    val currentSentence = ArrayBuffer.empty[String]
1✔
515

516
    def addWordToSentence(word: String): Unit = {
517

518
      /** Adds +1 because of the space joining words */
519
      currentLength += word.length + 1
1✔
520
      currentSentence.append(word)
1✔
521
    }
522

523
    sentence
524
      .split(" ")
1✔
525
      .foreach(word => {
1✔
526
        if (currentLength + word.length > maxLength) {
1✔
527
          allSentences.append(currentSentence.mkString(" "))
1✔
528
          currentSentence.clear()
1✔
529
          currentLength = 0
1✔
530
          addWordToSentence(word)
1✔
531
        } else {
532
          addWordToSentence(word)
1✔
533
        }
534
      })
535

536
    /** add leftovers */
537
    allSentences.append(currentSentence.mkString(" "))
1✔
538
    allSentences.toArray
1✔
539
  }
540

541
  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
542

543
    val documents = annotations.filter(_.annotatorType == DOCUMENT)
1✔
544
    val outputAnnotations = ArrayBuffer[Annotation]()
1✔
545

546
    documents.foreach(doc => {
1✔
547
      var sentenceNo = 0
1✔
548
      processText(doc.result).foreach(posSentence => {
1✔
549

550
        if (posSentence._3.trim.nonEmpty) {
1✔
551
          var sentenceBegin = posSentence._1
1✔
552

553
          truncateSentence(posSentence._3, getSplitLength).foreach(splitSentence => {
1✔
554

555
            outputAnnotations.append(
1✔
556
              new Annotation(
1✔
557
                annotatorType = AnnotatorType.DOCUMENT,
1✔
558
                begin = sentenceBegin,
559
                end = sentenceBegin + splitSentence.length - 1,
1✔
560
                result = splitSentence,
561
                metadata = mutable.Map("sentence" -> sentenceNo.toString)))
1✔
562
            sentenceBegin += splitSentence.length
1✔
563
            sentenceNo += 1
1✔
564
          })
565
        }
566
      })
567
      if ((sentenceNo == 0) && (doc.end > doc.begin)) {
1✔
568
        outputAnnotations.append(
×
569
          new Annotation(
×
570
            annotatorType = AnnotatorType.DOCUMENT,
×
571
            begin = doc.begin,
×
572
            end = doc.end,
×
573
            result = doc.result,
×
574
            metadata = mutable.Map("sentence" -> sentenceNo.toString)))
×
575
      }
576
    })
577

578
    outputAnnotations
579
      .filter(anno => anno.result.length >= getMinLength && anno.result.length <= getMaxLength)
1✔
580
  }
581

582
  override protected def afterAnnotate(dataset: DataFrame): DataFrame = {
583

584
    import org.apache.spark.sql.functions.{array, col, explode}
585

586
    if ($(explodeSentences)) {
1✔
587
      dataset
588
        .select(dataset.columns.filterNot(_ == getOutputCol).map(col) :+ explode(
589
          col(getOutputCol)).as("_tmp"): _*)
590
        .withColumn(
591
          getOutputCol,
592
          array(col("_tmp"))
593
            .as(getOutputCol, dataset.schema.fields.find(_.name == getOutputCol).get.metadata))
594
        .drop("_tmp")
×
595
    } else dataset
1✔
596
  }
597

598
  override def onWrite(path: String, spark: SparkSession): Unit = {
599
    super.onWrite(path, spark)
1✔
600

601
    writeTensorflowModel(
1✔
602
      path,
603
      spark,
604
      getTFClassifier.getTFModel,
1✔
605
      "_genericclassifier",
1✔
606
      SentenceDetectorDLModel.tfFile)
1✔
607
  }
608
}
609

610
trait ReadsSentenceDetectorDLGraph
611
    extends ParamsAndFeaturesReadable[SentenceDetectorDLModel]
612
    with ReadTensorflowModel {
613

614
  override val tfFile = "generic_classifier_tensorflow"
1✔
615

616
  def readSentenceDetectorDLGraph(
617
      instance: SentenceDetectorDLModel,
618
      path: String,
619
      spark: SparkSession): Unit = {
620

621
    val tf = readTensorflowModel(path, spark, "_genericclassifier")
1✔
622
    instance.setupTFClassifier(spark, tf)
1✔
623
  }
624

625
  addReader(readSentenceDetectorDLGraph)
1✔
626
}
627

628
trait ReadablePretrainedSentenceDetectorDL
629
    extends ParamsAndFeaturesReadable[SentenceDetectorDLModel]
630
    with HasPretrained[SentenceDetectorDLModel] {
631

632
  override val defaultModelName: Some[String] = Some("sentence_detector_dl")
1✔
633

634
  /** Java compliant-overrides */
635
  override def pretrained(): SentenceDetectorDLModel = super.pretrained()
1✔
636

637
  override def pretrained(name: String): SentenceDetectorDLModel = super.pretrained(name)
×
638

639
  override def pretrained(name: String, lang: String): SentenceDetectorDLModel =
640
    super.pretrained(name, lang)
×
641

642
  override def pretrained(
643
      name: String,
644
      lang: String,
645
      remoteLoc: String): SentenceDetectorDLModel = super.pretrained(name, lang, remoteLoc)
1✔
646
}
647

648
object SentenceDetectorDLModel
649
    extends ReadsSentenceDetectorDLGraph
650
    with ReadablePretrainedSentenceDetectorDL
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc