• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4947838414

pending completion
4947838414

Pull #13796

github

GitHub
Merge 30bdeef19 into ef7906c5e
Pull Request #13796: Add unzip param to downloadModelDirectly in ResourceDownloader

39 of 39 new or added lines in 2 files covered. (100.0%)

8632 of 13111 relevant lines covered (65.84%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

34.29
/src/main/scala/com/johnsnowlabs/nlp/pretrained/S3ResourceDownloader.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.pretrained
18

19
import com.amazonaws.services.s3.model.ObjectMetadata
20
import com.johnsnowlabs.client.aws.AWSGateway
21
import com.johnsnowlabs.client.gcp.GCPGateway
22
import com.johnsnowlabs.nlp.util.io.ResourceHelper
23
import com.johnsnowlabs.util.{ConfigHelper, FileHelper}
24
import org.apache.commons.io.IOUtils
25
import org.apache.hadoop.fs.Path
26
import org.apache.spark.sql.SparkSession
27

28
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File}
29
import java.nio.file.Files
30
import java.sql.Timestamp
31
import java.util.Calendar
32
import java.util.zip.ZipInputStream
33
import scala.collection.mutable
34

35
class S3ResourceDownloader(
36
    bucket: => String,
37
    s3Path: => String,
38
    cacheFolder: => String,
39
    credentialsType: => String,
40
    region: String = "us-east-1")
41
    extends ResourceDownloader {
42

43
  private val repoFolder2Metadata: mutable.Map[String, RepositoryMetadata] =
44
    mutable.Map[String, RepositoryMetadata]()
1✔
45
  val cachePath = new Path(cacheFolder)
1✔
46

47
  if (!doesCacheFolderInCloud && !fileSystem.exists(cachePath)) {
1✔
48
    fileSystem.mkdirs(cachePath)
1✔
49
  }
50

51
  lazy val awsGateway = new AWSGateway(region = region, credentialsType = credentialsType)
52

53
  private def doesCacheFolderInCloud(): Boolean = {
54
    cacheFolder.startsWith("s3") || cacheFolder.startsWith("gs")
1✔
55
  }
56

57
  def downloadMetadataIfNeed(folder: String): List[ResourceMetadata] = {
58
    val lastState = repoFolder2Metadata.get(folder)
1✔
59

60
    val fiveMinsBefore = getTimestamp(-5)
1✔
61
    val needToRefresh = lastState.isEmpty || lastState.get.lastMetadataDownloaded
1✔
62
      .before(fiveMinsBefore)
×
63

64
    if (!needToRefresh) {
1✔
65
      lastState.get.metadata
×
66
    } else {
67
      awsGateway.getMetadata(s3Path, folder, bucket)
1✔
68
    }
69
  }
70

71
  def resolveLink(request: ResourceRequest): Option[ResourceMetadata] = {
72
    val metadata = downloadMetadataIfNeed(request.folder)
1✔
73
    ResourceMetadata.resolveResource(metadata, request)
1✔
74
  }
75

76
  /** Download resource to local file
77
    *
78
    * @param request
79
    *   Resource request
80
    * @return
81
    *   Downloaded file or None if resource is not found
82
    */
83
  override def download(request: ResourceRequest): Option[String] = {
84
    val link = resolveLink(request)
1✔
85
    link.flatMap { resource =>
1✔
86
      val s3FilePath = awsGateway.getS3File(s3Path, request.folder, resource.fileName)
1✔
87
      if (!awsGateway.doesS3ObjectExist(bucket, s3FilePath)) {
1✔
88
        None
×
89
      } else {
1✔
90

91
        val s3Path = "^s3.*".r
1✔
92
        val gcpStoragePath = "^gs.*".r
1✔
93

94
        val sourceS3URI = s"s3a://$bucket/$s3FilePath"
1✔
95
        val zipFile = sourceS3URI.split("/").last
1✔
96
        val modelName = zipFile.substring(0, zipFile.indexOf(".zip"))
1✔
97

98
        cachePath.toString match {
1✔
99
          case s3Path() => {
100
            val destinationS3URI = cachePath.toString.replace("s3:", "s3a:")
×
101
            val modelExists =
102
              doesModelExistInExternalCloudStorage(modelName, destinationS3URI, "S3")
×
103

104
            if (!modelExists) {
×
105
              val destinationKey = unzipInExternalCloudStorage(
×
106
                ResourceHelper.spark,
107
                sourceS3URI,
108
                destinationS3URI,
109
                "S3")
×
110
              Option(destinationKey)
×
111
            } else {
112
              Option(destinationS3URI + "/" + modelName)
×
113
            }
114

115
          }
116
          case gcpStoragePath() => {
117
            val sourceS3URI = s"s3a://$bucket/$s3FilePath"
×
118

119
            val modelExists =
120
              doesModelExistInExternalCloudStorage(modelName, cachePath.toString, "GCP")
×
121

122
            if (!modelExists) {
×
123
              val destination = unzipInExternalCloudStorage(
×
124
                ResourceHelper.spark,
125
                sourceS3URI,
126
                cachePath.toString,
×
127
                "GCP")
×
128
              Option(destination)
×
129
            } else {
130
              Option(cachePath.toString + "/" + modelName)
×
131
            }
132

133
          }
134
          case _ => {
135
            val destinationFile = new Path(cachePath.toString, resource.fileName)
1✔
136
            downloadAndUnzipFile(destinationFile, resource, s3FilePath)
1✔
137
          }
138
        }
139
      }
140
    }
141
  }
142

143
  private def doesModelExistInExternalCloudStorage(
144
      modelName: String,
145
      destinationURI: String,
146
      destinationCloud: String): Boolean = {
147

148
    destinationCloud match {
149
      case "S3" => {
150
        val (accessKeyId, secretKey, sessionToken) = ConfigHelper.getHadoopS3Config
×
151
        val awsDestinationGateway = new AWSGateway(accessKeyId, secretKey, sessionToken)
×
152
        val (destinationBucketName, destinationKey) = ResourceHelper.parseS3URI(destinationURI)
×
153

154
        val modelPath = destinationKey + "/" + modelName
×
155

156
        awsDestinationGateway.doesS3FolderExist(destinationBucketName, modelPath)
×
157
      }
158
      case "GCP" => {
159
        val (gcpGateway, destinationBucketName, destinationStoragePath) = getGCPStorageConfig(
×
160
          destinationURI)
161
        val modelPath = destinationStoragePath + "/" + modelName
×
162

163
        gcpGateway.doesFolderExist(destinationBucketName, modelPath)
×
164
      }
165
    }
166

167
  }
168

169
  private def unzipInExternalCloudStorage(
170
      sparkSession: SparkSession,
171
      sourceS3URI: String,
172
      destinationStorageURI: String,
173
      destinationCloud: String) = {
174

175
    val (sourceBucketName, sourceKey) = ResourceHelper.parseS3URI(sourceS3URI)
×
176
    val zippedModel = awsGateway.getS3Object(sourceBucketName, sourceKey)
×
177
    val zipInputStream = new ZipInputStream(zippedModel.getObjectContent)
×
178
    var zipEntry = zipInputStream.getNextEntry
×
179

180
    val zipFile = sourceKey.split("/").last
×
181
    val modelName = zipFile.substring(0, zipFile.indexOf(".zip"))
×
182

183
    println(s"Uploading model $modelName to external Cloud Storage URI: $destinationStorageURI")
×
184
    while (zipEntry != null) {
×
185
      if (!zipEntry.isDirectory) {
×
186
        val outputStream = new ByteArrayOutputStream()
×
187
        IOUtils.copy(zipInputStream, outputStream)
×
188
        val inputStream = new ByteArrayInputStream(outputStream.toByteArray)
×
189

190
        if (destinationCloud == "S3") {
×
191
          val (awsGatewayDestination, destinationBucketName, destinationKey) =
×
192
            getS3Config(sparkSession, destinationStorageURI)
193
          val fileName = s"$modelName/${zipEntry.getName}"
×
194
          val destinationS3Path = destinationKey + "/" + fileName
×
195

196
          awsGatewayDestination.client.putObject(
×
197
            destinationBucketName,
198
            destinationS3Path,
199
            inputStream,
200
            new ObjectMetadata())
×
201

202
        } else {
×
203
          val (gcpGateway, destinationBucketName, destinationStoragePath) = getGCPStorageConfig(
×
204
            destinationStorageURI)
205
          val destinationGCPStoragePath =
206
            s"$destinationStoragePath/$modelName/${zipEntry.getName}"
×
207

208
          gcpGateway.copyFileToGCPStorage(
×
209
            destinationBucketName,
210
            destinationGCPStoragePath,
211
            inputStream)
212
        }
213

214
      }
215
      zipEntry = zipInputStream.getNextEntry
×
216
    }
217
    destinationStorageURI + "/" + modelName
×
218
  }
219

220
  private def getS3Config(sparkSession: SparkSession, destinationS3URI: String) = {
221
    var accessKeyId =
222
      sparkSession.sparkContext.hadoopConfiguration.get("fs.s3a.access.key")
×
223
    var secretAccessKey =
224
      sparkSession.sparkContext.hadoopConfiguration.get("fs.s3a.secret.key")
×
225
    var sessionToken =
226
      sparkSession.sparkContext.hadoopConfiguration.get("fs.s3a.session.token")
×
227

228
    if (accessKeyId == null) accessKeyId = ""
×
229
    if (secretAccessKey == null) secretAccessKey = ""
×
230
    if (sessionToken == null) sessionToken = ""
×
231

232
    if (accessKeyId == "" && secretAccessKey == "") {
×
233
      throw new IllegalAccessException(
×
234
        "Using S3 as cachePath requires to define access.key and secret.key hadoop configuration")
235
    }
236
    val awsGatewayDestination = new AWSGateway(accessKeyId, secretAccessKey, sessionToken)
×
237

238
    val (destinationBucketName, destinationKey) = ResourceHelper.parseS3URI(destinationS3URI)
×
239

240
    (awsGatewayDestination, destinationBucketName, destinationKey)
×
241
  }
242

243
  private def getGCPStorageConfig(destinationGCPStorageURI: String) = {
244
    val gcpGateway = new GCPGateway()
×
245
    val (destinationBucketName, destinationStoragePath) =
×
246
      ResourceHelper.parseGCPStorageURI(destinationGCPStorageURI)
247

248
    (gcpGateway, destinationBucketName, destinationStoragePath)
×
249
  }
250

251
  def downloadAndUnzipFile(
252
      destinationFile: Path,
253
      resource: ResourceMetadata,
254
      s3FilePath: String): Option[String] = {
255

256
    val splitPath = destinationFile.toString.substring(0, destinationFile.toString.length - 4)
1✔
257
    if (!(fileSystem.exists(destinationFile) || fileSystem.exists(new Path(splitPath)))) {
1✔
258
      // 1. Create tmp file
259
      val tmpFileName = Files.createTempFile(resource.fileName, "").toString
1✔
260
      val tmpFile = new File(tmpFileName)
1✔
261

262
      // 2. Download content to tmp file
263
      awsGateway.getS3Object(bucket, s3FilePath, tmpFile)
1✔
264
      // 3. validate checksum
265
      if (!resource.checksum.equals(""))
1✔
266
        require(
×
267
          FileHelper.generateChecksum(tmpFileName).equals(resource.checksum),
×
268
          "Checksum validation failed!")
×
269

270
      // 4. Move tmp file to destination
271
      fileSystem.moveFromLocalFile(new Path(tmpFile.toString), destinationFile)
1✔
272

273
    }
274

275
    // 5. Unzip if needs
276
    if (resource.isZipped) {
1✔
277
      // if not already unzipped
278
      if (!fileSystem.exists(new Path(splitPath))) {
1✔
279
        val zis = new ZipInputStream(fileSystem.open(destinationFile))
1✔
280
        val buf = Array.ofDim[Byte](1024)
1✔
281
        var entry = zis.getNextEntry
1✔
282
        require(
1✔
283
          destinationFile.toString.substring(destinationFile.toString.length - 4) == ".zip",
1✔
284
          "Not a zip file.")
×
285

286
        while (entry != null) {
1✔
287
          if (!entry.isDirectory) {
1✔
288
            val entryName = new Path(splitPath, entry.getName)
1✔
289
            val outputStream = fileSystem.create(entryName)
1✔
290
            var bytesRead = zis.read(buf, 0, 1024)
1✔
291
            while (bytesRead > -1) {
1✔
292
              outputStream.write(buf, 0, bytesRead)
1✔
293
              bytesRead = zis.read(buf, 0, 1024)
1✔
294
            }
295
            outputStream.close()
1✔
296
          }
297
          zis.closeEntry()
1✔
298
          entry = zis.getNextEntry
1✔
299
        }
300
        zis.close()
1✔
301
        // delete the zip file
302
        fileSystem.delete(destinationFile, true)
1✔
303
      }
304
      Some(splitPath)
1✔
305
    } else {
306
      Some(destinationFile.getName)
×
307
    }
308
  }
309

310
  def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean): Option[String] = {
311
    // handle s3FilePath options:
312
    // 1--> s3://auxdata.johnsnowlabs.com/public/models/albert_base_sequence_classifier_ag_news_en_3.4.0_3.0_1639648298937.zip
313
    // 2--> public/models/albert_base_sequence_classifier_ag_news_en_3.4.0_3.0_1639648298937.zip
314

315
    val newS3FilePath = if (s3FilePath.startsWith("s3")) {
×
316
      ResourceHelper.parseS3URI(s3FilePath)._2
×
317
    } else s3FilePath
×
318

319
    val s3File = newS3FilePath.split("/").last
×
320

321
    val destinationFile = new Path(cachePath.toString + "/" + s3File)
×
322
    val splitPath = destinationFile.toString.substring(0, destinationFile.toString.length - 4)
×
323

324
    if (!(fileSystem.exists(destinationFile) || fileSystem.exists(new Path(splitPath)))) {
×
325
      // 1. Create tmp file
326
      val tmpFileName = Files.createTempFile(s3File, "").toString
×
327
      val tmpFile = new File(tmpFileName)
×
328

329
      val newStrfilePath: String = newS3FilePath
330
      val mybucket: String = bucket
×
331
      // 2. Download content to tmp file
332
      awsGateway.getS3Object(mybucket, newStrfilePath, tmpFile)
×
333

334
      // 4. Move tmp file to destination
335
      fileSystem.moveFromLocalFile(new Path(tmpFile.toString), destinationFile)
×
336
    }
337
    if (unzip) {
×
338
      if (!fileSystem.exists(new Path(splitPath))) {
×
339
        val zis = new ZipInputStream(fileSystem.open(destinationFile))
×
340
        val buf = Array.ofDim[Byte](1024)
×
341
        var entry = zis.getNextEntry
×
342
        require(
×
343
          destinationFile.toString.substring(destinationFile.toString.length - 4) == ".zip",
×
344
          "Not a zip file.")
×
345

346
        while (entry != null) {
×
347
          if (!entry.isDirectory) {
×
348
            val entryName = new Path(splitPath, entry.getName)
×
349
            val outputStream = fileSystem.create(entryName)
×
350
            var bytesRead = zis.read(buf, 0, 1024)
×
351
            while (bytesRead > -1) {
×
352
              outputStream.write(buf, 0, bytesRead)
×
353
              bytesRead = zis.read(buf, 0, 1024)
×
354
            }
355
            outputStream.close()
×
356
          }
357
          zis.closeEntry()
×
358
          entry = zis.getNextEntry
×
359
        }
360
        zis.close()
×
361
        // delete the zip file
362
        fileSystem.delete(destinationFile, true)
×
363
      }
364
    }
365
    Some(splitPath)
×
366

367
  }
368

369
  override def getDownloadSize(request: ResourceRequest): Option[Long] = {
370
    val link = resolveLink(request)
1✔
371
    link.flatMap { resource =>
1✔
372
      awsGateway.getS3DownloadSize(s3Path, request.folder, resource.fileName, bucket)
1✔
373
    }
374
  }
375

376
  override def clearCache(request: ResourceRequest): Unit = {
377
    val metadata = downloadMetadataIfNeed(request.folder)
×
378

379
    val resources = ResourceMetadata.resolveResource(metadata, request)
×
380
    for (resource <- resources) {
×
381
      val fileName = new Path(cachePath.toString, resource.fileName)
×
382
      if (fileSystem.exists(fileName))
×
383
        fileSystem.delete(fileName, true)
×
384

385
      if (resource.isZipped) {
×
386
        require(fileName.toString.substring(fileName.toString.length - 4) == ".zip")
×
387
        val unzipped = fileName.toString.substring(0, fileName.toString.length - 4)
×
388
        val unzippedFile = new Path(unzipped)
×
389
        if (fileSystem.exists(unzippedFile))
×
390
          fileSystem.delete(unzippedFile, true)
×
391
      }
392
    }
393
  }
394

395
  private def getTimestamp(addMinutes: Int = 0): Timestamp = {
396
    val cal = Calendar.getInstance()
1✔
397
    cal.add(Calendar.MINUTE, addMinutes)
1✔
398
    val timestamp = new Timestamp(cal.getTime.getTime)
1✔
399
    cal.clear()
1✔
400
    timestamp
401
  }
402

403
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc