• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4992350528

pending completion
4992350528

Pull #13797

github

GitHub
Merge 424c7ff18 into ef7906c5e
Pull Request #13797: SPARKNLP-835: ProtectedParam and ProtectedFeature

24 of 24 new or added lines in 6 files covered. (100.0%)

8643 of 13129 relevant lines covered (65.83%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.1
/src/main/scala/com/johnsnowlabs/nlp/annotators/GraphExtraction.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators
18

19
import com.johnsnowlabs.nlp.AnnotatorType._
20
import com.johnsnowlabs.nlp._
21
import com.johnsnowlabs.nlp.annotators.common.LabeledDependency.DependencyInfo
22
import com.johnsnowlabs.nlp.annotators.common.{LabeledDependency, NerTagged}
23
import com.johnsnowlabs.nlp.annotators.ner.NerTagsEncoding
24
import com.johnsnowlabs.nlp.annotators.parser.dep.DependencyParserModel
25
import com.johnsnowlabs.nlp.annotators.parser.typdep.TypedDependencyParserModel
26
import com.johnsnowlabs.nlp.annotators.pos.perceptron.PerceptronModel
27
import com.johnsnowlabs.nlp.util.GraphBuilder
28
import org.apache.spark.ml.PipelineModel
29
import org.apache.spark.ml.param.{BooleanParam, IntParam, Param, StringArrayParam}
30
import org.apache.spark.ml.util.Identifiable
31
import org.apache.spark.sql.functions.array
32
import org.apache.spark.sql.{DataFrame, Dataset}
33

34
/** Extracts a dependency graph between entities.
35
  *
36
  * The GraphExtraction class takes e.g. extracted entities from a
37
  * [[com.johnsnowlabs.nlp.annotators.ner.dl.NerDLModel NerDLModel]] and creates a dependency tree
38
  * which describes how the entities relate to each other. For that a triple store format is used.
39
  * Nodes represent the entities and the edges represent the relations between those entities. The
40
  * graph can then be used to find relevant relationships between words.
41
  *
42
  * Both the
43
  * [[com.johnsnowlabs.nlp.annotators.parser.dep.DependencyParserModel DependencyParserModel]] and
44
  * [[com.johnsnowlabs.nlp.annotators.parser.typdep.TypedDependencyParserModel TypedDependencyParserModel]]
45
  * need to be present in the pipeline. There are two ways to set them:
46
  *
47
  *   1. Both Annotators are present in the pipeline already. The dependencies are taken
48
  *      implicitly from these two Annotators.
49
  *   1. Setting `setMergeEntities` to `true` will download the default pretrained models for
50
  *      those two Annotators automatically. The specific models can also be set with
51
  *      `setDependencyParserModel` and `setTypedDependencyParserModel`:
52
  *      {{{
53
  *            val graph_extraction = new GraphExtraction()
54
  *              .setInputCols("document", "token", "ner")
55
  *              .setOutputCol("graph")
56
  *              .setRelationshipTypes(Array("prefer-LOC"))
57
  *              .setMergeEntities(true)
58
  *            //.setDependencyParserModel(Array("dependency_conllu", "en",  "public/models"))
59
  *            //.setTypedDependencyParserModel(Array("dependency_typed_conllu", "en",  "public/models"))
60
  *      }}}
61
  *
62
  * To transform the resulting graph into a more generic form such as RDF, see the
63
  * [[com.johnsnowlabs.nlp.GraphFinisher GraphFinisher]].
64
  *
65
  * ==Example==
66
  * {{{
67
  * import spark.implicits._
68
  * import com.johnsnowlabs.nlp.base.DocumentAssembler
69
  * import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
70
  * import com.johnsnowlabs.nlp.annotators.Tokenizer
71
  * import com.johnsnowlabs.nlp.annotators.ner.dl.NerDLModel
72
  * import com.johnsnowlabs.nlp.embeddings.WordEmbeddingsModel
73
  * import com.johnsnowlabs.nlp.annotators.pos.perceptron.PerceptronModel
74
  * import com.johnsnowlabs.nlp.annotators.parser.dep.DependencyParserModel
75
  * import com.johnsnowlabs.nlp.annotators.parser.typdep.TypedDependencyParserModel
76
  * import org.apache.spark.ml.Pipeline
77
  * import com.johnsnowlabs.nlp.annotators.GraphExtraction
78
  *
79
  * val documentAssembler = new DocumentAssembler()
80
  *   .setInputCol("text")
81
  *   .setOutputCol("document")
82
  *
83
  * val sentence = new SentenceDetector()
84
  *   .setInputCols("document")
85
  *   .setOutputCol("sentence")
86
  *
87
  * val tokenizer = new Tokenizer()
88
  *   .setInputCols("sentence")
89
  *   .setOutputCol("token")
90
  *
91
  * val embeddings = WordEmbeddingsModel.pretrained()
92
  *   .setInputCols("sentence", "token")
93
  *   .setOutputCol("embeddings")
94
  *
95
  * val nerTagger = NerDLModel.pretrained()
96
  *   .setInputCols("sentence", "token", "embeddings")
97
  *   .setOutputCol("ner")
98
  *
99
  * val posTagger = PerceptronModel.pretrained()
100
  *   .setInputCols("sentence", "token")
101
  *   .setOutputCol("pos")
102
  *
103
  * val dependencyParser = DependencyParserModel.pretrained()
104
  *   .setInputCols("sentence", "pos", "token")
105
  *   .setOutputCol("dependency")
106
  *
107
  * val typedDependencyParser = TypedDependencyParserModel.pretrained()
108
  *   .setInputCols("dependency", "pos", "token")
109
  *   .setOutputCol("dependency_type")
110
  *
111
  * val graph_extraction = new GraphExtraction()
112
  *   .setInputCols("document", "token", "ner")
113
  *   .setOutputCol("graph")
114
  *   .setRelationshipTypes(Array("prefer-LOC"))
115
  *
116
  * val pipeline = new Pipeline().setStages(Array(
117
  *   documentAssembler,
118
  *   sentence,
119
  *   tokenizer,
120
  *   embeddings,
121
  *   nerTagger,
122
  *   posTagger,
123
  *   dependencyParser,
124
  *   typedDependencyParser,
125
  *   graph_extraction
126
  * ))
127
  *
128
  * val data = Seq("You and John prefer the morning flight through Denver").toDF("text")
129
  * val result = pipeline.fit(data).transform(data)
130
  *
131
  * result.select("graph").show(false)
132
  * +-----------------------------------------------------------------------------------------------------------------+
133
  * |graph                                                                                                            |
134
  * +-----------------------------------------------------------------------------------------------------------------+
135
  * |[[node, 13, 18, prefer, [relationship -> prefer,LOC, path1 -> prefer,nsubj,morning,flat,flight,flat,Denver], []]]|
136
  * +-----------------------------------------------------------------------------------------------------------------+
137
  * }}}
138
  *
139
  * @see
140
  *   [[com.johnsnowlabs.nlp.GraphFinisher GraphFinisher]] to output the paths in a more generic
141
  *   format, like RDF
142
  * @param uid
143
  *   required uid for storing annotator to disk
144
  * @groupname anno Annotator types
145
  * @groupdesc anno
146
  *   Required input and expected output annotator types
147
  * @groupname Ungrouped Members
148
  * @groupname param Parameters
149
  * @groupname setParam Parameter setters
150
  * @groupname getParam Parameter getters
151
  * @groupname Ungrouped Members
152
  * @groupprio param  1
153
  * @groupprio anno  2
154
  * @groupprio Ungrouped 3
155
  * @groupprio setParam  4
156
  * @groupprio getParam  5
157
  * @groupdesc param
158
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
159
  *   parameter values through setters and getters, respectively.
160
  */
161
class GraphExtraction(override val uid: String)
162
    extends AnnotatorModel[GraphExtraction]
163
    with HasSimpleAnnotate[GraphExtraction] {
164

165
  def this() = this(Identifiable.randomUID("GRAPH_EXTRACTOR"))
1✔
166

167
  /** Find paths between a pair of token and entity (Default: `Array()`)
168
    *
169
    * @group param
170
    */
171
  val relationshipTypes = new StringArrayParam(
1✔
172
    this,
173
    "relationshipTypes",
1✔
174
    "Find paths between a pair of token and entity")
1✔
175

176
  /** Find paths between a pair of entities (Default: `Array()`)
177
    *
178
    * @group param
179
    */
180
  val entityTypes =
181
    new StringArrayParam(this, "entityTypes", "Find paths between a pair of entities")
1✔
182

183
  /** When set to true find paths between entities (Default: `false`)
184
    *
185
    * @group param
186
    */
187
  val explodeEntities =
188
    new BooleanParam(this, "explodeEntities", "When set to true find paths between entities")
1✔
189

190
  /** Tokens to be consider as root to start traversing the paths (Default: `Array()`). Use it
191
    * along with `explodeEntities`
192
    *
193
    * @group param
194
    */
195
  val rootTokens = new StringArrayParam(
1✔
196
    this,
197
    "rootTokens",
1✔
198
    "Tokens to be consider as root to start traversing the paths. Use it along with explodeEntities")
1✔
199

200
  /** Maximum sentence size that the annotator will process (Default: `1000`). Above this, the
201
    * sentence is skipped
202
    *
203
    * @group param
204
    */
205
  val maxSentenceSize = new IntParam(
1✔
206
    this,
207
    "maxSentenceSize",
1✔
208
    "Maximum sentence size that the annotator will process. Above this, the sentence is skipped")
1✔
209

210
  /** Minimum sentence size that the annotator will process (Default: `2`). Below this, the
211
    * sentence is skipped
212
    *
213
    * @group param
214
    */
215
  val minSentenceSize = new IntParam(
1✔
216
    this,
217
    "minSentenceSize",
1✔
218
    "Minimum sentence size that the annotator will process. Below this, the sentence is skipped")
1✔
219

220
  /** Merge same neighboring entities as a single token (Default: `false`)
221
    *
222
    * @group param
223
    */
224
  val mergeEntities =
225
    new BooleanParam(this, "mergeEntities", "Merge same neighboring entities as a single token")
1✔
226

227
  /** IOB format to apply when merging entities
228
    *
229
    * @group param
230
    */
231
  val mergeEntitiesIOBFormat = new Param[String](
1✔
232
    this,
233
    "mergeEntitiesIOBFormat",
1✔
234
    "IOB format to apply when merging entities. Values: IOB or IOB2")
1✔
235

236
  /** Whether to include edges when building paths (Default: `true`)
237
    *
238
    * @group param
239
    */
240
  val includeEdges =
241
    new BooleanParam(this, "includeEdges", "Whether to include edges when building paths")
1✔
242

243
  /** Delimiter symbol used for path output (Default: `","`)
244
    *
245
    * @group param
246
    */
247
  val delimiter = new Param[String](this, "delimiter", "Delimiter symbol used for path output")
1✔
248

249
  /** Coordinates (name, lang, remoteLoc) to a pretrained POS model (Default: `Array()`)
250
    *
251
    * @group param
252
    */
253
  val posModel = new StringArrayParam(
1✔
254
    this,
255
    "posModel",
1✔
256
    "Coordinates (name, lang, remoteLoc) to a pretrained POS model")
1✔
257

258
  /** Coordinates (name, lang, remoteLoc) to a pretrained Dependency Parser model (Default:
259
    * `Array()`)
260
    *
261
    * @group param
262
    */
263
  val dependencyParserModel = new StringArrayParam(
1✔
264
    this,
265
    "dependencyParserModel",
1✔
266
    "Coordinates (name, lang, remoteLoc) to a pretrained Dependency Parser model")
1✔
267

268
  /** Coordinates (name, lang, remoteLoc) to a pretrained Typed Dependency Parser model (Default:
269
    * `Array()`)
270
    *
271
    * @group param
272
    */
273
  val typedDependencyParserModel = new StringArrayParam(
1✔
274
    this,
275
    "typedDependencyParserModel",
1✔
276
    "Coordinates (name, lang, remoteLoc) to a pretrained Typed Dependency Parser model")
1✔
277

278
  /** @group setParam */
279
  def setRelationshipTypes(value: Array[String]): this.type = set(relationshipTypes, value)
1✔
280

281
  /** @group setParam */
282
  def setEntityTypes(value: Array[String]): this.type = set(entityTypes, value)
1✔
283

284
  /** @group setParam */
285
  def setExplodeEntities(value: Boolean): this.type = set(explodeEntities, value)
1✔
286

287
  /** @group setParam */
288
  def setRootTokens(value: Array[String]): this.type = set(rootTokens, value)
1✔
289

290
  /** @group setParam */
291
  def setMaxSentenceSize(value: Int): this.type = set(maxSentenceSize, value)
1✔
292

293
  /** @group setParam */
294
  def setMinSentenceSize(value: Int): this.type = set(minSentenceSize, value)
1✔
295

296
  /** @group setParam */
297
  def setMergeEntities(value: Boolean): this.type = set(mergeEntities, value)
1✔
298

299
  /** @group setParam */
300
  def setMergeEntitiesIOBFormat(value: String): this.type = set(mergeEntitiesIOBFormat, value)
×
301

302
  /** @group setParam */
303
  def setIncludeEdges(value: Boolean): this.type = set(includeEdges, value)
1✔
304

305
  /** @group setParam */
306
  def setDelimiter(value: String): this.type = set(delimiter, value)
×
307

308
  /** @group setParam */
309
  def setPosModel(value: Array[String]): this.type = set(posModel, value)
×
310

311
  /** @group setParam */
312
  def setDependencyParserModel(value: Array[String]): this.type =
313
    set(dependencyParserModel, value)
×
314

315
  /** @group setParam */
316
  def setTypedDependencyParserModel(value: Array[String]): this.type =
317
    set(typedDependencyParserModel, value)
×
318

319
  setDefault(
1✔
320
    entityTypes -> Array(),
1✔
321
    explodeEntities -> true,
1✔
322
    maxSentenceSize -> 1000,
1✔
323
    minSentenceSize -> 2,
1✔
324
    mergeEntities -> true,
1✔
325
    rootTokens -> Array(),
1✔
326
    relationshipTypes -> Array(),
1✔
327
    includeEdges -> true,
1✔
328
    delimiter -> ",",
1✔
329
    posModel -> Array(),
1✔
330
    dependencyParserModel -> Array(),
1✔
331
    typedDependencyParserModel -> Array(),
1✔
332
    mergeEntitiesIOBFormat -> "IOB2")
1✔
333

334
  private lazy val allowedEntityRelationships = $(entityTypes).map { entityRelationship =>
335
    val result = entityRelationship.split("-")
336
    (result.head, result.last)
337
  }.distinct
338

339
  private lazy val allowedRelationshipTypes = $(relationshipTypes).map { relationshipTypes =>
340
    val result = relationshipTypes.split("-")
341
    (result.head, result.last)
342
  }.distinct
343

344
  private var pretrainedPos: Option[PerceptronModel] = None
1✔
345
  private var pretrainedDependencyParser: Option[DependencyParserModel] = None
1✔
346
  private var pretrainedTypedDependencyParser: Option[TypedDependencyParserModel] =
347
    None
1✔
348

349
  override def _transform(
350
      dataset: Dataset[_],
351
      recursivePipeline: Option[PipelineModel]): DataFrame = {
352
    if ($(mergeEntities)) {
1✔
353
      super._transform(dataset, recursivePipeline)
×
354
    } else {
1✔
355
      val structFields = dataset.schema.fields
1✔
356
        .filter(field => field.metadata.contains("annotatorType"))
1✔
357
        .filter(field =>
1✔
358
          field.metadata.getString("annotatorType") == DEPENDENCY ||
1✔
359
            field.metadata.getString("annotatorType") == LABELED_DEPENDENCY)
1✔
360
      if (structFields.length < 2) {
1✔
361
        throw new IllegalArgumentException(
×
362
          s"Missing either $DEPENDENCY or $LABELED_DEPENDENCY annotators. " +
363
            s"Make sure such annotators exist in your pipeline or setMergeEntities parameter to True")
364
      }
365

366
      val columnNames = structFields.map(structField => structField.name)
1✔
367
      val inputCols = getInputCols ++ columnNames
1✔
368
      val processedDataset = dataset.withColumn(
1✔
369
        getOutputCol,
1✔
370
        wrapColumnMetadata(dfAnnotate(array(inputCols.map(c => dataset.col(c)): _*))))
1✔
371
      processedDataset
372
    }
373
  }
374

375
  override def beforeAnnotate(dataset: Dataset[_]): Dataset[_] = {
376

377
    if ($(mergeEntities)) {
×
378
      pretrainedPos = Some(PretrainedAnnotations.getPretrainedPos($(posModel)))
×
379
      pretrainedDependencyParser = Some(
×
380
        PretrainedAnnotations.getDependencyParser($(dependencyParserModel)))
×
381
      pretrainedTypedDependencyParser = Some(TypedDependencyParserModel.pretrained())
×
382
    }
383

384
    dataset
385
  }
386

387
  /** takes a document and annotations and produces new annotations of this annotator's annotation
388
    * type
389
    *
390
    * @param annotations
391
    *   Annotations that correspond to inputAnnotationCols generated by previous annotators if any
392
    * @return
393
    *   any number of annotations processed for every input annotation. Not necessary one to one
394
    *   relationship
395
    */
396
  override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
397
    val sentenceIndexesToSkip = annotations
398
      .filter(_.annotatorType == AnnotatorType.DOCUMENT)
1✔
399
      .filter(annotation =>
400
        annotation.result.length > $(maxSentenceSize) || annotation.result.length < $(
1✔
401
          minSentenceSize))
1✔
402
      .map(annotation => annotation.metadata("sentence"))
1✔
403
      .toList
404
      .distinct
1✔
405

406
    val annotationsToProcess = annotations.filter(annotation => {
1✔
407
      !sentenceIndexesToSkip.contains(annotation.metadata.getOrElse("sentence", "0"))
1✔
408
    })
409

410
    if (annotationsToProcess.isEmpty) {
1✔
411
      Seq(Annotation(NODE, 0, 0, "", Map()))
1✔
412
    } else {
413
      computeAnnotatePaths(annotationsToProcess)
1✔
414
    }
415
  }
416

417
  private def computeAnnotatePaths(annotations: Seq[Annotation]): Seq[Annotation] = {
418
    val annotationsBySentence = annotations
419
      .groupBy(token => token.metadata.getOrElse("sentence", "0").toInt)
1✔
420
      .toSeq
421
      .sortBy(_._1)
1✔
422
      .map(annotationBySentence => annotationBySentence._2)
1✔
423

424
    val graphPaths = annotationsBySentence.flatMap { sentenceAnnotations =>
1✔
425
      val annotationsWithDependencies = getAnnotationsWithDependencies(sentenceAnnotations)
1✔
426
      val tokens = annotationsWithDependencies.filter(_.annotatorType == AnnotatorType.TOKEN)
1✔
427
      val nerEntities = annotationsWithDependencies.filter(annotation =>
1✔
428
        annotation.annotatorType == TOKEN && annotation.metadata("entity") != "O")
1✔
429

430
      if (nerEntities.isEmpty) {
1✔
431
        Seq(Annotation(NODE, 0, 0, "", Map()))
×
432
      } else {
1✔
433
        val dependencyData = LabeledDependency.unpackHeadAndRelation(annotationsWithDependencies)
1✔
434
        val annotationsInfo = AnnotationsInfo(tokens, nerEntities, dependencyData)
1✔
435

436
        val graph = new GraphBuilder(dependencyData.length + 1)
1✔
437
        dependencyData.zipWithIndex.foreach { case (dependencyInfo, index) =>
1✔
438
          graph.addEdge(dependencyInfo.headIndex, index + 1)
1✔
439
        }
440

441
        if ($(explodeEntities)) {
1✔
442
          extractGraphsFromEntities(annotationsInfo, graph)
1✔
443
        } else {
444
          extractGraphsFromRelationships(annotationsInfo, graph)
1✔
445
        }
446
      }
447
    }
448

449
    graphPaths
450

451
  }
452

453
  private def getAnnotationsWithDependencies(
454
      sentenceAnnotations: Seq[Annotation]): Seq[Annotation] = {
455
    if ($(mergeEntities)) {
1✔
456
      getPretrainedAnnotations(sentenceAnnotations)
×
457
    } else {
458
      getEntityAnnotations(sentenceAnnotations)
1✔
459
    }
460
  }
461

462
  private def getPretrainedAnnotations(annotationsToProcess: Seq[Annotation]): Seq[Annotation] = {
463

464
    val relatedAnnotatedTokens = mergeRelatedTokens(annotationsToProcess)
×
465
    val sentence = annotationsToProcess.filter(_.annotatorType == AnnotatorType.DOCUMENT)
×
466

467
    val posInput = sentence ++ relatedAnnotatedTokens
×
468
    val posAnnotations = PretrainedAnnotations.getPosOutput(posInput, pretrainedPos.get)
×
469

470
    val dependencyParserInput = sentence ++ relatedAnnotatedTokens ++ posAnnotations
×
471
    val dependencyParserAnnotations =
472
      PretrainedAnnotations.getDependencyParserOutput(
×
473
        dependencyParserInput,
474
        pretrainedDependencyParser.get)
×
475

476
    val typedDependencyParserInput =
477
      relatedAnnotatedTokens ++ posAnnotations ++ dependencyParserAnnotations
×
478
    val typedDependencyParserAnnotations = PretrainedAnnotations.getTypedDependencyParserOutput(
×
479
      typedDependencyParserInput,
480
      pretrainedTypedDependencyParser.get)
×
481

482
    relatedAnnotatedTokens ++ dependencyParserAnnotations ++ typedDependencyParserAnnotations
×
483
  }
484

485
  private def getEntityAnnotations(annotationsToProcess: Seq[Annotation]): Seq[Annotation] = {
486
    val entityAnnotations = annotationsToProcess.filter(_.annotatorType == NAMED_ENTITY)
1✔
487
    val tokensWithEntity =
488
      annotationsToProcess.filter(_.annotatorType == TOKEN).zipWithIndex.map {
1✔
489
        case (annotation, index) =>
490
          val tag = entityAnnotations(index).result
1✔
491
          val entity = if (tag.length == 1) tag else tag.substring(2)
1✔
492
          val metadata = annotation.metadata ++ Map("entity" -> entity)
1✔
493
          Annotation(
1✔
494
            annotation.annotatorType,
1✔
495
            annotation.begin,
1✔
496
            annotation.end,
1✔
497
            annotation.result,
1✔
498
            metadata)
499
      }
500
    val dependencyParserAnnotations = annotationsToProcess.filter(annotation =>
1✔
501
      annotation.annotatorType == DEPENDENCY || annotation.annotatorType == LABELED_DEPENDENCY)
1✔
502

503
    tokensWithEntity ++ dependencyParserAnnotations
1✔
504
  }
505

506
  private def mergeRelatedTokens(annotations: Seq[Annotation]): Seq[Annotation] = {
507
    val sentences = NerTagged.unpack(annotations)
×
508
    val docs = annotations.filter(a =>
×
509
      a.annotatorType == AnnotatorType.DOCUMENT && sentences.exists(b =>
×
510
        b.indexedTaggedWords.exists(c => c.begin >= a.begin && c.end <= a.end)))
×
511

512
    val entities = sentences.zip(docs.zipWithIndex).flatMap { case (sentence, doc) =>
×
513
      NerTagsEncoding.fromIOB(
×
514
        sentence,
515
        doc._1,
×
516
        sentenceIndex = doc._2,
×
517
        includeNoneEntities = true,
×
518
        format = $(mergeEntitiesIOBFormat))
×
519
    }
520

521
    entities.map(entity =>
×
522
      Annotation(
×
523
        TOKEN,
×
524
        entity.start,
×
525
        entity.end,
×
526
        entity.text,
×
527
        Map("sentence" -> entity.sentenceId, "entity" -> entity.entity)))
×
528
  }
529

530
  private def extractGraphsFromEntities(
531
      annotationsInfo: AnnotationsInfo,
532
      graph: GraphBuilder): Seq[Annotation] = {
533

534
    var rootIndices: Array[Int] = Array()
1✔
535
    var sourceDependencies: Array[DependencyInfo] = Array()
1✔
536

537
    if ($(rootTokens).isEmpty) {
1✔
538
      sourceDependencies = annotationsInfo.dependencyData
539
        .filter(dependencyInfo => dependencyInfo.headIndex == 0)
1✔
540
        .toArray
1✔
541
      rootIndices = Array(annotationsInfo.dependencyData.indexOf(sourceDependencies.head) + 1)
1✔
542
    } else {
1✔
543
      sourceDependencies = $(rootTokens).flatMap(rootToken =>
1✔
544
        annotationsInfo.dependencyData.filter(_.token == rootToken))
1✔
545
      rootIndices = sourceDependencies.map(sourceDependency =>
1✔
546
        annotationsInfo.dependencyData.indexOf(sourceDependency) + 1)
1✔
547
    }
548

549
    val entitiesPairData =
550
      getEntitiesData(annotationsInfo.nerEntities, annotationsInfo.dependencyData)
1✔
551
    val annotatedPaths = rootIndices.flatMap(rootIndex =>
1✔
552
      getAnnotatedPaths(entitiesPairData, graph, rootIndex, annotationsInfo))
1✔
553

554
    if (annotatedPaths.isEmpty && $(rootTokens).nonEmpty) {
1✔
555
      println(
×
556
        s"[WARN] Not found paths between given roots: [${$(rootTokens).mkString(",")}] and" +
×
557
          s" entities pairs: ${entitiesPairData.map(x => x.entities).mkString(",")}.\n" +
×
558
          s"This could mean there are no more labeled tokens below the given roots or NER didn't label any token.\n" +
×
559
          s"$entitiesWarnMessage")
×
560
    }
561

562
    if (annotatedPaths.isEmpty && $(rootTokens).isEmpty) {
×
563
      println(
×
564
        s"[WARN] Not found paths between the root [${sourceDependencies.head.token}] and " +
×
565
          s" entities pairs ${entitiesPairData.map(x => x.entities).mkString(",")}.\n" +
×
566
          s"This could mean there are no more labeled tokens below the root or NER didn't label any token.\n" +
×
567
          s"$entitiesWarnMessage")
×
568
    }
569

570
    annotatedPaths
1✔
571
  }
572

573
  private def entitiesWarnMessage: String = {
574
    val notebooksURI =
575
      "https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/text/english/"
×
576
    val relationshipTypesNotebook =
577
      s"$notebooksURI/graph-extraction/graph_extraction_roots_paths.ipynb"
×
578
    val displayNotebook = s"$notebooksURI/graph-extraction/graph_extraction_helper_display.ipynb"
×
579
    val message =
580
      s"You can try using relationshipTypes parameter, check this notebook: $relationshipTypesNotebook \n" +
×
581
        s"You can also use spark-nlp-display to visualize Dependency Parser and NER output to help identify the kind of relations you can extract" +
×
582
        s", check this example: $displayNotebook"
×
583
    message
584
  }
585

586
  private def extractGraphsFromRelationships(
587
      annotationsInfo: AnnotationsInfo,
588
      graph: GraphBuilder): Seq[Annotation] = {
589

590
    val annotatedGraphPaths = allowedRelationshipTypes.flatMap { relationshipTypes =>
1✔
591
      val rootData = annotationsInfo.tokens
592
        .filter(_.result == relationshipTypes._1)
1✔
593
        .map(token => (token, annotationsInfo.tokens.indexOf(token) + 1))
1✔
594
      val entityIndexes = annotationsInfo.nerEntities
595
        .filter(_.metadata("entity") == relationshipTypes._2)
1✔
596
        .map(nerEntity => annotationsInfo.tokens.indexOf(nerEntity) + 1)
1✔
597

598
      rootData.flatMap { rootInfo =>
1✔
599
        val paths = entityIndexes.flatMap(entityIndex =>
1✔
600
          buildPath(graph, (rootInfo._2, entityIndex), annotationsInfo.dependencyData))
1✔
601
        val pathsMap = paths.zipWithIndex.flatMap { case (path, index) =>
1✔
602
          Map(s"path${(index + 1).toString}" -> path)
1✔
603
        }.toMap
1✔
604
        if (paths.nonEmpty) {
1✔
605
          Some(
1✔
606
            Annotation(
1✔
607
              NODE,
1✔
608
              rootInfo._1.begin,
1✔
609
              rootInfo._1.end,
1✔
610
              rootInfo._1.result,
1✔
611
              Map(
612
                "relationship" -> s"${rootInfo._1.result},${relationshipTypes._2}") ++ pathsMap))
1✔
613
        } else {
614
          None
×
615
        }
616
      }
617
    }
618
    annotatedGraphPaths
1✔
619
  }
620

621
  private def buildPath(
622
      graph: GraphBuilder,
623
      nodesIndexes: (Int, Int),
624
      dependencyData: Seq[DependencyInfo]): Option[String] = {
625
    val rootIndex = nodesIndexes._1
1✔
626
    val nodesIndexesPath = graph.findPath(rootIndex, nodesIndexes._2)
1✔
627
    val path = nodesIndexesPath.map { nodeIndex =>
1✔
628
      val dependencyInfo = dependencyData(nodeIndex - 1)
1✔
629
      val relation = dependencyInfo.relation
1✔
630
      var result = dependencyInfo.token
1✔
631
      if ($(includeEdges)) {
1✔
632
        val edge =
633
          if (relation == "*root*" || nodeIndex == rootIndex) "" else relation + $(delimiter)
1✔
634
        result = edge + dependencyInfo.token
1✔
635
      }
636
      result
637
    }
638
    if (path.isEmpty) None else Some(path.mkString($(delimiter)))
1✔
639
  }
640

641
  private def getAnnotatedPaths(
642
      entitiesPairData: List[EntitiesPairInfo],
643
      graph: GraphBuilder,
644
      rootIndex: Int,
645
      annotationsInfo: AnnotationsInfo): Seq[Annotation] = {
646

647
    val tokens = annotationsInfo.tokens
1✔
648
    val dependencyData = annotationsInfo.dependencyData
1✔
649

650
    val paths = entitiesPairData.flatMap { entitiesPairInfo =>
1✔
651
      val leftPath =
652
        buildPath(graph, (rootIndex, entitiesPairInfo.entitiesIndex._1), dependencyData)
1✔
653
      val rightPath =
654
        buildPath(graph, (rootIndex, entitiesPairInfo.entitiesIndex._2), dependencyData)
1✔
655
      if (leftPath.nonEmpty && rightPath.nonEmpty) {
1✔
656
        Some(GraphInfo(entitiesPairInfo.entities, leftPath, rightPath))
1✔
657
      } else None
1✔
658
    }
659

660
    val sourceToken = tokens(rootIndex - 1)
1✔
661
    val annotatedPaths = paths.map { path =>
1✔
662
      val leftEntity = path.entities._1
1✔
663
      val rightEntity = path.entities._2
1✔
664
      val leftPathTokens = path.leftPath
1✔
665
      val rightPathTokens = path.rightPath
1✔
666

667
      Annotation(
1✔
668
        NODE,
1✔
669
        sourceToken.begin,
1✔
670
        sourceToken.end,
1✔
671
        sourceToken.result,
1✔
672
        Map(
1✔
673
          "entities" -> s"$leftEntity,$rightEntity",
1✔
674
          "left_path" -> leftPathTokens.mkString($(delimiter)),
1✔
675
          "right_path" -> rightPathTokens.mkString($(delimiter))))
1✔
676
    }
677
    annotatedPaths
678
  }
679

680
  private def getEntitiesData(
681
      annotatedEntities: Seq[Annotation],
682
      dependencyData: Seq[DependencyInfo]): List[EntitiesPairInfo] = {
683
    var annotatedEntitiesPairs: List[(Annotation, Annotation)] = List()
1✔
684
    if (allowedEntityRelationships.isEmpty) {
1✔
685
      annotatedEntitiesPairs =
1✔
686
        annotatedEntities.combinations(2).map(entity => (entity.head, entity.last)).toList
1✔
687
    } else {
688
      annotatedEntitiesPairs = allowedEntityRelationships
1✔
689
        .flatMap(entities => getAnnotatedNerEntitiesPairs(entities, annotatedEntities))
1✔
690
        .filter(entities =>
1✔
691
          entities._1.begin != entities._2.begin && entities._1.end != entities._2.end)
1✔
692
        .toList
1✔
693
    }
694

695
    val entitiesPairData = annotatedEntitiesPairs.map { annotatedEntityPair =>
1✔
696
      val dependencyInfoLeft = dependencyData.filter(dependencyInfo =>
1✔
697
        dependencyInfo.beginToken == annotatedEntityPair._1.begin && dependencyInfo.endToken == annotatedEntityPair._1.end)
1✔
698
      val dependencyInfoRight = dependencyData.filter(dependencyInfo =>
1✔
699
        dependencyInfo.beginToken == annotatedEntityPair._2.begin && dependencyInfo.endToken == annotatedEntityPair._2.end)
1✔
700
      val indexLeft = dependencyData.indexOf(dependencyInfoLeft.head) + 1
1✔
701
      val indexRight = dependencyData.indexOf(dependencyInfoRight.head) + 1
1✔
702

703
      EntitiesPairInfo(
1✔
704
        (indexLeft, indexRight),
1✔
705
        (annotatedEntityPair._1.metadata("entity"), annotatedEntityPair._2.metadata("entity")))
1✔
706
    }
707
    entitiesPairData.distinct
1✔
708
  }
709

710
  private def getAnnotatedNerEntitiesPairs(
711
      entities: (String, String),
712
      annotatedEntities: Seq[Annotation]): List[(Annotation, Annotation)] = {
713

714
    val leftEntities = annotatedEntities.filter(annotatedEntity =>
1✔
715
      annotatedEntity.metadata("entity") == entities._1)
1✔
716
    val rightEntities = annotatedEntities.filter(annotatedEntity =>
1✔
717
      annotatedEntity.metadata("entity") == entities._2)
1✔
718

719
    if (leftEntities.length > rightEntities.length) {
1✔
720
      leftEntities.flatMap { leftEntity =>
1✔
721
        rightEntities.map(rightEntity => (leftEntity, rightEntity))
1✔
722
      }.toList
1✔
723
    } else {
724
      rightEntities.flatMap { rightEntity =>
1✔
725
        leftEntities.map(leftEntity => (leftEntity, rightEntity))
1✔
726
      }.toList
1✔
727
    }
728

729
  }
730

731
  private case class EntitiesPairInfo(entitiesIndex: (Int, Int), entities: (String, String))
732

733
  private case class GraphInfo(
734
      entities: (String, String),
735
      leftPath: Option[String],
736
      rightPath: Option[String])
737

738
  private case class AnnotationsInfo(
739
      tokens: Seq[Annotation],
740
      nerEntities: Seq[Annotation],
741
      dependencyData: Seq[DependencyInfo])
742

743
  /** Output annotator types: NODE
744
    *
745
    * @group anno
746
    */
747
  override val outputAnnotatorType: AnnotatorType = NODE
1✔
748

749
  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
750
    * type
751
    */
752
  /** Input annotator types: DOCUMENT, TOKEN, NAMED_ENTITY
753
    *
754
    * @group anno
755
    */
756
  override val inputAnnotatorTypes: Array[String] = Array(DOCUMENT, TOKEN, NAMED_ENTITY)
1✔
757

758
  override val optionalInputAnnotatorTypes: Array[String] = Array(DEPENDENCY, LABELED_DEPENDENCY)
1✔
759

760
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc