• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4111789725

pending completion
4111789725

Pull #13346

github

GitHub
Merge 56af3827e into 0fcb84467
Pull Request #13346: Release/430 release candidate

419 of 419 new or added lines in 23 files covered. (100.0%)

8492 of 12789 relevant lines covered (66.4%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/ZeroShotNerModel.scala
1
/*
2
 * Copyright 2017-2023 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.ner.dl
18

19
import com.johnsnowlabs.ml.ai.{RoBertaClassification, ZeroShotNerClassification}
20
import com.johnsnowlabs.ml.tensorflow.{ReadTensorflowModel, TensorflowWrapper}
21
import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, NAMED_ENTITY, TOKEN}
22
import com.johnsnowlabs.nlp.annotator.RoBertaForQuestionAnswering
23
import com.johnsnowlabs.nlp.pretrained.ResourceDownloader
24
import com.johnsnowlabs.nlp.serialization.MapFeature
25
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType, HasPretrained, ParamsAndFeaturesReadable}
26
import org.apache.spark.broadcast.Broadcast
27
import org.apache.spark.ml.PipelineStage
28
import org.apache.spark.ml.param.{FloatParam, StringArrayParam}
29
import org.apache.spark.ml.util.Identifiable
30
import org.apache.spark.sql.SparkSession
31

32
import java.util
33
import scala.collection.JavaConverters._
34

35
/** ZeroShotNerModel implements zero shot named entity recognition by utilizing RoBERTa
36
  * transformer models fine tuned on a question answering task.
37
  *
38
  * Its input is a list of document annotations and it automatically generates questions which are
39
  * used to recognize entities. The definitions of entities is given by a dictionary structures,
40
  * specifying a set of questions for each entity. The model is based on
41
  * RoBertaForQuestionAnswering.
42
  *
43
  * For more extended examples see the
44
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/text/english/named-entity-recognition/ZeroShot_NER.ipynb Examples]]
45
  *
46
  * Pretrained models can be loaded with `pretrained` of the companion object:
47
  * {{{
48
  * val zeroShotNer = ZeroShotNerModel.pretrained()
49
  *   .setInputCols("document")
50
  *   .setOutputCol("zer_shot_ner")
51
  * }}}
52
  *
53
  * For available pretrained models please see the
54
  * [[https://nlp.johnsnowlabs.com/models?task=Zero-Shot-NER Models Hub]].
55
  *
56
  * ==Example==
57
  * {{{
58
  *  val documentAssembler = new DocumentAssembler()
59
  *    .setInputCol("text")
60
  *    .setOutputCol("document")
61
  *
62
  *  val sentenceDetector = new SentenceDetector()
63
  *    .setInputCols(Array("document"))
64
  *    .setOutputCol("sentences")
65
  *
66
  *  val zeroShotNer = ZeroShotNerModel
67
  *    .pretrained()
68
  *    .setEntityDefinitions(
69
  *      Map(
70
  *        "NAME" -> Array("What is his name?", "What is her name?"),
71
  *        "CITY" -> Array("Which city?")))
72
  *    .setPredictionThreshold(0.01f)
73
  *    .setInputCols("sentences")
74
  *    .setOutputCol("zero_shot_ner")
75
  *
76
  *  val pipeline = new Pipeline()
77
  *    .setStages(Array(
78
  *      documentAssembler,
79
  *      sentenceDetector,
80
  *      zeroShotNer))
81
  *
82
  *  val model = pipeline.fit(Seq("").toDS.toDF("text"))
83
  *  val results = model.transform(
84
  *    Seq("Clara often travels between New York and Paris.").toDS.toDF("text"))
85
  *
86
  *  results
87
  *    .selectExpr("document", "explode(zero_shot_ner) AS entity")
88
  *    .select(
89
  *      col("entity.result"),
90
  *      col("entity.metadata.word"),
91
  *      col("entity.metadata.sentence"),
92
  *      col("entity.begin"),
93
  *      col("entity.end"),
94
  *      col("entity.metadata.confidence"),
95
  *      col("entity.metadata.question"))
96
  *    .show(truncate=false)
97
  *
98
  * +------+-----+--------+-----+---+----------+------------------+
99
  * |result|word |sentence|begin|end|confidence|question          |
100
  * +------+-----+--------+-----+---+----------+------------------+
101
  * |B-CITY|Paris|0       |41   |45 |0.78655756|Which is the city?|
102
  * |B-CITY|New  |0       |28   |30 |0.29346612|Which city?       |
103
  * |I-CITY|York |0       |32   |35 |0.29346612|Which city?       |
104
  * +------+-----+--------+-----+---+----------+------------------+
105
  *
106
  * }}}
107
  *
108
  * @see
109
  *   [[https://arxiv.org/abs/1907.11692]] for details about the RoBERTa transformer
110
  * @see
111
  *   [[RoBertaForQuestionAnswering]] for the SparkNLP implementation of RoBERTa question
112
  *   answering
113
  * @param uid
114
  *   required uid for storing annotator to disk
115
  * @groupname anno Annotator types
116
  * @groupdesc anno
117
  *   Required input and expected output annotator types
118
  * @groupname Ungrouped Members
119
  * @groupname param Parameters
120
  * @groupname setParam Parameter setters
121
  * @groupname getParam Parameter getters
122
  * @groupname Ungrouped Members
123
  * @groupprio param  1
124
  * @groupprio anno  2
125
  * @groupprio Ungrouped 3
126
  * @groupprio setParam  4
127
  * @groupprio getParam  5
128
  * @groupdesc param
129
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
130
  *   parameter values through setters and getters, respectively.
131
  */
132
class ZeroShotNerModel(override val uid: String) extends RoBertaForQuestionAnswering {
133

134
  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
135
    * type
136
    */
137
  def this() = this(Identifiable.randomUID("ZeroShotNerModel"))
×
138

139
  /** Input Annotator Types: DOCUMENT
140
    *
141
    * @group anno
142
    */
143
  override val inputAnnotatorTypes: Array[String] = Array(DOCUMENT, TOKEN)
×
144

145
  /** Output Annotator Types: NAMED_ENTITY
146
    *
147
    * @group anno
148
    */
149
  override val outputAnnotatorType: AnnotatorType = NAMED_ENTITY
×
150

151
  /** List of definitions of named entities
152
    *
153
    * @group param
154
    */
155
  private val entityDefinitions = new MapFeature[String, Array[String]](this, "entityDefinitions")
×
156

157
  /** Set definitions of named entities
158
    *
159
    * @group setParam
160
    */
161
  def setEntityDefinitions(definitions: Map[String, Array[String]]): this.type = {
162
    set(this.entityDefinitions, definitions)
×
163
  }
164

165
  /** Set definitions of named entities
166
    *
167
    * @group setParam
168
    */
169
  def setEntityDefinitions(definitions: util.HashMap[String, util.List[String]]): this.type = {
170
    val c = definitions.asScala.mapValues(_.asScala.toList.toArray).toMap
×
171
    set(this.entityDefinitions, c)
×
172
  }
173

174
  /** Get definitions of named entities
175
    *
176
    * @group getParam
177
    */
178
  private def getEntityDefinitions: scala.collection.immutable.Map[String, Array[String]] = {
179
    if (!entityDefinitions.isSet)
×
180
      return Map.empty
×
181
    $$(entityDefinitions)
×
182
  }
183

184
  def getEntityDefinitionsStr: Array[String] = {
185
    getEntityDefinitions.map(x => x._1 + "@@@" + x._2.mkString("@@@")).toArray
×
186
  }
187

188
  var predictionThreshold =
189
    new FloatParam(this, "predictionThreshold", "Minimal score of predicted entity")
×
190

191
  var ignoreEntities = new StringArrayParam(this, "ignoreEntities", "List of entities to ignore")
×
192

193
  /** Get the minimum entity prediction score
194
    *
195
    * @group getParam
196
    */
197
  def getPredictionThreshold: Float = $(predictionThreshold)
×
198

199
  /** Set the minimum entity prediction score
200
    *
201
    * @group setParam
202
    */
203
  def setPredictionThreshold(value: Float): this.type = set(this.predictionThreshold, value)
×
204

205
  /** Get the list of questions to catch the distractor entity
206
    *
207
    * @group getParam
208
    */
209
  def getIgnoreEntities: Array[String] = $(ignoreEntities)
×
210

211
  /** Get the list of entities which are recognized
212
    *
213
    * @group getParam
214
    */
215

216
  def getEntities: Array[String] = getEntityDefinitions.keys.toArray
×
217

218
  /** Set the list of questions to catch the distractor entity
219
    *
220
    * @group setParam
221
    */
222
  def setIgnoreEntities(value: Array[String]): this.type = set(this.ignoreEntities, value)
×
223

224
  private def getNerQuestionAnnotations
225
      : scala.collection.immutable.Map[String, Array[Annotation]] = {
226
    getEntityDefinitions.map(nerDef => {
×
227
      (
×
228
        nerDef._1,
×
229
        nerDef._2.map(nerQ =>
×
230
          new Annotation(
×
231
            AnnotatorType.DOCUMENT,
×
232
            0,
×
233
            nerQ.length,
×
234
            nerQ,
235
            Map("entity" -> nerDef._1) ++ Map("ner_question" -> nerQ))))
×
236
    })
237
  }
238

239
  private var _model: Option[Broadcast[ZeroShotNerClassification]] = None
×
240

241
  override def setModelIfNotSet(
242
      spark: SparkSession,
243
      tensorflowWrapper: TensorflowWrapper): ZeroShotNerModel = {
244
    if (_model.isEmpty) {
×
245
      _model = Some(
×
246
        spark.sparkContext.broadcast(
×
247
          new ZeroShotNerClassification(
×
248
            tensorflowWrapper,
249
            sentenceStartTokenId,
×
250
            sentenceEndTokenId,
×
251
            padTokenId,
×
252
            false,
×
253
            configProtoBytes = getConfigProtoBytes,
×
254
            tags = Map.empty[String, Int],
×
255
            signatures = getSignatures,
×
256
            $$(merges),
×
257
            $$(vocabulary))))
×
258
    }
259

260
    this
261
  }
262

263
  override def getModelIfNotSet: RoBertaClassification = _model.get.value
×
264

265
  setDefault(ignoreEntities -> Array(), predictionThreshold -> 0.01f)
×
266

267
  val maskSymbol = "_"
×
268

269
  private def spansOverlap(span1: (Int, Int), span2: (Int, Int)): Boolean = {
270
    !((span2._1 > span1._2) || (span1._1 > span2._2))
×
271
  }
272

273
  private def recognizeEntities(
274
      document: Annotation,
275
      nerDefs: Map[String, Array[Annotation]]): Seq[Annotation] = {
276
    val docPredictions = nerDefs
277
      .flatMap(nerDef => {
×
278
        val nerBatch = nerDef._2.map(nerQuestion => Array(nerQuestion, document)).toSeq
×
279
        val entityPredictions = super
280
          .batchAnnotate(nerBatch)
281
          .zip(nerBatch.map(_.head.result))
×
282
          .map(x => (x._1.head, x._2))
×
283
          .filter(x => x._1.result.nonEmpty)
×
284
          .filter(x =>
×
285
            (if (x._1.metadata.contains("score"))
×
286
               x._1.metadata("score").toFloat
×
287
             else Float.MinValue) > getPredictionThreshold)
×
288
        entityPredictions.map(prediction =>
×
289
          new Annotation(
×
290
            AnnotatorType.CHUNK,
×
291
            prediction._1.begin,
×
292
            prediction._1.end,
×
293
            prediction._1.result,
×
294
            Map(
×
295
              "entity" -> nerDef._1,
×
296
              "sentence" -> document.metadata("sentence"),
×
297
              "word" -> prediction._1.result,
×
298
              "confidence" -> prediction._1.metadata("score"),
×
299
              "question" -> prediction._2)))
×
300
      })
301
      .toSeq
×
302
    // Detect overlapping predictions and choose the one with the higher score
303
    docPredictions
304
      .filter(x => ! $(ignoreEntities).contains(x.metadata("entity"))) // Discard ignored entities
×
305
      .filter(prediction => {
×
306
        !docPredictions
×
307
          .filter(_ != prediction)
×
308
          .exists(otherPrediction => {
309
            spansOverlap(
310
              (prediction.begin, prediction.end),
×
311
              (otherPrediction.begin, otherPrediction.end)) && (otherPrediction.metadata(
×
312
              "confidence") > prediction.metadata("confidence"))
×
313
          })
314
      })
315

316
  }
317

318
  def maskEntity(document: Annotation, entity: Annotation): String = {
319
    val entityStart = entity.begin - document.begin
×
320
    val entityEnd = entity.end - document.begin
×
321
    //    println(document.result.slice(0, entityStart) + maskSymbol + {entityStart to entityEnd - 2}.map(_ => " ").mkString + document.result.slice(entityEnd, $(maxSentenceLength)))
322
    document.result.slice(0, entityStart) + maskSymbol + {
×
323
      entityStart to entityEnd - 2
×
324
    }.map(_ => " ").mkString + document.result.slice(entityEnd, $(maxSentenceLength))
×
325
  }
326

327
  def recognizeMultipleEntities(
328
      document: Annotation,
329
      nerDefs: Map[String, Array[Annotation]],
330
      recognizedEntities: Seq[Annotation] = Seq()): Seq[Annotation] = {
331
    val newEntities = recognizeEntities(document, nerDefs)
332
      .filter(entity =>
×
333
        (!recognizedEntities
334
          .exists(recognizedEntity =>
335
            spansOverlap(
×
336
              (entity.begin, entity.end),
×
337
              (recognizedEntity.begin, recognizedEntity.end)))) && (entity.result != maskSymbol))
×
338

339
    newEntities ++ newEntities.flatMap { entity =>
×
340
      val newDoc = new Annotation(
×
341
        document.annotatorType,
×
342
        document.begin,
×
343
        document.end,
×
344
        maskEntity(document, entity),
×
345
        document.metadata)
×
346
      recognizeMultipleEntities(
×
347
        newDoc,
348
        nerDefs.filter(x => x._1 == entity.metadata("entity")),
×
349
        recognizedEntities ++ newEntities)
×
350
    }
351
  }
352

353
  def isTokenInEntity(token: Annotation, entity: Annotation): Boolean = {
354
    (
355
      token.metadata("sentence") == entity.metadata("sentence")
×
356
      && (token.begin >= entity.begin) && (token.end <= entity.end)
×
357
    )
358
  }
359

360
  override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
361

362
    batchedAnnotations.map(annotations => {
×
363
      val documents = annotations
364
        .filter(_.annotatorType == AnnotatorType.DOCUMENT)
×
365
        .toSeq
×
366
      val tokens = annotations.filter(_.annotatorType == AnnotatorType.TOKEN)
×
367
      val entities = documents.flatMap { doc =>
×
368
        recognizeMultipleEntities(doc, getNerQuestionAnnotations).flatMap { entity =>
×
369
          tokens
370
            .filter(t => isTokenInEntity(t, entity))
×
371
            .zipWithIndex
×
372
            .map { case (token, i) =>
×
373
              val bioPrefix = if (i == 0) "B-" else "I-"
×
374
              new Annotation(
×
375
                annotatorType = AnnotatorType.NAMED_ENTITY,
×
376
                begin = token.begin,
×
377
                end = token.end,
×
378
                result = bioPrefix + entity.metadata("entity"),
×
379
                metadata = Map(
×
380
                  "sentence" -> entity.metadata("sentence"),
×
381
                  "word" -> token.result,
×
382
                  "confidence" -> entity.metadata("confidence"),
×
383
                  "question" -> entity.metadata("question")))
×
384
            }
385
        }.toList
×
386
      }
387
      tokens
388
        .map(token => {
×
389
          val entity = entities.find(e => isTokenInEntity(token, e))
×
390
          if (entity.nonEmpty) {
×
391
            entity.get
×
392
          } else {
393
            new Annotation(
×
394
              annotatorType = AnnotatorType.NAMED_ENTITY,
×
395
              begin = token.begin,
×
396
              end = token.end,
×
397
              result = "O",
×
398
              metadata = Map("sentence" -> token.metadata("sentence"), "word" -> token.result))
×
399
          }
400
        })
401
        .toSeq
×
402
    })
403
  }
404

405
}
406

407
trait ReadablePretrainedZeroShotNer
408
    extends ParamsAndFeaturesReadable[ZeroShotNerModel]
409
    with HasPretrained[ZeroShotNerModel] {
410
  override val defaultModelName: Some[String] = Some("zero_shot_ner_roberta")
×
411

412
  /** Java compliant-overrides */
413
  override def pretrained(): ZeroShotNerModel =
414
    pretrained(defaultModelName.get, defaultLang, defaultLoc)
×
415

416
  override def pretrained(name: String): ZeroShotNerModel =
417
    pretrained(name, defaultLang, defaultLoc)
×
418

419
  override def pretrained(name: String, lang: String): ZeroShotNerModel =
420
    pretrained(name, lang, defaultLoc)
×
421

422
  override def pretrained(name: String, lang: String, remoteLoc: String): ZeroShotNerModel = {
423
    try {
424
      ZeroShotNerModel.getFromRoBertaForQuestionAnswering(
×
425
        ResourceDownloader
426
          .downloadModel(RoBertaForQuestionAnswering, name, Option(lang), remoteLoc))
×
427
    } catch {
428
      case _: java.lang.RuntimeException =>
429
        ResourceDownloader.downloadModel(ZeroShotNerModel, name, Option(lang), remoteLoc)
×
430
    }
431
  }
432

433
  override def load(path: String): ZeroShotNerModel = {
434
    try {
435
      super.load(path)
×
436
    } catch {
437
      case e: java.lang.ClassCastException =>
438
        try {
439
          ZeroShotNerModel.getFromRoBertaForQuestionAnswering(
×
440
            RoBertaForQuestionAnswering.load(path))
×
441
        } catch {
442
          case _: Throwable => throw e
×
443
        }
444
    }
445
  }
446
}
447

448
trait ReadZeroShotNerDLModel extends ReadTensorflowModel {
449
  this: ParamsAndFeaturesReadable[ZeroShotNerModel] =>
450

451
  override val tfFile: String = "roberta_classification_tensorflow"
×
452

453
  def readTensorflow(instance: ZeroShotNerModel, path: String, spark: SparkSession): Unit = {
454

455
    val tf = readTensorflowModel(path, spark, "_roberta_classification_tf", initAllTables = false)
×
456
    instance.setModelIfNotSet(spark, tf)
×
457
  }
458

459
  addReader(readTensorflow)
×
460
}
461
object ZeroShotNerModel extends ReadablePretrainedZeroShotNer with ReadZeroShotNerDLModel {
462

463
  def apply(model: PipelineStage): PipelineStage = {
464
    model match {
465
      case answering: RoBertaForQuestionAnswering if !model.isInstanceOf[ZeroShotNerModel] =>
×
466
        getFromRoBertaForQuestionAnswering(answering)
×
467
      case _ =>
468
        model
469
    }
470
  }
471

472
  def getFromRoBertaForQuestionAnswering(model: RoBertaForQuestionAnswering): ZeroShotNerModel = {
473
    val spark = SparkSession.builder.getOrCreate()
×
474

475
    val newModel = new ZeroShotNerModel()
476
      .setVocabulary(
477
        model.vocabulary.get.getOrElse(throw new RuntimeException("Vocabulary not set")))
×
478
      .setMerges(model.merges.get.getOrElse(throw new RuntimeException("Merges not set")))
×
479
      .setCaseSensitive(model.getCaseSensitive)
×
480
      .setBatchSize(model.getBatchSize)
×
481

482
    if (model.signatures.isSet)
×
483
      newModel.setSignatures(
×
484
        model.signatures.get.getOrElse(throw new RuntimeException("Signatures not set")))
×
485

486
    newModel.setModelIfNotSet(spark, model.getModelIfNotSet.tensorflowWrapper)
×
487

488
    model
489
      .extractParamMap()
490
      .toSeq
491
      .foreach(x => {
×
492
        newModel.set(x.param.name, x.value)
×
493
      })
494

495
    newModel
496
  }
497
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc