• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 9929262317

14 Jul 2024 04:27PM UTC coverage: 62.618% (+0.008%) from 62.61%
9929262317

push

github

maziyarpanahi
Bump version to 5.4.1 [run doc]

2 of 2 new or added lines in 2 files covered. (100.0%)

52 existing lines in 36 files now uncovered.

8970 of 14325 relevant lines covered (62.62%)

0.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.52
/src/main/scala/com/johnsnowlabs/nlp/training/CoNLL.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.training
18

19
import com.johnsnowlabs.nlp.annotators.common.Annotated.{NerTaggedSentence, PosTaggedSentence}
20
import com.johnsnowlabs.nlp.annotators.common._
21
import com.johnsnowlabs.nlp.util.io.{ExternalResource, OutputHelper, ReadAs, ResourceHelper}
22
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType, DocumentAssembler}
23
import org.apache.spark.sql.types._
24
import org.apache.spark.sql.{Dataset, SparkSession}
25
import org.apache.spark.storage.StorageLevel
26

27
import scala.collection.mutable.ArrayBuffer
28

29
case class CoNLLDocument(
30
    text: String,
31
    nerTagged: Seq[NerTaggedSentence],
32
    posTagged: Seq[PosTaggedSentence],
33
    docId: Option[String])
34

35
/** Helper class to load a CoNLL type dataset for training.
36
  *
37
  * The dataset should be in the format of
38
  * [[https://www.clips.uantwerpen.be/conll2003/ner/ CoNLL 2003]] and needs to be specified with
39
  * `readDataset`. Other CoNLL datasets are not supported.
40
  *
41
  * Two types of input paths are supported,
42
  *
43
  * Folder: this is a path ending in `*`, and representing a collection of CoNLL files within a
44
  * directory. E.g., 'path/to/multiple/conlls/*' Using this pattern will result in all the
45
  * files being read into a single Dataframe. Some constraints apply on the schemas of the
46
  * multiple files.
47
  *
48
  * File: this is a path to a single file. E.g., 'path/to/single_file.conll'
49
  *
50
  * ==Example==
51
  * {{{
52
  * val trainingData = CoNLL().readDataset(spark, "src/test/resources/conll2003/eng.train")
53
  * trainingData.selectExpr("text", "token.result as tokens", "pos.result as pos", "label.result as label")
54
  *   .show(3, false)
55
  * +------------------------------------------------+----------------------------------------------------------+-------------------------------------+-----------------------------------------+
56
  * |text                                            |tokens                                                    |pos                                  |label                                    |
57
  * +------------------------------------------------+----------------------------------------------------------+-------------------------------------+-----------------------------------------+
58
  * |EU rejects German call to boycott British lamb .|[EU, rejects, German, call, to, boycott, British, lamb, .]|[NNP, VBZ, JJ, NN, TO, VB, JJ, NN, .]|[B-ORG, O, B-MISC, O, O, O, B-MISC, O, O]|
59
  * |Peter Blackburn                                 |[Peter, Blackburn]                                        |[NNP, NNP]                           |[B-PER, I-PER]                           |
60
  * |BRUSSELS 1996-08-22                             |[BRUSSELS, 1996-08-22]                                    |[NNP, CD]                            |[B-LOC, O]                               |
61
  * +------------------------------------------------+----------------------------------------------------------+-------------------------------------+-----------------------------------------+
62
  *
63
  * trainingData.printSchema
64
  * root
65
  *  |-- text: string (nullable = true)
66
  *  |-- document: array (nullable = false)
67
  *  |    |-- element: struct (containsNull = true)
68
  *  |    |    |-- annotatorType: string (nullable = true)
69
  *  |    |    |-- begin: integer (nullable = false)
70
  *  |    |    |-- end: integer (nullable = false)
71
  *  |    |    |-- result: string (nullable = true)
72
  *  |    |    |-- metadata: map (nullable = true)
73
  *  |    |    |    |-- key: string
74
  *  |    |    |    |-- value: string (valueContainsNull = true)
75
  *  |    |    |-- embeddings: array (nullable = true)
76
  *  |    |    |    |-- element: float (containsNull = false)
77
  *  |-- sentence: array (nullable = false)
78
  *  |    |-- element: struct (containsNull = true)
79
  *  |    |    |-- annotatorType: string (nullable = true)
80
  *  |    |    |-- begin: integer (nullable = false)
81
  *  |    |    |-- end: integer (nullable = false)
82
  *  |    |    |-- result: string (nullable = true)
83
  *  |    |    |-- metadata: map (nullable = true)
84
  *  |    |    |    |-- key: string
85
  *  |    |    |    |-- value: string (valueContainsNull = true)
86
  *  |    |    |-- embeddings: array (nullable = true)
87
  *  |    |    |    |-- element: float (containsNull = false)
88
  *  |-- token: array (nullable = false)
89
  *  |    |-- element: struct (containsNull = true)
90
  *  |    |    |-- annotatorType: string (nullable = true)
91
  *  |    |    |-- begin: integer (nullable = false)
92
  *  |    |    |-- end: integer (nullable = false)
93
  *  |    |    |-- result: string (nullable = true)
94
  *  |    |    |-- metadata: map (nullable = true)
95
  *  |    |    |    |-- key: string
96
  *  |    |    |    |-- value: string (valueContainsNull = true)
97
  *  |    |    |-- embeddings: array (nullable = true)
98
  *  |    |    |    |-- element: float (containsNull = false)
99
  *  |-- pos: array (nullable = false)
100
  *  |    |-- element: struct (containsNull = true)
101
  *  |    |    |-- annotatorType: string (nullable = true)
102
  *  |    |    |-- begin: integer (nullable = false)
103
  *  |    |    |-- end: integer (nullable = false)
104
  *  |    |    |-- result: string (nullable = true)
105
  *  |    |    |-- metadata: map (nullable = true)
106
  *  |    |    |    |-- key: string
107
  *  |    |    |    |-- value: string (valueContainsNull = true)
108
  *  |    |    |-- embeddings: array (nullable = true)
109
  *  |    |    |    |-- element: float (containsNull = false)
110
  *  |-- label: array (nullable = false)
111
  *  |    |-- element: struct (containsNull = true)
112
  *  |    |    |-- annotatorType: string (nullable = true)
113
  *  |    |    |-- begin: integer (nullable = false)
114
  *  |    |    |-- end: integer (nullable = false)
115
  *  |    |    |-- result: string (nullable = true)
116
  *  |    |    |-- metadata: map (nullable = true)
117
  *  |    |    |    |-- key: string
118
  *  |    |    |    |-- value: string (valueContainsNull = true)
119
  *  |    |    |-- embeddings: array (nullable = true)
120
  *  |    |    |    |-- element: float (containsNull = false)
121
  * }}}
122
  *
123
  * @param documentCol
124
  *   Name of the `DOCUMENT` Annotator type column
125
  * @param sentenceCol
126
  *   Name of the Sentences of `DOCUMENT` Annotator type column
127
  * @param tokenCol
128
  *   Name of the `TOKEN` Annotator type column
129
  * @param posCol
130
  *   Name of the `POS` Annotator type column
131
  * @param conllLabelIndex
132
  *   Index of the column for NER Label in the dataset
133
  * @param conllPosIndex
134
  *   Index of the column for the POS tags in the dataset
135
  * @param conllDocIdCol
136
  *   Name of the column for the text in the dataset
137
  * @param conllTextCol
138
  *   Name of the column for the text in the dataset
139
  * @param labelCol
140
  *   Name of the `NAMED_ENTITY` Annotator type column
141
  * @param explodeSentences
142
  *   Whether to explode each sentence to a separate row
143
  * @param delimiter
144
  *   Delimiter used to separate columns inside CoNLL file
145
  * @param includeDocId
146
  *   Whether to try and parse the document id from the third item in the -DOCSTART- line (X if
147
  *   not found)
148
  */
149
case class CoNLL(
150
    documentCol: String = "document",
151
    sentenceCol: String = "sentence",
152
    tokenCol: String = "token",
153
    posCol: String = "pos",
154
    conllLabelIndex: Int = 3,
155
    conllPosIndex: Int = 1,
156
    conllDocIdCol: String = "doc_id",
157
    conllTextCol: String = "text",
158
    labelCol: String = "label",
159
    explodeSentences: Boolean = true,
160
    delimiter: String = " ",
161
    includeDocId: Boolean = false) {
162
  /*
163
    Reads Dataset in CoNLL format and pack it into docs
164
   */
165
  def readDocs(er: ExternalResource): Seq[CoNLLDocument] = {
166
    val lines = ResourceHelper.parseLines(er)
1✔
167

168
    readLines(lines)
1✔
169
  }
170

171
  def clearTokens(tokens: Array[IndexedTaggedWord]): Array[IndexedTaggedWord] = {
172
    tokens.filter(t => t.word.trim().nonEmpty)
1✔
173
  }
174

175
  def readLines(lines: Array[String]): Seq[CoNLLDocument] = {
176
    var docId: Option[String] = None
1✔
177
    val doc = new StringBuilder()
1✔
178
    val lastSentence = ArrayBuffer.empty[(IndexedTaggedWord, IndexedTaggedWord)]
1✔
179

180
    val sentences = ArrayBuffer.empty[(TaggedSentence, TaggedSentence)]
1✔
181

182
    def addSentence(): Unit = {
183
      val nerTokens = clearTokens(lastSentence.map(t => t._1).toArray)
1✔
184
      val posTokens = clearTokens(lastSentence.map(t => t._2).toArray)
1✔
185

186
      if (nerTokens.nonEmpty) {
1✔
187
        assert(posTokens.nonEmpty)
1✔
188

189
        val ner = TaggedSentence(nerTokens)
1✔
190
        val pos = TaggedSentence(posTokens)
1✔
191

192
        sentences.append((ner, pos))
1✔
193
        lastSentence.clear()
1✔
194
      }
195
    }
196

197
    def closeDocument = {
198

199
      val result = (doc.toString, sentences.toList, docId)
1✔
200
      doc.clear()
1✔
201
      sentences.clear()
1✔
202

203
      if (result._1.nonEmpty) {
1✔
204
        Some(result._1, result._2, if (includeDocId) docId else None)
1✔
205
      } else
206
        None
1✔
207
    }
208

209
    val docs = lines
210
      .flatMap { line =>
1✔
211
        val items = line.trim.split(delimiter)
1✔
212
        if (items.nonEmpty && items(0) == "-DOCSTART-") {
1✔
213
          addSentence()
1✔
214
          val closedDoc = closeDocument
215
          docId = items.lift(2)
1✔
216
          closedDoc
1✔
217
        } else if (items.length <= 1) {
1✔
UNCOV
218
          if (!explodeSentences && (doc.nonEmpty && !doc.endsWith(
×
219
              System.lineSeparator) && lastSentence.nonEmpty)) {
×
220
            doc.append(System.lineSeparator * 2)
×
221
          }
222
          addSentence()
1✔
223
          if (explodeSentences)
1✔
224
            closeDocument
1✔
225
          else
226
            None
×
227
        } else if (items.length > conllLabelIndex) {
1✔
228
          if (doc.nonEmpty && !doc.endsWith(System.lineSeparator()))
1✔
229
            doc.append(delimiter)
1✔
230

231
          val begin = doc.length
1✔
232
          doc.append(items(0))
1✔
233
          val end = doc.length - 1
1✔
234
          val tag = items(conllLabelIndex)
1✔
235
          val posTag = items(conllPosIndex)
1✔
236
          val ner = IndexedTaggedWord(items(0), tag, begin, end)
1✔
237
          val pos = IndexedTaggedWord(items(0), posTag, begin, end)
1✔
238
          lastSentence.append((ner, pos))
1✔
239
          None
1✔
240
        } else {
241
          None
×
242
        }
243
      }
244

245
    addSentence()
1✔
246

247
    val last = if (doc.nonEmpty) Seq((doc.toString, sentences.toList, docId)) else Seq.empty
1✔
248

249
    (docs ++ last).map {
1✔
250
      case (text, textSentences: Seq[(NerTaggedSentence, PosTaggedSentence)], docId) =>
251
        val (ner, pos) = textSentences.unzip
1✔
252
        CoNLLDocument(text, ner, pos, docId)
1✔
253
    }
254
  }
255

256
  def packNerTagged(sentences: Seq[NerTaggedSentence]): Seq[Annotation] = {
257
    NerTagged.pack(sentences)
1✔
258
  }
259

260
  def packAssembly(text: String, isTraining: Boolean = true): Seq[Annotation] = {
261
    new DocumentAssembler()
262
      .assemble(text, Map("training" -> isTraining.toString))
1✔
263
  }
264

265
  def packSentence(text: String, sentences: Seq[TaggedSentence]): Seq[Annotation] = {
266
    val indexedSentences = sentences.zipWithIndex.map { case (sentence, index) =>
1✔
267
      val start = sentence.indexedTaggedWords.map(t => t.begin).min
1✔
268
      val end = sentence.indexedTaggedWords.map(t => t.end).max
1✔
269
      val sentenceText = text.substring(start, end + 1)
1✔
270
      new Sentence(sentenceText, start, end, index)
1✔
271
    }
272

273
    SentenceSplit.pack(indexedSentences)
1✔
274
  }
275

276
  def packTokenized(text: String, sentences: Seq[TaggedSentence]): Seq[Annotation] = {
277
    val tokenizedSentences = sentences.zipWithIndex.map { case (sentence, index) =>
1✔
278
      val tokens = sentence.indexedTaggedWords.map(t => IndexedToken(t.word, t.begin, t.end))
1✔
279
      TokenizedSentence(tokens, index)
1✔
280
    }
281

282
    TokenizedWithSentence.pack(tokenizedSentences)
1✔
283
  }
284

285
  def packPosTagged(sentences: Seq[TaggedSentence]): Seq[Annotation] = {
286
    PosTagged.pack(sentences)
1✔
287
  }
288

289
  def removeSurroundingHyphens(text: String) =
290
    "-(.+)-".r.findFirstMatchIn(text).map(_.group(1)).getOrElse(text)
1✔
291

292
  val annotationType: ArrayType = ArrayType(Annotation.dataType)
1✔
293

294
  def getAnnotationType(
295
      column: String,
296
      annotatorType: String,
297
      addMetadata: Boolean = true): StructField = {
298
    if (!addMetadata)
1✔
299
      StructField(column, annotationType, nullable = false)
×
300
    else {
1✔
301
      val metadataBuilder: MetadataBuilder = new MetadataBuilder()
1✔
302
      metadataBuilder.putString("annotatorType", annotatorType)
1✔
303
      StructField(column, annotationType, nullable = false, metadataBuilder.build)
1✔
304
    }
305
  }
306

307
  def schema: StructType = {
308
    val docId = StructField(conllDocIdCol, StringType)
1✔
309
    val text = StructField(conllTextCol, StringType)
1✔
310
    val doc = getAnnotationType(documentCol, AnnotatorType.DOCUMENT)
1✔
311
    val sentence = getAnnotationType(sentenceCol, AnnotatorType.DOCUMENT)
1✔
312
    val token = getAnnotationType(tokenCol, AnnotatorType.TOKEN)
1✔
313
    val pos = getAnnotationType(posCol, AnnotatorType.POS)
1✔
314
    val label = getAnnotationType(labelCol, AnnotatorType.NAMED_ENTITY)
1✔
315

316
    if (includeDocId)
1✔
317
      StructType(Seq(docId, text, doc, sentence, token, pos, label))
1✔
318
    else
319
      StructType(Seq(text, doc, sentence, token, pos, label))
1✔
320
  }
321

322
  private def coreTransformation(doc: CoNLLDocument) = {
323
    val text = doc.text
1✔
324
    val labels = packNerTagged(doc.nerTagged)
1✔
325
    val docs = packAssembly(text)
1✔
326
    val sentences = packSentence(text, doc.nerTagged)
1✔
327
    val tokenized = packTokenized(text, doc.nerTagged)
1✔
328
    val posTagged = packPosTagged(doc.posTagged)
1✔
329
    (text, docs, sentences, tokenized, posTagged, labels)
1✔
330
  }
331

332
  private def coreTransformationWithDocId(doc: CoNLLDocument) = {
333
    val docId = removeSurroundingHyphens(doc.docId.getOrElse("X"))
1✔
334
    val (text, docs, sentences, tokenized, posTagged, labels) = coreTransformation(doc)
1✔
335
    (docId, text, docs, sentences, tokenized, posTagged, labels)
1✔
336
  }
337

338
  def packDocs(docs: Seq[CoNLLDocument], spark: SparkSession): Dataset[_] = {
339
    val preDf = if (includeDocId) {
1✔
340
      spark.createDataFrame(docs.map(coreTransformationWithDocId))
1✔
341
    } else {
342
      spark.createDataFrame(docs.map(coreTransformation))
1✔
343
    }
344
    spark.createDataFrame(preDf.rdd, schema)
1✔
345
  }
346

347
  def readDataset(
348
      spark: SparkSession,
349
      path: String,
350
      readAs: String = ReadAs.TEXT.toString,
351
      parallelism: Int = 8,
352
      storageLevel: StorageLevel = StorageLevel.DISK_ONLY): Dataset[_] = {
353
    if (path.endsWith("*")) {
1✔
354
      val rdd = spark.sparkContext
355
        .wholeTextFiles(OutputHelper.parsePath(path), minPartitions = parallelism)
356
        .flatMap { case (_, content) =>
357
          val lines = content.split(System.lineSeparator)
358
          readLines(lines)
359
        }
360
        .persist(storageLevel)
1✔
361

362
      val preDf = if (includeDocId) {
1✔
363
        spark.createDataFrame(rdd.map(coreTransformationWithDocId))
1✔
364
      } else {
UNCOV
365
        spark.createDataFrame(rdd.map(coreTransformation))
×
366
      }
367
      spark.createDataFrame(preDf.rdd, schema)
1✔
368
    } else {
1✔
369
      val er = ExternalResource(path, readAs, Map("format" -> "text"))
1✔
370
      packDocs(readDocs(er), spark)
1✔
371
    }
372
  }
373

374
  def readDatasetFromLines(lines: Array[String], spark: SparkSession): Dataset[_] = {
375
    packDocs(readLines(lines), spark)
1✔
376
  }
377
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc