• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 7861513225

11 Feb 2024 11:05AM UTC coverage: 62.678% (-0.05%) from 62.731%
7861513225

Pull #14169

github

web-flow
Merge 13f2acde4 into 6010244ba
Pull Request #14169: Fixed a bug with models that has 'onnx_data' file not working in dbfs/hdfs

8951 of 14281 relevant lines covered (62.68%)

0.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

41.72
/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala
1
/*
2
 * Copyright 2017-2023 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.pretrained
18

19
import com.johnsnowlabs.nlp.annotators._
20
import com.johnsnowlabs.nlp.annotators.audio.{HubertForCTC, Wav2Vec2ForCTC, WhisperForCTC}
21
import com.johnsnowlabs.nlp.annotators.classifier.dl._
22
import com.johnsnowlabs.nlp.annotators.coref.SpanBertCorefModel
23
import com.johnsnowlabs.nlp.annotators.cv._
24
import com.johnsnowlabs.nlp.annotators.er.EntityRulerModel
25
import com.johnsnowlabs.nlp.annotators.ld.dl.LanguageDetectorDL
26
import com.johnsnowlabs.nlp.annotators.ner.crf.NerCrfModel
27
import com.johnsnowlabs.nlp.annotators.ner.dl.{NerDLModel, ZeroShotNerModel}
28
import com.johnsnowlabs.nlp.annotators.parser.dep.DependencyParserModel
29
import com.johnsnowlabs.nlp.annotators.parser.typdep.TypedDependencyParserModel
30
import com.johnsnowlabs.nlp.annotators.pos.perceptron.PerceptronModel
31
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
32
import com.johnsnowlabs.nlp.annotators.sda.pragmatic.SentimentDetectorModel
33
import com.johnsnowlabs.nlp.annotators.sda.vivekn.ViveknSentimentModel
34
import com.johnsnowlabs.nlp.annotators.sentence_detector_dl.SentenceDetectorDLModel
35
import com.johnsnowlabs.nlp.annotators.seq2seq.{
36
  BartTransformer,
37
  GPT2Transformer,
38
  LLAMA2Transformer,
39
  M2M100Transformer,
40
  MarianTransformer,
41
  T5Transformer
42
}
43
import com.johnsnowlabs.nlp.annotators.spell.context.ContextSpellCheckerModel
44
import com.johnsnowlabs.nlp.annotators.spell.norvig.NorvigSweetingModel
45
import com.johnsnowlabs.nlp.annotators.spell.symmetric.SymmetricDeleteModel
46
import com.johnsnowlabs.nlp.annotators.ws.WordSegmenterModel
47
import com.johnsnowlabs.nlp.embeddings._
48
import com.johnsnowlabs.nlp.pretrained.ResourceType.ResourceType
49
import com.johnsnowlabs.nlp.util.io.{OutputHelper, ResourceHelper}
50
import com.johnsnowlabs.nlp.{DocumentAssembler, TableAssembler, pretrained}
51
import com.johnsnowlabs.util._
52
import org.apache.hadoop.fs.FileSystem
53
import org.apache.spark.ml.util.DefaultParamsReadable
54
import org.apache.spark.ml.{PipelineModel, PipelineStage}
55
import org.slf4j.{Logger, LoggerFactory}
56

57
import scala.collection.mutable
58
import scala.collection.mutable.ListBuffer
59
import scala.concurrent.ExecutionContext.Implicits.global
60
import scala.concurrent.Future
61
import scala.util.{Failure, Success}
62

63
trait ResourceDownloader {
64

65
  /** Download resource to local file
66
    *
67
    * @param request
68
    *   Resource request
69
    * @return
70
    *   downloaded file or None if resource is not found
71
    */
72
  def download(request: ResourceRequest): Option[String]
73

74
  def getDownloadSize(request: ResourceRequest): Option[Long]
75

76
  def clearCache(request: ResourceRequest): Unit
77

78
  def downloadMetadataIfNeed(folder: String): List[ResourceMetadata]
79

80
  def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean = true): Option[String]
81

82
  val fileSystem: FileSystem = ResourceDownloader.fileSystem
1✔
83

84
}
85

86
object ResourceDownloader {
87

88
  private val logger: Logger = LoggerFactory.getLogger(this.getClass.toString)
1✔
89

90
  val fileSystem: FileSystem = OutputHelper.getFileSystem
1✔
91

92
  def s3Bucket: String = ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedS3BucketKey)
1✔
93

94
  def s3BucketCommunity: String =
95
    ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedCommunityS3BucketKey)
1✔
96

97
  def s3Path: String = ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedS3PathKey)
1✔
98

99
  def cacheFolder: String = ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedCacheFolder)
1✔
100

101
  val publicLoc = "public/models"
1✔
102

103
  private val cache: mutable.Map[ResourceRequest, PipelineStage] =
104
    mutable.Map[ResourceRequest, PipelineStage]()
1✔
105

106
  lazy val sparkVersion: Version = {
107
    val spark_version = ResourceHelper.spark.version
108
    Version.parse(spark_version)
109
  }
110

111
  lazy val libVersion: Version = {
112
    Version.parse(Build.version)
113
  }
114

115
  var privateDownloader: ResourceDownloader =
116
    new S3ResourceDownloader(s3Bucket, s3Path, cacheFolder, "private")
×
117
  var publicDownloader: ResourceDownloader =
118
    new S3ResourceDownloader(s3Bucket, s3Path, cacheFolder, "public")
1✔
119
  var communityDownloader: ResourceDownloader =
120
    new S3ResourceDownloader(s3BucketCommunity, s3Path, cacheFolder, "community")
1✔
121

122
  def getResourceDownloader(folder: String): ResourceDownloader = {
123
    folder match {
124
      case this.publicLoc => publicDownloader
1✔
125
      case loc if loc.startsWith("@") => communityDownloader
1✔
126
      case _ => privateDownloader
×
127
    }
128
  }
129

130
  /** Reset the cache and recreate ResourceDownloader S3 credentials */
131
  def resetResourceDownloader(): Unit = {
132
    cache.empty
×
133
    this.privateDownloader = new S3ResourceDownloader(s3Bucket, s3Path, cacheFolder, "private")
×
134
  }
135

136
  /** List all pretrained models in public name_lang */
137
  def listPublicModels(): List[String] = {
138
    listPretrainedResources(folder = publicLoc, ResourceType.MODEL)
×
139
  }
140

141
  /** Prints all pretrained models for a particular annotator model, that are compatible with a
142
    * version of Spark NLP. If any of the optional arguments are not set, the filter is not
143
    * considered.
144
    *
145
    * @param annotator
146
    *   Name of the model class, for example "NerDLModel"
147
    * @param lang
148
    *   Language of the pretrained models to display, for example "en"
149
    * @param version
150
    *   Version of Spark NLP that the model should be compatible with, for example "3.2.3"
151
    */
152
  def showPublicModels(
153
      annotator: Option[String] = None,
154
      lang: Option[String] = None,
155
      version: Option[String] = Some(Build.version)): Unit = {
156
    println(
1✔
157
      publicResourceString(
1✔
158
        annotator = annotator,
159
        lang = lang,
160
        version = version,
161
        resourceType = ResourceType.MODEL))
1✔
162
  }
163

164
  /** Prints all pretrained models for a particular annotator model, that are compatible with this
165
    * version of Spark NLP.
166
    *
167
    * @param annotator
168
    *   Name of the annotator class
169
    */
170
  def showPublicModels(annotator: String): Unit = showPublicModels(Some(annotator))
1✔
171

172
  /** Prints all pretrained models for a particular annotator model, that are compatible with this
173
    * version of Spark NLP.
174
    *
175
    * @param annotator
176
    *   Name of the annotator class
177
    * @param lang
178
    *   Language of the pretrained models to display
179
    */
180
  def showPublicModels(annotator: String, lang: String): Unit =
181
    showPublicModels(Some(annotator), Some(lang))
1✔
182

183
  /** Prints all pretrained models for a particular annotator, that are compatible with a version
184
    * of Spark NLP.
185
    *
186
    * @param annotator
187
    *   Name of the model class, for example "NerDLModel"
188
    * @param lang
189
    *   Language of the pretrained models to display, for example "en"
190
    * @param version
191
    *   Version of Spark NLP that the model should be compatible with, for example "3.2.3"
192
    */
193
  def showPublicModels(annotator: String, lang: String, version: String): Unit =
194
    showPublicModels(Some(annotator), Some(lang), Some(version))
1✔
195

196
  /** List all pretrained pipelines in public */
197
  def listPublicPipelines(): List[String] = {
198
    listPretrainedResources(folder = publicLoc, ResourceType.PIPELINE)
×
199
  }
200

201
  /** Prints all Pipelines available for a language and a version of Spark NLP. By default shows
202
    * all languages and uses the current version of Spark NLP.
203
    *
204
    * @param lang
205
    *   Language of the Pipeline
206
    * @param version
207
    *   Version of Spark NLP
208
    */
209
  def showPublicPipelines(
210
      lang: Option[String] = None,
211
      version: Option[String] = Some(Build.version)): Unit = {
212
    println(
1✔
213
      publicResourceString(
1✔
214
        annotator = None,
1✔
215
        lang = lang,
216
        version = version,
217
        resourceType = ResourceType.PIPELINE))
1✔
218
  }
219

220
  /** Prints all Pipelines available for a language and this version of Spark NLP.
221
    *
222
    * @param lang
223
    *   Language of the Pipeline
224
    */
225
  def showPublicPipelines(lang: String): Unit = showPublicPipelines(Some(lang))
1✔
226

227
  /** Prints all Pipelines available for a language and a version of Spark NLP.
228
    *
229
    * @param lang
230
    *   Language of the Pipeline
231
    * @param version
232
    *   Version of Spark NLP
233
    */
234
  def showPublicPipelines(lang: String, version: String): Unit =
235
    showPublicPipelines(Some(lang), Some(version))
1✔
236

237
  /** Returns models or pipelines in metadata json which has not been categorized yet.
238
    *
239
    * @return
240
    *   list of models or pipelines which are not categorized in metadata json
241
    */
242
  def listUnCategorizedResources(): List[String] = {
243
    listPretrainedResources(folder = publicLoc, ResourceType.NOT_DEFINED)
×
244
  }
245

246
  def showUnCategorizedResources(lang: String): Unit = {
247
    println(publicResourceString(None, Some(lang), None, resourceType = ResourceType.NOT_DEFINED))
1✔
248
  }
249

250
  def showUnCategorizedResources(lang: String, version: String): Unit = {
251
    println(
×
252
      publicResourceString(
×
253
        None,
×
254
        Some(lang),
×
255
        Some(version),
×
256
        resourceType = ResourceType.NOT_DEFINED))
×
257

258
  }
259

260
  def showString(list: List[String], resourceType: ResourceType): String = {
261
    val sb = new StringBuilder
1✔
262
    var max_length = 14
1✔
263
    var max_length_version = 7
1✔
264
    for (data <- list) {
1✔
265
      val temp = data.split(":")
1✔
266
      max_length = scala.math.max(temp(0).length, max_length)
1✔
267
      max_length_version = scala.math.max(temp(2).length, max_length_version)
1✔
268
    }
269
    // adding head
270
    sb.append("+")
1✔
271
    sb.append("-" * (max_length + 2))
1✔
272
    sb.append("+")
1✔
273
    sb.append("-" * 6)
1✔
274
    sb.append("+")
1✔
275
    sb.append("-" * (max_length_version + 2))
1✔
276
    sb.append("+\n")
1✔
277
    if (resourceType.equals(ResourceType.PIPELINE))
1✔
278
      sb.append(
1✔
279
        "| " + "Pipeline" + (" " * (max_length - 8)) + " | " + "lang" + " | " + "version" + " " * (max_length_version - 7) + " |\n")
1✔
280
    else if (resourceType.equals(ResourceType.MODEL))
1✔
281
      sb.append(
1✔
282
        "| " + "Model" + (" " * (max_length - 5)) + " | " + "lang" + " | " + "version" + " " * (max_length_version - 7) + " |\n")
1✔
283
    else
284
      sb.append(
1✔
285
        "| " + "Pipeline/Model" + (" " * (max_length - 14)) + " | " + "lang" + " | " + "version" + " " * (max_length_version - 7) + " |\n")
1✔
286

287
    sb.append("+")
1✔
288
    sb.append("-" * (max_length + 2))
1✔
289
    sb.append("+")
1✔
290
    sb.append("-" * 6)
1✔
291
    sb.append("+")
1✔
292
    sb.append("-" * (max_length_version + 2))
1✔
293
    sb.append("+\n")
1✔
294
    for (data <- list) {
1✔
295
      val temp = data.split(":")
1✔
296
      sb.append(
1✔
297
        "| " + temp(0) + (" " * (max_length - temp(0).length)) + " |  " + temp(1) + "  | " + temp(
298
          2) + " " * (max_length_version - temp(2).length) + " |\n")
1✔
299
    }
300
    // adding bottom
301
    sb.append("+")
1✔
302
    sb.append("-" * (max_length + 2))
1✔
303
    sb.append("+")
1✔
304
    sb.append("-" * 6)
1✔
305
    sb.append("+")
1✔
306
    sb.append("-" * (max_length_version + 2))
1✔
307
    sb.append("+\n")
1✔
308
    sb.toString()
1✔
309

310
  }
311

312
  def publicResourceString(
313
      annotator: Option[String] = None,
314
      lang: Option[String] = None,
315
      version: Option[String] = Some(Build.version),
316
      resourceType: ResourceType): String = {
317
    showString(
1✔
318
      listPretrainedResources(
1✔
319
        folder = publicLoc,
1✔
320
        resourceType,
321
        annotator = annotator,
322
        lang = lang,
323
        version = version match {
324
          case Some(ver) => Some(Version.parse(ver))
1✔
325
          case None => None
1✔
326
        }),
327
      resourceType)
328
  }
329

330
  /** Lists pretrained resource from metadata.json, depending on the set filters. The folder in
331
    * the S3 location and the resourceType is necessary. The other filters are optional and will
332
    * be ignored if not set.
333
    *
334
    * @param folder
335
    *   Folder in the S3 location
336
    * @param resourceType
337
    *   Type of the Resource. Can Either `ResourceType.MODEL`, `ResourceType.PIPELINE` or
338
    *   `ResourceType.NOT_DEFINED`
339
    * @param annotator
340
    *   Name of the model class
341
    * @param lang
342
    *   Language of the model
343
    * @param version
344
    *   Version that the model should be compatible with
345
    * @return
346
    *   A list of the available resources
347
    */
348
  def listPretrainedResources(
349
      folder: String,
350
      resourceType: ResourceType,
351
      annotator: Option[String] = None,
352
      lang: Option[String] = None,
353
      version: Option[Version] = None): List[String] = {
354
    val resourceList = new ListBuffer[String]()
1✔
355

356
    val resourceMetaData = getResourceMetadata(folder)
1✔
357

358
    for (meta <- resourceMetaData) {
1✔
359
      val isSameResourceType =
360
        meta.category.getOrElse(ResourceType.NOT_DEFINED).toString.equals(resourceType.toString)
1✔
361
      val isCompatibleWithVersion = version match {
362
        case Some(ver) => Version.isCompatible(ver, meta.libVersion)
1✔
363
        case None => true
1✔
364
      }
365
      val isSameAnnotator = annotator match {
366
        case Some(cls) => meta.annotator.getOrElse("").equalsIgnoreCase(cls)
1✔
367
        case None => true
1✔
368
      }
369
      val isSameLanguage = lang match {
370
        case Some(l) => meta.language.getOrElse("").equalsIgnoreCase(l)
1✔
371
        case None => true
1✔
372
      }
373

374
      if (isSameResourceType & isCompatibleWithVersion & isSameAnnotator & isSameLanguage) {
1✔
375
        resourceList += meta.name + ":" + meta.language.getOrElse("-") + ":" + meta.libVersion
1✔
376
          .getOrElse("-")
1✔
377
      }
378
    }
379
    resourceList.result()
1✔
380
  }
381

382
  def listPretrainedResources(
383
      folder: String,
384
      resourceType: ResourceType,
385
      lang: String): List[String] =
386
    listPretrainedResources(folder, resourceType, lang = Some(lang))
×
387

388
  def listPretrainedResources(
389
      folder: String,
390
      resourceType: ResourceType,
391
      version: Version): List[String] =
392
    listPretrainedResources(folder, resourceType, version = Some(version))
×
393

394
  def listPretrainedResources(
395
      folder: String,
396
      resourceType: ResourceType,
397
      lang: String,
398
      version: Version): List[String] =
399
    listPretrainedResources(folder, resourceType, lang = Some(lang), version = Some(version))
×
400

401
  def listAvailableAnnotators(folder: String = publicLoc): List[String] = {
402

403
    val resourceMetaData = getResourceMetadata(folder)
1✔
404

405
    resourceMetaData
406
      .map(_.annotator.getOrElse(""))
1✔
407
      .toSet
408
      .filter { a =>
409
        !a.equals("")
1✔
410
      }
411
      .toList
412
      .sorted
1✔
413
  }
414

415
  private def getResourceMetadata(location: String): List[ResourceMetadata] = {
416
    getResourceDownloader(location).downloadMetadataIfNeed(location)
1✔
417
  }
418

419
  def showAvailableAnnotators(folder: String = publicLoc): Unit = {
420
    println(listAvailableAnnotators(folder).mkString("\n"))
1✔
421
  }
422

423
  /** Loads resource to path
424
    *
425
    * @param name
426
    *   Name of Resource
427
    * @param folder
428
    *   Subfolder in s3 where to search model (e.g. medicine)
429
    * @param language
430
    *   Desired language of Resource
431
    * @return
432
    *   path of downloaded resource
433
    */
434
  def downloadResource(
435
      name: String,
436
      language: Option[String] = None,
437
      folder: String = publicLoc): String = {
438
    downloadResource(ResourceRequest(name, language, folder))
×
439
  }
440

441
  /** Loads resource to path
442
    *
443
    * @param request
444
    *   Request for resource
445
    * @return
446
    *   path of downloaded resource
447
    */
448
  def downloadResource(request: ResourceRequest): String = {
449
    val future = Future {
1✔
450
      val updatedRequest: ResourceRequest = if (request.folder.startsWith("@")) {
1✔
451
        request.copy(folder = request.folder.replace("@", ""))
1✔
452
      } else request
1✔
453
      getResourceDownloader(request.folder).download(updatedRequest)
1✔
454
    }
455

456
    var downloadFinished = false
1✔
457
    var path: Option[String] = None
1✔
458
    val fileSize = getDownloadSize(request)
1✔
459
    require(
1✔
460
      !fileSize.equals("-1"),
1✔
461
      s"Can not find ${request.name} inside ${request.folder} to download. Please make sure the name and location are correct!")
×
462
    println(request.name + " download started this may take some time.")
1✔
463
    println("Approximate size to download " + fileSize)
1✔
464

465
    while (!downloadFinished) {
1✔
466
      future.onComplete {
1✔
467
        case Success(value) =>
468
          downloadFinished = true
1✔
469
          path = value
470
        case Failure(exception) =>
471
          println(s"Error: ${exception.getMessage}")
×
472
          logger.error(exception.getMessage)
×
473
          downloadFinished = true
×
474
          path = None
×
475
      }
476
      Thread.sleep(1000)
1✔
477

478
    }
479

480
    require(
1✔
481
      path.isDefined,
1✔
482
      s"Was not found appropriate resource to download for request: $request with downloader: $privateDownloader")
×
483
    println("Download done! Loading the resource.")
1✔
484
    path.get
1✔
485
  }
486

487
  /** Downloads a model from the default S3 bucket to the cache pretrained folder.
488
    * @param model
489
    *   the name of the key in the S3 bucket or s3 URI
490
    * @param folder
491
    *   the folder of the model
492
    * @param unzip
493
    *   used to unzip the model, by default true
494
    */
495
  def downloadModelDirectly(
496
      model: String,
497
      folder: String = publicLoc,
498
      unzip: Boolean = true): Unit = {
499
    getResourceDownloader(folder).downloadAndUnzipFile(model, unzip)
×
500
  }
501

502
  def downloadModel[TModel <: PipelineStage](
503
      reader: DefaultParamsReadable[TModel],
504
      name: String,
505
      language: Option[String] = None,
506
      folder: String = publicLoc): TModel = {
507
    downloadModel(reader, ResourceRequest(name, language, folder))
1✔
508
  }
509

510
  def downloadModel[TModel <: PipelineStage](
511
      reader: DefaultParamsReadable[TModel],
512
      request: ResourceRequest): TModel = {
513
    if (!cache.contains(request)) {
1✔
514
      val path = downloadResource(request)
1✔
515
      val model = reader.read.load(path)
1✔
516
      cache(request) = model
1✔
517
      model
518
    } else {
519
      cache(request).asInstanceOf[TModel]
1✔
520
    }
521
  }
522

523
  def downloadPipeline(
524
      name: String,
525
      language: Option[String] = None,
526
      folder: String = publicLoc): PipelineModel = {
527
    downloadPipeline(ResourceRequest(name, language, folder))
×
528
  }
529

530
  def downloadPipeline(request: ResourceRequest): PipelineModel = {
531
    if (!cache.contains(request)) {
×
532
      val path = downloadResource(request)
×
533
      val model = PipelineModel.read.load(path)
×
534
      cache(request) = model
×
535
      model
536
    } else {
537
      cache(request).asInstanceOf[PipelineModel]
×
538
    }
539
  }
540

541
  def clearCache(
542
      name: String,
543
      language: Option[String] = None,
544
      folder: String = publicLoc): Unit = {
545
    clearCache(ResourceRequest(name, language, folder))
×
546
  }
547

548
  def clearCache(request: ResourceRequest): Unit = {
549
    privateDownloader.clearCache(request)
×
550
    publicDownloader.clearCache(request)
×
551
    communityDownloader.clearCache(request)
×
552
    cache.remove(request)
×
553
  }
554

555
  def getDownloadSize(resourceRequest: ResourceRequest): String = {
556

557
    val updatedResourceRequest: ResourceRequest = if (resourceRequest.folder.startsWith("@")) {
1✔
558
      resourceRequest.copy(folder = resourceRequest.folder.replace("@", ""))
1✔
559
    } else resourceRequest
1✔
560

561
    val size = getResourceDownloader(resourceRequest.folder)
562
      .getDownloadSize(updatedResourceRequest)
1✔
563

564
    size match {
565
      case Some(downloadBytes) => FileHelper.getHumanReadableFileSize(downloadBytes)
1✔
566
      case None => "-1"
×
567

568
    }
569
  }
570

571
}
572

573
object ResourceType extends Enumeration {
574
  type ResourceType = Value
575
  val MODEL: pretrained.ResourceType.Value = Value("ml")
1✔
576
  val PIPELINE: pretrained.ResourceType.Value = Value("pl")
1✔
577
  val NOT_DEFINED: pretrained.ResourceType.Value = Value("nd")
1✔
578
}
579

580
case class ResourceRequest(
581
    name: String,
582
    language: Option[String] = None,
583
    folder: String = ResourceDownloader.publicLoc,
584
    libVersion: Version = ResourceDownloader.libVersion,
585
    sparkVersion: Version = ResourceDownloader.sparkVersion)
586

587
/* convenience accessor for Py4J calls */
588
object PythonResourceDownloader {
589

590
  val keyToReader: mutable.Map[String, DefaultParamsReadable[_]] = mutable.Map(
×
591
    "DocumentAssembler" -> DocumentAssembler,
×
592
    "SentenceDetector" -> SentenceDetector,
×
593
    "TokenizerModel" -> TokenizerModel,
×
594
    "PerceptronModel" -> PerceptronModel,
×
595
    "NerCrfModel" -> NerCrfModel,
×
596
    "Stemmer" -> Stemmer,
×
597
    "NormalizerModel" -> NormalizerModel,
×
598
    "RegexMatcherModel" -> RegexMatcherModel,
×
599
    "LemmatizerModel" -> LemmatizerModel,
×
600
    "DateMatcher" -> DateMatcher,
×
601
    "TextMatcherModel" -> TextMatcherModel,
×
602
    "SentimentDetectorModel" -> SentimentDetectorModel,
×
603
    "ViveknSentimentModel" -> ViveknSentimentModel,
×
604
    "NorvigSweetingModel" -> NorvigSweetingModel,
×
605
    "SymmetricDeleteModel" -> SymmetricDeleteModel,
×
606
    "NerDLModel" -> NerDLModel,
×
607
    "WordEmbeddingsModel" -> WordEmbeddingsModel,
×
608
    "BertEmbeddings" -> BertEmbeddings,
×
609
    "DependencyParserModel" -> DependencyParserModel,
×
610
    "TypedDependencyParserModel" -> TypedDependencyParserModel,
×
611
    "UniversalSentenceEncoder" -> UniversalSentenceEncoder,
×
612
    "ElmoEmbeddings" -> ElmoEmbeddings,
×
613
    "ClassifierDLModel" -> ClassifierDLModel,
×
614
    "ContextSpellCheckerModel" -> ContextSpellCheckerModel,
×
615
    "AlbertEmbeddings" -> AlbertEmbeddings,
×
616
    "XlnetEmbeddings" -> XlnetEmbeddings,
×
617
    "SentimentDLModel" -> SentimentDLModel,
×
618
    "LanguageDetectorDL" -> LanguageDetectorDL,
×
619
    "StopWordsCleaner" -> StopWordsCleaner,
×
620
    "BertSentenceEmbeddings" -> BertSentenceEmbeddings,
×
621
    "MultiClassifierDLModel" -> MultiClassifierDLModel,
×
622
    "SentenceDetectorDLModel" -> SentenceDetectorDLModel,
×
623
    "T5Transformer" -> T5Transformer,
×
624
    "MarianTransformer" -> MarianTransformer,
×
625
    "WordSegmenterModel" -> WordSegmenterModel,
×
626
    "DistilBertEmbeddings" -> DistilBertEmbeddings,
×
627
    "RoBertaEmbeddings" -> RoBertaEmbeddings,
×
628
    "XlmRoBertaEmbeddings" -> XlmRoBertaEmbeddings,
×
629
    "LongformerEmbeddings" -> LongformerEmbeddings,
×
630
    "RoBertaSentenceEmbeddings" -> RoBertaSentenceEmbeddings,
×
631
    "XlmRoBertaSentenceEmbeddings" -> XlmRoBertaSentenceEmbeddings,
×
632
    "AlbertForTokenClassification" -> AlbertForTokenClassification,
×
633
    "BertForTokenClassification" -> BertForTokenClassification,
×
634
    "DeBertaForTokenClassification" -> DeBertaForTokenClassification,
×
635
    "DistilBertForTokenClassification" -> DistilBertForTokenClassification,
×
636
    "LongformerForTokenClassification" -> LongformerForTokenClassification,
×
637
    "RoBertaForTokenClassification" -> RoBertaForTokenClassification,
×
638
    "XlmRoBertaForTokenClassification" -> XlmRoBertaForTokenClassification,
×
639
    "XlnetForTokenClassification" -> XlnetForTokenClassification,
×
640
    "AlbertForSequenceClassification" -> AlbertForSequenceClassification,
×
641
    "BertForSequenceClassification" -> BertForSequenceClassification,
×
642
    "DeBertaForSequenceClassification" -> DeBertaForSequenceClassification,
×
643
    "DistilBertForSequenceClassification" -> DistilBertForSequenceClassification,
×
644
    "LongformerForSequenceClassification" -> LongformerForSequenceClassification,
×
645
    "RoBertaForSequenceClassification" -> RoBertaForSequenceClassification,
×
646
    "XlmRoBertaForSequenceClassification" -> XlmRoBertaForSequenceClassification,
×
647
    "XlnetForSequenceClassification" -> XlnetForSequenceClassification,
×
648
    "GPT2Transformer" -> GPT2Transformer,
×
649
    "EntityRulerModel" -> EntityRulerModel,
×
650
    "Doc2VecModel" -> Doc2VecModel,
×
651
    "Word2VecModel" -> Word2VecModel,
×
652
    "DeBertaEmbeddings" -> DeBertaEmbeddings,
×
653
    "DeBertaForSequenceClassification" -> DeBertaForSequenceClassification,
×
654
    "DeBertaForTokenClassification" -> DeBertaForTokenClassification,
×
655
    "CamemBertEmbeddings" -> CamemBertEmbeddings,
×
656
    "AlbertForQuestionAnswering" -> AlbertForQuestionAnswering,
×
657
    "BertForQuestionAnswering" -> BertForQuestionAnswering,
×
658
    "DeBertaForQuestionAnswering" -> DeBertaForQuestionAnswering,
×
659
    "DistilBertForQuestionAnswering" -> DistilBertForQuestionAnswering,
×
660
    "LongformerForQuestionAnswering" -> LongformerForQuestionAnswering,
×
661
    "RoBertaForQuestionAnswering" -> RoBertaForQuestionAnswering,
×
662
    "XlmRoBertaForQuestionAnswering" -> XlmRoBertaForQuestionAnswering,
×
663
    "SpanBertCorefModel" -> SpanBertCorefModel,
×
664
    "ViTForImageClassification" -> ViTForImageClassification,
×
665
    "VisionEncoderDecoderForImageCaptioning" -> VisionEncoderDecoderForImageCaptioning,
×
666
    "SwinForImageClassification" -> SwinForImageClassification,
×
667
    "ConvNextForImageClassification" -> ConvNextForImageClassification,
×
668
    "Wav2Vec2ForCTC" -> Wav2Vec2ForCTC,
×
669
    "HubertForCTC" -> HubertForCTC,
×
670
    "WhisperForCTC" -> WhisperForCTC,
×
671
    "CamemBertForTokenClassification" -> CamemBertForTokenClassification,
×
672
    "TableAssembler" -> TableAssembler,
×
673
    "TapasForQuestionAnswering" -> TapasForQuestionAnswering,
×
674
    "CamemBertForSequenceClassification" -> CamemBertForSequenceClassification,
×
675
    "CamemBertForQuestionAnswering" -> CamemBertForQuestionAnswering,
×
676
    "ZeroShotNerModel" -> ZeroShotNerModel,
×
677
    "BartTransformer" -> BartTransformer,
×
678
    "BertForZeroShotClassification" -> BertForZeroShotClassification,
×
679
    "DistilBertForZeroShotClassification" -> DistilBertForZeroShotClassification,
×
680
    "RoBertaForZeroShotClassification" -> RoBertaForZeroShotClassification,
×
681
    "XlmRoBertaForZeroShotClassification" -> XlmRoBertaForZeroShotClassification,
×
682
    "BartForZeroShotClassification" -> BartForZeroShotClassification,
×
683
    "InstructorEmbeddings" -> InstructorEmbeddings,
×
684
    "E5Embeddings" -> E5Embeddings,
×
685
    "MPNetEmbeddings" -> MPNetEmbeddings,
×
686
    "CLIPForZeroShotClassification" -> CLIPForZeroShotClassification,
×
687
    "DeBertaForZeroShotClassification" -> DeBertaForZeroShotClassification,
×
688
    "BGEEmbeddings" -> BGEEmbeddings,
×
689
    "MPNetForSequenceClassification" -> MPNetForSequenceClassification,
×
690
    "MPNetForQuestionAnswering" -> MPNetForQuestionAnswering,
×
691
    "LLAMA2Transformer" -> LLAMA2Transformer,
×
692
    "M2M100Transformer" -> M2M100Transformer)
×
693

694
  // List pairs of types such as the one with key type can load a pretrained model from the value type
695
  val typeMapper: Map[String, String] = Map("ZeroShotNerModel" -> "RoBertaForQuestionAnswering")
×
696

697
  def downloadModel(
698
      readerStr: String,
699
      name: String,
700
      language: String = null,
701
      remoteLoc: String = null): PipelineStage = {
702

703
    val reader = keyToReader.getOrElse(
×
704
      if (typeMapper.contains(readerStr)) typeMapper(readerStr) else readerStr,
×
705
      throw new RuntimeException(s"Unsupported Model: $readerStr"))
×
706

707
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
708

709
    val model = ResourceDownloader.downloadModel(
×
710
      reader.asInstanceOf[DefaultParamsReadable[PipelineStage]],
×
711
      name,
712
      Option(language),
×
713
      correctedFolder)
714

715
    // Cast the model to the required type. This has to be done for each entry in the typeMapper map
716
    if (typeMapper.contains(readerStr) && readerStr == "ZeroShotNerModel")
×
717
      ZeroShotNerModel(model)
×
718
    else
719
      model
×
720
  }
721

722
  def downloadPipeline(
723
      name: String,
724
      language: String = null,
725
      remoteLoc: String = null): PipelineModel = {
726
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
727
    ResourceDownloader.downloadPipeline(name, Option(language), correctedFolder)
×
728
  }
729

730
  def clearCache(name: String, language: String = null, remoteLoc: String = null): Unit = {
731
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
732
    ResourceDownloader.clearCache(name, Option(language), correctedFolder)
×
733
  }
734

735
  def downloadModelDirectly(
736
      model: String,
737
      remoteLoc: String = null,
738
      unzip: Boolean = true): Unit = {
739
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
740
    ResourceDownloader.downloadModelDirectly(model, correctedFolder, unzip)
×
741
  }
742

743
  def showUnCategorizedResources(): String = {
744
    ResourceDownloader.publicResourceString(
×
745
      annotator = None,
×
746
      lang = None,
×
747
      version = None,
×
748
      resourceType = ResourceType.NOT_DEFINED)
×
749
  }
750

751
  def showPublicPipelines(lang: String, version: String): String = {
752
    val ver: Option[String] = version match {
753
      case null => Some(Build.version)
×
754
      case _ => Some(version)
×
755
    }
756
    ResourceDownloader.publicResourceString(
×
757
      annotator = None,
×
758
      lang = Option(lang),
×
759
      version = ver,
760
      resourceType = ResourceType.PIPELINE)
×
761
  }
762

763
  def showPublicModels(annotator: String, lang: String, version: String): String = {
764
    val ver: Option[String] = version match {
765
      case null => Some(Build.version)
×
766
      case _ => Some(version)
×
767
    }
768
    ResourceDownloader.publicResourceString(
×
769
      annotator = Option(annotator),
×
770
      lang = Option(lang),
×
771
      version = ver,
772
      resourceType = ResourceType.MODEL)
×
773
  }
774

775
  def showAvailableAnnotators(): String = {
776
    ResourceDownloader.listAvailableAnnotators().mkString("\n")
×
777
  }
778

779
  def getDownloadSize(name: String, language: String = "en", remoteLoc: String = null): String = {
780
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
781
    ResourceDownloader.getDownloadSize(ResourceRequest(name, Option(language), correctedFolder))
×
782
  }
783
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc