• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 15252839065

26 May 2025 11:30AM UTC coverage: 52.115% (-0.6%) from 52.715%
15252839065

Pull #14585

github

web-flow
Merge 625e5c10f into 56512b006
Pull Request #14585: SparkNLP 1131 - Introducing Florance-2

0 of 199 new or added lines in 4 files covered. (0.0%)

50 existing lines in 33 files now uncovered.

9931 of 19056 relevant lines covered (52.11%)

0.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

32.11
/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ViTForImageClassification.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.annotators.cv
18

19
import com.johnsnowlabs.ml.ai.ViTClassifier
20
import com.johnsnowlabs.ml.tensorflow.{
21
  ReadTensorflowModel,
22
  TensorflowWrapper,
23
  WriteTensorflowModel
24
}
25
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel}
26
import com.johnsnowlabs.ml.openvino.{OpenvinoWrapper, ReadOpenvinoModel, WriteOpenvinoModel}
27
import com.johnsnowlabs.ml.util.LoadExternalModel.{
28
  loadJsonStringAsset,
29
  modelSanityCheck,
30
  notSupportedEngineError
31
}
32
import com.johnsnowlabs.ml.util.{ONNX, Openvino, TensorFlow}
33
import com.johnsnowlabs.nlp.AnnotatorType.{CATEGORY, IMAGE}
34
import com.johnsnowlabs.nlp._
35
import com.johnsnowlabs.nlp.annotators.classifier.dl.XlmRoBertaForQuestionAnswering
36
import com.johnsnowlabs.nlp.annotators.cv.feature_extractor.Preprocessor
37
import com.johnsnowlabs.nlp.serialization.MapFeature
38
import org.apache.spark.broadcast.Broadcast
39
import org.apache.spark.ml.param.IntArrayParam
40
import org.apache.spark.ml.util.Identifiable
41
import org.apache.spark.sql.SparkSession
42
import org.json4s._
43
import org.json4s.jackson.JsonMethods._
44

45
/** Vision Transformer (ViT) for image classification.
46
  *
47
  * ViT is a transformer based alternative to the convolutional neural networks usually used for
48
  * image recognition tasks.
49
  *
50
  * Pretrained models can be loaded with `pretrained` of the companion object:
51
  * {{{
52
  * val imageClassifier = ViTForImageClassification.pretrained()
53
  *   .setInputCols("image_assembler")
54
  *   .setOutputCol("class")
55
  * }}}
56
  * The default model is `"image_classifier_vit_base_patch16_224"`, if no name is provided.
57
  *
58
  * For available pretrained models please see the
59
  * [[https://sparknlp.org/models?task=Image+Classification Models Hub]].
60
  *
61
  * Models from the HuggingFace 🤗 Transformers library are also compatible with Spark NLP 🚀. To
62
  * see which models are compatible and how to import them see
63
  * [[https://github.com/JohnSnowLabs/spark-nlp/discussions/5669]] and to see more extended
64
  * examples, see
65
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/cv/ViTImageClassificationTestSpec.scala ViTImageClassificationTestSpec]].
66
  *
67
  * '''References:'''
68
  *
69
  * [[https://arxiv.org/abs/2010.11929 An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale]]
70
  *
71
  * '''Paper Abstract:'''
72
  *
73
  * ''While the Transformer architecture has become the de-facto standard for natural language
74
  * processing tasks, its applications to computer vision remain limited. In vision, attention is
75
  * either applied in conjunction with convolutional networks, or used to replace certain
76
  * components of convolutional networks while keeping their overall structure in place. We show
77
  * that this reliance on CNNs is not necessary and a pure transformer applied directly to
78
  * sequences of image patches can perform very well on image classification tasks. When
79
  * pre-trained on large amounts of data and transferred to multiple mid-sized or small image
80
  * recognition benchmarks (ImageNet, CIFAR-100, VTAB, etc.), Vision Transformer (ViT) attains
81
  * excellent results compared to state-of-the-art convolutional networks while requiring
82
  * substantially fewer computational resources to train.''
83
  *
84
  * ==Example==
85
  * {{{
86
  * import com.johnsnowlabs.nlp.annotator._
87
  * import com.johnsnowlabs.nlp.ImageAssembler
88
  * import org.apache.spark.ml.Pipeline
89
  *
90
  * val imageDF: DataFrame = spark.read
91
  *   .format("image")
92
  *   .option("dropInvalid", value = true)
93
  *   .load("src/test/resources/image/")
94
  *
95
  * val imageAssembler = new ImageAssembler()
96
  *   .setInputCol("image")
97
  *   .setOutputCol("image_assembler")
98
  *
99
  * val imageClassifier = ViTForImageClassification
100
  *   .pretrained()
101
  *   .setInputCols("image_assembler")
102
  *   .setOutputCol("class")
103
  *
104
  * val pipeline = new Pipeline().setStages(Array(imageAssembler, imageClassifier))
105
  * val pipelineDF = pipeline.fit(imageDF).transform(imageDF)
106
  *
107
  * pipelineDF
108
  *   .selectExpr("reverse(split(image.origin, '/'))[0] as image_name", "class.result")
109
  *   .show(truncate = false)
110
  * +-----------------+----------------------------------------------------------+
111
  * |image_name       |result                                                    |
112
  * +-----------------+----------------------------------------------------------+
113
  * |palace.JPEG      |[palace]                                                  |
114
  * |egyptian_cat.jpeg|[Egyptian cat]                                            |
115
  * |hippopotamus.JPEG|[hippopotamus, hippo, river horse, Hippopotamus amphibius]|
116
  * |hen.JPEG         |[hen]                                                     |
117
  * |ostrich.JPEG     |[ostrich, Struthio camelus]                               |
118
  * |junco.JPEG       |[junco, snowbird]                                         |
119
  * |bluetick.jpg     |[bluetick]                                                |
120
  * |chihuahua.jpg    |[Chihuahua]                                               |
121
  * |tractor.JPEG     |[tractor]                                                 |
122
  * |ox.JPEG          |[ox]                                                      |
123
  * +-----------------+----------------------------------------------------------+
124
  * }}}
125
  *
126
  * @param uid
127
  *   required uid for storing annotator to disk
128
  * @groupname anno Annotator types
129
  * @groupdesc anno
130
  *   Required input and expected output annotator types
131
  * @groupname Ungrouped Members
132
  * @groupname param Parameters
133
  * @groupname setParam Parameter setters
134
  * @groupname getParam Parameter getters
135
  * @groupname Ungrouped Members
136
  * @groupprio param  1
137
  * @groupprio anno  2
138
  * @groupprio Ungrouped 3
139
  * @groupprio setParam  4
140
  * @groupprio getParam  5
141
  * @groupdesc param
142
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
143
  *   parameter values through setters and getters, respectively.
144
  */
145
class ViTForImageClassification(override val uid: String)
146
    extends AnnotatorModel[ViTForImageClassification]
147
    with HasBatchedAnnotateImage[ViTForImageClassification]
148
    with HasImageFeatureProperties
149
    with WriteTensorflowModel
150
    with WriteOnnxModel
151
    with WriteOpenvinoModel
152
    with HasEngine {
153

154
  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
155
    * type
156
    */
157
  def this() = this(Identifiable.randomUID("ViTForImageClassification"))
×
158

159
  /** Output annotator type : CATEGORY
160
    *
161
    * @group anno
162
    */
163
  override val outputAnnotatorType: AnnotatorType = CATEGORY
1✔
164

165
  /** Input annotator type : IMAGE
166
    *
167
    * @group anno
168
    */
169
  override val inputAnnotatorTypes: Array[AnnotatorType] = Array(IMAGE)
1✔
170

171
  /** ConfigProto from tensorflow, serialized into byte array. Get with
172
    * config_proto.SerializeToString()
173
    *
174
    * @group param
175
    */
176
  val configProtoBytes = new IntArrayParam(
1✔
177
    this,
178
    "configProtoBytes",
1✔
179
    "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()")
1✔
180

181
  /** ConfigProto from tensorflow, serialized into byte array. Get with
182
    * config_proto.SerializeToString()
183
    *
184
    * @group setParam
185
    */
186
  def setConfigProtoBytes(bytes: Array[Int]): ViTForImageClassification.this.type =
187
    set(this.configProtoBytes, bytes)
×
188

189
  /** ConfigProto from tensorflow, serialized into byte array. Get with
190
    * config_proto.SerializeToString()
191
    *
192
    * @group getParam
193
    */
194
  def getConfigProtoBytes: Option[Array[Byte]] =
UNCOV
195
    get(this.configProtoBytes).map(_.map(_.toByte))
×
196

197
  /** Labels used to decode predicted IDs back to string tags
198
    *
199
    * @group param
200
    */
201
  val labels: MapFeature[String, BigInt] = new MapFeature(this, "labels").setProtected()
1✔
202

203
  /** @group setParam */
204
  def setLabels(value: Map[String, BigInt]): this.type = set(labels, value)
×
205

206
  /** Returns labels used to train this model */
207
  def getClasses: Array[String] = {
208
    $$(labels).keys.toArray
×
209
  }
210

211
  /** It contains TF model signatures for the laded saved model
212
    *
213
    * @group param
214
    */
215
  val signatures =
216
    new MapFeature[String, String](model = this, name = "signatures").setProtected()
1✔
217

218
  /** @group setParam */
219
  def setSignatures(value: Map[String, String]): this.type = {
220
    set(signatures, value)
×
221
    this
222
  }
223

224
  /** @group getParam */
225
  def getSignatures: Option[Map[String, String]] = get(this.signatures)
1✔
226

227
  private var _model: Option[Broadcast[ViTClassifier]] = None
1✔
228

229
  /** @group getParam */
230
  def getModelIfNotSet: ViTClassifier = _model.get.value
×
231

232
  /** @group setParam */
233
  def setModelIfNotSet(
234
      spark: SparkSession,
235
      tensorflowWrapper: Option[TensorflowWrapper],
236
      onnxWrapper: Option[OnnxWrapper],
237
      openvinoWrapper: Option[OpenvinoWrapper],
238
      preprocessor: Preprocessor): this.type = {
239
    if (_model.isEmpty) {
1✔
240

241
      _model = Some(
1✔
242
        spark.sparkContext.broadcast(
1✔
243
          new ViTClassifier(
1✔
244
            tensorflowWrapper,
245
            onnxWrapper,
246
            openvinoWrapper,
247
            configProtoBytes = getConfigProtoBytes,
1✔
248
            tags = $$(labels),
1✔
249
            preprocessor = preprocessor,
250
            signatures = getSignatures)))
1✔
251
    }
252
    this
253
  }
254

255
  setDefault(batchSize -> 2)
1✔
256

257
  /** Takes a document and annotations and produces new annotations of this annotator's annotation
258
    * type
259
    *
260
    * @param batchedAnnotations
261
    *   Annotations that correspond to inputAnnotationCols generated by previous annotators if any
262
    * @return
263
    *   any number of annotations processed for every input annotation. Not necessary one to one
264
    *   relationship
265
    */
266
  override def batchAnnotate(
267
      batchedAnnotations: Seq[Array[AnnotationImage]]): Seq[Seq[Annotation]] = {
268

269
    // Zip annotations to the row it belongs to
270
    val imagesWithRow = batchedAnnotations.zipWithIndex
×
271
      .flatMap { case (annotations, i) => annotations.map(x => (x, i)) }
×
272

273
    val noneEmptyImages = imagesWithRow.map(_._1).filter(_.result.nonEmpty).toArray
×
274

275
    val allAnnotations =
276
      if (noneEmptyImages.nonEmpty) {
×
277
        getModelIfNotSet.predict(
×
278
          images = noneEmptyImages,
279
          batchSize = $(batchSize),
×
280
          preprocessor = Preprocessor(
×
281
            do_normalize = getDoNormalize,
×
282
            do_resize = getDoResize,
×
283
            feature_extractor_type = getFeatureExtractorType,
×
284
            image_mean = getImageMean,
×
285
            image_std = getImageStd,
×
286
            resample = getResample,
×
287
            size = getSize))
×
288
      } else {
289
        Seq.empty[Annotation]
×
290
      }
291

292
    // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
293
    batchedAnnotations.indices.map(rowIndex => {
×
294
      val rowAnnotations = allAnnotations
295
        // zip each annotation with its corresponding row index
296
        .zip(imagesWithRow)
×
297
        // select the sentences belonging to the current row
298
        .filter(_._2._2 == rowIndex)
×
299
        // leave the annotation only
300
        .map(_._1)
×
301

302
      if (rowAnnotations.nonEmpty)
×
303
        rowAnnotations
×
304
      else
305
        Seq.empty[Annotation]
×
306
    })
307

308
  }
309

310
  override def onWrite(path: String, spark: SparkSession): Unit = {
311
    super.onWrite(path, spark)
×
312
    val suffix = "_image_classification"
×
313

314
    getEngine match {
×
315
      case TensorFlow.name =>
316
        writeTensorflowModelV2(
×
317
          path,
318
          spark,
319
          getModelIfNotSet.tensorflowWrapper.get,
×
320
          suffix,
321
          ViTForImageClassification.tfFile,
×
322
          configProtoBytes = getConfigProtoBytes)
×
323
      case ONNX.name =>
324
        writeOnnxModel(
×
325
          path,
326
          spark,
327
          getModelIfNotSet.onnxWrapper.get,
×
328
          suffix,
329
          ViTForImageClassification.onnxFile)
×
330

331
      case Openvino.name =>
332
        writeOpenvinoModel(
×
333
          path,
334
          spark,
335
          getModelIfNotSet.openvinoWrapper.get,
×
336
          "openvino_model.xml",
×
337
          ViTForImageClassification.openvinoFile)
×
338
    }
339
  }
340

341
}
342

343
trait ReadablePretrainedViTForImageModel
344
    extends ParamsAndFeaturesReadable[ViTForImageClassification]
345
    with HasPretrained[ViTForImageClassification] {
346
  override val defaultModelName: Some[String] = Some("image_classifier_vit_base_patch16_224")
1✔
347

348
  /** Java compliant-overrides */
349
  override def pretrained(): ViTForImageClassification = super.pretrained()
1✔
350

351
  override def pretrained(name: String): ViTForImageClassification = super.pretrained(name)
×
352

353
  override def pretrained(name: String, lang: String): ViTForImageClassification =
354
    super.pretrained(name, lang)
×
355

356
  override def pretrained(
357
      name: String,
358
      lang: String,
359
      remoteLoc: String): ViTForImageClassification = super.pretrained(name, lang, remoteLoc)
1✔
360
}
361

362
trait ReadViTForImageDLModel
363
    extends ReadTensorflowModel
364
    with ReadOnnxModel
365
    with ReadOpenvinoModel {
366
  this: ParamsAndFeaturesReadable[ViTForImageClassification] =>
367

368
  override val tfFile: String = "image_classification_tensorflow"
1✔
369
  override val onnxFile: String = "image_classification_onnx"
1✔
370
  override val openvinoFile: String = "image_classification_openvino"
1✔
371

372
  def readModel(instance: ViTForImageClassification, path: String, spark: SparkSession): Unit = {
373

374
    val preprocessor = Preprocessor(
1✔
375
      do_normalize = true,
1✔
376
      do_resize = true,
1✔
377
      "ViTFeatureExtractor",
1✔
378
      instance.getImageMean,
1✔
379
      instance.getImageStd,
1✔
380
      instance.getResample,
1✔
381
      instance.getSize)
1✔
382
    instance.getEngine match {
1✔
383
      case TensorFlow.name =>
384
        val tfWrapper =
385
          readTensorflowModel(path, spark, tfFile, initAllTables = false)
×
386

387
        instance.setModelIfNotSet(spark, Some(tfWrapper), None, None, preprocessor)
×
388
      case ONNX.name =>
389
        val onnxWrapper =
390
          readOnnxModel(path, spark, onnxFile, zipped = true, useBundle = false, None)
1✔
391

392
        instance.setModelIfNotSet(spark, None, Some(onnxWrapper), None, preprocessor)
1✔
393

394
      case Openvino.name =>
395
        val openvinoWrapper =
396
          readOpenvinoModel(path, spark, "vit_for_image_classification_openvino")
×
397
        instance.setModelIfNotSet(spark, None, None, Some(openvinoWrapper), preprocessor)
×
398

399
      case _ =>
400
        throw new Exception(notSupportedEngineError)
×
401
    }
402

403
  }
404

405
  addReader(readModel)
1✔
406

407
  def loadSavedModel(modelPath: String, spark: SparkSession): ViTForImageClassification = {
408

409
    val (localModelPath, detectedEngine) = modelSanityCheck(modelPath)
×
410

411
    // TODO: sometimes results in [String, BigInt] where BigInt is actually a string
412
    val labelJsonContent = loadJsonStringAsset(localModelPath, "labels.json")
×
413
    val labelJsonMap =
414
      parse(labelJsonContent, useBigIntForLong = true).values
×
415
        .asInstanceOf[Map[String, BigInt]]
×
416

417
    val preprocessorConfigJsonContent =
418
      loadJsonStringAsset(localModelPath, "preprocessor_config.json")
×
419
    val preprocessorConfig =
420
      Preprocessor.loadPreprocessorConfig(preprocessorConfigJsonContent)
×
421

422
    /*Universal parameters for all engines*/
423
    val annotatorModel = new ViTForImageClassification()
424
      .setLabels(labelJsonMap)
425
      .setDoNormalize(preprocessorConfig.do_normalize)
×
426
      .setDoResize(preprocessorConfig.do_resize)
×
427
      .setFeatureExtractorType(preprocessorConfig.feature_extractor_type)
×
428
      .setImageMean(preprocessorConfig.image_mean)
×
429
      .setImageStd(preprocessorConfig.image_std)
×
430
      .setResample(preprocessorConfig.resample)
×
431
      .setSize(preprocessorConfig.size)
×
432

433
    annotatorModel.set(annotatorModel.engine, detectedEngine)
×
434

435
    detectedEngine match {
436
      case TensorFlow.name =>
437
        val (tfwrapper, signatures) =
×
438
          TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)
439

440
        val _signatures = signatures match {
441
          case Some(s) => s
442
          case None => throw new Exception("Cannot load signature definitions from model!")
×
443
        }
444

445
        /** the order of setSignatures is important if we use getSignatures inside
446
          * setModelIfNotSet
447
          */
448
        annotatorModel
449
          .setSignatures(_signatures)
450
          .setModelIfNotSet(spark, Some(tfwrapper), None, None, preprocessorConfig)
×
451

452
      case ONNX.name =>
453
        val onnxWrapper =
454
          OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true)
×
455

456
        annotatorModel
457
          .setModelIfNotSet(spark, None, Some(onnxWrapper), None, preprocessorConfig)
×
458

459
      case Openvino.name =>
460
        val ovWrapper: OpenvinoWrapper =
461
          OpenvinoWrapper.read(
×
462
            spark,
463
            localModelPath,
464
            zipped = false,
×
465
            useBundle = true,
×
466
            detectedEngine = detectedEngine)
467
        annotatorModel
468
          .setModelIfNotSet(spark, None, None, Some(ovWrapper), preprocessorConfig)
×
469

470
      case _ =>
471
        throw new Exception(notSupportedEngineError)
×
472
    }
473

474
    annotatorModel
475
  }
476
}
477

478
/** This is the companion object of [[ViTForImageClassification]]. Please refer to that class for
479
  * the documentation.
480
  */
481
object ViTForImageClassification
482
    extends ReadablePretrainedViTForImageModel
483
    with ReadViTForImageDLModel
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc