• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 13883000244

16 Mar 2025 11:44AM CUT coverage: 59.034% (-1.0%) from 60.072%
13883000244

Pull #14444

github

web-flow
Merge 6d717703b into 05000ab4a
Pull Request #14444: Sparknlp 1060 implement phi 3.5 vision

0 of 292 new or added lines in 5 files covered. (0.0%)

20 existing lines in 14 files now uncovered.

9413 of 15945 relevant lines covered (59.03%)

0.59 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

39.81
/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala
1
/*
2
 * Copyright 2017-2023 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.pretrained
18

19
import com.johnsnowlabs.nlp.annotators._
20
import com.johnsnowlabs.nlp.annotators.audio.{HubertForCTC, Wav2Vec2ForCTC, WhisperForCTC}
21
import com.johnsnowlabs.nlp.annotators.classifier.dl._
22
import com.johnsnowlabs.nlp.annotators.coref.SpanBertCorefModel
23
import com.johnsnowlabs.nlp.annotators.cv._
24
import com.johnsnowlabs.nlp.annotators.er.EntityRulerModel
25
import com.johnsnowlabs.nlp.annotators.ld.dl.LanguageDetectorDL
26
import com.johnsnowlabs.nlp.annotators.ner.crf.NerCrfModel
27
import com.johnsnowlabs.nlp.annotators.ner.dl.{NerDLModel, ZeroShotNerModel}
28
import com.johnsnowlabs.nlp.annotators.parser.dep.DependencyParserModel
29
import com.johnsnowlabs.nlp.annotators.parser.typdep.TypedDependencyParserModel
30
import com.johnsnowlabs.nlp.annotators.pos.perceptron.PerceptronModel
31
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
32
import com.johnsnowlabs.nlp.annotators.sda.pragmatic.SentimentDetectorModel
33
import com.johnsnowlabs.nlp.annotators.sda.vivekn.ViveknSentimentModel
34
import com.johnsnowlabs.nlp.annotators.sentence_detector_dl.SentenceDetectorDLModel
35
import com.johnsnowlabs.nlp.annotators.seq2seq._
36
import com.johnsnowlabs.nlp.annotators.spell.context.ContextSpellCheckerModel
37
import com.johnsnowlabs.nlp.annotators.spell.norvig.NorvigSweetingModel
38
import com.johnsnowlabs.nlp.annotators.spell.symmetric.SymmetricDeleteModel
39
import com.johnsnowlabs.nlp.annotators.ws.WordSegmenterModel
40
import com.johnsnowlabs.nlp.embeddings._
41
import com.johnsnowlabs.nlp.pretrained.ResourceType.ResourceType
42
import com.johnsnowlabs.nlp.util.io.{OutputHelper, ResourceHelper}
43
import com.johnsnowlabs.nlp.{DocumentAssembler, PromptAssembler, TableAssembler, pretrained}
44
import com.johnsnowlabs.util._
45
import org.apache.hadoop.fs.FileSystem
46
import org.apache.spark.ml.util.DefaultParamsReadable
47
import org.apache.spark.ml.{PipelineModel, PipelineStage}
48
import org.slf4j.{Logger, LoggerFactory}
49

50
import scala.collection.mutable
51
import scala.collection.mutable.ListBuffer
52
import scala.concurrent.ExecutionContext.Implicits.global
53
import scala.concurrent.Future
54
import scala.util.{Failure, Success}
55

56
trait ResourceDownloader {
57

58
  /** Download resource to local file
59
    *
60
    * @param request
61
    *   Resource request
62
    * @return
63
    *   downloaded file or None if resource is not found
64
    */
65
  def download(request: ResourceRequest): Option[String]
66

67
  def getDownloadSize(request: ResourceRequest): Option[Long]
68

69
  def clearCache(request: ResourceRequest): Unit
70

71
  def downloadMetadataIfNeed(folder: String): List[ResourceMetadata]
72

73
  def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean = true): Option[String]
74

75
  val fileSystem: FileSystem = ResourceDownloader.fileSystem
1✔
76

77
}
78

79
object ResourceDownloader {
80

81
  private val logger: Logger = LoggerFactory.getLogger(this.getClass.toString)
1✔
82

83
  val fileSystem: FileSystem = OutputHelper.getFileSystem
1✔
84

85
  def s3Bucket: String = ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedS3BucketKey)
1✔
86

87
  def s3BucketCommunity: String =
88
    ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedCommunityS3BucketKey)
1✔
89

90
  def s3Path: String = ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedS3PathKey)
1✔
91

92
  def cacheFolder: String = ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedCacheFolder)
1✔
93

94
  val publicLoc = "public/models"
1✔
95

96
  private val cache: mutable.Map[ResourceRequest, PipelineStage] =
97
    mutable.Map[ResourceRequest, PipelineStage]()
1✔
98

99
  lazy val sparkVersion: Version = {
100
    val spark_version = ResourceHelper.spark.version
101
    Version.parse(spark_version)
102
  }
103

104
  lazy val libVersion: Version = {
105
    Version.parse(Build.version)
106
  }
107

108
  var privateDownloader: ResourceDownloader =
109
    new S3ResourceDownloader(s3Bucket, s3Path, cacheFolder, "private")
1✔
110
  var publicDownloader: ResourceDownloader =
111
    new S3ResourceDownloader(s3Bucket, s3Path, cacheFolder, "public")
1✔
112
  var communityDownloader: ResourceDownloader =
113
    new S3ResourceDownloader(s3BucketCommunity, s3Path, cacheFolder, "community")
1✔
114

115
  def getResourceDownloader(folder: String): ResourceDownloader = {
116
    folder match {
117
      case this.publicLoc => publicDownloader
1✔
118
      case loc if loc.startsWith("@") => communityDownloader
1✔
119
      case _ => privateDownloader
×
120
    }
121
  }
122

123
  /** Reset the cache and recreate ResourceDownloader S3 credentials */
124
  def resetResourceDownloader(): Unit = {
125
    cache.empty
×
126
    this.privateDownloader = new S3ResourceDownloader(s3Bucket, s3Path, cacheFolder, "private")
×
127
  }
128

129
  /** List all pretrained models in public name_lang */
130
  def listPublicModels(): List[String] = {
131
    listPretrainedResources(folder = publicLoc, ResourceType.MODEL)
×
132
  }
133

134
  /** Prints all pretrained models for a particular annotator model, that are compatible with a
135
    * version of Spark NLP. If any of the optional arguments are not set, the filter is not
136
    * considered.
137
    *
138
    * @param annotator
139
    *   Name of the model class, for example "NerDLModel"
140
    * @param lang
141
    *   Language of the pretrained models to display, for example "en"
142
    * @param version
143
    *   Version of Spark NLP that the model should be compatible with, for example "3.2.3"
144
    */
145
  def showPublicModels(
146
      annotator: Option[String] = None,
147
      lang: Option[String] = None,
148
      version: Option[String] = Some(Build.version)): Unit = {
149
    println(
1✔
150
      publicResourceString(
1✔
151
        annotator = annotator,
152
        lang = lang,
153
        version = version,
154
        resourceType = ResourceType.MODEL))
1✔
155
  }
156

157
  /** Prints all pretrained models for a particular annotator model, that are compatible with this
158
    * version of Spark NLP.
159
    *
160
    * @param annotator
161
    *   Name of the annotator class
162
    */
163
  def showPublicModels(annotator: String): Unit = showPublicModels(Some(annotator))
1✔
164

165
  /** Prints all pretrained models for a particular annotator model, that are compatible with this
166
    * version of Spark NLP.
167
    *
168
    * @param annotator
169
    *   Name of the annotator class
170
    * @param lang
171
    *   Language of the pretrained models to display
172
    */
173
  def showPublicModels(annotator: String, lang: String): Unit =
174
    showPublicModels(Some(annotator), Some(lang))
1✔
175

176
  /** Prints all pretrained models for a particular annotator, that are compatible with a version
177
    * of Spark NLP.
178
    *
179
    * @param annotator
180
    *   Name of the model class, for example "NerDLModel"
181
    * @param lang
182
    *   Language of the pretrained models to display, for example "en"
183
    * @param version
184
    *   Version of Spark NLP that the model should be compatible with, for example "3.2.3"
185
    */
186
  def showPublicModels(annotator: String, lang: String, version: String): Unit =
187
    showPublicModels(Some(annotator), Some(lang), Some(version))
1✔
188

189
  /** List all pretrained pipelines in public */
190
  def listPublicPipelines(): List[String] = {
191
    listPretrainedResources(folder = publicLoc, ResourceType.PIPELINE)
×
192
  }
193

194
  /** Prints all Pipelines available for a language and a version of Spark NLP. By default shows
195
    * all languages and uses the current version of Spark NLP.
196
    *
197
    * @param lang
198
    *   Language of the Pipeline
199
    * @param version
200
    *   Version of Spark NLP
201
    */
202
  def showPublicPipelines(
203
      lang: Option[String] = None,
204
      version: Option[String] = Some(Build.version)): Unit = {
205
    println(
1✔
206
      publicResourceString(
1✔
207
        annotator = None,
1✔
208
        lang = lang,
209
        version = version,
210
        resourceType = ResourceType.PIPELINE))
1✔
211
  }
212

213
  /** Prints all Pipelines available for a language and this version of Spark NLP.
214
    *
215
    * @param lang
216
    *   Language of the Pipeline
217
    */
218
  def showPublicPipelines(lang: String): Unit = showPublicPipelines(Some(lang))
1✔
219

220
  /** Prints all Pipelines available for a language and a version of Spark NLP.
221
    *
222
    * @param lang
223
    *   Language of the Pipeline
224
    * @param version
225
    *   Version of Spark NLP
226
    */
227
  def showPublicPipelines(lang: String, version: String): Unit =
228
    showPublicPipelines(Some(lang), Some(version))
1✔
229

230
  /** Returns models or pipelines in metadata json which has not been categorized yet.
231
    *
232
    * @return
233
    *   list of models or pipelines which are not categorized in metadata json
234
    */
235
  def listUnCategorizedResources(): List[String] = {
236
    listPretrainedResources(folder = publicLoc, ResourceType.NOT_DEFINED)
×
237
  }
238

239
  def showUnCategorizedResources(lang: String): Unit = {
240
    println(publicResourceString(None, Some(lang), None, resourceType = ResourceType.NOT_DEFINED))
1✔
241
  }
242

243
  def showUnCategorizedResources(lang: String, version: String): Unit = {
244
    println(
×
245
      publicResourceString(
×
246
        None,
×
247
        Some(lang),
×
248
        Some(version),
×
249
        resourceType = ResourceType.NOT_DEFINED))
×
250

251
  }
252

253
  def showString(list: List[String], resourceType: ResourceType): String = {
254
    val sb = new StringBuilder
1✔
255
    var max_length = 14
1✔
256
    var max_length_version = 7
1✔
257
    for (data <- list) {
1✔
258
      val temp = data.split(":")
1✔
259
      max_length = scala.math.max(temp(0).length, max_length)
1✔
260
      max_length_version = scala.math.max(temp(2).length, max_length_version)
1✔
261
    }
262
    // adding head
263
    sb.append("+")
1✔
264
    sb.append("-" * (max_length + 2))
1✔
265
    sb.append("+")
1✔
266
    sb.append("-" * 6)
1✔
267
    sb.append("+")
1✔
268
    sb.append("-" * (max_length_version + 2))
1✔
269
    sb.append("+\n")
1✔
270
    if (resourceType.equals(ResourceType.PIPELINE))
1✔
271
      sb.append(
1✔
272
        "| " + "Pipeline" + (" " * (max_length - 8)) + " | " + "lang" + " | " + "version" + " " * (max_length_version - 7) + " |\n")
1✔
273
    else if (resourceType.equals(ResourceType.MODEL))
1✔
274
      sb.append(
1✔
275
        "| " + "Model" + (" " * (max_length - 5)) + " | " + "lang" + " | " + "version" + " " * (max_length_version - 7) + " |\n")
1✔
276
    else
277
      sb.append(
1✔
278
        "| " + "Pipeline/Model" + (" " * (max_length - 14)) + " | " + "lang" + " | " + "version" + " " * (max_length_version - 7) + " |\n")
1✔
279

280
    sb.append("+")
1✔
281
    sb.append("-" * (max_length + 2))
1✔
282
    sb.append("+")
1✔
283
    sb.append("-" * 6)
1✔
284
    sb.append("+")
1✔
285
    sb.append("-" * (max_length_version + 2))
1✔
286
    sb.append("+\n")
1✔
287
    for (data <- list) {
1✔
288
      val temp = data.split(":")
1✔
289
      sb.append(
1✔
290
        "| " + temp(0) + (" " * (max_length - temp(0).length)) + " |  " + temp(1) + "  | " + temp(
291
          2) + " " * (max_length_version - temp(2).length) + " |\n")
1✔
292
    }
293
    // adding bottom
294
    sb.append("+")
1✔
295
    sb.append("-" * (max_length + 2))
1✔
296
    sb.append("+")
1✔
297
    sb.append("-" * 6)
1✔
298
    sb.append("+")
1✔
299
    sb.append("-" * (max_length_version + 2))
1✔
300
    sb.append("+\n")
1✔
301
    sb.toString()
1✔
302

303
  }
304

305
  def publicResourceString(
306
      annotator: Option[String] = None,
307
      lang: Option[String] = None,
308
      version: Option[String] = Some(Build.version),
309
      resourceType: ResourceType): String = {
310
    showString(
1✔
311
      listPretrainedResources(
1✔
312
        folder = publicLoc,
1✔
313
        resourceType,
314
        annotator = annotator,
315
        lang = lang,
316
        version = version match {
317
          case Some(ver) => Some(Version.parse(ver))
1✔
318
          case None => None
1✔
319
        }),
320
      resourceType)
321
  }
322

323
  /** Lists pretrained resource from metadata.json, depending on the set filters. The folder in
324
    * the S3 location and the resourceType is necessary. The other filters are optional and will
325
    * be ignored if not set.
326
    *
327
    * @param folder
328
    *   Folder in the S3 location
329
    * @param resourceType
330
    *   Type of the Resource. Can Either `ResourceType.MODEL`, `ResourceType.PIPELINE` or
331
    *   `ResourceType.NOT_DEFINED`
332
    * @param annotator
333
    *   Name of the model class
334
    * @param lang
335
    *   Language of the model
336
    * @param version
337
    *   Version that the model should be compatible with
338
    * @return
339
    *   A list of the available resources
340
    */
341
  def listPretrainedResources(
342
      folder: String,
343
      resourceType: ResourceType,
344
      annotator: Option[String] = None,
345
      lang: Option[String] = None,
346
      version: Option[Version] = None): List[String] = {
347
    val resourceList = new ListBuffer[String]()
1✔
348

349
    val resourceMetaData = getResourceMetadata(folder)
1✔
350

351
    for (meta <- resourceMetaData) {
1✔
352
      val isSameResourceType =
353
        meta.category.getOrElse(ResourceType.NOT_DEFINED).toString.equals(resourceType.toString)
1✔
354
      val isCompatibleWithVersion = version match {
355
        case Some(ver) => Version.isCompatible(ver, meta.libVersion)
1✔
356
        case None => true
1✔
357
      }
358
      val isSameAnnotator = annotator match {
359
        case Some(cls) => meta.annotator.getOrElse("").equalsIgnoreCase(cls)
1✔
360
        case None => true
1✔
361
      }
362
      val isSameLanguage = lang match {
363
        case Some(l) => meta.language.getOrElse("").equalsIgnoreCase(l)
1✔
364
        case None => true
1✔
365
      }
366

367
      if (isSameResourceType & isCompatibleWithVersion & isSameAnnotator & isSameLanguage) {
1✔
368
        resourceList += meta.name + ":" + meta.language.getOrElse("-") + ":" + meta.libVersion
1✔
369
          .getOrElse("-")
1✔
370
      }
371
    }
372
    resourceList.result()
1✔
373
  }
374

375
  def listPretrainedResources(
376
      folder: String,
377
      resourceType: ResourceType,
378
      lang: String): List[String] =
379
    listPretrainedResources(folder, resourceType, lang = Some(lang))
×
380

381
  def listPretrainedResources(
382
      folder: String,
383
      resourceType: ResourceType,
384
      version: Version): List[String] =
385
    listPretrainedResources(folder, resourceType, version = Some(version))
×
386

387
  def listPretrainedResources(
388
      folder: String,
389
      resourceType: ResourceType,
390
      lang: String,
391
      version: Version): List[String] =
392
    listPretrainedResources(folder, resourceType, lang = Some(lang), version = Some(version))
×
393

394
  def listAvailableAnnotators(folder: String = publicLoc): List[String] = {
395

396
    val resourceMetaData = getResourceMetadata(folder)
1✔
397

398
    resourceMetaData
399
      .map(_.annotator.getOrElse(""))
1✔
400
      .toSet
401
      .filter { a =>
402
        !a.equals("")
1✔
403
      }
404
      .toList
405
      .sorted
1✔
406
  }
407

408
  private def getResourceMetadata(location: String): List[ResourceMetadata] = {
409
    getResourceDownloader(location).downloadMetadataIfNeed(location)
1✔
410
  }
411

412
  def showAvailableAnnotators(folder: String = publicLoc): Unit = {
413
    println(listAvailableAnnotators(folder).mkString("\n"))
1✔
414
  }
415

416
  /** Loads resource to path
417
    *
418
    * @param name
419
    *   Name of Resource
420
    * @param folder
421
    *   Subfolder in s3 where to search model (e.g. medicine)
422
    * @param language
423
    *   Desired language of Resource
424
    * @return
425
    *   path of downloaded resource
426
    */
427
  def downloadResource(
428
      name: String,
429
      language: Option[String] = None,
430
      folder: String = publicLoc): String = {
431
    downloadResource(ResourceRequest(name, language, folder))
×
432
  }
433

434
  /** Loads resource to path
435
    *
436
    * @param request
437
    *   Request for resource
438
    * @return
439
    *   path of downloaded resource
440
    */
441
  def downloadResource(request: ResourceRequest): String = {
442
    val future = Future {
1✔
443
      val updatedRequest: ResourceRequest = if (request.folder.startsWith("@")) {
1✔
444
        request.copy(folder = request.folder.replace("@", ""))
1✔
445
      } else request
1✔
446
      getResourceDownloader(request.folder).download(updatedRequest)
1✔
447
    }
448

449
    var downloadFinished = false
1✔
450
    var path: Option[String] = None
1✔
451
    val fileSize = getDownloadSize(request)
1✔
452
    require(
1✔
453
      !fileSize.equals("-1"),
1✔
454
      s"Can not find ${request.name} inside ${request.folder} to download. Please make sure the name and location are correct!")
×
455
    println(request.name + " download started this may take some time.")
1✔
456
    println("Approximate size to download " + fileSize)
1✔
457

458
    while (!downloadFinished) {
1✔
459
      future.onComplete {
1✔
460
        case Success(value) =>
461
          downloadFinished = true
1✔
462
          path = value
463
        case Failure(exception) =>
464
          println(s"Error: ${exception.getMessage}")
×
465
          logger.error(exception.getMessage)
×
466
          downloadFinished = true
×
467
          path = None
×
468
      }
469
      Thread.sleep(1000)
1✔
470

471
    }
472

473
    require(
1✔
474
      path.isDefined,
1✔
475
      s"Was not found appropriate resource to download for request: $request with downloader: $privateDownloader")
×
476
    println("Download done! Loading the resource.")
1✔
477
    path.get
1✔
478
  }
479

480
  /** Downloads a model from the default S3 bucket to the cache pretrained folder.
481
    * @param model
482
    *   the name of the key in the S3 bucket or s3 URI
483
    * @param folder
484
    *   the folder of the model
485
    * @param unzip
486
    *   used to unzip the model, by default true
487
    */
488
  def downloadModelDirectly(
489
      model: String,
490
      folder: String = publicLoc,
491
      unzip: Boolean = true): Unit = {
492
    getResourceDownloader(folder).downloadAndUnzipFile(model, unzip)
×
493
  }
494

495
  def downloadModel[TModel <: PipelineStage](
496
      reader: DefaultParamsReadable[TModel],
497
      name: String,
498
      language: Option[String] = None,
499
      folder: String = publicLoc): TModel = {
500
    downloadModel(reader, ResourceRequest(name, language, folder))
1✔
501
  }
502

503
  def downloadModel[TModel <: PipelineStage](
504
      reader: DefaultParamsReadable[TModel],
505
      request: ResourceRequest): TModel = {
506
    if (!cache.contains(request)) {
1✔
507
      val path = downloadResource(request)
1✔
508
      val model = reader.read.load(path)
1✔
509
      cache(request) = model
1✔
510
      model
511
    } else {
512
      cache(request).asInstanceOf[TModel]
1✔
513
    }
514
  }
515

516
  def downloadPipeline(
517
      name: String,
518
      language: Option[String] = None,
519
      folder: String = publicLoc): PipelineModel = {
520
    downloadPipeline(ResourceRequest(name, language, folder))
×
521
  }
522

523
  def downloadPipeline(request: ResourceRequest): PipelineModel = {
524
    if (!cache.contains(request)) {
×
525
      val path = downloadResource(request)
×
526
      val model = PipelineModel.read.load(path)
×
527
      cache(request) = model
×
528
      model
529
    } else {
530
      cache(request).asInstanceOf[PipelineModel]
×
531
    }
532
  }
533

534
  def clearCache(
535
      name: String,
536
      language: Option[String] = None,
537
      folder: String = publicLoc): Unit = {
538
    clearCache(ResourceRequest(name, language, folder))
×
539
  }
540

541
  def clearCache(request: ResourceRequest): Unit = {
542
    privateDownloader.clearCache(request)
×
543
    publicDownloader.clearCache(request)
×
544
    communityDownloader.clearCache(request)
×
545
    cache.remove(request)
×
546
  }
547

548
  def getDownloadSize(resourceRequest: ResourceRequest): String = {
549

550
    val updatedResourceRequest: ResourceRequest = if (resourceRequest.folder.startsWith("@")) {
1✔
551
      resourceRequest.copy(folder = resourceRequest.folder.replace("@", ""))
1✔
552
    } else resourceRequest
1✔
553

554
    val size = getResourceDownloader(resourceRequest.folder)
555
      .getDownloadSize(updatedResourceRequest)
1✔
556

557
    size match {
558
      case Some(downloadBytes) => FileHelper.getHumanReadableFileSize(downloadBytes)
1✔
559
      case None => "-1"
×
560

561
    }
562
  }
563

564
}
565

566
object ResourceType extends Enumeration {
567
  type ResourceType = Value
568
  val MODEL: pretrained.ResourceType.Value = Value("ml")
1✔
569
  val PIPELINE: pretrained.ResourceType.Value = Value("pl")
1✔
570
  val NOT_DEFINED: pretrained.ResourceType.Value = Value("nd")
1✔
571
}
572

573
case class ResourceRequest(
574
    name: String,
575
    language: Option[String] = None,
576
    folder: String = ResourceDownloader.publicLoc,
577
    libVersion: Version = ResourceDownloader.libVersion,
578
    sparkVersion: Version = ResourceDownloader.sparkVersion)
579

580
/* convenience accessor for Py4J calls */
581
object PythonResourceDownloader {
582

583
  val keyToReader: mutable.Map[String, DefaultParamsReadable[_]] = mutable.Map(
×
584
    "DocumentAssembler" -> DocumentAssembler,
×
585
    "SentenceDetector" -> SentenceDetector,
×
586
    "TokenizerModel" -> TokenizerModel,
×
587
    "PerceptronModel" -> PerceptronModel,
×
588
    "NerCrfModel" -> NerCrfModel,
×
589
    "Stemmer" -> Stemmer,
×
590
    "NormalizerModel" -> NormalizerModel,
×
591
    "RegexMatcherModel" -> RegexMatcherModel,
×
592
    "LemmatizerModel" -> LemmatizerModel,
×
593
    "DateMatcher" -> DateMatcher,
×
594
    "TextMatcherModel" -> TextMatcherModel,
×
595
    "SentimentDetectorModel" -> SentimentDetectorModel,
×
596
    "ViveknSentimentModel" -> ViveknSentimentModel,
×
597
    "NorvigSweetingModel" -> NorvigSweetingModel,
×
598
    "SymmetricDeleteModel" -> SymmetricDeleteModel,
×
599
    "NerDLModel" -> NerDLModel,
×
600
    "WordEmbeddingsModel" -> WordEmbeddingsModel,
×
601
    "BertEmbeddings" -> BertEmbeddings,
×
602
    "DependencyParserModel" -> DependencyParserModel,
×
603
    "TypedDependencyParserModel" -> TypedDependencyParserModel,
×
604
    "UniversalSentenceEncoder" -> UniversalSentenceEncoder,
×
605
    "ElmoEmbeddings" -> ElmoEmbeddings,
×
606
    "ClassifierDLModel" -> ClassifierDLModel,
×
607
    "ContextSpellCheckerModel" -> ContextSpellCheckerModel,
×
608
    "AlbertEmbeddings" -> AlbertEmbeddings,
×
609
    "XlnetEmbeddings" -> XlnetEmbeddings,
×
610
    "SentimentDLModel" -> SentimentDLModel,
×
611
    "LanguageDetectorDL" -> LanguageDetectorDL,
×
612
    "StopWordsCleaner" -> StopWordsCleaner,
×
613
    "BertSentenceEmbeddings" -> BertSentenceEmbeddings,
×
614
    "MultiClassifierDLModel" -> MultiClassifierDLModel,
×
615
    "SentenceDetectorDLModel" -> SentenceDetectorDLModel,
×
616
    "T5Transformer" -> T5Transformer,
×
617
    "MarianTransformer" -> MarianTransformer,
×
618
    "WordSegmenterModel" -> WordSegmenterModel,
×
619
    "DistilBertEmbeddings" -> DistilBertEmbeddings,
×
620
    "RoBertaEmbeddings" -> RoBertaEmbeddings,
×
621
    "XlmRoBertaEmbeddings" -> XlmRoBertaEmbeddings,
×
622
    "LongformerEmbeddings" -> LongformerEmbeddings,
×
623
    "RoBertaSentenceEmbeddings" -> RoBertaSentenceEmbeddings,
×
624
    "XlmRoBertaSentenceEmbeddings" -> XlmRoBertaSentenceEmbeddings,
×
625
    "AlbertForTokenClassification" -> AlbertForTokenClassification,
×
626
    "BertForTokenClassification" -> BertForTokenClassification,
×
627
    "DeBertaForTokenClassification" -> DeBertaForTokenClassification,
×
628
    "DistilBertForTokenClassification" -> DistilBertForTokenClassification,
×
629
    "LongformerForTokenClassification" -> LongformerForTokenClassification,
×
630
    "RoBertaForTokenClassification" -> RoBertaForTokenClassification,
×
631
    "XlmRoBertaForTokenClassification" -> XlmRoBertaForTokenClassification,
×
632
    "XlnetForTokenClassification" -> XlnetForTokenClassification,
×
633
    "AlbertForSequenceClassification" -> AlbertForSequenceClassification,
×
634
    "BertForSequenceClassification" -> BertForSequenceClassification,
×
635
    "DeBertaForSequenceClassification" -> DeBertaForSequenceClassification,
×
636
    "DistilBertForSequenceClassification" -> DistilBertForSequenceClassification,
×
637
    "LongformerForSequenceClassification" -> LongformerForSequenceClassification,
×
638
    "RoBertaForSequenceClassification" -> RoBertaForSequenceClassification,
×
639
    "XlmRoBertaForSequenceClassification" -> XlmRoBertaForSequenceClassification,
×
640
    "XlnetForSequenceClassification" -> XlnetForSequenceClassification,
×
641
    "GPT2Transformer" -> GPT2Transformer,
×
642
    "EntityRulerModel" -> EntityRulerModel,
×
643
    "Doc2VecModel" -> Doc2VecModel,
×
644
    "Word2VecModel" -> Word2VecModel,
×
645
    "DeBertaEmbeddings" -> DeBertaEmbeddings,
×
646
    "DeBertaForSequenceClassification" -> DeBertaForSequenceClassification,
×
647
    "DeBertaForTokenClassification" -> DeBertaForTokenClassification,
×
648
    "CamemBertEmbeddings" -> CamemBertEmbeddings,
×
649
    "AlbertForQuestionAnswering" -> AlbertForQuestionAnswering,
×
650
    "BertForQuestionAnswering" -> BertForQuestionAnswering,
×
651
    "DeBertaForQuestionAnswering" -> DeBertaForQuestionAnswering,
×
652
    "DistilBertForQuestionAnswering" -> DistilBertForQuestionAnswering,
×
653
    "LongformerForQuestionAnswering" -> LongformerForQuestionAnswering,
×
654
    "RoBertaForQuestionAnswering" -> RoBertaForQuestionAnswering,
×
655
    "XlmRoBertaForQuestionAnswering" -> XlmRoBertaForQuestionAnswering,
×
656
    "SpanBertCorefModel" -> SpanBertCorefModel,
×
657
    "ViTForImageClassification" -> ViTForImageClassification,
×
658
    "VisionEncoderDecoderForImageCaptioning" -> VisionEncoderDecoderForImageCaptioning,
×
659
    "SwinForImageClassification" -> SwinForImageClassification,
×
660
    "ConvNextForImageClassification" -> ConvNextForImageClassification,
×
661
    "Wav2Vec2ForCTC" -> Wav2Vec2ForCTC,
×
662
    "HubertForCTC" -> HubertForCTC,
×
663
    "WhisperForCTC" -> WhisperForCTC,
×
664
    "CamemBertForTokenClassification" -> CamemBertForTokenClassification,
×
665
    "TableAssembler" -> TableAssembler,
×
666
    "TapasForQuestionAnswering" -> TapasForQuestionAnswering,
×
667
    "CamemBertForSequenceClassification" -> CamemBertForSequenceClassification,
×
668
    "CamemBertForQuestionAnswering" -> CamemBertForQuestionAnswering,
×
669
    "ZeroShotNerModel" -> ZeroShotNerModel,
×
670
    "BartTransformer" -> BartTransformer,
×
671
    "BertForZeroShotClassification" -> BertForZeroShotClassification,
×
672
    "DistilBertForZeroShotClassification" -> DistilBertForZeroShotClassification,
×
673
    "RoBertaForZeroShotClassification" -> RoBertaForZeroShotClassification,
×
674
    "XlmRoBertaForZeroShotClassification" -> XlmRoBertaForZeroShotClassification,
×
675
    "BartForZeroShotClassification" -> BartForZeroShotClassification,
×
676
    "InstructorEmbeddings" -> InstructorEmbeddings,
×
677
    "E5Embeddings" -> E5Embeddings,
×
678
    "MPNetEmbeddings" -> MPNetEmbeddings,
×
679
    "CLIPForZeroShotClassification" -> CLIPForZeroShotClassification,
×
680
    "DeBertaForZeroShotClassification" -> DeBertaForZeroShotClassification,
×
681
    "BGEEmbeddings" -> BGEEmbeddings,
×
682
    "MPNetForSequenceClassification" -> MPNetForSequenceClassification,
×
683
    "MPNetForQuestionAnswering" -> MPNetForQuestionAnswering,
×
684
    "LLAMA2Transformer" -> LLAMA2Transformer,
×
685
    "LLAMA3Transformer" -> LLAMA3Transformer,
×
686
    "M2M100Transformer" -> M2M100Transformer,
×
687
    "UAEEmbeddings" -> UAEEmbeddings,
×
688
    "AutoGGUFModel" -> AutoGGUFModel,
×
689
    "AlbertForZeroShotClassification" -> AlbertForZeroShotClassification,
×
690
    "MxbaiEmbeddings" -> MxbaiEmbeddings,
×
691
    "SnowFlakeEmbeddings" -> SnowFlakeEmbeddings,
×
692
    "CamemBertForZeroShotClassification" -> CamemBertForZeroShotClassification,
×
693
    "BertForMultipleChoice" -> BertForMultipleChoice,
×
694
    "PromptAssembler" -> PromptAssembler,
×
695
    "CPMTransformer" -> CPMTransformer,
×
696
    "NomicEmbeddings" -> NomicEmbeddings,
×
697
    "NLLBTransformer" -> NLLBTransformer,
×
698
    "Phi3Transformer" -> Phi3Transformer,
×
699
    "QwenTransformer" -> QwenTransformer,
×
700
    "AutoGGUFEmbeddings" -> AutoGGUFEmbeddings,
×
NEW
701
    "Phi3Vision" -> Phi3Vision,
×
UNCOV
702
    "OLMoTransformer" -> OLMoTransformer)
×
703

704
  // List pairs of types such as the one with key type can load a pretrained model from the value type
705
  val typeMapper: Map[String, String] = Map("ZeroShotNerModel" -> "RoBertaForQuestionAnswering")
×
706

707
  def downloadModel(
708
      readerStr: String,
709
      name: String,
710
      language: String = null,
711
      remoteLoc: String = null): PipelineStage = {
712

713
    val reader = keyToReader.getOrElse(
×
714
      if (typeMapper.contains(readerStr)) typeMapper(readerStr) else readerStr,
×
715
      throw new RuntimeException(s"Unsupported Model: $readerStr"))
×
716

717
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
718

719
    val model = ResourceDownloader.downloadModel(
×
720
      reader.asInstanceOf[DefaultParamsReadable[PipelineStage]],
×
721
      name,
722
      Option(language),
×
723
      correctedFolder)
724

725
    // Cast the model to the required type. This has to be done for each entry in the typeMapper map
726
    if (typeMapper.contains(readerStr) && readerStr == "ZeroShotNerModel")
×
727
      ZeroShotNerModel(model)
×
728
    else
729
      model
×
730
  }
731

732
  def downloadPipeline(
733
      name: String,
734
      language: String = null,
735
      remoteLoc: String = null): PipelineModel = {
736
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
737
    ResourceDownloader.downloadPipeline(name, Option(language), correctedFolder)
×
738
  }
739

740
  def clearCache(name: String, language: String = null, remoteLoc: String = null): Unit = {
741
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
742
    ResourceDownloader.clearCache(name, Option(language), correctedFolder)
×
743
  }
744

745
  def downloadModelDirectly(
746
      model: String,
747
      remoteLoc: String = null,
748
      unzip: Boolean = true): Unit = {
749
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
750
    ResourceDownloader.downloadModelDirectly(model, correctedFolder, unzip)
×
751
  }
752

753
  def showUnCategorizedResources(): String = {
754
    ResourceDownloader.publicResourceString(
×
755
      annotator = None,
×
756
      lang = None,
×
757
      version = None,
×
758
      resourceType = ResourceType.NOT_DEFINED)
×
759
  }
760

761
  def showPublicPipelines(lang: String, version: String): String = {
762
    val ver: Option[String] = version match {
763
      case null => Some(Build.version)
×
764
      case _ => Some(version)
×
765
    }
766
    ResourceDownloader.publicResourceString(
×
767
      annotator = None,
×
768
      lang = Option(lang),
×
769
      version = ver,
770
      resourceType = ResourceType.PIPELINE)
×
771
  }
772

773
  def showPublicModels(annotator: String, lang: String, version: String): String = {
774
    val ver: Option[String] = version match {
775
      case null => Some(Build.version)
×
776
      case _ => Some(version)
×
777
    }
778
    ResourceDownloader.publicResourceString(
×
779
      annotator = Option(annotator),
×
780
      lang = Option(lang),
×
781
      version = ver,
782
      resourceType = ResourceType.MODEL)
×
783
  }
784

785
  def showAvailableAnnotators(): String = {
786
    ResourceDownloader.listAvailableAnnotators().mkString("\n")
×
787
  }
788

789
  def getDownloadSize(name: String, language: String = "en", remoteLoc: String = null): String = {
790
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
791
    ResourceDownloader.getDownloadSize(ResourceRequest(name, Option(language), correctedFolder))
×
792
  }
793
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc