• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 8133034558

03 Mar 2024 09:07PM UTC coverage: 62.603%. First build
8133034558

Pull #14190

github

web-flow
Merge 901c88425 into ad5a4ea14
Pull Request #14190: Release/531 release candidate

0 of 14 new or added lines in 1 file covered. (0.0%)

8956 of 14306 relevant lines covered (62.6%)

0.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/ZeroShotNerModel.scala
1
/*
2
 * Copyright 2017-2023 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.ner.dl
18

19
import com.johnsnowlabs.ml.ai.{RoBertaClassification, ZeroShotNerClassification}
20
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel}
21
import com.johnsnowlabs.ml.tensorflow.{ReadTensorflowModel, TensorflowWrapper}
22
import com.johnsnowlabs.ml.util.LoadExternalModel.notSupportedEngineError
23
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
24
import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, NAMED_ENTITY, TOKEN}
25
import com.johnsnowlabs.nlp.annotator.RoBertaForQuestionAnswering
26
import com.johnsnowlabs.nlp.pretrained.ResourceDownloader
27
import com.johnsnowlabs.nlp.serialization.MapFeature
28
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType, HasPretrained, ParamsAndFeaturesReadable}
29
import org.apache.spark.broadcast.Broadcast
30
import org.apache.spark.ml.PipelineStage
31
import org.apache.spark.ml.param.{FloatParam, StringArrayParam}
32
import org.apache.spark.ml.util.Identifiable
33
import org.apache.spark.sql.SparkSession
34

35
import java.util
36
import scala.collection.JavaConverters._
37

38
/** ZeroShotNerModel implements zero shot named entity recognition by utilizing RoBERTa
39
  * transformer models fine tuned on a question answering task.
40
  *
41
  * Its input is a list of document annotations and it automatically generates questions which are
42
  * used to recognize entities. The definitions of entities is given by a dictionary structures,
43
  * specifying a set of questions for each entity. The model is based on
44
  * RoBertaForQuestionAnswering.
45
  *
46
  * For more extended examples see the
47
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/text/english/named-entity-recognition/ZeroShot_NER.ipynb Examples]]
48
  *
49
  * Pretrained models can be loaded with `pretrained` of the companion object:
50
  * {{{
51
  * val zeroShotNer = ZeroShotNerModel.pretrained()
52
  *   .setInputCols("document")
53
  *   .setOutputCol("zer_shot_ner")
54
  * }}}
55
  *
56
  * For available pretrained models please see the
57
  * [[https://sparknlp.org/models?task=Zero-Shot-NER Models Hub]].
58
  *
59
  * ==Example==
60
  * {{{
61
  *  val documentAssembler = new DocumentAssembler()
62
  *    .setInputCol("text")
63
  *    .setOutputCol("document")
64
  *
65
  *  val sentenceDetector = new SentenceDetector()
66
  *    .setInputCols(Array("document"))
67
  *    .setOutputCol("sentences")
68
  *
69
  *  val zeroShotNer = ZeroShotNerModel
70
  *    .pretrained()
71
  *    .setEntityDefinitions(
72
  *      Map(
73
  *        "NAME" -> Array("What is his name?", "What is her name?"),
74
  *        "CITY" -> Array("Which city?")))
75
  *    .setPredictionThreshold(0.01f)
76
  *    .setInputCols("sentences")
77
  *    .setOutputCol("zero_shot_ner")
78
  *
79
  *  val pipeline = new Pipeline()
80
  *    .setStages(Array(
81
  *      documentAssembler,
82
  *      sentenceDetector,
83
  *      zeroShotNer))
84
  *
85
  *  val model = pipeline.fit(Seq("").toDS.toDF("text"))
86
  *  val results = model.transform(
87
  *    Seq("Clara often travels between New York and Paris.").toDS.toDF("text"))
88
  *
89
  *  results
90
  *    .selectExpr("document", "explode(zero_shot_ner) AS entity")
91
  *    .select(
92
  *      col("entity.result"),
93
  *      col("entity.metadata.word"),
94
  *      col("entity.metadata.sentence"),
95
  *      col("entity.begin"),
96
  *      col("entity.end"),
97
  *      col("entity.metadata.confidence"),
98
  *      col("entity.metadata.question"))
99
  *    .show(truncate=false)
100
  *
101
  * +------+-----+--------+-----+---+----------+------------------+
102
  * |result|word |sentence|begin|end|confidence|question          |
103
  * +------+-----+--------+-----+---+----------+------------------+
104
  * |B-CITY|Paris|0       |41   |45 |0.78655756|Which is the city?|
105
  * |B-CITY|New  |0       |28   |30 |0.29346612|Which city?       |
106
  * |I-CITY|York |0       |32   |35 |0.29346612|Which city?       |
107
  * +------+-----+--------+-----+---+----------+------------------+
108
  *
109
  * }}}
110
  *
111
  * @see
112
  *   [[https://arxiv.org/abs/1907.11692]] for details about the RoBERTa transformer
113
  * @see
114
  *   [[RoBertaForQuestionAnswering]] for the SparkNLP implementation of RoBERTa question
115
  *   answering
116
  * @param uid
117
  *   required uid for storing annotator to disk
118
  * @groupname anno Annotator types
119
  * @groupdesc anno
120
  *   Required input and expected output annotator types
121
  * @groupname Ungrouped Members
122
  * @groupname param Parameters
123
  * @groupname setParam Parameter setters
124
  * @groupname getParam Parameter getters
125
  * @groupname Ungrouped Members
126
  * @groupprio param  1
127
  * @groupprio anno  2
128
  * @groupprio Ungrouped 3
129
  * @groupprio setParam  4
130
  * @groupprio getParam  5
131
  * @groupdesc param
132
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
133
  *   parameter values through setters and getters, respectively.
134
  */
135
class ZeroShotNerModel(override val uid: String) extends RoBertaForQuestionAnswering {
136

137
  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
138
    * type
139
    */
140
  def this() = this(Identifiable.randomUID("ZeroShotNerModel"))
×
141

142
  /** Input Annotator Types: DOCUMENT
143
    *
144
    * @group anno
145
    */
146
  override val inputAnnotatorTypes: Array[String] = Array(DOCUMENT, TOKEN)
×
147

148
  /** Output Annotator Types: NAMED_ENTITY
149
    *
150
    * @group anno
151
    */
152
  override val outputAnnotatorType: AnnotatorType = NAMED_ENTITY
×
153

154
  /** List of definitions of named entities
155
    *
156
    * @group param
157
    */
158
  private val entityDefinitions = new MapFeature[String, Array[String]](this, "entityDefinitions")
×
159

160
  /** Set definitions of named entities
161
    *
162
    * @group setParam
163
    */
164
  def setEntityDefinitions(definitions: Map[String, Array[String]]): this.type = {
165
    set(this.entityDefinitions, definitions)
×
166
  }
167

168
  /** Set definitions of named entities
169
    *
170
    * @group setParam
171
    */
172
  def setEntityDefinitions(definitions: util.HashMap[String, util.List[String]]): this.type = {
173
    val c = definitions.asScala.mapValues(_.asScala.toList.toArray).toMap
×
174
    set(this.entityDefinitions, c)
×
175
  }
176

177
  /** Get definitions of named entities
178
    *
179
    * @group getParam
180
    */
181
  private def getEntityDefinitions: scala.collection.immutable.Map[String, Array[String]] = {
182
    if (!entityDefinitions.isSet)
×
183
      return Map.empty
×
184
    $$(entityDefinitions)
×
185
  }
186

187
  def getEntityDefinitionsStr: Array[String] = {
188
    getEntityDefinitions.map(x => x._1 + "@@@" + x._2.mkString("@@@")).toArray
×
189
  }
190

191
  var predictionThreshold =
192
    new FloatParam(this, "predictionThreshold", "Minimal score of predicted entity")
×
193

194
  var ignoreEntities = new StringArrayParam(this, "ignoreEntities", "List of entities to ignore")
×
195

196
  /** Get the minimum entity prediction score
197
    *
198
    * @group getParam
199
    */
200
  def getPredictionThreshold: Float = $(predictionThreshold)
×
201

202
  /** Set the minimum entity prediction score
203
    *
204
    * @group setParam
205
    */
206
  def setPredictionThreshold(value: Float): this.type = set(this.predictionThreshold, value)
×
207

208
  /** Get the list of questions to catch the distractor entity
209
    *
210
    * @group getParam
211
    */
212
  def getIgnoreEntities: Array[String] = $(ignoreEntities)
×
213

214
  /** Get the list of entities which are recognized
215
    *
216
    * @group getParam
217
    */
218

219
  def getEntities: Array[String] = getEntityDefinitions.keys.toArray
×
220

221
  /** Set the list of questions to catch the distractor entity
222
    *
223
    * @group setParam
224
    */
225
  def setIgnoreEntities(value: Array[String]): this.type = set(this.ignoreEntities, value)
×
226

227
  private def getNerQuestionAnnotations
228
      : scala.collection.immutable.Map[String, Array[Annotation]] = {
229
    getEntityDefinitions.map(nerDef => {
×
230
      (
×
231
        nerDef._1,
×
232
        nerDef._2.map(nerQ =>
×
233
          new Annotation(
×
234
            AnnotatorType.DOCUMENT,
×
235
            0,
×
236
            nerQ.length,
×
237
            nerQ,
238
            Map("entity" -> nerDef._1) ++ Map("ner_question" -> nerQ))))
×
239
    })
240
  }
241

242
  private var _model: Option[Broadcast[ZeroShotNerClassification]] = None
×
243

244
  override def setModelIfNotSet(
245
      spark: SparkSession,
246
      tensorflowWrapper: Option[TensorflowWrapper],
247
      onnxWrapper: Option[OnnxWrapper]): ZeroShotNerModel = {
248
    if (_model.isEmpty) {
×
249
      _model = Some(
×
250
        spark.sparkContext.broadcast(
×
251
          new ZeroShotNerClassification(
×
252
            tensorflowWrapper,
253
            onnxWrapper,
254
            sentenceStartTokenId,
×
255
            sentenceEndTokenId,
×
256
            padTokenId,
×
257
            false,
×
258
            configProtoBytes = getConfigProtoBytes,
×
259
            tags = Map.empty[String, Int],
×
260
            signatures = getSignatures,
×
261
            $$(merges),
×
262
            $$(vocabulary))))
×
263
    }
264

265
    this
266
  }
267

268
  override def getModelIfNotSet: RoBertaClassification = _model.get.value
×
269

270
  setDefault(ignoreEntities -> Array(), predictionThreshold -> 0.01f)
×
271

272
  val maskSymbol = "_"
×
273

274
  private def spansOverlap(span1: (Int, Int), span2: (Int, Int)): Boolean = {
275
    !((span2._1 > span1._2) || (span1._1 > span2._2))
×
276
  }
277

278
  private def recognizeEntities(
279
      document: Annotation,
280
      nerDefs: Map[String, Array[Annotation]]): Seq[Annotation] = {
281
    val docPredictions = nerDefs
282
      .flatMap(nerDef => {
×
283
        val nerBatch = nerDef._2.map(nerQuestion => Array(nerQuestion, document)).toSeq
×
284
        val entityPredictions = super
285
          .batchAnnotate(nerBatch)
286
          .zip(nerBatch.map(_.head.result))
×
287
          .map(x => (x._1.head, x._2))
×
288
          .filter(x => x._1.result.nonEmpty)
×
289
          .filter(x =>
×
290
            (if (x._1.metadata.contains("score"))
×
291
               x._1.metadata("score").toFloat
×
292
             else Float.MinValue) > getPredictionThreshold)
×
293
        entityPredictions.map(prediction =>
×
294
          new Annotation(
×
295
            AnnotatorType.CHUNK,
×
296
            prediction._1.begin,
×
297
            prediction._1.end,
×
298
            prediction._1.result,
×
299
            Map(
×
300
              "entity" -> nerDef._1,
×
301
              "sentence" -> document.metadata("sentence"),
×
302
              "word" -> prediction._1.result,
×
303
              "confidence" -> prediction._1.metadata("score"),
×
304
              "question" -> prediction._2)))
×
305
      })
306
      .toSeq
×
307
    // Detect overlapping predictions and choose the one with the higher score
308
    docPredictions
309
      .filter(x => ! $(ignoreEntities).contains(x.metadata("entity"))) // Discard ignored entities
×
310
      .filter(prediction => {
×
311
        !docPredictions
×
312
          .filter(_ != prediction)
×
313
          .exists(otherPrediction => {
314
            spansOverlap(
315
              (prediction.begin, prediction.end),
×
316
              (otherPrediction.begin, otherPrediction.end)) && (otherPrediction.metadata(
×
317
              "confidence") > prediction.metadata("confidence"))
×
318
          })
319
      })
320

321
  }
322

323
  def maskEntity(document: Annotation, entity: Annotation): String = {
324
    val entityStart = entity.begin - document.begin
×
325
    val entityEnd = entity.end - document.begin
×
326
    //    println(document.result.slice(0, entityStart) + maskSymbol + {entityStart to entityEnd - 2}.map(_ => " ").mkString + document.result.slice(entityEnd, $(maxSentenceLength)))
327
    document.result.slice(0, entityStart) + maskSymbol + {
×
328
      entityStart to entityEnd - 2
×
329
    }.map(_ => " ").mkString + document.result.slice(entityEnd, $(maxSentenceLength))
×
330
  }
331

332
  def recognizeMultipleEntities(
333
      document: Annotation,
334
      nerDefs: Map[String, Array[Annotation]],
335
      recognizedEntities: Seq[Annotation] = Seq()): Seq[Annotation] = {
336
    val newEntities = recognizeEntities(document, nerDefs)
337
      .filter(entity =>
×
338
        (!recognizedEntities
339
          .exists(recognizedEntity =>
340
            spansOverlap(
×
341
              (entity.begin, entity.end),
×
342
              (recognizedEntity.begin, recognizedEntity.end)))) && (entity.result != maskSymbol))
×
343

344
    newEntities ++ newEntities.flatMap { entity =>
×
345
      val newDoc = new Annotation(
×
346
        document.annotatorType,
×
347
        document.begin,
×
348
        document.end,
×
349
        maskEntity(document, entity),
×
350
        document.metadata)
×
351
      recognizeMultipleEntities(
×
352
        newDoc,
353
        nerDefs.filter(x => x._1 == entity.metadata("entity")),
×
354
        recognizedEntities ++ newEntities)
×
355
    }
356
  }
357

358
  def isTokenInEntity(token: Annotation, entity: Annotation): Boolean = {
359
    (
360
      token.metadata("sentence") == entity.metadata("sentence")
×
361
      && (token.begin >= entity.begin) && (token.end <= entity.end)
×
362
    )
363
  }
364

365
  override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
366

367
    batchedAnnotations.map(annotations => {
×
368
      val documents = annotations
369
        .filter(_.annotatorType == AnnotatorType.DOCUMENT)
×
370
        .toSeq
×
371
      val tokens = annotations.filter(_.annotatorType == AnnotatorType.TOKEN)
×
372
      val entities = documents.flatMap { doc =>
×
373
        recognizeMultipleEntities(doc, getNerQuestionAnnotations).flatMap { entity =>
×
374
          tokens
375
            .filter(t => isTokenInEntity(t, entity))
×
376
            .zipWithIndex
×
377
            .map { case (token, i) =>
×
378
              val bioPrefix = if (i == 0) "B-" else "I-"
×
379
              new Annotation(
×
380
                annotatorType = AnnotatorType.NAMED_ENTITY,
×
381
                begin = token.begin,
×
382
                end = token.end,
×
383
                result = bioPrefix + entity.metadata("entity"),
×
384
                metadata = Map(
×
385
                  "sentence" -> entity.metadata("sentence"),
×
386
                  "word" -> token.result,
×
387
                  "confidence" -> entity.metadata("confidence"),
×
388
                  "question" -> entity.metadata("question")))
×
389
            }
390
        }.toList
×
391
      }
392
      tokens
393
        .map(token => {
×
394
          val entity = entities.find(e => isTokenInEntity(token, e))
×
395
          if (entity.nonEmpty) {
×
396
            entity.get
×
397
          } else {
398
            new Annotation(
×
399
              annotatorType = AnnotatorType.NAMED_ENTITY,
×
400
              begin = token.begin,
×
401
              end = token.end,
×
402
              result = "O",
×
403
              metadata = Map("sentence" -> token.metadata("sentence"), "word" -> token.result))
×
404
          }
405
        })
406
        .toSeq
×
407
    })
408
  }
409

410
}
411

412
trait ReadablePretrainedZeroShotNer
413
    extends ParamsAndFeaturesReadable[ZeroShotNerModel]
414
    with HasPretrained[ZeroShotNerModel] {
415
  override val defaultModelName: Some[String] = Some("zero_shot_ner_roberta")
×
416

417
  /** Java compliant-overrides */
418
  override def pretrained(): ZeroShotNerModel =
419
    pretrained(defaultModelName.get, defaultLang, defaultLoc)
×
420

421
  override def pretrained(name: String): ZeroShotNerModel =
422
    pretrained(name, defaultLang, defaultLoc)
×
423

424
  override def pretrained(name: String, lang: String): ZeroShotNerModel =
425
    pretrained(name, lang, defaultLoc)
×
426

427
  override def pretrained(name: String, lang: String, remoteLoc: String): ZeroShotNerModel = {
428
    try {
429
      ZeroShotNerModel.getFromRoBertaForQuestionAnswering(
×
430
        ResourceDownloader
431
          .downloadModel(RoBertaForQuestionAnswering, name, Option(lang), remoteLoc))
×
432
    } catch {
433
      case _: java.lang.RuntimeException =>
434
        ResourceDownloader.downloadModel(ZeroShotNerModel, name, Option(lang), remoteLoc)
×
435
    }
436
  }
437

438
  override def load(path: String): ZeroShotNerModel = {
439
    try {
440
      super.load(path)
×
441
    } catch {
442
      case e: java.lang.ClassCastException =>
443
        try {
444
          ZeroShotNerModel.getFromRoBertaForQuestionAnswering(
×
445
            RoBertaForQuestionAnswering.load(path))
×
446
        } catch {
447
          case _: Throwable => throw e
×
448
        }
449
    }
450
  }
451
}
452

453
trait ReadZeroShotNerDLModel extends ReadTensorflowModel with ReadOnnxModel {
454
  this: ParamsAndFeaturesReadable[ZeroShotNerModel] =>
455

456
  override val tfFile: String = "roberta_classification_tensorflow"
×
NEW
457
  override val onnxFile: String = "roberta_classification_onnx"
×
458

459
  def readModel(instance: ZeroShotNerModel, path: String, spark: SparkSession): Unit = {
NEW
460
    instance.getEngine match {
×
461
      case TensorFlow.name => {
462
        val tfWrapper = readTensorflowModel(path, spark, "_roberta_classification_tf", initAllTables = false)
NEW
463
        instance.setModelIfNotSet(spark, Some(tfWrapper), None)
×
NEW
464
      }
×
465
      case ONNX.name => {
466
        val onnxWrapper = readOnnxModel(
NEW
467
          path,
×
468
          spark,
469
          "_roberta_classification_onnx",
NEW
470
          zipped = true,
×
NEW
471
          useBundle = false,
×
NEW
472
          None)
×
NEW
473
        instance.setModelIfNotSet(spark, None, Some(onnxWrapper))
×
NEW
474
      }
×
475
      case _ =>
476
        throw new Exception(notSupportedEngineError)
NEW
477
    }
×
478
  }
479

480
  addReader(readModel)
481
}
×
482
object ZeroShotNerModel extends ReadablePretrainedZeroShotNer with ReadZeroShotNerDLModel {
483

484
  def apply(model: PipelineStage): PipelineStage = {
485
    model match {
486
      case answering: RoBertaForQuestionAnswering if !model.isInstanceOf[ZeroShotNerModel] =>
487
        getFromRoBertaForQuestionAnswering(answering)
×
488
      case _ =>
×
489
        model
490
    }
491
  }
492

493
  def getFromRoBertaForQuestionAnswering(model: RoBertaForQuestionAnswering): ZeroShotNerModel = {
494
    val spark = SparkSession.builder.getOrCreate()
495

×
496
    val newModel = new ZeroShotNerModel()
497
      .setVocabulary(
498
        model.vocabulary.get.getOrElse(throw new RuntimeException("Vocabulary not set")))
499
      .setMerges(model.merges.get.getOrElse(throw new RuntimeException("Merges not set")))
×
500
      .setCaseSensitive(model.getCaseSensitive)
×
501
      .setBatchSize(model.getBatchSize)
×
502

×
503
    if (model.signatures.isSet)
504
      newModel.setSignatures(
×
505
        model.signatures.get.getOrElse(throw new RuntimeException("Signatures not set")))
×
506

×
507
    model.getEngine match {
NEW
508
      case TensorFlow.name =>
×
509
        newModel.setModelIfNotSet(spark, model.getModelIfNotSet.tensorflowWrapper, None)
NEW
510
      case ONNX.name =>
×
511
        newModel.setModelIfNotSet(spark, None, model.getModelIfNotSet.onnxWrapper)
NEW
512
    }
×
513

514
    model
515
      .extractParamMap()
516
      .toSeq
517
      .foreach(x => {
518
        newModel.set(x.param.name, x.value)
×
519
      })
×
520

521
    newModel
522
  }
523
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc