• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 9930120094

14 Jul 2024 07:11PM UTC coverage: 62.618% (-0.5%) from 63.11%
9930120094

push

github

web-flow
Merge pull request #14350 from JohnSnowLabs/release/541-release-candidate

* Fixing default names for Phi2 and MistralAI

* Phi2 is 2.7B in size

---------

Co-authored-by: ahmedlone127 <ahmedlone127@gmail.com>

5 of 8 new or added lines in 4 files covered. (62.5%)

423 existing lines in 46 files now uncovered.

8970 of 14325 relevant lines covered (62.62%)

0.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.audio
18

19
import com.johnsnowlabs.ml.ai.Whisper
20
import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig
21
import com.johnsnowlabs.ml.onnx.OnnxWrapper.EncoderDecoderWrappers
22
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel}
23
import com.johnsnowlabs.ml.tensorflow.{
24
  ReadTensorflowModel,
25
  TensorflowWrapper,
26
  WriteTensorflowModel
27
}
28
import com.johnsnowlabs.ml.util.LoadExternalModel.{
29
  loadJsonStringAsset,
30
  modelSanityCheck,
31
  notSupportedEngineError
32
}
33
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
34
import com.johnsnowlabs.nlp._
35
import com.johnsnowlabs.nlp.annotators.audio.feature_extractor.{Preprocessor, WhisperPreprocessor}
36
import com.johnsnowlabs.nlp.serialization.{MapFeature, StructFeature}
37
import com.johnsnowlabs.util.Version
38
import org.apache.spark.broadcast.Broadcast
39
import org.apache.spark.ml.param.{BooleanParam, IntArrayParam, Param}
40
import org.apache.spark.ml.util.Identifiable
41
import org.apache.spark.sql.SparkSession
42
import org.json4s._
43
import org.json4s.jackson.JsonMethods._
44

45
/** Whisper Model with a language modeling head on top for Connectionist Temporal Classification
46
  * (CTC).
47
  *
48
  * Whisper is an automatic speech recognition (ASR) system trained on 680,000 hours of
49
  * multilingual and multitask supervised data collected from the web. It transcribe in multiple
50
  * languages, as well as translate from those languages into English.
51
  *
52
  * The audio needs to be provided pre-processed an array of floats.
53
  *
54
  * For multilingual models, the language and the task (transcribe or translate) can be set with
55
  * `setLanguage` and `setTask`.
56
  *
57
  * Note that at the moment, this annotator only supports greedy search and only Spark Versions
58
  * 3.4 and up are supported.
59
  *
60
  * Pretrained models can be loaded with `pretrained` of the companion object:
61
  * {{{
62
  * val speechToText = WhisperForCTC.pretrained()
63
  *   .setInputCols("audio_assembler")
64
  *   .setOutputCol("text")
65
  * }}}
66
  * The default model is `"asr_whisper_tiny_opt"`, if no name is provided.
67
  *
68
  * For available pretrained models please see the [[https://sparknlp.org/models Models Hub]].
69
  *
70
  * To see which models are compatible and how to import them see
71
  * [[https://github.com/JohnSnowLabs/spark-nlp/discussions/5669]] and to see more extended
72
  * examples, see
73
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTCTest.scala WhisperForCTCTestSpec]].
74
  *
75
  * '''References:'''
76
  *
77
  * [[https://arxiv.org/abs/2212.04356 Robust Speech Recognition via Large-Scale Weak Supervision]]
78
  *
79
  * '''Paper Abstract:'''
80
  *
81
  * ''We study the capabilities of speech processing systems trained simply to predict large
82
  * amounts of transcripts of audio on the internet. When scaled to 680,000 hours of multilingual
83
  * and multitask supervision, the resulting models generalize well to standard benchmarks and are
84
  * often competitive with prior fully supervised results but in a zero- shot transfer setting
85
  * without the need for any fine- tuning. When compared to humans, the models approach their
86
  * accuracy and robustness. We are releasing models and inference code to serve as a foundation
87
  * for further work on robust speech processing.''
88
  *
89
  * ==Example==
90
  *
91
  * {{{
92
  * import spark.implicits._
93
  * import com.johnsnowlabs.nlp.base._
94
  * import com.johnsnowlabs.nlp.annotators._
95
  * import com.johnsnowlabs.nlp.annotators.audio.WhisperForCTC
96
  * import org.apache.spark.ml.Pipeline
97
  *
98
  * val audioAssembler: AudioAssembler = new AudioAssembler()
99
  *   .setInputCol("audio_content")
100
  *   .setOutputCol("audio_assembler")
101
  *
102
  * val speechToText: WhisperForCTC = WhisperForCTC
103
  *   .pretrained()
104
  *   .setInputCols("audio_assembler")
105
  *   .setOutputCol("text")
106
  *
107
  * val pipeline: Pipeline = new Pipeline().setStages(Array(audioAssembler, speechToText))
108
  *
109
  * val bufferedSource =
110
  *   scala.io.Source.fromFile("src/test/resources/audio/txt/librispeech_asr_0.txt")
111
  *
112
  * val rawFloats = bufferedSource
113
  *   .getLines()
114
  *   .map(_.split(",").head.trim.toFloat)
115
  *   .toArray
116
  * bufferedSource.close
117
  *
118
  * val processedAudioFloats = Seq(rawFloats).toDF("audio_content")
119
  *
120
  * val result = pipeline.fit(processedAudioFloats).transform(processedAudioFloats)
121
  * result.select("text.result").show(truncate = false)
122
  * +------------------------------------------------------------------------------------------+
123
  * |result                                                                                    |
124
  * +------------------------------------------------------------------------------------------+
125
  * |[ Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.]|
126
  * +------------------------------------------------------------------------------------------+
127
  * }}}
128
  *
129
  * @param uid
130
  *   required uid for storing annotator to disk
131
  * @groupname anno Annotator types
132
  * @groupdesc anno
133
  *   Required input and expected output annotator types
134
  * @groupname Ungrouped Members
135
  * @groupname param Parameters
136
  * @groupname setParam Parameter setters
137
  * @groupname getParam Parameter getters
138
  * @groupname Ungrouped Members
139
  * @groupprio param  1
140
  * @groupprio anno  2
141
  * @groupprio Ungrouped 3
142
  * @groupprio setParam  4
143
  * @groupprio getParam  5
144
  * @groupdesc param
145
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
146
  *   parameter values through setters and getters, respectively.
147
  */
148
class WhisperForCTC(override val uid: String)
149
    extends AnnotatorModel[WhisperForCTC]
150
    with HasBatchedAnnotateAudio[WhisperForCTC]
151
    with HasAudioFeatureProperties
152
    with WriteTensorflowModel
153
    with WriteOnnxModel
154
    with HasEngine
155
    with HasGeneratorProperties
156
    with HasProtectedParams {
157

158
  override val outputAnnotatorType: AnnotatorType = AnnotatorType.DOCUMENT
×
159
  override val inputAnnotatorTypes: Array[AnnotatorType] = Array(AnnotatorType.AUDIO)
×
160

161
  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
162
    * type
163
    */
164
  def this() = this(Identifiable.randomUID("WhisperForCTC"))
×
165

166
  /** Optional language to set for the transcription. The imported model needs to support multiple
167
    * languages.
168
    * @group param
169
    */
170
  val language =
171
    new Param[String](
×
172
      this,
173
      "language",
×
174
      "Optional parameter to set the language for the transcription.")
×
175

176
  /** Sets the language for the audio, formatted to e.g. `<|en|>`. Check the model description for
177
    * supported languages.
178
    *
179
    * @group setParam
180
    */
181
  def setLanguage(value: String): this.type = {
182
    require(getIsMultilingual, "Only multilingual models can have the language set.")
×
183
    require(
×
184
      value.length == 6 && value.startsWith("<|") && value.endsWith("|>"),
×
185
      "The language does not have the right format." +
×
186
        " Should be a two letter code enclosed in angle brackets with a vertical line (e.g. <|en|>).")
187
    require(getModelIfNotSet.tokenInVocabulary(value), "Language was not found in vocabulary.")
×
188
    set(language, value)
×
189
    this
190
  }
191

192
  /** @group getParam */
193
  def getLanguage: Option[String] = get(this.language)
×
194

195
  /** Sets the formatted task for the audio. Either `<|translate|>` or `<|transcribe|>`.
196
    *
197
    * Only multilingual models can do translation.
198
    *
199
    * @group setParam
200
    */
201
  override def setTask(value: String): this.type = {
202
    require(
×
203
      getIsMultilingual,
×
204
      "Only multilingual models can have tasks. For single language models, the default task will be transcribe.")
×
205
    require(
×
206
      value == "<|translate|>" || value == "<|transcribe|>",
×
207
      "Task should be either '<|translate|>' or '<|transcribe|>'")
×
208
    set(task, value)
×
209
    this
210
  }
211

212
  /** Whether or not the model is multilingual.
213
    *
214
    * @group param
215
    */
216
  val isMultilingual: ProtectedParam[Boolean] =
217
    new BooleanParam(this, "isMultilingual", "Whether or not the model is multilingual.")
218
      .setProtected()
×
219

220
  /** @group setParam */
221
  def setIsMultilingual(value: Boolean): this.type = {
222
    set(isMultilingual, value)
×
223
    this
224
  }
225

226
  /** @group getParam */
227
  def getIsMultilingual: Boolean = getOrDefault(this.isMultilingual)
×
228

229
  /** It contains TF model signatures for the loaded saved model
230
    *
231
    * @group param
232
    */
233
  val signatures: MapFeature[AnnotatorType, AnnotatorType] =
234
    new MapFeature[String, String](model = this, name = "signatures").setProtected()
×
235

236
  /** @group setParam */
237
  def setSignatures(value: Map[String, String]): this.type = {
238
    set(signatures, value)
×
239
    this
240
  }
241

242
  /** @group getParam */
243
  def getSignatures: Option[Map[String, String]] = get(this.signatures)
×
244

245
  /** ConfigProto from tensorflow, serialized into byte array. Get with
246
    * config_proto.SerializeToString()
247
    *
248
    * @group param
249
    */
250
  val configProtoBytes = new IntArrayParam(
×
251
    this,
252
    "configProtoBytes",
×
253
    "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()")
×
254

255
  /** ConfigProto from tensorflow, serialized into byte array. Get with
256
    * config_proto.SerializeToString()
257
    *
258
    * @group setParam
259
    */
260
  def setConfigProtoBytes(bytes: Array[Int]): this.type =
261
    set(this.configProtoBytes, bytes)
×
262

263
  /** ConfigProto from tensorflow, serialized into byte array. Get with
264
    * config_proto.SerializeToString()
265
    *
266
    * @group getParam
267
    */
268
  def getConfigProtoBytes: Option[Array[Byte]] =
269
    get(this.configProtoBytes).map(_.map(_.toByte))
×
270

271
  /** Vocabulary used to encode the words to ids */
272
  protected[nlp] val vocabulary: MapFeature[String, Int] =
273
    new MapFeature(this, "vocabulary").setProtected()
×
274

275
  def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value)
×
276

277
  def getVocabulary: Map[String, Int] = $$(vocabulary)
×
278

279
  protected[nlp] val addedSpecialTokens: MapFeature[String, Int] =
280
    new MapFeature(this, "addedSpecialTokens").setProtected()
×
281

282
  protected[nlp] def setAddedSpecialTokens(value: Map[String, Int]): this.type =
283
    set(addedSpecialTokens, value)
×
284

285
  protected[nlp] def getAddedSpecialTokens: Map[String, Int] = $$(addedSpecialTokens)
×
286

287
  protected[nlp] val generationConfig: StructFeature[GenerationConfig] =
288
    new StructFeature(this, "generationConfig").setProtected()
×
289

290
  protected[nlp] def setGenerationConfig(value: GenerationConfig): this.type =
291
    set(generationConfig, value)
×
292

293
  protected[nlp] def getGenerationConfig: GenerationConfig = $$(generationConfig)
×
294

295
  protected[nlp] val preprocessor: StructFeature[WhisperPreprocessor] =
296
    new StructFeature(this, "preprocessor").setProtected()
×
297

298
  protected[nlp] def setPreprocessor(value: WhisperPreprocessor): this.type =
299
    set(preprocessor, value)
×
300

301
  protected[nlp] def getPreprocessor: WhisperPreprocessor = $$(preprocessor)
×
302

303
  setDefault(
×
304
    minOutputLength -> 0,
×
305
    maxOutputLength -> 448,
×
306
    doSample -> false,
×
307
    temperature -> 1.0,
×
308
    topK -> 1,
×
309
    topP -> 1.0,
×
310
    repetitionPenalty -> 1.0,
×
311
    noRepeatNgramSize -> 0,
×
312
    batchSize -> 2,
×
313
    beamSize -> 1,
×
314
    nReturnSequences -> 1,
×
315
    isMultilingual -> true)
×
316

317
  private var _model: Option[Broadcast[Whisper]] = None
×
318

319
  /** @group getParam */
320
  def getModelIfNotSet: Whisper = _model.get.value
×
321

322
  /** @group setParam */
323
  def setModelIfNotSet(
324
      spark: SparkSession,
325
      tensorflowWrapper: Option[TensorflowWrapper],
326
      onnxWrappers: Option[EncoderDecoderWrappers]): this.type = {
327
    if (_model.isEmpty) {
×
328
      val preprocessor = getPreprocessor
×
329

330
      _model = Some(
×
331
        spark.sparkContext.broadcast(
×
332
          new Whisper(
×
333
            tensorflowWrapper,
334
            onnxWrappers,
335
            configProtoBytes = getConfigProtoBytes,
×
336
            signatures = getSignatures,
×
337
            preprocessor = preprocessor,
338
            vocabulary = getVocabulary,
×
339
            addedSpecialTokens = $$(addedSpecialTokens),
×
340
            generationConfig = getGenerationConfig)))
×
341
    }
342
    this
343
  }
344

345
  override def onWrite(path: String, spark: SparkSession): Unit = {
346
    super.onWrite(path, spark)
×
347
    getEngine match {
×
348
      case TensorFlow.name =>
349
        writeTensorflowModelV2(
×
350
          path,
351
          spark,
352
          getModelIfNotSet.tensorflowWrapper.get,
×
353
          WhisperForCTC.suffix,
×
354
          WhisperForCTC.tfFile,
×
355
          configProtoBytes = getConfigProtoBytes,
×
356
          savedSignatures = getSignatures)
×
357
      case ONNX.name =>
358
        val wrappers = getModelIfNotSet.onnxWrappers.get
×
359
        writeOnnxModels(
×
360
          path,
361
          spark,
362
          Seq(
×
363
            (wrappers.encoder, "encoder_model"),
×
364
            (wrappers.decoder, "decoder_model"),
×
365
            (wrappers.decoderWithPast, "decoder_with_past_model")),
×
366
          WhisperForCTC.suffix)
×
367
    }
368

369
  }
370

371
  /** Takes audio annotations and produces transcribed document annotations.
372
    *
373
    * @param batchedAnnotations
374
    *   Audio annotations in batches
375
    * @return
376
    *   Transcribed audio as DOCUMENT type annotation
377
    */
378
  override def batchAnnotate(
379
      batchedAnnotations: Seq[Array[AnnotationAudio]]): Seq[Seq[Annotation]] = {
380
    batchedAnnotations.map { audioAnnotations =>
×
381
      if (audioAnnotations.nonEmpty) {
×
382
        getModelIfNotSet.generateFromAudio(
×
383
          batchAudio = audioAnnotations,
×
384
          batchSize = getBatchSize,
×
385
          maxOutputLength = getMaxOutputLength,
×
386
          minOutputLength = getMinOutputLength,
×
387
          doSample = getDoSample,
×
388
          beamSize = getBeamSize,
×
389
          numReturnSequences = getNReturnSequences,
×
390
          temperature = getTemperature,
×
391
          topK = getTopK,
×
392
          topP = getTopP,
×
393
          repetitionPenalty = getRepetitionPenalty,
×
394
          noRepeatNgramSize = getNoRepeatNgramSize,
×
395
          randomSeed = getRandomSeed,
×
396
          task = getTask,
×
397
          language = getLanguage)
×
398
      } else Seq.empty
×
399
    }
400
  }
401

402
}
403

404
trait ReadablePretrainedWhisperForCTCModel
405
    extends ParamsAndFeaturesReadable[WhisperForCTC]
406
    with HasPretrained[WhisperForCTC] {
407
  override val defaultModelName: Some[String] = Some("asr_whisper_tiny_opt")
×
408
  override val defaultLang: String = "xx"
×
409

410
  /** Java compliant-overrides */
411
  override def pretrained(): WhisperForCTC = super.pretrained()
×
412

413
  override def pretrained(name: String): WhisperForCTC = super.pretrained(name)
×
414

415
  override def pretrained(name: String, lang: String): WhisperForCTC =
416
    super.pretrained(name, lang)
×
417

418
  override def pretrained(name: String, lang: String, remoteLoc: String): WhisperForCTC =
419
    super.pretrained(name, lang, remoteLoc)
×
420
}
421

422
trait ReadWhisperForCTCDLModel extends ReadTensorflowModel with ReadOnnxModel {
423
  this: ParamsAndFeaturesReadable[WhisperForCTC] =>
424

425
  override val tfFile: String = "whisper_ctc_tensorflow"
×
426
  override val onnxFile: String = "whisper_ctc_onnx"
×
427
  val suffix: String = "_whisper_ctc"
×
428

429
  private def checkVersion(spark: SparkSession): Unit = {
430
    val version = Version.parse(spark.version).toFloat
×
431
    require(version >= 3.4, "WhisperForCTC requires Spark versions 3.4 and up.")
×
432
  }
433
  def readModel(instance: WhisperForCTC, path: String, spark: SparkSession): Unit = {
434
    checkVersion(spark)
×
435

436
    instance.getEngine match {
×
437
      case TensorFlow.name =>
438
        val tfWrapper = readTensorflowModel(
×
439
          path,
440
          spark,
441
          WhisperForCTC.suffix,
×
442
          savedSignatures = instance.getSignatures)
×
443
        instance.setModelIfNotSet(spark, Some(tfWrapper), None)
×
444

445
      case ONNX.name =>
446
        val wrappers =
447
          readOnnxModels(
×
448
            path,
449
            spark,
450
            Seq("encoder_model", "decoder_model", "decoder_with_past_model"),
×
451
            WhisperForCTC.suffix,
×
UNCOV
452
            dataFilePostfix = ".onnx_data")
×
453

454
        val onnxWrappers = EncoderDecoderWrappers(
×
455
          wrappers("encoder_model"),
×
456
          decoder = wrappers("decoder_model"),
×
UNCOV
457
          decoderWithPast = wrappers("decoder_with_past_model"))
×
458

UNCOV
459
        instance.setModelIfNotSet(spark, None, Some(onnxWrappers))
×
460
      case _ =>
UNCOV
461
        throw new Exception(notSupportedEngineError)
×
462
    }
463
  }
464

UNCOV
465
  addReader(readModel)
×
466

467
  def loadSavedModel(modelPath: String, spark: SparkSession): WhisperForCTC = {
UNCOV
468
    checkVersion(spark)
×
469

UNCOV
470
    implicit val formats: DefaultFormats.type = DefaultFormats // for json4s
×
471

UNCOV
472
    val (localModelPath, detectedEngine) =
×
473
      modelSanityCheck(modelPath, isEncoderDecoder = true, withPast = true)
474

UNCOV
475
    val ppJsonString: String = loadJsonStringAsset(localModelPath, "preprocessor_config.json")
×
476

477
    val preprocessor: WhisperPreprocessor =
UNCOV
478
      Preprocessor.loadPreprocessorConfig(ppJsonString).asInstanceOf[WhisperPreprocessor]
×
479

480
    val addedTokens: Map[String, Int] =
481
      try {
UNCOV
482
        parse(loadJsonStringAsset(localModelPath, "added_tokens.json")).values
×
483
          .asInstanceOf[Map[String, BigInt]]
484
          .map {
×
UNCOV
485
            case (key, value) if value.isValidInt => (key, value.toInt)
×
486
            case _ =>
UNCOV
487
              throw new IllegalArgumentException(
×
488
                "Could not convert BigInt to Int while parsing added_tokens.json")
489
          }
490
      } catch {
491
        case _: IllegalArgumentException =>
UNCOV
492
          Map.empty
×
493
      }
494

495
    val vocabMap: Map[String, Int] = {
496
      val vocabJsonContent = loadJsonStringAsset(localModelPath, "vocab.json")
×
UNCOV
497
      parse(vocabJsonContent, useBigIntForLong = true).values
×
498
        .asInstanceOf[Map[String, BigInt]]
499
        .map {
×
UNCOV
500
          case (key, value) if value.isValidInt => (key, value.toInt)
×
501
          case _ =>
UNCOV
502
            throw new IllegalArgumentException(
×
503
              "Could not convert BigInt to Int while parsing vocab.json")
504
        }
505
    }
506

507
    val modelConfig: JValue =
UNCOV
508
      parse(loadJsonStringAsset(localModelPath, "config.json"))
×
509

510
    val beginSuppressTokens: Array[Int] =
UNCOV
511
      (modelConfig \ "begin_suppress_tokens").extract[Array[Int]]
×
512

513
    val suppressTokenIds: Array[Int] =
UNCOV
514
      (modelConfig \ "suppress_tokens").extract[Array[Int]]
×
515

516
    val forcedDecoderIds: Array[(Int, Int)] =
517
      (modelConfig \ "forced_decoder_ids").extract[Array[Array[Int]]].map {
×
518
        case idxWithTokenId: Array[Int] if idxWithTokenId.length == 2 =>
×
UNCOV
519
          (idxWithTokenId(0), idxWithTokenId(1))
×
520
        case _ =>
UNCOV
521
          throw new Exception(
×
522
            "Could not extract forced_decoder_ids. Should be a list of tuples with 2 entries.")
523
      }
524

525
    val maxOutputLength = (modelConfig \ "max_length").extract[Int]
×
526
    val bosTokenId = (modelConfig \ "decoder_start_token_id").extract[Int]
×
527
    val eosTokenId = (modelConfig \ "eos_token_id").extract[Int]
×
528
    val padTokenId = (modelConfig \ "pad_token_id").extract[Int]
×
UNCOV
529
    val vocabSize = (modelConfig \ "vocab_size").extract[Int]
×
530

531
    // 3 means multilingual (for official models), e.g. [<|en|>, <|transcribe|>, <|notimestamps|>]
532
    // Single language models only force the force token to be <|notimestamps|>
533
    // Some custom models might have no forced tokens at all, assume its multilingual
UNCOV
534
    val isMultilingual = forcedDecoderIds.length != 1
×
535

536
    def arrayOrNone[T](array: Array[T]): Option[Array[T]] =
UNCOV
537
      if (array.nonEmpty) Some(array) else None
×
538

539
    val annotatorModel = new WhisperForCTC()
540
      .setVocabulary(vocabMap)
541
      .setMaxOutputLength(maxOutputLength)
542
      .setDoNormalize(preprocessor.do_normalize)
543
      .setReturnAttentionMask(preprocessor.return_attention_mask)
544
      .setPaddingSide(preprocessor.padding_side)
545
      .setPaddingValue(preprocessor.padding_value)
546
      .setFeatureSize(preprocessor.feature_size)
547
      .setSamplingRate(preprocessor.sampling_rate)
548
      .setAddedSpecialTokens(addedTokens)
549
      .setGenerationConfig(
550
        GenerationConfig(
551
          bosTokenId,
552
          padTokenId,
553
          eosTokenId,
554
          vocabSize,
555
          arrayOrNone(beginSuppressTokens),
556
          arrayOrNone(suppressTokenIds),
557
          arrayOrNone(forcedDecoderIds)))
558
      .setPreprocessor(preprocessor)
UNCOV
559
      .setIsMultilingual(isMultilingual)
×
560

UNCOV
561
    annotatorModel.set(annotatorModel.engine, detectedEngine)
×
562

563
    detectedEngine match {
564
      case TensorFlow.name =>
UNCOV
565
        val (tfWrapper, signatures) =
×
566
          TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)
567

568
        val _signatures = signatures match {
569
          case Some(s) => s
UNCOV
570
          case None => throw new Exception("Cannot load signature definitions from model!")
×
571
        }
572

573
        /** the order of setSignatures is important if we use getSignatures inside
574
          * setModelIfNotSet
575
          */
576
        annotatorModel
577
          .setSignatures(_signatures)
UNCOV
578
          .setModelIfNotSet(spark, Some(tfWrapper), None)
×
579

580
      case ONNX.name =>
581
        val onnxWrapperEncoder =
UNCOV
582
          OnnxWrapper.read(
×
583
            spark,
584
            localModelPath,
585
            zipped = false,
×
UNCOV
586
            useBundle = true,
×
UNCOV
587
            modelName = "encoder_model",
×
588
            onnxFileSuffix = None)
×
589

590
        val onnxWrapperDecoder =
591
          OnnxWrapper.read(
×
592
            spark,
593
            localModelPath,
UNCOV
594
            zipped = false,
×
595
            useBundle = true,
×
UNCOV
596
            modelName = "decoder_model",
×
597
            onnxFileSuffix = None)
×
598

599
        val onnxWrapperDecoderWithPast =
UNCOV
600
          OnnxWrapper.read(
×
601
            spark,
602
            localModelPath,
UNCOV
603
            zipped = false,
×
UNCOV
604
            useBundle = true,
×
UNCOV
605
            modelName = "decoder_with_past_model",
×
UNCOV
606
            onnxFileSuffix = None)
×
607

UNCOV
608
        val onnxWrappers = EncoderDecoderWrappers(
×
609
          onnxWrapperEncoder,
610
          onnxWrapperDecoder,
611
          onnxWrapperDecoderWithPast)
612

613
        annotatorModel
UNCOV
614
          .setModelIfNotSet(spark, None, Some(onnxWrappers))
×
615

616
      case _ =>
UNCOV
617
        throw new Exception(notSupportedEngineError)
×
618
    }
619

620
    annotatorModel
621
  }
622
}
623

624
/** This is the companion object of [[WhisperForCTC]]. Please refer to that class for the
625
  * documentation.
626
  */
627
object WhisperForCTC extends ReadablePretrainedWhisperForCTCModel with ReadWhisperForCTCDLModel
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc