• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4951808959

pending completion
4951808959

Pull #13792

github

GitHub
Merge efe6b42df into ef7906c5e
Pull Request #13792: SPARKNLP-825 Adding multilabel param

7 of 7 new or added lines in 1 file covered. (100.0%)

8637 of 13128 relevant lines covered (65.79%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

66.67
/src/main/scala/com/johnsnowlabs/nlp/annotators/common/Tagged.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.common
18

19
import com.johnsnowlabs.ml.crf.TextSentenceLabels
20
import com.johnsnowlabs.nlp.Annotation
21
import com.johnsnowlabs.nlp.AnnotatorType.{NAMED_ENTITY, POS}
22
import com.johnsnowlabs.nlp.annotators.common.Annotated.{NerTaggedSentence, PosTaggedSentence}
23
import org.apache.spark.sql.{Dataset, Row}
24

25
import java.util
26
import scala.collection.Map
27
import scala.util.Random
28

29
trait Tagged[T >: TaggedSentence <: TaggedSentence] extends Annotated[T] {
30
  val emptyTag = "O"
1✔
31

32
  override def unpack(annotations: Seq[Annotation]): Seq[T] = {
33

34
    val tokenized = TokenizedWithSentence.unpack(annotations)
1✔
35

36
    val tagAnnotations = annotations
37
      .filter(a => a.annotatorType == annotatorType)
1✔
38
      .sortBy(a => a.begin)
1✔
39
      .toIterator
1✔
40

41
    var annotation: Option[Annotation] = None
1✔
42

43
    tokenized.map { sentence =>
1✔
44
      val tokens = sentence.indexedTokens.map { token =>
1✔
45
        while (tagAnnotations.hasNext && (annotation.isEmpty || annotation.get.begin < token.begin))
1✔
46
          annotation = Some(tagAnnotations.next)
1✔
47

48
        val tag = if (annotation.isDefined && annotation.get.begin == token.begin) {
1✔
49
          annotation.get.result
1✔
50
        } else
51
          emptyTag
×
52
        // etract the confidence score belong to the tag
53
        val metadata =
54
          try {
55
            if (annotation.get.metadata.isDefinedAt("confidence"))
1✔
56
              Map(tag -> annotation.get.metadata("confidence"))
×
57
            else
58
              Map(tag -> annotation.get.metadata(tag))
1✔
59
          } catch {
60
            case _: Exception =>
61
              Map.empty[String, String]
1✔
62
          }
63

64
        IndexedTaggedWord(token.token, tag, token.begin, token.end, metadata = metadata)
1✔
65
      }
66

67
      new TaggedSentence(tokens)
1✔
68
    }
69
  }
70

71
  override def pack(items: Seq[T]): Seq[Annotation] = {
72
    items.zipWithIndex.flatMap { case (item, sentenceIndex) =>
1✔
73
      item.indexedTaggedWords.map { tag =>
1✔
74
        val metadata: Map[String, String] = if (tag.confidence.isDefined) {
1✔
75
          Map("word" -> tag.word) ++ tag.confidence
×
76
            .getOrElse(Array.empty[Map[String, String]])
×
77
            .flatten ++
×
78
            Map("sentence" -> sentenceIndex.toString)
×
79
        } else {
80
          Map("word" -> tag.word) ++ Map.empty[String, String] ++ Map(
1✔
81
            "sentence" -> sentenceIndex.toString)
1✔
82
        }
83
        new Annotation(annotatorType, tag.begin, tag.end, tag.tag, metadata)
1✔
84
      }
85
    }
86
  }
87

88
  /** Method is usefull for testing.
89
    *
90
    * @param dataset
91
    *   dataset row
92
    * @param taggedCols
93
    *   list of tagged columns
94
    * @param labelColumn
95
    *   label column
96
    * @return
97
    */
98
  def collectLabeledInstances(
99
      dataset: Dataset[Row],
100
      taggedCols: Seq[String],
101
      labelColumn: String): Array[(TextSentenceLabels, T)] = {
102

103
    dataset
104
      .select(labelColumn, taggedCols: _*)
105
      .collect()
×
106
      .flatMap { row =>
×
107
        val labelAnnotations = getAnnotations(row, 0)
×
108
        val sentenceAnnotations =
109
          (1 to taggedCols.length).flatMap(idx => getAnnotations(row, idx))
×
110
        val sentences = unpack(sentenceAnnotations)
×
111
        val labels = getLabelsFromTaggedSentences(sentences, labelAnnotations)
×
112
        labels.zip(sentences)
×
113
      }
114
  }
115

116
  def getAnnotations(row: Row, colNum: Int): Seq[Annotation] = {
117
    row.getAs[Seq[Row]](colNum).map(obj => Annotation(obj))
1✔
118
  }
119

120
  protected def getLabelsFromSentences(
121
      sentences: Seq[WordpieceEmbeddingsSentence],
122
      labelAnnotations: Seq[Annotation]): Seq[TextSentenceLabels] = {
123
    val sortedLabels = labelAnnotations.sortBy(a => a.begin).toArray
1✔
124

125
    sentences.map { sentence =>
1✔
126
      // Extract labels only for wordpiece that are at the begin of tokens
127
      val tokens = sentence.tokens.filter(t => t.isWordStart)
1✔
128
      val labels = tokens.map { w =>
1✔
129
        val tag = Annotation
130
          .searchCoverage(sortedLabels, w.begin, w.end)
1✔
131
          .map(a => a.result)
1✔
132
          .headOption
133
          .getOrElse(emptyTag)
×
134

135
        tag
136
      }
137
      TextSentenceLabels(labels)
1✔
138
    }
139
  }
140

141
  protected def getLabelsFromTaggedSentences(
142
      sentences: Seq[TaggedSentence],
143
      labelAnnotations: Seq[Annotation]): Seq[TextSentenceLabels] = {
144
    val sortedLabels = labelAnnotations.sortBy(a => a.begin).toArray
1✔
145

146
    sentences.map { sentence =>
1✔
147
      val labels = sentence.indexedTaggedWords.map { w =>
1✔
148
        val tag = Annotation
149
          .searchCoverage(sortedLabels, w.begin, w.end)
1✔
150
          .map(a => a.result)
1✔
151
          .headOption
152
          .getOrElse(emptyTag)
1✔
153

154
        tag
155
      }
156
      TextSentenceLabels(labels)
1✔
157
    }
158
  }
159
}
160

161
object PosTagged extends Tagged[PosTaggedSentence] {
162
  override def annotatorType: String = POS
1✔
163
}
164

165
object NerTagged extends Tagged[NerTaggedSentence] {
166
  override def annotatorType: String = NAMED_ENTITY
1✔
167

168
  def collectTrainingInstancesWithPos(
169
      dataset: Dataset[Row],
170
      posTaggedCols: Seq[String],
171
      labelColumn: String)
172
      : Array[(TextSentenceLabels, PosTaggedSentence, WordpieceEmbeddingsSentence)] = {
173

174
    val annotations = dataset
175
      .select(labelColumn, posTaggedCols: _*)
176
      .collect()
1✔
177

178
    annotations
179
      .flatMap { row =>
1✔
180
        val labelAnnotations = this.getAnnotations(row, 0)
1✔
181
        val sentenceAnnotations =
182
          (1 to posTaggedCols.length).flatMap(idx => getAnnotations(row, idx))
1✔
183
        val sentences = PosTagged
184
          .unpack(sentenceAnnotations)
185
          .filter(s => s.indexedTaggedWords.nonEmpty)
1✔
186
          .sortBy(s => s.indexedTaggedWords.head.begin)
1✔
187

188
        val withEmbeddings = WordpieceEmbeddingsSentence
189
          .unpack(sentenceAnnotations)
190
          .filter(s => s.tokens.nonEmpty)
1✔
191
          .sortBy(s => s.tokens.head.begin)
×
192

193
        val labels = getLabelsFromTaggedSentences(sentences, labelAnnotations)
1✔
194
        labels
195
          .zip(sentences zip withEmbeddings)
1✔
196
          .map { case (l, (s, w)) => (l, s, w) }
1✔
197
      }
198
  }
199

200
  /** FIXME: ColNums not always in the given order */
201
  def iterateOnDataframe(
202
      dataset: Dataset[Row],
203
      sentenceCols: Seq[String],
204
      labelColumn: String,
205
      batchSize: Int): Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = {
206

207
    new Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] {
×
208

209
      import com.johnsnowlabs.nlp.annotators.common.DatasetHelpers._
210

211
      // Send batches, don't collect(), only keeping a single batch in memory anytime
212
      val it: util.Iterator[Row] = dataset
213
        .select(labelColumn, sentenceCols: _*)
214
        .randomize // to improve training
215
        .toLocalIterator()
×
216

217
      // create a batch
218
      override def next(): Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)] = {
219
        var count = 0
×
220
        var thisBatch = Array.empty[(TextSentenceLabels, WordpieceEmbeddingsSentence)]
×
221

222
        while (it.hasNext && count < batchSize) {
×
223
          count += 1
×
224
          val nextRow = it.next
×
225

226
          val labelAnnotations = getAnnotations(nextRow, 0)
×
227
          val sentenceAnnotations =
228
            (1 to sentenceCols.length).flatMap(idx => getAnnotations(nextRow, idx))
×
229
          val sentences = WordpieceEmbeddingsSentence.unpack(sentenceAnnotations)
×
230
          val labels = getLabelsFromSentences(sentences, labelAnnotations)
×
231
          val thisOne = labels.zip(sentences)
×
232

233
          thisBatch = thisBatch ++ thisOne
×
234
        }
235
        thisBatch
236
      }
237

238
      override def hasNext: Boolean = it.hasNext
×
239
    }
240

241
  }
242

243
  /** FIXME: ColNums not always in the given order */
244
  def iterateOnArray(
245
      inputArray: Array[Row],
246
      sentenceCols: Seq[String],
247
      batchSize: Int): Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = {
248
    import com.johnsnowlabs.nlp.annotators.common.DatasetHelpers._
249

250
    slice(
1✔
251
      Random
252
        .shuffle(inputArray.toSeq)
1✔
253
        .flatMap { row =>
1✔
254
          val labelAnnotations = this.getAnnotations(row, 0)
1✔
255
          val sentenceAnnotations =
256
            (1 to sentenceCols.length).flatMap(idx => getAnnotations(row, idx))
1✔
257
          val sentences = WordpieceEmbeddingsSentence.unpack(sentenceAnnotations)
1✔
258
          val labels = getLabelsFromSentences(sentences, labelAnnotations)
1✔
259
          labels.zip(sentences)
1✔
260
        },
261
      batchSize)
262
  }
263
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc