• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4992350528

pending completion
4992350528

Pull #13797

github

GitHub
Merge 424c7ff18 into ef7906c5e
Pull Request #13797: SPARKNLP-835: ProtectedParam and ProtectedFeature

24 of 24 new or added lines in 6 files covered. (100.0%)

8643 of 13129 relevant lines covered (65.83%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.4
/src/main/scala/com/johnsnowlabs/nlp/annotators/er/EntityRulerModel.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.er
18

19
import com.johnsnowlabs.nlp.AnnotatorType.{CHUNK, DOCUMENT, TOKEN}
20
import com.johnsnowlabs.nlp.annotators.common._
21
import com.johnsnowlabs.nlp.serialization.StructFeature
22
import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel, HasPretrained, HasSimpleAnnotate}
23
import com.johnsnowlabs.storage.Database.{ENTITY_REGEX_PATTERNS, Name}
24
import com.johnsnowlabs.storage._
25
import org.apache.spark.broadcast.Broadcast
26
import org.apache.spark.ml.PipelineModel
27
import org.apache.spark.ml.param.{BooleanParam, StringArrayParam}
28
import org.apache.spark.ml.util.Identifiable
29
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
30
import org.slf4j.{Logger, LoggerFactory}
31

32
/** Instantiated model of the [[EntityRulerApproach]]. For usage and examples see the
33
  * documentation of the main class.
34
  *
35
  * @param uid
36
  *   internally renquired UID to make it writable
37
  * @groupname anno Annotator types
38
  * @groupdesc anno
39
  *   Required input and expected output annotator types
40
  * @groupname Ungrouped Members
41
  * @groupname param Parameters
42
  * @groupname setParam Parameter setters
43
  * @groupname getParam Parameter getters
44
  * @groupname Ungrouped Members
45
  * @groupprio param  1
46
  * @groupprio anno  2
47
  * @groupprio Ungrouped 3
48
  * @groupprio setParam  4
49
  * @groupprio getParam  5
50
  * @groupdesc param
51
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
52
  *   parameter values through setters and getters, respectively.
53
  */
54
class EntityRulerModel(override val uid: String)
55
    extends AnnotatorModel[EntityRulerModel]
56
    with HasSimpleAnnotate[EntityRulerModel]
57
    with HasStorageModel {
58

59
  def this() = this(Identifiable.randomUID("ENTITY_RULER"))
1✔
60

61
  private val logger: Logger = LoggerFactory.getLogger("Credentials")
1✔
62

63
  // Keeping this parameter for backward compatibility
64
  private[er] val enablePatternRegex =
65
    new BooleanParam(this, "enablePatternRegex", "Enables regex pattern match")
1✔
66

67
  private[er] val useStorage =
68
    new BooleanParam(this, "useStorage", "Whether to use RocksDB storage to serialize patterns")
1✔
69

70
  private[er] val regexEntities =
71
    new StringArrayParam(this, "regexEntities", "entities defined in regex patterns")
1✔
72

73
  private[er] val entityRulerFeatures: StructFeature[EntityRulerFeatures] =
74
    new StructFeature[EntityRulerFeatures](
1✔
75
      this,
76
      "Structure to store data when RocksDB is not used")
1✔
77

78
  private[er] val sentenceMatch = new BooleanParam(
1✔
79
    this,
80
    "sentenceMatch",
1✔
81
    "Whether to find match at sentence level (regex only). True: sentence level. False: token level")
1✔
82

83
  private[er] val ahoCorasickAutomaton: StructFeature[Option[AhoCorasickAutomaton]] =
84
    new StructFeature[Option[AhoCorasickAutomaton]](this, "AhoCorasickAutomaton")
1✔
85

86
  private[er] def setRegexEntities(value: Array[String]): this.type = set(regexEntities, value)
1✔
87

88
  private[er] def setEntityRulerFeatures(value: EntityRulerFeatures): this.type =
89
    set(entityRulerFeatures, value)
1✔
90

91
  private[er] def setUseStorage(value: Boolean): this.type = set(useStorage, value)
1✔
92

93
  private[er] def setSentenceMatch(value: Boolean): this.type = set(sentenceMatch, value)
×
94

95
  private[er] def setAhoCorasickAutomaton(value: Option[AhoCorasickAutomaton]): this.type =
96
    set(ahoCorasickAutomaton, value)
1✔
97

98
  private var automatonModel: Option[Broadcast[AhoCorasickAutomaton]] = None
1✔
99

100
  def setAutomatonModelIfNotSet(
101
      spark: SparkSession,
102
      automaton: Option[AhoCorasickAutomaton]): this.type = {
103
    if (automatonModel.isEmpty && automaton.isDefined) {
1✔
104
      automatonModel = Some(spark.sparkContext.broadcast(automaton.get))
1✔
105
    }
106
    this
107
  }
108

109
  def getAutomatonModelIfNotSet: Option[AhoCorasickAutomaton] = {
110
    if (automatonModel.isDefined) {
1✔
111
      Some(automatonModel.get.value)
1✔
112
    } else {
113
      if ($$(ahoCorasickAutomaton).isDefined) $$(ahoCorasickAutomaton) else None
1✔
114
    }
115
  }
116

117
  setDefault(useStorage -> false, caseSensitive -> true, enablePatternRegex -> false)
1✔
118

119
  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
120
    * type
121
    */
122
  val inputAnnotatorTypes: Array[String] = Array(DOCUMENT)
1✔
123
  override val optionalInputAnnotatorTypes: Array[String] = Array(TOKEN)
1✔
124
  val outputAnnotatorType: AnnotatorType = CHUNK
1✔
125

126
  override def _transform(
127
      dataset: Dataset[_],
128
      recursivePipeline: Option[PipelineModel]): DataFrame = {
129
    if ($(regexEntities).nonEmpty) {
1✔
130
      val structFields = dataset.schema.fields
1✔
131
        .filter(field => field.metadata.contains("annotatorType"))
1✔
132
        .filter(field => field.metadata.getString("annotatorType") == TOKEN)
1✔
133
      if (structFields.isEmpty) {
1✔
134
        throw new IllegalArgumentException(
×
135
          s"Missing $TOKEN annotator. Regex patterns requires it in your pipeline")
136
      } else {
137
        super._transform(dataset, recursivePipeline)
1✔
138
      }
139
    } else {
140
      super._transform(dataset, recursivePipeline)
1✔
141
    }
142
  }
143

144
  override def beforeAnnotate(dataset: Dataset[_]): Dataset[_] = {
145
    this.setAutomatonModelIfNotSet(dataset.sparkSession, $$(ahoCorasickAutomaton))
1✔
146
    dataset
147
  }
148

149
  /** takes a document and annotations and produces new annotations of this annotator's annotation
150
    * type
151
    *
152
    * @param annotations
153
    *   Annotations that correspond to inputAnnotationCols generated by previous annotators if any
154
    * @return
155
    *   any number of annotations processed for every input annotation. Not necessary one to one
156
    *   relationship
157
    */
158
  def annotate(annotations: Seq[Annotation]): Seq[Annotation] = {
159
    var annotatedEntitiesByKeywords: Seq[Annotation] = Seq()
1✔
160
    val sentences = SentenceSplit.unpack(annotations)
1✔
161
    val annotatedEntitiesByRegex = computeAnnotatedEntitiesByRegex(annotations, sentences)
1✔
162

163
    if (getAutomatonModelIfNotSet.isDefined) {
1✔
164
      annotatedEntitiesByKeywords = sentences.flatMap { sentence =>
1✔
165
        getAutomatonModelIfNotSet.get.searchPatternsInText(sentence)
1✔
166
      }
167
    }
168

169
    annotatedEntitiesByRegex ++ annotatedEntitiesByKeywords
1✔
170
  }
171

172
  private def computeAnnotatedEntitiesByRegex(
173
      annotations: Seq[Annotation],
174
      sentences: Seq[Sentence]): Seq[Annotation] = {
175
    if ($(regexEntities).nonEmpty) {
1✔
176
      val regexPatternsReader =
177
        if ($(useStorage))
1✔
178
          Some(getReader(Database.ENTITY_REGEX_PATTERNS).asInstanceOf[RegexPatternsReader])
1✔
179
        else None
1✔
180

181
      if ($(sentenceMatch)) {
1✔
182
        annotateEntitiesFromRegexPatternsBySentence(sentences, regexPatternsReader)
1✔
183
      } else {
1✔
184
        val tokenizedWithSentences = TokenizedWithSentence.unpack(annotations)
1✔
185
        annotateEntitiesFromRegexPatterns(tokenizedWithSentences, regexPatternsReader)
1✔
186
      }
187
    } else Seq()
1✔
188
  }
189

190
  private def annotateEntitiesFromRegexPatterns(
191
      tokenizedWithSentences: Seq[TokenizedSentence],
192
      regexPatternsReader: Option[RegexPatternsReader]): Seq[Annotation] = {
193

194
    val annotatedEntities = tokenizedWithSentences.flatMap { tokenizedWithSentence =>
1✔
195
      tokenizedWithSentence.indexedTokens.flatMap { indexedToken =>
1✔
196
        val entity = getMatchedEntity(indexedToken.token, regexPatternsReader)
1✔
197
        if (entity.isDefined) {
1✔
198
          val entityMetadata = getEntityMetadata(entity)
1✔
199
          Some(
1✔
200
            Annotation(
1✔
201
              CHUNK,
1✔
202
              indexedToken.begin,
1✔
203
              indexedToken.end,
1✔
204
              indexedToken.token,
1✔
205
              entityMetadata ++ Map("sentence" -> tokenizedWithSentence.sentenceIndex.toString)))
1✔
206
        } else None
1✔
207
      }
208
    }
209

210
    annotatedEntities
211
  }
212

213
  private def getMatchedEntity(
214
      token: String,
215
      regexPatternsReader: Option[RegexPatternsReader]): Option[String] = {
216

217
    val matchesByEntity = $(regexEntities).flatMap { regexEntity =>
1✔
218
      val regexPatterns: Option[Seq[String]] = regexPatternsReader match {
219
        case Some(rpr) => rpr.lookup(regexEntity)
1✔
220
        case None => $$(entityRulerFeatures).regexPatterns.get(regexEntity)
1✔
221
      }
222
      if (regexPatterns.isDefined) {
1✔
223
        val matches = regexPatterns.get.flatMap(regexPattern => regexPattern.r.findFirstIn(token))
1✔
224
        if (matches.nonEmpty) Some(regexEntity) else None
1✔
225
      } else None
×
226
    }.toSeq
1✔
227

228
    if (matchesByEntity.size > 1) {
1✔
229
      logger.warn("More than one entity found. Sending the first element of the array")
1✔
230
    }
231

232
    matchesByEntity.headOption
1✔
233
  }
234

235
  private def getMatchedEntityBySentence(
236
      sentence: Sentence,
237
      regexPatternsReader: Option[RegexPatternsReader]): Array[(IndexedToken, String)] = {
238

239
    val matchesByEntity = $(regexEntities)
1✔
240
      .flatMap { regexEntity =>
1✔
241
        val regexPatterns: Option[Seq[String]] = regexPatternsReader match {
242
          case Some(rpr) => rpr.lookup(regexEntity)
1✔
243
          case None => $$(entityRulerFeatures).regexPatterns.get(regexEntity)
1✔
244
        }
245
        if (regexPatterns.isDefined) {
1✔
246

247
          val resultMatches = regexPatterns.get.flatMap { regexPattern =>
1✔
248
            val matchedResult = regexPattern.r.findFirstMatchIn(sentence.content)
1✔
249
            if (matchedResult.isDefined) {
1✔
250
              val begin = matchedResult.get.start + sentence.start
1✔
251
              val end = matchedResult.get.end + sentence.start - 1
1✔
252
              Some(matchedResult.get.toString(), begin, end, regexEntity)
1✔
253
            } else None
×
254
          }
255

256
          val intervals =
257
            resultMatches.map(resultMatch => List(resultMatch._2, resultMatch._3)).toList
1✔
258
          val mergedIntervals = EntityRulerUtil.mergeIntervals(intervals)
1✔
259
          val filteredMatches =
260
            resultMatches.filter(x => mergedIntervals.contains(List(x._2, x._3)))
1✔
261

262
          if (filteredMatches.nonEmpty) Some(filteredMatches) else None
×
263
        } else None
×
264
      }
265
      .flatten
1✔
266
      .sortBy(_._2)
1✔
267

268
    matchesByEntity.map(matches => (IndexedToken(matches._1, matches._2, matches._3), matches._4))
1✔
269
  }
270

271
  private def annotateEntitiesFromRegexPatternsBySentence(
272
      sentences: Seq[Sentence],
273
      patternsReader: Option[RegexPatternsReader]): Seq[Annotation] = {
274

275
    val annotatedEntities = sentences.flatMap { sentence =>
1✔
276
      val matchedEntities = getMatchedEntityBySentence(sentence, patternsReader)
1✔
277
      matchedEntities.map { case (indexedToken, label) =>
1✔
278
        val entityMetadata = getEntityMetadata(Some(label))
1✔
279
        Annotation(
1✔
280
          CHUNK,
1✔
281
          indexedToken.begin,
1✔
282
          indexedToken.end,
1✔
283
          indexedToken.token,
1✔
284
          entityMetadata ++ Map("sentence" -> sentence.index.toString))
1✔
285
      }
286
    }
287
    annotatedEntities
288
  }
289

290
  private def getEntityMetadata(labelData: Option[String]): Map[String, String] = {
291

292
    val entityMetadata = labelData.get
293
      .split(",")
1✔
294
      .zipWithIndex
1✔
295
      .flatMap { case (metadata, index) =>
1✔
296
        if (index == 0) {
1✔
297
          Map("entity" -> metadata)
1✔
298
        } else Map("id" -> metadata)
1✔
299
      }
300
      .toMap
1✔
301

302
    entityMetadata
303
  }
304

305
  override def deserializeStorage(path: String, spark: SparkSession): Unit = {
306
    if ($(useStorage)) {
1✔
307
      super.deserializeStorage(path: String, spark: SparkSession)
×
308
    }
309
  }
310

311
  override def onWrite(path: String, spark: SparkSession): Unit = {
312
    if ($(useStorage)) {
1✔
313
      super.onWrite(path, spark)
×
314
    }
315
  }
316

317
  protected val databases: Array[Name] = EntityRulerModel.databases
1✔
318

319
  protected def createReader(database: Name, connection: RocksDBConnection): StorageReader[_] = {
320
    new RegexPatternsReader(connection)
1✔
321
  }
322
}
323

324
trait ReadablePretrainedEntityRuler
325
    extends StorageReadable[EntityRulerModel]
326
    with HasPretrained[EntityRulerModel] {
327

328
  override val databases: Array[Name] = Array(ENTITY_REGEX_PATTERNS)
1✔
329

330
  override val defaultModelName: Option[String] = None
1✔
331

332
  override def pretrained(): EntityRulerModel = super.pretrained()
×
333

334
  override def pretrained(name: String): EntityRulerModel = super.pretrained(name)
×
335

336
  override def pretrained(name: String, lang: String): EntityRulerModel =
337
    super.pretrained(name, lang)
×
338

339
  override def pretrained(name: String, lang: String, remoteLoc: String): EntityRulerModel =
340
    super.pretrained(name, lang, remoteLoc)
×
341

342
}
343

344
object EntityRulerModel extends ReadablePretrainedEntityRuler
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc