• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4951808959

pending completion
4951808959

Pull #13792

github

GitHub
Merge efe6b42df into ef7906c5e
Pull Request #13792: SPARKNLP-825 Adding multilabel param

7 of 7 new or added lines in 1 file covered. (100.0%)

8637 of 13128 relevant lines covered (65.79%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/main/scala/com/johnsnowlabs/nlp/HasClassifierActivationProperties.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp
18

19
import org.apache.spark.ml.param.{FloatParam, Param}
20

21
trait HasClassifierActivationProperties extends ParamsAndFeaturesWritable {
22

23
  /** Whether to enable caching DataFrames or RDDs during the training (Default depends on model).
24
    *
25
    * @group param
26
    */
27
  val activation: Param[String] = new Param(
×
28
    this,
29
    "activation",
×
30
    "Whether to calculate logits via Softmax or Sigmoid. Default is Softmax")
×
31

32
  /** Choose the threshold to determine which logits are considered to be positive or negative.
33
    * (Default: `0.5f`). The value should be between 0.0 and 1.0. Changing the threshold value
34
    * will affect the resulting labels and can be used to adjust the balance between precision and
35
    * recall in the classification process.
36
    *
37
    * @group param
38
    */
39
  val threshold = new FloatParam(
×
40
    this,
41
    "threshold",
×
42
    "Choose the threshold to determine which logits are considered to be positive or negative")
×
43

44
  /** Whether or not the result should be multi-class (the sum of all probabilities is 1.0) or
45
   *  multi-label (each label has a probability between 0.0 to 1.0).
46
   *  Default is False i.e. multi-class
47
   *
48
   * @group param
49
   */
50
  val multilabel: Param[Boolean] = new Param(
×
51
    this,
52
    "multilabel",
×
53
    "Whether or not the result should be multi-class (the sum of all probabilities is 1.0) or multi-label (each label has a probability between 0.0 to 1.0). Default is False i.e. multi-class")
×
54

55
  setDefault(activation -> ActivationFunction.softmax, threshold -> 0.5f, multilabel -> false)
×
56

57
  /** @group getParam */
58
  def getActivation: String = $(activation)
×
59

60
  /** @group setParam */
61
  def setActivation(value: String): this.type = {
62

63
    value match {
64
      case ActivationFunction.softmax =>
65
        set(this.activation, ActivationFunction.softmax)
×
66
      case ActivationFunction.sigmoid =>
67
        set(this.activation, ActivationFunction.sigmoid)
×
68
      case _ =>
69
        set(this.activation, ActivationFunction.softmax)
×
70
    }
71

72
  }
73

74
  /** Choose the threshold to determine which logits are considered to be positive or negative.
75
    * (Default: `0.5f`). The value should be between 0.0 and 1.0. Changing the threshold value
76
    * will affect the resulting labels and can be used to adjust the balance between precision and
77
    * recall in the classification process.
78
    *
79
    * @group param
80
    */
81
  def setThreshold(threshold: Float): this.type =
82
    set(this.threshold, threshold)
×
83

84
  /** Set whether or not the result should be multi-class (the sum of all probabilities is 1.0) or
85
   * multi-label (each label has a probability between 0.0 to 1.0).
86
   * Default is False i.e. multi-class
87
   *
88
   * @group param
89
   */
90
  def setMultilabel(value: Boolean): this.type = {
91
    if (value) {
92
      setActivation(ActivationFunction.sigmoid)
×
93
    } else setActivation(ActivationFunction.softmax)
×
94
    set(this.multilabel, value)
×
95
  }
96

97
}
98

99
object ActivationFunction {
100

101
  val softmax = "softmax"
×
102
  val sigmoid = "sigmoid"
×
103

104
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc