• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 7340940136

27 Dec 2023 06:28PM UTC coverage: 62.876%. First build
7340940136

Pull #14112

github

web-flow
Merge 64ecc94ab into e9099b0f1
Pull Request #14112: Release/521 release candidate

20 of 29 new or added lines in 6 files covered. (68.97%)

8958 of 14247 relevant lines covered (62.88%)

0.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.47
/src/main/scala/com/johnsnowlabs/nlp/HasBatchedAnnotate.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp
18

19
import org.apache.spark.ml.Model
20
import org.apache.spark.ml.param.IntParam
21
import org.apache.spark.sql.Row
22

23
trait HasBatchedAnnotate[M <: Model[M]] {
24

25
  this: RawAnnotator[M] =>
26

27
  /** Size of every batch (Default depends on model).
28
    *
29
    * @group param
30
    */
31
  val batchSize = new IntParam(this, "batchSize", "Size of every batch.")
1✔
32

33
  /** Size of every batch.
34
    *
35
    * @group setParam
36
    */
37
  def setBatchSize(size: Int): this.type = {
38
    val recommended = size
39
    require(recommended > 0, "batchSize must be greater than 0")
×
40
    set(this.batchSize, recommended)
×
41
  }
42

43
  /** Size of every batch.
44
    *
45
    * @group getParam
46
    */
47
  def getBatchSize: Int = $(batchSize)
1✔
48

49
  def batchProcess(rows: Iterator[_]): Iterator[Row] = {
50
    val groupedRows = rows.grouped(getBatchSize)
1✔
51

52
    groupedRows.flatMap {
1✔
53
      case batchRow: Seq[Row] => processBatchRows(batchRow)
1✔
NEW
54
      case singleRow: Row => processBatchRows(Seq(singleRow))
×
NEW
55
      case _ => Seq(Row.empty)
×
56
    }
57
  }
58

59
  private def processBatchRows(batchedRows: Seq[Row]): Seq[Row] = {
60
    val inputAnnotations = batchedRows.map(row => {
1✔
61
      getInputCols.flatMap(inputCol => {
1✔
62
        row.getAs[Seq[Row]](inputCol).map(Annotation(_))
1✔
63
      })
64
    })
65
    val outputAnnotations = batchAnnotate(inputAnnotations)
1✔
66
    batchedRows
67
      .zip(outputAnnotations)
1✔
68
      .map { case (row, annotations) =>
1✔
69
        row.toSeq ++ Array(annotations.map(a => Row(a.productIterator.toSeq: _*)))
1✔
70
      }
71
      .map(Row.fromSeq)
1✔
72
  }
73

74
  /** takes a document and annotations and produces new annotations of this annotator's annotation
75
    * type
76
    *
77
    * @param batchedAnnotations
78
    *   Annotations in batches that correspond to inputAnnotationCols generated by previous
79
    *   annotators if any
80
    * @return
81
    *   any number of annotations processed for every batch of input annotations. Not necessary
82
    *   one to one relationship
83
    *
84
    * IMPORTANT: !MUST! return sequences of equal lengths !! IMPORTANT: !MUST! return sentences
85
    * that belong to the same original row !! (challenging)
86
    */
87
  def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]]
88

89
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc