• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4992350528

pending completion
4992350528

Pull #13797

github

GitHub
Merge 424c7ff18 into ef7906c5e
Pull Request #13797: SPARKNLP-835: ProtectedParam and ProtectedFeature

24 of 24 new or added lines in 6 files covered. (100.0%)

8643 of 13129 relevant lines covered (65.83%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.45
/src/main/scala/com/johnsnowlabs/nlp/annotators/parser/dep/GreedyTransition/DependencyMaker.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.parser.dep.GreedyTransition
18

19
import com.johnsnowlabs.nlp.annotators.parser.dep.{Perceptron, Tagger}
20

21
/** Inspired on https://github.com/mdda/ConciseGreedyDependencyParser-in-Scala */
22
class DependencyMaker(tagger: Tagger) extends Serializable {
23
  val SHIFT: Move = 0; val RIGHT: Move = 1; val LEFT: Move = 2; val INVALID: Move = -1
1✔
24
  def movesString(s: Set[Move]) = {
25
    val moveNames = Vector[ClassName]("INVALID", "SHIFT", "RIGHT", "LEFT") // NB: requires a +1
×
26
    s.toList.sorted.map(i => moveNames(i + 1)).mkString("{", ", ", "}")
×
27
  }
28

29
  private val perceptron = new Perceptron(3)
1✔
30

31
  case class ParseState(
32
      n: Int,
33
      heads: Vector[Int],
34
      lefts: Vector[List[Int]],
35
      rights: Vector[List[Int]]) { // NB: Insert at start, not at end...
36
    // This makes the word at 'child' point to head and adds the child to the appropriate left/right list of head
37
    def add(head: Int, child: Int): ParseState = {
38
      if (child < head) {
1✔
39
        ParseState(
1✔
40
          n,
1✔
41
          heads.updated(child, head),
1✔
42
          lefts.updated(head, child :: lefts(head)),
1✔
43
          rights)
1✔
44
      } else {
45
        ParseState(
1✔
46
          n,
1✔
47
          heads.updated(child, head),
1✔
48
          lefts,
1✔
49
          rights.updated(head, child :: rights(head)))
1✔
50
      }
51
    }
52
  }
53

54
  def ParseStateInit(n: Int): ParseState = {
55
    // heads are the dependencies for each word in the sentence, except the last one (the ROOT)
56
    val heads = Vector.fill(n)(0: Int) // i.e. (0, .., n-1)
1✔
57

58
    // Each possible head (including ROOT) has a (lefts) and (rights) list, initially none
59
    // Entries (0, ..., n-1) are words, (n) is the 'ROOT'  ('to' is INCLUSIVE)
60
    val lefts = (0 to n).map(i => List[Int]()).toVector
1✔
61
    val rights = (0 to n).map(i => List[Int]()).toVector
1✔
62
    ParseState(n, heads, lefts, rights)
1✔
63
  }
64

65
  case class CurrentState(i: Int, stack: List[Int], parse: ParseState) {
66
    def transition(move: Move): CurrentState = move match {
67
      // i either increases and lengthens the stack,
68
      case SHIFT => CurrentState(i + 1, i :: stack, parse)
1✔
69
      // or stays the same, and shortens the stack, and manipulates left&right
70
      case RIGHT =>
71
        CurrentState(i, stack.tail, parse.add(stack.tail.head, stack.head)) // as in Arc-Standard
1✔
72
      case LEFT => CurrentState(i, stack.tail, parse.add(i, stack.head)) // as in Arc-Eager
1✔
73
    }
74

75
    def getValidMoves: Set[Move] =
76
      List[Move]( // only depends on stack_depth (not parse itself)
77
        if (i < parse.n) SHIFT
1✔
78
        else INVALID, // i.e. not yet at the last word in sentence  // was n-1
1✔
79
        if (stack.length >= 2) RIGHT else INVALID,
1✔
80
        if (stack.length >= 1) LEFT else INVALID // Original version
1✔
81
        // if(stack.length>=1 && stack.head != parse.n)  LEFT  else INVALID // See page 405 for second condition
82
      ).filterNot(_ == INVALID).toSet
1✔
83

84
    def getGoldMoves(goldHeads: Vector[Int]): Set[Move] = {
85
      // See :  Goldberg and Nivre (2013) :: Training Deterministic Parsers with Non-Deterministic Oracles, TACL 2013
86
      //        https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00237/43216/Training-Deterministic-Parsers-with-Non
87
      //        Method implemented == "dynamic-oracle Arc-Hybrid" (bottom left of page 405, top right of page 411)
88
      def depsBetween(target: Int, others: List[Int]) = {
89
        others.exists(word => goldHeads(word) == target || goldHeads(target) == word)
1✔
90
      }
91

92
      val valid = getValidMoves
1✔
93
      // println(s"GetGold valid moves = ${moves_str(valid)}")
94

95
      if (stack.isEmpty || (valid.contains(SHIFT) && goldHeads(i) == stack.head)) {
1✔
96
        // println(" gold move shortcut : {SHIFT}")
97
        Set(SHIFT) // First condition obvious, second rather weird
1✔
98
      } else if (goldHeads(stack.head) == i) {
1✔
99
        // println(" gold move shortcut : {LEFT}")
100
        Set(LEFT) // This move is a must, since the gold_heads tell us to do it
1✔
101
      } else {
1✔
102
        // Original Python logic has been flipped over by constructing a 'val non_gold' and return 'valid - non_gold'
103
        // println(s" gold move logic required")
104

105
        // If the word second in the stack is its gold head, Left is incorrect
106
        val leftIncorrect = stack.length >= 2 && goldHeads(stack.head) == stack.tail.head
1✔
107

108
        // If there are any dependencies between i and the stack, pushing i will lose them.
109
        val dontPushI =
110
          valid.contains(SHIFT) && depsBetween(
1✔
111
            i,
1✔
112
            stack
1✔
113
          ) // containing SHIFT protects us against running over end of words
114

115
        // If there are any dependencies between the stackhead and the remainder of the buffer, popping the stack will lose them.
116
        val dontPopStack =
117
          depsBetween(stack.head, ((i + 1) until parse.n).toList) // UNTIL is EXCLUSIVE of top
1✔
118

119
        val nonGold = List[Move](
120
          if (leftIncorrect) LEFT else INVALID,
1✔
121
          if (dontPushI) SHIFT else INVALID,
1✔
122
          if (dontPopStack) LEFT else INVALID,
1✔
123
          if (dontPopStack) RIGHT else INVALID).toSet
1✔
124
        // println(s" gold move logic implies  : non_gold = ${moves_str(non_gold)}")
125

126
        // return the (remaining) moves, which are 'gold'
127
        valid -- nonGold
1✔
128
      }
129
    }
130

131
    def extractFeatures(words: Vector[Word], tags: Vector[ClassName]): Map[Feature, Score] = {
132
      def getStackContext[T <: String](data: Vector[T]): (T, T, T) =
133
        ( // Applies to both Word and ClassName (depth is implict from stack length)
1✔
134
          // NB: Always expecting 3 entries back...
135
          if (stack.length > 0) data(stack(0)) else "".asInstanceOf[T],
1✔
136
          if (stack.length > 1) data(stack(1)) else "".asInstanceOf[T],
1✔
137
          if (stack.length > 2) data(stack(2)) else "".asInstanceOf[T])
1✔
138

139
      def getBufferContext[T <: String](data: Vector[T]): (T, T, T) =
140
        ( // Applies to both Word and ClassName (depth is implict from stack length)
1✔
141
          // NB: Always expecting 3 entries back...
142
          if (i + 0 < parse.n) data(i + 0) else "".asInstanceOf[T],
1✔
143
          if (i + 1 < parse.n) data(i + 1) else "".asInstanceOf[T],
1✔
144
          if (i + 2 < parse.n) data(i + 2) else "".asInstanceOf[T])
1✔
145

146
      def getParseContext[T <: String](idx: Int, deps: Vector[List[Int]], data: Vector[T]): (
147
          Int,
148
          T,
149
          T) = { // Applies to both Word and ClassName (depth is implict from stack length)
150
        if (idx < 0) { // For the cases of empty stack
1✔
151
          (0, "".asInstanceOf[T], "".asInstanceOf[T])
1✔
152
        } else {
1✔
153
          val dependencies = deps(idx) // Find the list of dependencies at this index
1✔
154
          val valency = dependencies.length
1✔
155
          // return the tuple here :
156
          (
1✔
157
            valency,
158
            if (valency > 0) data(dependencies(0)) else "".asInstanceOf[T],
1✔
159
            if (valency > 1) data(dependencies(1)) else "".asInstanceOf[T])
1✔
160
        }
161
      }
162

163
      // Set up the context pieces --- the word (W) and tag (T) of:
164
      //   s0,1,2: Top three words on the stack
165
      //   n0,1,2: Next three words of the buffer (inluding this one)
166
      //   n0b1, n0b2: Two leftmost children of the current buffer word
167
      //   s0b1, s0b2: Two leftmost children of the top word of the stack
168
      //   s0f1, s0f2: Two rightmost children of the top word of the stack
169

170
      val n0 = i // Just for notational consistency
1✔
171
      val s0 = if (stack.isEmpty) -1 else stack.head
1✔
172

173
      val (ws0, ws1, ws2) = getStackContext(words)
1✔
174
      val (ts0, ts1, ts2) = getStackContext(tags)
1✔
175

176
      val (wn0, wn1, wn2) = getBufferContext(words)
1✔
177
      val (tn0, tn1, tn2) = getBufferContext(tags)
1✔
178

179
      val (vn0b, wn0b1, wn0b2) = getParseContext(n0, parse.lefts, words)
1✔
180
      val (_, tn0b1, tn0b2) = getParseContext(n0, parse.lefts, tags)
1✔
181

182
      val (vn0f, wn0f1, wn0f2) = getParseContext(n0, parse.rights, words)
1✔
183
      val (_, tn0f1, tn0f2) = getParseContext(n0, parse.rights, tags)
1✔
184

185
      val (vs0b, ws0b1, ws0b2) = getParseContext(s0, parse.lefts, words)
1✔
186
      val (_, ts0b1, ts0b2) = getParseContext(s0, parse.lefts, tags)
1✔
187

188
      val (vs0f, ws0f1, ws0f2) = getParseContext(s0, parse.rights, words)
1✔
189
      val (_, ts0f1, ts0f2) = getParseContext(s0, parse.rights, tags)
1✔
190

191
      //  String-distance :: Cap numeric features at 5? (NB: n0 always > s0, by construction)
192
      val dist = if (s0 >= 0) math.min(n0 - s0, 5) else 0 // WAS :: ds0n0
1✔
193

194
      val bias = Feature(
1✔
195
        "bias",
196
        ""
197
      ) // It's useful to have a constant feature, which acts sort of like a prior
198

199
      val wordUnigrams =
200
        for (word <- List(wn0, wn1, wn2, ws0, ws1, ws2, wn0b1, wn0b2, ws0b1, ws0b2, ws0f1, ws0f2)
1✔
201
          if (word != 0)) yield Feature("w", word)
1✔
202

203
      val tagUnigrams =
204
        for (tag <- List(tn0, tn1, tn2, ts0, ts1, ts2, tn0b1, tn0b2, ts0b1, ts0b2, ts0f1, ts0f2)
1✔
205
          if (tag != 0)) yield Feature("t", tag)
1✔
206

207
      val wordTagPairs =
208
        for (((word, tag), idx) <- List(
1✔
209
            (wn0, tn0),
210
            (wn1, tn1),
211
            (wn2, tn2),
212
            (ws0, ts0)).zipWithIndex
213
          if (word != 0 || tag != 0))
1✔
214
          yield Feature(s"wt$idx", s"w=$word t=$tag")
1✔
215

216
      val bigrams = Set(
1✔
217
        Feature("w s0n0", s"$ws0 $wn0"),
1✔
218
        Feature("t s0n0", s"$ts0 $tn0"),
1✔
219
        Feature("t n0n1", s"$tn0 $tn1"))
1✔
220

221
      val trigrams = Set(
1✔
222
        Feature("wtw nns", s"$wn0/$tn0 $ws0"),
1✔
223
        Feature("wtt nns", s"$wn0/$tn0 $ts0"),
1✔
224
        Feature("wtw ssn", s"$ws0/$ts0 $wn0"),
1✔
225
        Feature("wtt ssn", s"$ws0/$ts0 $tn0"))
1✔
226

227
      val quadgrams = Set(Feature("wtwt", s"$ws0/$ts0 $wn0/$tn0"))
1✔
228

229
      val tagTrigrams =
230
        for (((t0, t1, t2), idx) <- List(
1✔
231
            (tn0, tn1, tn2),
1✔
232
            (ts0, tn0, tn1),
1✔
233
            (ts0, ts1, tn0),
1✔
234
            (ts0, ts1, ts1),
1✔
235
            (ts0, ts0f1, tn0),
1✔
236
            (ts0, ts0f1, tn0),
1✔
237
            (ts0, tn0, tn0b1),
1✔
238
            (ts0, ts0b1, ts0b2),
1✔
239
            (ts0, ts0f1, ts0f2),
1✔
240
            (tn0, tn0b1, tn0b2)).zipWithIndex
1✔
241
          if (t0 != 0 || t1 != 0 || t2 != 0))
×
242
          yield Feature(s"ttt-$idx", s"$t0 $t1 $t2")
1✔
243

244
      val valencyAndDistance =
245
        for (((str, v), idx) <- List(
1✔
246
            (ws0, vs0f),
1✔
247
            (ws0, vs0b),
1✔
248
            (wn0, vn0b),
1✔
249
            (ts0, vs0f),
1✔
250
            (ts0, vs0b),
1✔
251
            (tn0, vn0b),
1✔
252
            (ws0, dist),
1✔
253
            (wn0, dist),
1✔
254
            (ts0, dist),
1✔
255
            (tn0, dist),
1✔
256
            ("t" + tn0 + ts0, dist),
1✔
257
            ("w" + wn0 + ws0, dist)).zipWithIndex
1✔
258
          if str.length > 0 || v != 0)
1✔
259
          yield Feature(s"val$idx", s"$str $v")
1✔
260

261
      val featureSetCombined = Set(bias) ++ bigrams ++ trigrams ++ quadgrams ++
262
        wordUnigrams.toSet ++ tagUnigrams.toSet ++ wordTagPairs.toSet ++
1✔
263
        tagTrigrams.toSet ++ valencyAndDistance.toSet
1✔
264

265
      // All weights on this set of features are ==1
266
      featureSetCombined.map(f => (f, 1: Score)).toMap
1✔
267
    }
268

269
  }
270

271
  def train(sentences: List[Sentence], seed: Int): Float = {
272
    val rand = new scala.util.Random(seed)
1✔
273
    rand.shuffle(sentences).map(s => trainSentence(s)).sum / sentences.length
1✔
274
  }
275

276
  def trainSentence(sentence: Sentence): Float =
277
    goodness(sentence, process(sentence, train = true))
1✔
278
  def parse(sentence: Sentence): List[Int] = process(sentence, train = false)
×
279

280
  def process(sentence: Sentence, train: Boolean): List[Int] = {
281
    // NB: Our structure just has a 'pure' list of sentences.  The root will point to (n)
282
    // Previously it was assumed that the sentence has 1 entry pre-pended, and the stack starts at {1}
283

284
    // These should be Vectors, since we're going to be accessing them at random (not sequentially)
285
    val words = sentence.map(_.norm).toVector
1✔
286
    val tags = tagger.tag(sentence).toVector
1✔
287
    val goldHeads = sentence.map(_.dep).toVector
1✔
288

289
    // print "train_one(n=%d, %s)" % (n, words, )
290
    // print " gold_heads = %s" % (gold_heads, )
291

292
    def moveThroughSentenceFrom(state: CurrentState): CurrentState = {
293
      val validMoves = state.getValidMoves
1✔
294
      if (validMoves.isEmpty) {
1✔
295
        state // This the answer!
1✔
296
      } else {
1✔
297
        // println(s"  i/n=${state.i}/${state.parse.n} stack=${state.stack}")
298
        val features = state.extractFeatures(words, tags)
1✔
299

300
        // This will produce scores for features that aren't valid too
301
        val score =
302
          perceptron.score(features, if (train) perceptron.current else perceptron.average)
×
303

304
        // Sort valid_moves scores into descending order, and pick the best move
305
        val guess = validMoves.map(m => (-score(m), m)).toList.minBy(_._1)._2
1✔
306

307
        if (train) { // Update the perceptron
×
308
          // println(f"Training '${word_norm}%12s': ${classes(guessed)}%4s -> ${classes(truth(i))}%4s :: ")
309
          val goldMoves = state.getGoldMoves(goldHeads)
1✔
310
          if (goldMoves.isEmpty) {
1✔
311
            /*throw new Exception(s"No Gold Moves at ${state.i}/${state.parse.n}!")*/
312
          } else {
1✔
313

314
            val best = goldMoves.map(m => (-score(m), m)).toList.minBy(_._1)._2
1✔
315
            perceptron.update(best, guess, features.keys)
1✔
316
          }
317

318
        }
319

320
        moveThroughSentenceFrom(state.transition(guess))
1✔
321
      }
322
    }
323

324
    // This annotates the list of words so that parse.heads is its best guess when it finishes
325
    val finalState = moveThroughSentenceFrom(
1✔
326
      CurrentState(1, List(0), ParseStateInit(sentence.length)))
1✔
327

328
    finalState.parse.heads.toList
1✔
329
  }
330

331
  def goodness(sentence: Sentence, fit: List[Int]): Float = {
332
    val gold = sentence.map(_.dep).toVector
1✔
333
    val correct = fit.zip(gold).count(pair => (pair._1 == pair._2)) / gold.length.toFloat
1✔
334
    correct
335
  }
336

337
  override def toString: String = {
338
    perceptron.toString()
×
339
  }
340

341
  def testGoldMoves(sentence: Sentence): Boolean = {
342
    val words = sentence.map(_.norm).toVector
×
343
    val tags = tagger.tag(sentence).toVector
×
344
    val goldHeads = sentence.map(_.dep).toVector
×
345

346
    def moveTrhoughSentenceForm(state: CurrentState): CurrentState = {
347
      val validMoves = state.getValidMoves
×
348
      if (validMoves.isEmpty) {
×
349
        state // This the answer!
×
350
      } else {
×
351
        val features = state.extractFeatures(words, tags)
×
352
        val goldMoves = state.getGoldMoves(goldHeads)
×
353
        if (goldMoves.isEmpty) {
×
354
          if (goldMoves.isEmpty) {
×
355
            throw new Exception(s"No Gold Moves at ${state.i}/${state.parse.n}!")
×
356
          }
357
        }
358
        val guess = goldMoves.toList.head
×
359
        moveTrhoughSentenceForm(state.transition(guess))
×
360
      }
361
    }
362

363
    // This annotates the list of words so that parse.heads is its best guess when it finishes
364
    val finalState = moveTrhoughSentenceForm(
×
365
      CurrentState(1, List(0), ParseStateInit(sentence.length)))
×
366

367
    def pct_fit_fmt_str(correct_01: Float) = {
368
      val correctPct = correct_01 * 100
×
369
      val correctStars = (0 until 100).map(i => if (i < correctPct) "x" else "-").mkString
×
370
      f"${correctPct}%6.1f%% :: $correctStars"
371
    }
372

373
    val correct = finalState.parse.heads
374
      .zip(goldHeads)
×
375
      .count(pair => pair._1 == pair._2) / goldHeads.length
×
376
    println(s"""       index : ${goldHeads.indices.map(v => f"${v}%2d")}""")
×
377
    println(s"""Gold   Moves : ${goldHeads.map(v => f"${v}%2d")}""")
×
378
    println(s""" Found Moves : ${finalState.parse.heads.map(v => f"${v}%2d")}""")
×
379
    println(f"Dependency GoldMoves correct = ${pct_fit_fmt_str(correct)}")
×
380

381
    correct > 0.99
×
382
  }
383

384
  def getDependencyAsArray: Iterator[String] = {
385
    this.toString().split(System.lineSeparator()).toIterator
×
386
  }
387

388
}
389

390
object DependencyMaker {
391
  def load(lines: Iterator[String], tagger: Tagger): DependencyMaker = {
392
    val dm = new DependencyMaker(tagger)
×
393
    dm.perceptron.load(lines)
×
394
    dm
395
  }
396

397
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc