• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 7861513225

11 Feb 2024 11:05AM UTC coverage: 62.678% (-0.05%) from 62.731%
7861513225

Pull #14169

github

web-flow
Merge 13f2acde4 into 6010244ba
Pull Request #14169: Fixed a bug with models that has 'onnx_data' file not working in dbfs/hdfs

8951 of 14281 relevant lines covered (62.68%)

0.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

42.31
/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/feature_extractor/WhisperPreprocessor.scala
1
package com.johnsnowlabs.nlp.annotators.audio.feature_extractor
2

3
import breeze.linalg.{DenseMatrix, DenseVector, max}
4
import breeze.signal.support.WindowFunctions.hanningWindow
5
import com.johnsnowlabs.nlp.annotators.audio.feature_extractor.AudioUtils.matrixToFloatArray
6

7
class WhisperPreprocessor(
8
    override val feature_size: Int,
9
    val hop_length: Int,
10
    val n_fft: Int,
11
    val n_samples: Int,
12
    override val padding_side: String,
13
    override val padding_value: Float,
14
    override val sampling_rate: Int)
15
    extends Preprocessor(
16
      do_normalize = false,
17
      feature_size = feature_size,
18
      padding_side = padding_side,
19
      padding_value = padding_value,
20
      return_attention_mask = false,
21
      sampling_rate = sampling_rate)
22
    with Serializable {
23

24
  require(n_fft < n_samples, "n_fft should be smaller than n_samples.")
1✔
25
  require(hop_length > 0, "hop_length must be greater than 0.")
1✔
26

27
  private def getHanningWindow(periodic: Boolean = true): DenseVector[Double] = {
28
    val windowLength = if (periodic) n_fft + 1 else n_fft
×
29
    val window = hanningWindow(windowLength)
1✔
30
    if (periodic) window(0 to -2) // Remove last element, so window is periodic
1✔
31
    else window
×
32
  }
33

34
  private val window: DenseVector[Double] = getHanningWindow()
1✔
35

36
  private val melFilterBank: DenseMatrix[Double] = AudioUtils.melFilterBank(
1✔
37
    numFrequencyBins = 1 + n_fft / 2,
1✔
38
    numMelFilters = feature_size,
1✔
39
    minFrequency = 0.0,
1✔
40
    maxFrequency = 8000.0,
1✔
41
    samplingRate = sampling_rate)
1✔
42

43
  /** Creates the log-mel spectrogram of given float waveform and transforms it into features for
44
    * the Whisper model. We assume, that the input has not been preprocessed yet.
45
    *
46
    * Adapted from huggingface transformer.
47
    *
48
    * @param rawFloats
49
    *   The waveform to transform into features
50
    * @return
51
    *   Extracted Features
52
    */
53
  def extractFeatures(rawFloats: Array[Float]): Array[Array[Float]] = {
54

55
    val waveformVector: DenseVector[Double] = {
56
      val truncated = Preprocessor.truncate(rawFloats, n_samples)
×
57
      val padded = Preprocessor.pad(truncated, padding_value, n_samples, padding_side)
×
58

59
      DenseVector(padded.map(_.toDouble))
×
60
    }
61

62
    // Calculate spectrogram first
63
    val logSpectrogram: DenseMatrix[Double] = AudioUtils.calculateSpectrogram(
×
64
      waveform = waveformVector,
65
      window = window,
×
66
      frameLength = n_fft,
×
67
      hopLength = hop_length,
×
68
      power = 2.0d,
×
69
      melFilters = melFilterBank)
×
70

71
    val processedLogSpec: Array[Array[Float]] = {
72
      val logSpec = logSpectrogram(::, 0 to -2)
×
73
      val maxes = max(logSpec, max(logSpec) - 8.0)
×
74
      val scaled = (maxes + 4.0) / 4.0
×
75

76
      matrixToFloatArray(scaled)
×
77
    }
78

79
    processedLogSpec
80
  }
81
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc