• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4992350528

pending completion
4992350528

Pull #13797

github

GitHub
Merge 424c7ff18 into ef7906c5e
Pull Request #13797: SPARKNLP-835: ProtectedParam and ProtectedFeature

24 of 24 new or added lines in 6 files covered. (100.0%)

8643 of 13129 relevant lines covered (65.83%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/main/scala/com/johnsnowlabs/nlp/annotators/coref/SpanBertCorefModel.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16
package com.johnsnowlabs.nlp.annotators.coref
17

18
import com.johnsnowlabs.ml.ai.SpanBertCoref
19
import com.johnsnowlabs.ml.tensorflow.{
20
  ReadTensorflowModel,
21
  TensorflowWrapper,
22
  WriteTensorflowModel
23
}
24
import com.johnsnowlabs.ml.util.LoadExternalModel.{
25
  loadTextAsset,
26
  modelSanityCheck,
27
  notSupportedEngineError
28
}
29
import com.johnsnowlabs.ml.util.ModelEngine
30
import com.johnsnowlabs.nlp._
31
import com.johnsnowlabs.nlp.annotators.common._
32
import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder}
33
import com.johnsnowlabs.nlp.embeddings.HasEmbeddingsProperties
34
import com.johnsnowlabs.nlp.serialization.MapFeature
35
import com.johnsnowlabs.storage.HasStorageRef
36
import org.apache.spark.broadcast.Broadcast
37
import org.apache.spark.ml.param.{IntArrayParam, IntParam, Param}
38
import org.apache.spark.ml.util.Identifiable
39
import org.apache.spark.sql.SparkSession
40
import org.slf4j.{Logger, LoggerFactory}
41

42
/** A coreference resolution model based on SpanBert
43
  *
44
  * A coreference resolution model identifies expressions which refer to the same entity in a
45
  * text. For example, given a sentence "John told Mary he would like to borrow a book from her."
46
  * the model will link "he" to "John" and "her" to "Mary".
47
  *
48
  * This model is based on SpanBert, which is fine-tuned on the OntoNotes 5.0 data set.
49
  *
50
  * Pretrained models can be loaded with `pretrained` of the companion object:
51
  * {{{
52
  * val dependencyParserApproach = SpanBertCorefModel.pretrained()
53
  *   .setInputCols("sentence", "token")
54
  *   .setOutputCol("corefs")
55
  * }}}
56
  * The default model is `"spanbert_base_coref"`, if no name is provided. For available pretrained
57
  * models please see the [[https://sparknlp.org/models Models Hub]].
58
  *
59
  * For extended examples of usage, see the
60
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/text/english/coreference-resolution/Coreference_Resolution_SpanBertCorefModel.ipynb Examples]]
61
  *
62
  * '''References:'''
63
  *   - [[https://github.com/mandarjoshi90/coref]]
64
  *
65
  * ==Example==
66
  * {{{
67
  * import spark.implicits._
68
  * import com.johnsnowlabs.nlp.base._
69
  * import com.johnsnowlabs.nlp.annotator._
70
  * import org.apache.spark.ml.Pipeline
71
  *
72
  * val documentAssembler = new DocumentAssembler()
73
  *   .setInputCol("text")
74
  *   .setOutputCol("document")
75
  *
76
  * val sentence = new SentenceDetector()
77
  *   .setInputCols("document")
78
  *   .setOutputCol("sentence")
79
  *
80
  * val tokenizer = new Tokenizer()
81
  *   .setInputCols("sentence")
82
  *   .setOutputCol("token")
83
  *
84
  * val corefResolution = SpanBertCorefModel.pretrained()
85
  *   .setInputCols("sentence", "token")
86
  *   .setOutputCol("corefs")
87
  *
88
  * val pipeline = new Pipeline().setStages(Array(
89
  *   documentAssembler,
90
  *   sentence,
91
  *   tokenizer,
92
  *   corefResolution
93
  * ))
94
  *
95
  * val data = Seq(
96
  *   "John told Mary he would like to borrow a book from her."
97
  * ).toDF("text")
98
  *
99
  * val result = pipeline.fit(data).transform(data)
100
  *
101
  * result.selectExpr(""explode(corefs) AS coref"")
102
  *   .selectExpr("coref.result as token", "coref.metadata").show(truncate = false)
103
  * +-----+------------------------------------------------------------------------------------+
104
  * |token|metadata                                                                            |
105
  * +-----+------------------------------------------------------------------------------------+
106
  * |John |{head.sentence -> -1, head -> ROOT, head.begin -> -1, head.end -> -1, sentence -> 0}|
107
  * |he   |{head.sentence -> 0, head -> John, head.begin -> 0, head.end -> 3, sentence -> 0}   |
108
  * |Mary |{head.sentence -> -1, head -> ROOT, head.begin -> -1, head.end -> -1, sentence -> 0}|
109
  * |her  |{head.sentence -> 0, head -> Mary, head.begin -> 10, head.end -> 13, sentence -> 0} |
110
  * +-----+------------------------------------------------------------------------------------+
111
  * }}}
112
  *
113
  * @groupname anno Annotator types
114
  * @groupdesc anno
115
  *   Required input and expected output annotator types
116
  * @groupname Ungrouped Members
117
  * @groupname param Parameters
118
  * @groupname setParam Parameter setters
119
  * @groupname getParam Parameter getters
120
  * @groupname Ungrouped Members
121
  * @groupprio param  1
122
  * @groupprio anno  2
123
  * @groupprio Ungrouped 3
124
  * @groupprio setParam  4
125
  * @groupprio getParam  5
126
  * @groupdesc param
127
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
128
  *   parameter values through setters and getters, respectively.
129
  */
130
class SpanBertCorefModel(override val uid: String)
131
    extends AnnotatorModel[SpanBertCorefModel]
132
    with HasSimpleAnnotate[SpanBertCorefModel]
133
    with WriteTensorflowModel
134
    with HasEmbeddingsProperties
135
    with HasStorageRef
136
    with HasCaseSensitiveProperties
137
    with HasEngine {
138

139
  def this() = this(Identifiable.randomUID("SPANBERTCOREFMODEL"))
×
140

141
  override val inputAnnotatorTypes: Array[String] =
142
    Array(AnnotatorType.DOCUMENT, AnnotatorType.TOKEN)
×
143
  override val outputAnnotatorType: AnnotatorType = AnnotatorType.DEPENDENCY
×
144

145
  def sentenceStartTokenId: Int = {
146
    $$(vocabulary)("[CLS]")
×
147
  }
148

149
  /** @group setParam */
150
  def sentenceEndTokenId: Int = {
151
    $$(vocabulary)("[SEP]")
×
152
  }
153

154
  /** Vocabulary used to encode the words to ids with WordPieceEncoder
155
    *
156
    * @group param
157
    */
158
  val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary").setProtected()
×
159

160
  /** @group setParam */
161
  def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value)
×
162

163
  /** ConfigProto from tensorflow, serialized into byte array. Get with
164
    * `config_proto.SerializeToString()`
165
    *
166
    * @group param
167
    */
168
  val configProtoBytes = new IntArrayParam(
×
169
    this,
170
    "configProtoBytes",
×
171
    "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()")
×
172

173
  /** @group setParam */
174
  def setConfigProtoBytes(bytes: Array[Int]): SpanBertCorefModel.this.type =
175
    set(this.configProtoBytes, bytes)
×
176

177
  /** @group getParam */
178
  def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte))
×
179

180
  /** Max sentence length to process (Default: `128`)
181
    *
182
    * @group param
183
    */
184
  val maxSentenceLength =
185
    new IntParam(this, "maxSentenceLength", "Max sentence length to process")
×
186

187
  /** @group setParam */
188
  def setMaxSentenceLength(value: Int): this.type = {
189
    require(
×
190
      value <= 512,
×
191
      "BERT models do not support sequences longer than 512 because of trainable positional embeddings.")
×
192
    require(value >= 1, "The maxSentenceLength must be at least 1")
×
193
    set(maxSentenceLength, value)
×
194
    this
195
  }
196

197
  /** @group getParam */
198
  def getMaxSentenceLength: Int = $(maxSentenceLength)
×
199

200
  /** It contains TF model signatures for the laded saved model
201
    *
202
    * @group param
203
    */
204
  val signatures =
205
    new MapFeature[String, String](model = this, name = "signatures").setProtected()
×
206

207
  /** @group setParam */
208
  def setSignatures(value: Map[String, String]): this.type = {
209
    set(signatures, value)
×
210
    this
211
  }
212

213
  /** @group getParam */
214
  def getSignatures: Option[Map[String, String]] = get(this.signatures)
×
215

216
  val _textGenres: Array[String] = Array(
×
217
    "bc", // Broadcast conversation, default
×
218
    "bn", // Broadcast news
×
219
    "mz", //
×
220
    "nw", // News wire
×
221
    "pt", // Pivot text: Old Testament and New Testament text
×
222
    "tc", // Telephone conversation
×
223
    "wb" // Web data
×
224
  )
225

226
  /** Text genre, one of the following values: `bc`: Broadcast conversation, default `bn`:
227
    * Broadcast news `nw`: News wire `pt`: Pivot text: Old Testament and New Testament text `tc`:
228
    * Telephone conversation `wb`: Web data
229
    *
230
    * @group param
231
    */
232
  val textGenre =
233
    new Param[String](
×
234
      this,
235
      "textGenre",
×
236
      s"Text genre, one of %s. Default is 'bc'.".format(
×
237
        _textGenres.map("\"" + _ + "\"").mkString(", ")))
×
238

239
  /** @group setParam */
240
  def setTextGenre(value: String): this.type = {
241
    require(
×
242
      Array().contains(value.toLowerCase),
×
243
      s"Text text genre must be one of %s".format(
×
244
        _textGenres.map("\"" + _ + "\"").mkString(", ")))
×
245
    set(textGenre, value.toLowerCase)
×
246
    this
247
  }
248

249
  /** @group getParam */
250
  def getTextGenre: String = $(textGenre)
×
251

252
  /** Max segment length to process (Read-only, depends on model)
253
    *
254
    * @group param
255
    */
256
  val maxSegmentLength = new IntParam(this, "maxSegmentLength", "Maximum segment length")
×
257

258
  /** @group setParam */
259
  def setMaxSegmentLength(value: Int): this.type = {
260
    if (get(maxSegmentLength).isEmpty)
×
261
      set(maxSegmentLength, value)
×
262
    this
263
  }
264

265
  /** @group getParam */
266
  def getMaxSegmentLength: Int = $(maxSegmentLength)
×
267

268
  private var _model: Option[Broadcast[SpanBertCoref]] = None
×
269

270
  setDefault(
×
271
    maxSentenceLength -> 512,
×
272
    caseSensitive -> true,
×
273
    textGenre -> _textGenres(0)
×
274
//    maxSegmentLength -> 384,
275
  )
276

277
  def setModelIfNotSet(
278
      spark: SparkSession,
279
      tensorflowWrapper: TensorflowWrapper): SpanBertCorefModel = {
280
    if (_model.isEmpty) {
×
281
      _model = Some(
×
282
        spark.sparkContext.broadcast(
×
283
          new SpanBertCoref(
×
284
            tensorflowWrapper,
285
            sentenceStartTokenId,
×
286
            sentenceEndTokenId,
×
287
            configProtoBytes = getConfigProtoBytes,
×
288
            signatures = getSignatures)))
×
289
    }
290

291
    this
292
  }
293

294
  def getModelIfNotSet: SpanBertCoref = _model.get.value
×
295

296
  def tokenizeSentence(tokens: Seq[TokenizedSentence]): Seq[WordpieceTokenizedSentence] = {
297
    val basicTokenizer = new BasicTokenizer($(caseSensitive))
×
298
    val encoder = new WordpieceEncoder($$(vocabulary))
×
299

300
    tokens.map { tokenIndex =>
×
301
      // filter empty and only whitespace tokens
302
      val bertTokens =
303
        tokenIndex.indexedTokens.filter(x => x.token.nonEmpty && !x.token.equals(" ")).map {
×
304
          token =>
305
            val content = if ($(caseSensitive)) token.token else token.token.toLowerCase()
×
306
            val sentenceBegin = token.begin
×
307
            val sentenceEnd = token.end
×
308
            val sentenceIndex = tokenIndex.sentenceIndex
×
309
            val result = basicTokenizer.tokenize(
×
310
              Sentence(content, sentenceBegin, sentenceEnd, sentenceIndex))
×
311
            if (result.nonEmpty) result.head else IndexedToken("")
×
312
        }
313
      val wordPieceTokens =
314
        bertTokens.flatMap(token => encoder.encode(token)).take($(maxSentenceLength) - 2)
×
315
      WordpieceTokenizedSentence(wordPieceTokens)
×
316
    }
317
  }
318

319
  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
320

321
    val sentencesWithRow = TokenizedWithSentence.unpack(annotations)
×
322
    val tokenizedSentences = tokenizeSentence(sentencesWithRow).toArray
×
323
    val inputIds = tokenizedSentences.map(x => x.tokens.map(_.pieceId))
×
324

325
    if (inputIds.map(x => x.length).sum < 2) {
×
326
      return Seq()
×
327
    }
328

329
    val predictedClusters = getModelIfNotSet.predict(
×
330
      inputIds = inputIds,
331
      genre = _textGenres.indexOf($(textGenre)),
×
332
      maxSegmentLength = $(maxSegmentLength))
×
333

334
    def getTokensFromSpan(span: ((Int, Int), (Int, Int))): Array[(TokenPiece, Int)] = {
335
      val sentence1 = span._1._1
×
336
      val sentence2 = span._2._1
×
337
      val tokenStart = span._1._2
×
338
      val tokenEnd = span._2._2
×
339
      if (sentence1 == sentence2) {
×
340
        tokenizedSentences(sentence1).tokens.slice(tokenStart, tokenEnd + 1).map((_, sentence1))
×
341
      } else {
342
        (tokenizedSentences(sentence1).tokens
×
343
          .slice(tokenStart, tokenizedSentences(sentence1).tokens.length - 1)
×
344
          .map((_, sentence1))
×
345
          ++
×
346
            tokenizedSentences(sentence2).tokens.slice(0, tokenEnd + 1).map((_, sentence2)))
×
347
      }
348
    }
349

350
//    predictedClusters.zipWithIndex.foreach{
351
//      case (cluster, i) =>
352
//        print(s"Cluster #$i\n")
353
//        print(s"\t%s\n".format(
354
//          cluster.map(
355
//            xy =>
356
//              getTokensFromSpan(xy).map(x => (if (x.isWordStart) " " else "") + x.wordpiece.replaceFirst("##", "") ).mkString("").trim,
357
//            ).mkString(", ")))
358
//    }
359
    predictedClusters.flatMap(cluster => {
×
360

361
      val clusterSpans = cluster.map(xy => getTokensFromSpan(xy))
×
362
      val clusterHeadSpan = clusterSpans.head
×
363
      val clusterHeadSpanText = clusterHeadSpan
364
        .map(x => (if (x._1.isWordStart) " " else "") + x._1.wordpiece.replaceFirst("##", ""))
365
        .mkString("")
366
        .trim
×
367
      Array(
×
368
        Annotation(
×
369
          annotatorType = AnnotatorType.DEPENDENCY,
×
370
          begin = clusterHeadSpan.head._1.begin,
×
371
          end = clusterHeadSpan.last._1.end,
×
372
          result = clusterHeadSpanText,
373
          metadata = Map(
×
374
            "head" -> "ROOT",
×
375
            "head.begin" -> "-1",
×
376
            "head.end" -> "-1",
×
377
            "head.sentence" -> "-1",
×
378
            "sentence" -> clusterHeadSpan.head._2.toString))) ++ clusterSpans.tail.map(span => {
×
379
        Annotation(
×
380
          annotatorType = AnnotatorType.DEPENDENCY,
×
381
          begin = span.head._1.begin,
×
382
          end = span.last._1.end,
×
383
          result = span
384
            .map(x => (if (x._1.isWordStart) " " else "") + x._1.wordpiece.replaceFirst("##", ""))
385
            .mkString("")
386
            .trim,
×
387
          metadata = Map(
×
388
            "head" -> clusterHeadSpanText,
×
389
            "head.begin" -> clusterHeadSpan.head._1.begin.toString,
×
390
            "head.end" -> clusterHeadSpan.last._1.end.toString,
×
391
            "head.sentence" -> clusterHeadSpan.head._2.toString,
×
392
            "sentence" -> span.head._2.toString))
×
393
      })
394
    })
395
  }
396

397
  override def onWrite(path: String, spark: SparkSession): Unit = {
398
    super.onWrite(path, spark)
×
399
    writeTensorflowModelV2(
×
400
      path,
401
      spark,
402
      getModelIfNotSet.tensorflowWrapper,
×
403
      "_bert",
×
404
      SpanBertCorefModel.tfFile,
×
405
      configProtoBytes = getConfigProtoBytes)
×
406
  }
407
}
408

409
trait ReadablePretrainedSpanBertCorefModel
410
    extends ParamsAndFeaturesReadable[SpanBertCorefModel]
411
    with HasPretrained[SpanBertCorefModel] {
412
  override val defaultModelName: Some[String] = Some("spanbert_base_coref")
×
413

414
  /** Java compliant-overrides */
415
  override def pretrained(): SpanBertCorefModel = super.pretrained()
×
416

417
  override def pretrained(name: String): SpanBertCorefModel = super.pretrained(name)
×
418

419
  override def pretrained(name: String, lang: String): SpanBertCorefModel =
420
    super.pretrained(name, lang)
×
421

422
  override def pretrained(name: String, lang: String, remoteLoc: String): SpanBertCorefModel =
423
    super.pretrained(name, lang, remoteLoc)
×
424
}
425

426
trait ReadSpanBertCorefTensorflowModel extends ReadTensorflowModel {
427
  this: ParamsAndFeaturesReadable[SpanBertCorefModel] =>
428

429
  override val tfFile: String = "spanbert_tensorflow"
×
430

431
  def readModel(instance: SpanBertCorefModel, path: String, spark: SparkSession): Unit = {
432

433
    val tf = readTensorflowModel(path, spark, "_bert_tf", initAllTables = false)
×
434
    instance.setModelIfNotSet(spark, tf)
×
435
  }
436

437
  addReader(readModel)
×
438

439
  def loadSavedModel(modelPath: String, spark: SparkSession): SpanBertCorefModel = {
440

441
    val (localModelPath, detectedEngine) = modelSanityCheck(modelPath)
×
442

443
    val vocabs = loadTextAsset(localModelPath, "vocab.txt").zipWithIndex.toMap
×
444

445
    /*Universal parameters for all engines*/
446
    val annotatorModel = new SpanBertCorefModel()
447
      .setVocabulary(vocabs)
×
448

449
    annotatorModel.set(annotatorModel.engine, detectedEngine)
×
450

451
    detectedEngine match {
452
      case ModelEngine.tensorflow =>
453
        val (wrapper, signatures) =
×
454
          TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)
455

456
        val _signatures = signatures match {
457
          case Some(s) => s
458
          case None => throw new Exception("Cannot load signature definitions from model!")
×
459
        }
460

461
        /** the order of setSignatures is important if we use getSignatures inside
462
          * setModelIfNotSet
463
          */
464
        annotatorModel
465
          .setSignatures(_signatures)
466
          .setModelIfNotSet(spark, wrapper)
×
467

468
      case _ =>
469
        throw new Exception(notSupportedEngineError)
×
470
    }
471

472
    annotatorModel
473
  }
474
}
475

476
/** This is the companion object of [[SpanBertCorefModel]]. Please refer to that class for the
477
  * documentation.
478
  */
479
object SpanBertCorefModel
480
    extends ReadablePretrainedSpanBertCorefModel
481
    with ReadSpanBertCorefTensorflowModel {
482
  private[SpanBertCorefModel] val logger: Logger = LoggerFactory.getLogger("SpanBertCorefModel")
×
483
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc