• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 10656322261

01 Sep 2024 06:18PM UTC coverage: 62.431% (-0.2%) from 62.618%
10656322261

push

github

web-flow
Introducing CamemBertForZeroShotClassification annotator (#14354)

* [SPARKNLP-856] Introducing CamemBertForZeroShotClassification

* [SPARKNLP-856] Adding notebook examples for CamemBertForZeroShotClassification

* [SPARKNLP-856] Adding CamemBertForZeroShotClassification to ResourceDownloader

0 of 2 new or added lines in 1 file covered. (0.0%)

155 existing lines in 39 files now uncovered.

8967 of 14363 relevant lines covered (62.43%)

0.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.31
/src/main/scala/com/johnsnowlabs/nlp/training/PubTator.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.training
18

19
import com.johnsnowlabs.nlp.annotator.{PerceptronModel, SentenceDetector, Tokenizer}
20
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType, DocumentAssembler, Finisher}
21
import org.apache.spark.ml.Pipeline
22
import org.apache.spark.sql.functions._
23
import org.apache.spark.sql.types._
24
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
25

26
/** The PubTator format includes medical papers’ titles, abstracts, and tagged chunks.
27
  *
28
  * For more information see
29
  * [[http://bioportal.bioontology.org/ontologies/EDAM?p=classes&conceptid=format_3783 PubTator Docs]]
30
  * and [[http://github.com/chanzuckerberg/MedMentions MedMentions Docs]].
31
  *
32
  * `readDataset` is used to create a Spark DataFrame from a PubTator text file.
33
  *
34
  * ==Example==
35
  * {{{
36
  * import com.johnsnowlabs.nlp.training.PubTator
37
  *
38
  * val pubTatorFile = "./src/test/resources/corpus_pubtator_sample.txt"
39
  * val pubTatorDataSet = PubTator().readDataset(ResourceHelper.spark, pubTatorFile)
40
  * pubTatorDataSet.show(1)
41
  * +--------+--------------------+--------------------+--------------------+-----------------------+---------------------+-----------------------+
42
  * |  doc_id|      finished_token|        finished_pos|        finished_ner|finished_token_metadata|finished_pos_metadata|finished_label_metadata|
43
  * +--------+--------------------+--------------------+--------------------+-----------------------+---------------------+-----------------------+
44
  * |25763772|[DCTN4, as, a, mo...|[NNP, IN, DT, NN,...|[B-T116, O, O, O,...|   [[sentence, 0], [...| [[word, DCTN4], [...|   [[word, DCTN4], [...|
45
  * +--------+--------------------+--------------------+--------------------+-----------------------+---------------------+-----------------------+
46
  * }}}
47
  */
48
case class PubTator() {
49

50
  def readDataset(spark: SparkSession, path: String, isPaddedToken: Boolean = true): DataFrame = {
51
    val pubtator = spark.sparkContext.textFile(path)
1✔
52
    val titles = pubtator.filter(x => x.contains("|a|") | x.contains("|t|"))
1✔
53
    val titlesText = titles
54
      .map(x => x.split("\\|"))
1✔
55
      .groupBy(_.head)
1✔
56
      .map(x => (x._1.toInt, x._2.foldLeft(Seq[String]())((a, b) => a ++ Seq(b.last))))
1✔
57
      .map(x => (x._1, x._2.mkString(" ")))
1✔
58
    val df = spark.createDataFrame(titlesText).toDF("doc_id", "text")
1✔
59
    val docAsm = new DocumentAssembler().setInputCol("text").setOutputCol("document")
1✔
60
    val setDet = new SentenceDetector().setInputCols("document").setOutputCol("sentence")
1✔
61
    val tknz = new Tokenizer().setInputCols("sentence").setOutputCol("token")
1✔
62
    val pl = new Pipeline().setStages(Array(docAsm, setDet, tknz))
1✔
63
    val nlpDf = pl.fit(df).transform(df)
1✔
64
    val annotations = pubtator.filter(x => !x.contains("|a|") & !x.contains("|t|") & x.nonEmpty)
1✔
65
    val splitAnnotations = annotations
66
      .map(_.split("\\t"))
1✔
67
      .map(x => (x(0), x(1).toInt, x(2).toInt - 1, x(3), x(4), x(5)))
1✔
68
    val docAnnotations = splitAnnotations
69
      .groupBy(_._1)
1✔
70
      .map(x => (x._1, x._2))
1✔
71
      .map(x =>
1✔
72
        (
1✔
73
          x._1.toInt,
1✔
74
          x._2.zipWithIndex
1✔
75
            .map(a =>
1✔
76
              (new Annotation(
1✔
77
                AnnotatorType.CHUNK,
1✔
78
                a._1._2,
1✔
79
                a._1._3,
1✔
80
                a._1._4,
1✔
81
                Map("entity" -> a._1._5, "chunk" -> a._2.toString),
1✔
82
                Array[Float]())))
1✔
83
            .toList))
1✔
84
    val chunkMeta = new MetadataBuilder().putString("annotatorType", AnnotatorType.CHUNK).build()
1✔
85
    val annDf = spark
86
      .createDataFrame(docAnnotations)
87
      .toDF("doc_id", "chunk")
1✔
88
      .withColumn("chunk", col("chunk").as("chunk", chunkMeta))
1✔
89
    val alignedDf =
90
      nlpDf.join(annDf, Seq("doc_id")).selectExpr("doc_id", "sentence", "token", "chunk")
1✔
91
    val iobTagging = udf((tokens: Seq[Row], chunkLabels: Seq[Row]) => {
1✔
92
      val tokenAnnotations = tokens.map(Annotation(_))
1✔
93
      val labelAnnotations = chunkLabels.map(Annotation(_))
1✔
94
      tokenAnnotations.map(ta => {
1✔
95
        val tokenLabel = labelAnnotations.find(la => la.begin <= ta.begin && la.end >= ta.end)
1✔
96
        val tokenTag = {
97
          if (tokenLabel.isEmpty) "O"
1✔
98
          else {
1✔
99
            val tokenCSV = tokenLabel.get.metadata("entity")
1✔
UNCOV
100
            if (tokenCSV == "UnknownType") "O"
×
101
            else {
1✔
102
              val tokenPrefix = if (ta.begin == tokenLabel.get.begin) "B-" else "I-"
1✔
103
              val token = if (isPaddedToken) {
104
                "T" + "%03d".format(tokenCSV.split(",")(0).slice(1, 4).toInt)
1✔
105
              } else tokenCSV
1✔
106
              tokenPrefix + token
1✔
107
            }
108
          }
109
        }
110

111
        Annotation(
1✔
112
          AnnotatorType.NAMED_ENTITY,
1✔
113
          ta.begin,
1✔
114
          ta.end,
1✔
115
          tokenTag,
116
          Map("word" -> ta.result))
1✔
117
      })
118
    })
119
    val labelMeta =
120
      new MetadataBuilder().putString("annotatorType", AnnotatorType.NAMED_ENTITY).build()
1✔
121
    val taggedDf =
122
      alignedDf.withColumn("label", iobTagging(col("token"), col("chunk")).as("label", labelMeta))
1✔
123

124
    val pos =
125
      PerceptronModel.pretrained().setInputCols(Array("sentence", "token")).setOutputCol("pos")
1✔
126
    val finisher = new Finisher().setInputCols("token", "pos", "label").setIncludeMetadata(true)
1✔
127
    val finishingPipeline = new Pipeline().setStages(Array(pos, finisher))
1✔
128
    finishingPipeline
129
      .fit(taggedDf)
130
      .transform(taggedDf)
131
      .withColumnRenamed("finished_label", "finished_ner") // CoNLL generator expects finished_ner
1✔
132
  }
133
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc