• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4947838414

pending completion
4947838414

Pull #13796

github

GitHub
Merge 30bdeef19 into ef7906c5e
Pull Request #13796: Add unzip param to downloadModelDirectly in ResourceDownloader

39 of 39 new or added lines in 2 files covered. (100.0%)

8632 of 13111 relevant lines covered (65.84%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

42.33
/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala
1
/*
2
 * Copyright 2017-2023 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.pretrained
18

19
import com.johnsnowlabs.client.aws.AWSGateway
20
import com.johnsnowlabs.nlp.annotators._
21
import com.johnsnowlabs.nlp.annotators.audio.{HubertForCTC, Wav2Vec2ForCTC}
22
import com.johnsnowlabs.nlp.annotators.classifier.dl._
23
import com.johnsnowlabs.nlp.annotators.coref.SpanBertCorefModel
24
import com.johnsnowlabs.nlp.annotators.cv.{
25
  ConvNextForImageClassification,
26
  SwinForImageClassification,
27
  ViTForImageClassification
28
}
29
import com.johnsnowlabs.nlp.annotators.er.EntityRulerModel
30
import com.johnsnowlabs.nlp.annotators.ld.dl.LanguageDetectorDL
31
import com.johnsnowlabs.nlp.annotators.ner.crf.NerCrfModel
32
import com.johnsnowlabs.nlp.annotators.ner.dl.{NerDLModel, ZeroShotNerModel}
33
import com.johnsnowlabs.nlp.annotators.parser.dep.DependencyParserModel
34
import com.johnsnowlabs.nlp.annotators.parser.typdep.TypedDependencyParserModel
35
import com.johnsnowlabs.nlp.annotators.pos.perceptron.PerceptronModel
36
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
37
import com.johnsnowlabs.nlp.annotators.sda.pragmatic.SentimentDetectorModel
38
import com.johnsnowlabs.nlp.annotators.sda.vivekn.ViveknSentimentModel
39
import com.johnsnowlabs.nlp.annotators.sentence_detector_dl.SentenceDetectorDLModel
40
import com.johnsnowlabs.nlp.annotators.seq2seq.{
41
  BartTransformer,
42
  GPT2Transformer,
43
  MarianTransformer,
44
  T5Transformer
45
}
46
import com.johnsnowlabs.nlp.annotators.spell.context.ContextSpellCheckerModel
47
import com.johnsnowlabs.nlp.annotators.spell.norvig.NorvigSweetingModel
48
import com.johnsnowlabs.nlp.annotators.spell.symmetric.SymmetricDeleteModel
49
import com.johnsnowlabs.nlp.annotators.ws.WordSegmenterModel
50
import com.johnsnowlabs.nlp.embeddings._
51
import com.johnsnowlabs.nlp.pretrained.ResourceType.ResourceType
52
import com.johnsnowlabs.nlp.util.io.{OutputHelper, ResourceHelper}
53
import com.johnsnowlabs.nlp.{DocumentAssembler, TableAssembler, pretrained}
54
import com.johnsnowlabs.util._
55
import org.apache.hadoop.fs.FileSystem
56
import org.apache.spark.SparkFiles
57
import org.apache.spark.ml.util.DefaultParamsReadable
58
import org.apache.spark.ml.{PipelineModel, PipelineStage}
59
import org.slf4j.{Logger, LoggerFactory}
60

61
import java.io.File
62
import java.net.URI
63
import java.nio.file.Paths
64
import scala.collection.mutable
65
import scala.collection.mutable.ListBuffer
66
import scala.concurrent.ExecutionContext.Implicits.global
67
import scala.concurrent.Future
68
import scala.util.{Failure, Success}
69

70
trait ResourceDownloader {
71

72
  /** Download resource to local file
73
    *
74
    * @param request
75
    *   Resource request
76
    * @return
77
    *   downloaded file or None if resource is not found
78
    */
79
  def download(request: ResourceRequest): Option[String]
80

81
  def getDownloadSize(request: ResourceRequest): Option[Long]
82

83
  def clearCache(request: ResourceRequest): Unit
84

85
  def downloadMetadataIfNeed(folder: String): List[ResourceMetadata]
86

87
  def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean = true): Option[String]
88

89
  val fileSystem: FileSystem = ResourceDownloader.fileSystem
1✔
90

91
}
92

93
object ResourceDownloader {
94

95
  private val logger: Logger = LoggerFactory.getLogger(this.getClass.toString)
1✔
96

97
  val fileSystem: FileSystem = OutputHelper.getFileSystem
1✔
98

99
  def s3Bucket: String = ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedS3BucketKey)
1✔
100

101
  def s3BucketCommunity: String =
102
    ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedCommunityS3BucketKey)
1✔
103

104
  def s3Path: String = ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedS3PathKey)
1✔
105

106
  def cacheFolder: String = ConfigLoader.getConfigStringValue(ConfigHelper.pretrainedCacheFolder)
1✔
107

108
  val publicLoc = "public/models"
1✔
109

110
  private val cache: mutable.Map[ResourceRequest, PipelineStage] =
111
    mutable.Map[ResourceRequest, PipelineStage]()
1✔
112

113
  lazy val sparkVersion: Version = {
114
    val spark_version = ResourceHelper.spark.version
115
    Version.parse(spark_version)
116
  }
117

118
  lazy val libVersion: Version = {
119
    Version.parse(Build.version)
120
  }
121

122
  var privateDownloader: ResourceDownloader =
123
    new S3ResourceDownloader(s3Bucket, s3Path, cacheFolder, "private")
1✔
124
  var publicDownloader: ResourceDownloader =
125
    new S3ResourceDownloader(s3Bucket, s3Path, cacheFolder, "public")
1✔
126
  var communityDownloader: ResourceDownloader =
127
    new S3ResourceDownloader(s3BucketCommunity, s3Path, cacheFolder, "community")
1✔
128

129
  def getResourceDownloader(folder: String): ResourceDownloader = {
130
    folder match {
131
      case this.publicLoc => publicDownloader
1✔
132
      case loc if loc.startsWith("@") => communityDownloader
1✔
133
      case _ => privateDownloader
×
134
    }
135
  }
136

137
  /** Reset the cache and recreate ResourceDownloader S3 credentials */
138
  def resetResourceDownloader(): Unit = {
139
    cache.empty
×
140
    this.privateDownloader = new S3ResourceDownloader(s3Bucket, s3Path, cacheFolder, "private")
×
141
  }
142

143
  /** List all pretrained models in public name_lang */
144
  def listPublicModels(): List[String] = {
145
    listPretrainedResources(folder = publicLoc, ResourceType.MODEL)
×
146
  }
147

148
  /** Prints all pretrained models for a particular annotator model, that are compatible with a
149
    * version of Spark NLP. If any of the optional arguments are not set, the filter is not
150
    * considered.
151
    *
152
    * @param annotator
153
    *   Name of the model class, for example "NerDLModel"
154
    * @param lang
155
    *   Language of the pretrained models to display, for example "en"
156
    * @param version
157
    *   Version of Spark NLP that the model should be compatible with, for example "3.2.3"
158
    */
159
  def showPublicModels(
160
      annotator: Option[String] = None,
161
      lang: Option[String] = None,
162
      version: Option[String] = Some(Build.version)): Unit = {
163
    println(
1✔
164
      publicResourceString(
1✔
165
        annotator = annotator,
166
        lang = lang,
167
        version = version,
168
        resourceType = ResourceType.MODEL))
1✔
169
  }
170

171
  /** Prints all pretrained models for a particular annotator model, that are compatible with this
172
    * version of Spark NLP.
173
    *
174
    * @param annotator
175
    *   Name of the annotator class
176
    */
177
  def showPublicModels(annotator: String): Unit = showPublicModels(Some(annotator))
1✔
178

179
  /** Prints all pretrained models for a particular annotator model, that are compatible with this
180
    * version of Spark NLP.
181
    *
182
    * @param annotator
183
    *   Name of the annotator class
184
    * @param lang
185
    *   Language of the pretrained models to display
186
    */
187
  def showPublicModels(annotator: String, lang: String): Unit =
188
    showPublicModels(Some(annotator), Some(lang))
1✔
189

190
  /** Prints all pretrained models for a particular annotator, that are compatible with a version
191
    * of Spark NLP.
192
    *
193
    * @param annotator
194
    *   Name of the model class, for example "NerDLModel"
195
    * @param lang
196
    *   Language of the pretrained models to display, for example "en"
197
    * @param version
198
    *   Version of Spark NLP that the model should be compatible with, for example "3.2.3"
199
    */
200
  def showPublicModels(annotator: String, lang: String, version: String): Unit =
201
    showPublicModels(Some(annotator), Some(lang), Some(version))
1✔
202

203
  /** List all pretrained pipelines in public */
204
  def listPublicPipelines(): List[String] = {
205
    listPretrainedResources(folder = publicLoc, ResourceType.PIPELINE)
×
206
  }
207

208
  /** Prints all Pipelines available for a language and a version of Spark NLP. By default shows
209
    * all languages and uses the current version of Spark NLP.
210
    *
211
    * @param lang
212
    *   Language of the Pipeline
213
    * @param version
214
    *   Version of Spark NLP
215
    */
216
  def showPublicPipelines(
217
      lang: Option[String] = None,
218
      version: Option[String] = Some(Build.version)): Unit = {
219
    println(
1✔
220
      publicResourceString(
1✔
221
        annotator = None,
1✔
222
        lang = lang,
223
        version = version,
224
        resourceType = ResourceType.PIPELINE))
1✔
225
  }
226

227
  /** Prints all Pipelines available for a language and this version of Spark NLP.
228
    *
229
    * @param lang
230
    *   Language of the Pipeline
231
    */
232
  def showPublicPipelines(lang: String): Unit = showPublicPipelines(Some(lang))
1✔
233

234
  /** Prints all Pipelines available for a language and a version of Spark NLP.
235
    *
236
    * @param lang
237
    *   Language of the Pipeline
238
    * @param version
239
    *   Version of Spark NLP
240
    */
241
  def showPublicPipelines(lang: String, version: String): Unit =
242
    showPublicPipelines(Some(lang), Some(version))
1✔
243

244
  /** Returns models or pipelines in metadata json which has not been categorized yet.
245
    *
246
    * @return
247
    *   list of models or pipelines which are not categorized in metadata json
248
    */
249
  def listUnCategorizedResources(): List[String] = {
250
    listPretrainedResources(folder = publicLoc, ResourceType.NOT_DEFINED)
×
251
  }
252

253
  def showUnCategorizedResources(lang: String): Unit = {
254
    println(publicResourceString(None, Some(lang), None, resourceType = ResourceType.NOT_DEFINED))
1✔
255
  }
256

257
  def showUnCategorizedResources(lang: String, version: String): Unit = {
258
    println(
×
259
      publicResourceString(
×
260
        None,
×
261
        Some(lang),
×
262
        Some(version),
×
263
        resourceType = ResourceType.NOT_DEFINED))
×
264

265
  }
266

267
  def showString(list: List[String], resourceType: ResourceType): String = {
268
    val sb = new StringBuilder
1✔
269
    var max_length = 14
1✔
270
    var max_length_version = 7
1✔
271
    for (data <- list) {
1✔
272
      val temp = data.split(":")
1✔
273
      max_length = scala.math.max(temp(0).length, max_length)
1✔
274
      max_length_version = scala.math.max(temp(2).length, max_length_version)
1✔
275
    }
276
    // adding head
277
    sb.append("+")
1✔
278
    sb.append("-" * (max_length + 2))
1✔
279
    sb.append("+")
1✔
280
    sb.append("-" * 6)
1✔
281
    sb.append("+")
1✔
282
    sb.append("-" * (max_length_version + 2))
1✔
283
    sb.append("+\n")
1✔
284
    if (resourceType.equals(ResourceType.PIPELINE))
1✔
285
      sb.append(
1✔
286
        "| " + "Pipeline" + (" " * (max_length - 8)) + " | " + "lang" + " | " + "version" + " " * (max_length_version - 7) + " |\n")
1✔
287
    else if (resourceType.equals(ResourceType.MODEL))
1✔
288
      sb.append(
1✔
289
        "| " + "Model" + (" " * (max_length - 5)) + " | " + "lang" + " | " + "version" + " " * (max_length_version - 7) + " |\n")
1✔
290
    else
291
      sb.append(
1✔
292
        "| " + "Pipeline/Model" + (" " * (max_length - 14)) + " | " + "lang" + " | " + "version" + " " * (max_length_version - 7) + " |\n")
1✔
293

294
    sb.append("+")
1✔
295
    sb.append("-" * (max_length + 2))
1✔
296
    sb.append("+")
1✔
297
    sb.append("-" * 6)
1✔
298
    sb.append("+")
1✔
299
    sb.append("-" * (max_length_version + 2))
1✔
300
    sb.append("+\n")
1✔
301
    for (data <- list) {
1✔
302
      val temp = data.split(":")
1✔
303
      sb.append(
1✔
304
        "| " + temp(0) + (" " * (max_length - temp(0).length)) + " |  " + temp(1) + "  | " + temp(
305
          2) + " " * (max_length_version - temp(2).length) + " |\n")
1✔
306
    }
307
    // adding bottom
308
    sb.append("+")
1✔
309
    sb.append("-" * (max_length + 2))
1✔
310
    sb.append("+")
1✔
311
    sb.append("-" * 6)
1✔
312
    sb.append("+")
1✔
313
    sb.append("-" * (max_length_version + 2))
1✔
314
    sb.append("+\n")
1✔
315
    sb.toString()
1✔
316

317
  }
318

319
  def publicResourceString(
320
      annotator: Option[String] = None,
321
      lang: Option[String] = None,
322
      version: Option[String] = Some(Build.version),
323
      resourceType: ResourceType): String = {
324
    showString(
1✔
325
      listPretrainedResources(
1✔
326
        folder = publicLoc,
1✔
327
        resourceType,
328
        annotator = annotator,
329
        lang = lang,
330
        version = version match {
331
          case Some(ver) => Some(Version.parse(ver))
1✔
332
          case None => None
1✔
333
        }),
334
      resourceType)
335
  }
336

337
  /** Lists pretrained resource from metadata.json, depending on the set filters. The folder in
338
    * the S3 location and the resourceType is necessary. The other filters are optional and will
339
    * be ignored if not set.
340
    *
341
    * @param folder
342
    *   Folder in the S3 location
343
    * @param resourceType
344
    *   Type of the Resource. Can Either `ResourceType.MODEL`, `ResourceType.PIPELINE` or
345
    *   `ResourceType.NOT_DEFINED`
346
    * @param annotator
347
    *   Name of the model class
348
    * @param lang
349
    *   Language of the model
350
    * @param version
351
    *   Version that the model should be compatible with
352
    * @return
353
    *   A list of the available resources
354
    */
355
  def listPretrainedResources(
356
      folder: String,
357
      resourceType: ResourceType,
358
      annotator: Option[String] = None,
359
      lang: Option[String] = None,
360
      version: Option[Version] = None): List[String] = {
361
    val resourceList = new ListBuffer[String]()
1✔
362

363
    val resourceMetaData = getResourceMetadata(folder)
1✔
364

365
    for (meta <- resourceMetaData) {
1✔
366
      val isSameResourceType =
367
        meta.category.getOrElse(ResourceType.NOT_DEFINED).toString.equals(resourceType.toString)
1✔
368
      val isCompatibleWithVersion = version match {
369
        case Some(ver) => Version.isCompatible(ver, meta.libVersion)
1✔
370
        case None => true
1✔
371
      }
372
      val isSameAnnotator = annotator match {
373
        case Some(cls) => meta.annotator.getOrElse("").equalsIgnoreCase(cls)
1✔
374
        case None => true
1✔
375
      }
376
      val isSameLanguage = lang match {
377
        case Some(l) => meta.language.getOrElse("").equalsIgnoreCase(l)
1✔
378
        case None => true
1✔
379
      }
380

381
      if (isSameResourceType & isCompatibleWithVersion & isSameAnnotator & isSameLanguage) {
1✔
382
        resourceList += meta.name + ":" + meta.language.getOrElse("-") + ":" + meta.libVersion
1✔
383
          .getOrElse("-")
1✔
384
      }
385
    }
386
    resourceList.result()
1✔
387
  }
388

389
  def listPretrainedResources(
390
      folder: String,
391
      resourceType: ResourceType,
392
      lang: String): List[String] =
393
    listPretrainedResources(folder, resourceType, lang = Some(lang))
×
394

395
  def listPretrainedResources(
396
      folder: String,
397
      resourceType: ResourceType,
398
      version: Version): List[String] =
399
    listPretrainedResources(folder, resourceType, version = Some(version))
×
400

401
  def listPretrainedResources(
402
      folder: String,
403
      resourceType: ResourceType,
404
      lang: String,
405
      version: Version): List[String] =
406
    listPretrainedResources(folder, resourceType, lang = Some(lang), version = Some(version))
×
407

408
  def listAvailableAnnotators(folder: String = publicLoc): List[String] = {
409

410
    val resourceMetaData = getResourceMetadata(folder)
1✔
411

412
    resourceMetaData
413
      .map(_.annotator.getOrElse(""))
1✔
414
      .toSet
415
      .filter { a =>
416
        !a.equals("")
1✔
417
      }
418
      .toList
419
      .sorted
1✔
420
  }
421

422
  private def getResourceMetadata(location: String): List[ResourceMetadata] = {
423
    getResourceDownloader(location).downloadMetadataIfNeed(location)
1✔
424
  }
425

426
  def showAvailableAnnotators(folder: String = publicLoc): Unit = {
427
    println(listAvailableAnnotators(folder).mkString("\n"))
1✔
428
  }
429

430
  /** Loads resource to path
431
    *
432
    * @param name
433
    *   Name of Resource
434
    * @param folder
435
    *   Subfolder in s3 where to search model (e.g. medicine)
436
    * @param language
437
    *   Desired language of Resource
438
    * @return
439
    *   path of downloaded resource
440
    */
441
  def downloadResource(
442
      name: String,
443
      language: Option[String] = None,
444
      folder: String = publicLoc): String = {
445
    downloadResource(ResourceRequest(name, language, folder))
×
446
  }
447

448
  /** Loads resource to path
449
    *
450
    * @param request
451
    *   Request for resource
452
    * @return
453
    *   path of downloaded resource
454
    */
455
  def downloadResource(request: ResourceRequest): String = {
456
    val future = Future {
1✔
457
      val updatedRequest: ResourceRequest = if (request.folder.startsWith("@")) {
1✔
458
        request.copy(folder = request.folder.replace("@", ""))
1✔
459
      } else request
1✔
460
      getResourceDownloader(request.folder).download(updatedRequest)
1✔
461
    }
462

463
    var downloadFinished = false
1✔
464
    var path: Option[String] = None
1✔
465
    val fileSize = getDownloadSize(request)
1✔
466
    require(
1✔
467
      !fileSize.equals("-1"),
1✔
468
      s"Can not find ${request.name} inside ${request.folder} to download. Please make sure the name and location are correct!")
×
469
    println(request.name + " download started this may take some time.")
1✔
470
    println("Approximate size to download " + fileSize)
1✔
471

472
    while (!downloadFinished) {
1✔
473
      future.onComplete {
1✔
474
        case Success(value) =>
475
          downloadFinished = true
1✔
476
          path = value
477
        case Failure(exception) =>
478
          println(s"Error: ${exception.getMessage}")
×
479
          logger.error(exception.getMessage)
×
480
          downloadFinished = true
×
481
          path = None
×
482
      }
483
      Thread.sleep(1000)
1✔
484

485
    }
486

487
    require(
1✔
488
      path.isDefined,
1✔
489
      s"Was not found appropriate resource to download for request: $request with downloader: $privateDownloader")
×
490
    println("Download done! Loading the resource.")
1✔
491
    path.get
1✔
492
  }
493

494
  /** Downloads a model from the default S3 bucket to the cache pretrained folder.
495
    * @param model
496
    *   the name of the key in the S3 bucket or s3 URI
497
    * @param folder
498
    *   the folder of the model
499
    * @param unzip
500
    *   used to unzip the model, by default true
501
    */
502
  def downloadModelDirectly(
503
      model: String,
504
      folder: String = publicLoc,
505
      unzip: Boolean = true): Unit = {
506
    getResourceDownloader(folder).downloadAndUnzipFile(model, unzip)
×
507
  }
508

509
  def downloadModel[TModel <: PipelineStage](
510
      reader: DefaultParamsReadable[TModel],
511
      name: String,
512
      language: Option[String] = None,
513
      folder: String = publicLoc): TModel = {
514
    downloadModel(reader, ResourceRequest(name, language, folder))
1✔
515
  }
516

517
  def downloadModel[TModel <: PipelineStage](
518
      reader: DefaultParamsReadable[TModel],
519
      request: ResourceRequest): TModel = {
520
    if (!cache.contains(request)) {
1✔
521
      val path = downloadResource(request)
1✔
522
      val model = reader.read.load(path)
1✔
523
      cache(request) = model
1✔
524
      model
525
    } else {
526
      cache(request).asInstanceOf[TModel]
1✔
527
    }
528
  }
529

530
  def downloadPipeline(
531
      name: String,
532
      language: Option[String] = None,
533
      folder: String = publicLoc): PipelineModel = {
534
    downloadPipeline(ResourceRequest(name, language, folder))
×
535
  }
536

537
  def downloadPipeline(request: ResourceRequest): PipelineModel = {
538
    if (!cache.contains(request)) {
×
539
      val path = downloadResource(request)
×
540
      val model = PipelineModel.read.load(path)
×
541
      cache(request) = model
×
542
      model
543
    } else {
544
      cache(request).asInstanceOf[PipelineModel]
×
545
    }
546
  }
547

548
  def clearCache(
549
      name: String,
550
      language: Option[String] = None,
551
      folder: String = publicLoc): Unit = {
552
    clearCache(ResourceRequest(name, language, folder))
×
553
  }
554

555
  def clearCache(request: ResourceRequest): Unit = {
556
    privateDownloader.clearCache(request)
×
557
    publicDownloader.clearCache(request)
×
558
    communityDownloader.clearCache(request)
×
559
    cache.remove(request)
×
560
  }
561

562
  def getDownloadSize(resourceRequest: ResourceRequest): String = {
563

564
    val updatedResourceRequest: ResourceRequest = if (resourceRequest.folder.startsWith("@")) {
1✔
565
      resourceRequest.copy(folder = resourceRequest.folder.replace("@", ""))
1✔
566
    } else resourceRequest
1✔
567

568
    val size = getResourceDownloader(resourceRequest.folder)
569
      .getDownloadSize(updatedResourceRequest)
1✔
570

571
    size match {
572
      case Some(downloadBytes) => FileHelper.getHumanReadableFileSize(downloadBytes)
1✔
573
      case None => "-1"
×
574

575
    }
576
  }
577

578
  /** Downloads the provided S3 path to a local temporary directory and returns the location of
579
    * the folder.
580
    *
581
    * @param s3Path
582
    *   S3 URL to the resource
583
    * @return
584
    *   URI of the local path to the temporary folder of the resource
585
    */
586
  def downloadS3Directory(
587
      s3Path: String,
588
      tempLocalPath: String = "",
589
      isIndex: Boolean = false): URI = {
590

591
    val (bucketName, keyPrefix) = ResourceHelper.parseS3URI(s3Path)
×
592
    val (accessKey, secretKey, sessionToken) = ConfigHelper.getHadoopS3Config
×
593
    val region = ConfigLoader.getConfigStringValue(ConfigHelper.awsExternalRegion)
×
594
    val privateS3Defined =
595
      accessKey != null && secretKey != null && sessionToken != null && region.nonEmpty
×
596

597
    val awsGateway =
598
      if (privateS3Defined)
599
        new AWSGateway(accessKey, secretKey, sessionToken, region = region)
×
600
      else {
×
601
        if (accessKey != null || secretKey != null || sessionToken != null)
×
602
          logger.info(
×
603
            "Not all configs set for private S3 access. Defaulting to public downloader.")
604
        new AWSGateway(credentialsType = "public")
×
605
      }
606

607
    val directory = if (tempLocalPath.isEmpty) SparkFiles.getRootDirectory() else tempLocalPath
×
608
    awsGateway.downloadFilesFromDirectory(bucketName, keyPrefix, new File(directory), isIndex)
×
609
    Paths.get(directory, keyPrefix).toUri
×
610

611
  }
612

613
}
614

615
object ResourceType extends Enumeration {
616
  type ResourceType = Value
617
  val MODEL: pretrained.ResourceType.Value = Value("ml")
1✔
618
  val PIPELINE: pretrained.ResourceType.Value = Value("pl")
1✔
619
  val NOT_DEFINED: pretrained.ResourceType.Value = Value("nd")
1✔
620
}
621

622
case class ResourceRequest(
623
    name: String,
624
    language: Option[String] = None,
625
    folder: String = ResourceDownloader.publicLoc,
626
    libVersion: Version = ResourceDownloader.libVersion,
627
    sparkVersion: Version = ResourceDownloader.sparkVersion)
628

629
/* convenience accessor for Py4J calls */
630
object PythonResourceDownloader {
631

632
  val keyToReader: mutable.Map[String, DefaultParamsReadable[_]] = mutable.Map(
×
633
    "DocumentAssembler" -> DocumentAssembler,
×
634
    "SentenceDetector" -> SentenceDetector,
×
635
    "TokenizerModel" -> TokenizerModel,
×
636
    "PerceptronModel" -> PerceptronModel,
×
637
    "NerCrfModel" -> NerCrfModel,
×
638
    "Stemmer" -> Stemmer,
×
639
    "NormalizerModel" -> NormalizerModel,
×
640
    "RegexMatcherModel" -> RegexMatcherModel,
×
641
    "LemmatizerModel" -> LemmatizerModel,
×
642
    "DateMatcher" -> DateMatcher,
×
643
    "TextMatcherModel" -> TextMatcherModel,
×
644
    "SentimentDetectorModel" -> SentimentDetectorModel,
×
645
    "ViveknSentimentModel" -> ViveknSentimentModel,
×
646
    "NorvigSweetingModel" -> NorvigSweetingModel,
×
647
    "SymmetricDeleteModel" -> SymmetricDeleteModel,
×
648
    "NerDLModel" -> NerDLModel,
×
649
    "WordEmbeddingsModel" -> WordEmbeddingsModel,
×
650
    "BertEmbeddings" -> BertEmbeddings,
×
651
    "DependencyParserModel" -> DependencyParserModel,
×
652
    "TypedDependencyParserModel" -> TypedDependencyParserModel,
×
653
    "UniversalSentenceEncoder" -> UniversalSentenceEncoder,
×
654
    "ElmoEmbeddings" -> ElmoEmbeddings,
×
655
    "ClassifierDLModel" -> ClassifierDLModel,
×
656
    "ContextSpellCheckerModel" -> ContextSpellCheckerModel,
×
657
    "AlbertEmbeddings" -> AlbertEmbeddings,
×
658
    "XlnetEmbeddings" -> XlnetEmbeddings,
×
659
    "SentimentDLModel" -> SentimentDLModel,
×
660
    "LanguageDetectorDL" -> LanguageDetectorDL,
×
661
    "StopWordsCleaner" -> StopWordsCleaner,
×
662
    "BertSentenceEmbeddings" -> BertSentenceEmbeddings,
×
663
    "MultiClassifierDLModel" -> MultiClassifierDLModel,
×
664
    "SentenceDetectorDLModel" -> SentenceDetectorDLModel,
×
665
    "T5Transformer" -> T5Transformer,
×
666
    "MarianTransformer" -> MarianTransformer,
×
667
    "WordSegmenterModel" -> WordSegmenterModel,
×
668
    "DistilBertEmbeddings" -> DistilBertEmbeddings,
×
669
    "RoBertaEmbeddings" -> RoBertaEmbeddings,
×
670
    "XlmRoBertaEmbeddings" -> XlmRoBertaEmbeddings,
×
671
    "LongformerEmbeddings" -> LongformerEmbeddings,
×
672
    "RoBertaSentenceEmbeddings" -> RoBertaSentenceEmbeddings,
×
673
    "XlmRoBertaSentenceEmbeddings" -> XlmRoBertaSentenceEmbeddings,
×
674
    "AlbertForTokenClassification" -> AlbertForTokenClassification,
×
675
    "BertForTokenClassification" -> BertForTokenClassification,
×
676
    "DeBertaForTokenClassification" -> DeBertaForTokenClassification,
×
677
    "DistilBertForTokenClassification" -> DistilBertForTokenClassification,
×
678
    "LongformerForTokenClassification" -> LongformerForTokenClassification,
×
679
    "RoBertaForTokenClassification" -> RoBertaForTokenClassification,
×
680
    "XlmRoBertaForTokenClassification" -> XlmRoBertaForTokenClassification,
×
681
    "XlnetForTokenClassification" -> XlnetForTokenClassification,
×
682
    "AlbertForSequenceClassification" -> AlbertForSequenceClassification,
×
683
    "BertForSequenceClassification" -> BertForSequenceClassification,
×
684
    "DeBertaForSequenceClassification" -> DeBertaForSequenceClassification,
×
685
    "DistilBertForSequenceClassification" -> DistilBertForSequenceClassification,
×
686
    "LongformerForSequenceClassification" -> LongformerForSequenceClassification,
×
687
    "RoBertaForSequenceClassification" -> RoBertaForSequenceClassification,
×
688
    "XlmRoBertaForSequenceClassification" -> XlmRoBertaForSequenceClassification,
×
689
    "XlnetForSequenceClassification" -> XlnetForSequenceClassification,
×
690
    "GPT2Transformer" -> GPT2Transformer,
×
691
    "EntityRulerModel" -> EntityRulerModel,
×
692
    "Doc2VecModel" -> Doc2VecModel,
×
693
    "Word2VecModel" -> Word2VecModel,
×
694
    "DeBertaEmbeddings" -> DeBertaEmbeddings,
×
695
    "DeBertaForSequenceClassification" -> DeBertaForSequenceClassification,
×
696
    "DeBertaForTokenClassification" -> DeBertaForTokenClassification,
×
697
    "CamemBertEmbeddings" -> CamemBertEmbeddings,
×
698
    "AlbertForQuestionAnswering" -> AlbertForQuestionAnswering,
×
699
    "BertForQuestionAnswering" -> BertForQuestionAnswering,
×
700
    "DeBertaForQuestionAnswering" -> DeBertaForQuestionAnswering,
×
701
    "DistilBertForQuestionAnswering" -> DistilBertForQuestionAnswering,
×
702
    "LongformerForQuestionAnswering" -> LongformerForQuestionAnswering,
×
703
    "RoBertaForQuestionAnswering" -> RoBertaForQuestionAnswering,
×
704
    "XlmRoBertaForQuestionAnswering" -> XlmRoBertaForQuestionAnswering,
×
705
    "SpanBertCorefModel" -> SpanBertCorefModel,
×
706
    "ViTForImageClassification" -> ViTForImageClassification,
×
707
    "SwinForImageClassification" -> SwinForImageClassification,
×
708
    "ConvNextForImageClassification" -> ConvNextForImageClassification,
×
709
    "Wav2Vec2ForCTC" -> Wav2Vec2ForCTC,
×
710
    "HubertForCTC" -> HubertForCTC,
×
711
    "CamemBertForTokenClassification" -> CamemBertForTokenClassification,
×
712
    "TableAssembler" -> TableAssembler,
×
713
    "TapasForQuestionAnswering" -> TapasForQuestionAnswering,
×
714
    "CamemBertForSequenceClassification" -> CamemBertForSequenceClassification,
×
715
    "CamemBertForQuestionAnswering" -> CamemBertForQuestionAnswering,
×
716
    "ZeroShotNerModel" -> ZeroShotNerModel,
×
717
    "BartTransformer" -> BartTransformer,
×
718
    "BertForZeroShotClassification" -> BertForZeroShotClassification,
×
719
    "DistilBertForZeroShotClassification" -> DistilBertForZeroShotClassification,
×
720
    "RoBertaForZeroShotClassification" -> RoBertaForZeroShotClassification)
×
721

722
  // List pairs of types such as the one with key type can load a pretrained model from the value type
723
  val typeMapper: Map[String, String] = Map("ZeroShotNerModel" -> "RoBertaForQuestionAnswering")
×
724

725
  def downloadModel(
726
      readerStr: String,
727
      name: String,
728
      language: String = null,
729
      remoteLoc: String = null): PipelineStage = {
730

731
    val reader = keyToReader.getOrElse(
×
732
      if (typeMapper.contains(readerStr)) typeMapper(readerStr) else readerStr,
×
733
      throw new RuntimeException(s"Unsupported Model: $readerStr"))
×
734

735
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
736

737
    val model = ResourceDownloader.downloadModel(
×
738
      reader.asInstanceOf[DefaultParamsReadable[PipelineStage]],
×
739
      name,
740
      Option(language),
×
741
      correctedFolder)
742

743
    // Cast the model to the required type. This has to be done for each entry in the typeMapper map
744
    if (typeMapper.contains(readerStr) && readerStr == "ZeroShotNerModel")
×
745
      ZeroShotNerModel(model)
×
746
    else
747
      model
×
748
  }
749

750
  def downloadPipeline(
751
      name: String,
752
      language: String = null,
753
      remoteLoc: String = null): PipelineModel = {
754
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
755
    ResourceDownloader.downloadPipeline(name, Option(language), correctedFolder)
×
756
  }
757

758
  def clearCache(name: String, language: String = null, remoteLoc: String = null): Unit = {
759
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
760
    ResourceDownloader.clearCache(name, Option(language), correctedFolder)
×
761
  }
762

763
  def downloadModelDirectly(
764
      model: String,
765
      remoteLoc: String = null,
766
      unzip: Boolean = true): Unit = {
767
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
768
    ResourceDownloader.downloadModelDirectly(model, correctedFolder, unzip)
×
769
  }
770

771
  def showUnCategorizedResources(): String = {
772
    ResourceDownloader.publicResourceString(
×
773
      annotator = None,
×
774
      lang = None,
×
775
      version = None,
×
776
      resourceType = ResourceType.NOT_DEFINED)
×
777
  }
778

779
  def showPublicPipelines(lang: String, version: String): String = {
780
    val ver: Option[String] = version match {
781
      case null => Some(Build.version)
×
782
      case _ => Some(version)
×
783
    }
784
    ResourceDownloader.publicResourceString(
×
785
      annotator = None,
×
786
      lang = Option(lang),
×
787
      version = ver,
788
      resourceType = ResourceType.PIPELINE)
×
789
  }
790

791
  def showPublicModels(annotator: String, lang: String, version: String): String = {
792
    val ver: Option[String] = version match {
793
      case null => Some(Build.version)
×
794
      case _ => Some(version)
×
795
    }
796
    ResourceDownloader.publicResourceString(
×
797
      annotator = Option(annotator),
×
798
      lang = Option(lang),
×
799
      version = ver,
800
      resourceType = ResourceType.MODEL)
×
801
  }
802

803
  def showAvailableAnnotators(): String = {
804
    ResourceDownloader.listAvailableAnnotators().mkString("\n")
×
805
  }
806

807
  def getDownloadSize(name: String, language: String = "en", remoteLoc: String = null): String = {
808
    val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
×
809
    ResourceDownloader.getDownloadSize(ResourceRequest(name, Option(language), correctedFolder))
×
810
  }
811
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc