• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 10656334014

01 Sep 2024 06:19PM CUT coverage: 62.392% (-0.02%) from 62.41%
10656334014

Pull #14355

github

web-flow
Merge 2a3ee298b into 50a69662f
Pull Request #14355: Implementing Mxbai Embeddings

0 of 2 new or added lines in 1 file covered. (0.0%)

27 existing lines in 7 files now uncovered.

8967 of 14372 relevant lines covered (62.39%)

0.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

52.59
/src/main/scala/com/johnsnowlabs/nlp/pretrained/S3ResourceDownloader.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.pretrained
18

19
import com.johnsnowlabs.client.CloudResources
20
import com.johnsnowlabs.client.aws.AWSGateway
21
import com.johnsnowlabs.client.util.CloudHelper
22
import com.johnsnowlabs.util.FileHelper
23
import org.apache.hadoop.fs.Path
24

25
import java.io.File
26
import java.nio.file.Files
27
import java.util.zip.ZipInputStream
28
import scala.collection.mutable
29

30
class S3ResourceDownloader(
31
    bucket: => String,
32
    s3Path: => String,
33
    cacheFolder: => String,
34
    credentialsType: => String,
35
    region: String = "us-east-1")
36
    extends ResourceDownloader {
37

38
  private val repoFolder2Metadata: mutable.Map[String, RepositoryMetadata] =
39
    mutable.Map[String, RepositoryMetadata]()
1✔
40
  val cachePath = new Path(cacheFolder)
1✔
41

42
  if (!CloudHelper.isCloudPath(cacheFolder) && !fileSystem.exists(cachePath)) {
1✔
43
    fileSystem.mkdirs(cachePath)
1✔
44
  }
45

46
  lazy val awsGateway = new AWSGateway(region = region, credentialsType = credentialsType)
47

48
  def downloadMetadataIfNeed(folder: String): List[ResourceMetadata] = {
49
    val lastMetadataState = repoFolder2Metadata.get(folder)
1✔
50
    val metadataFilePath = awsGateway.getS3File(s3Path, folder, "metadata.json")
2✔
51
    val metadataObject = awsGateway.client.getObject(bucket, metadataFilePath)
2✔
52
    val lastModifiedTimeInS3 = metadataObject.getObjectMetadata.getLastModified
2✔
53
    val needToRefresh =
54
      lastMetadataState.isEmpty || lastMetadataState.get.lastModified.before(lastModifiedTimeInS3)
2✔
55
    if (!needToRefresh) {
2✔
56
      metadataObject.close()
2✔
57
      lastMetadataState.get.metadata
2✔
58
    } else {
1✔
59
      val metadata = ResourceMetadata.readResources(metadataObject.getObjectContent)
1✔
60
      metadataObject.close()
1✔
61
      repoFolder2Metadata(folder) = RepositoryMetadata(
1✔
62
        folder,
63
        lastModifiedTimeInS3,
64
        java.util.Date.from(java.time.Instant.now()),
1✔
65
        metadata)
66
      metadata
67
    }
68
  }
69

70
  def resolveLink(request: ResourceRequest): Option[ResourceMetadata] = {
71
    val metadata = downloadMetadataIfNeed(request.folder)
1✔
72
    ResourceMetadata.resolveResource(metadata, request)
1✔
73
  }
74

75
  /** Download resource to local file
76
    *
77
    * @param request
78
    *   Resource request
79
    * @return
80
    *   Downloaded file or None if resource is not found
81
    */
82
  override def download(request: ResourceRequest): Option[String] = {
83
    val link = resolveLink(request)
1✔
84
    link.flatMap { resource =>
1✔
85
      val s3FilePath = awsGateway.getS3File(s3Path, request.folder, resource.fileName)
1✔
86

87
      if (!awsGateway.doesS3ObjectExist(bucket, s3FilePath)) {
1✔
88
        None
×
89
      } else {
1✔
90
        val sourceS3URI = s"s3a://$bucket/$s3FilePath"
1✔
91
        val zipFile = sourceS3URI.split("/").last
1✔
92
        val modelName = zipFile.substring(0, zipFile.indexOf(".zip"))
1✔
93

94
        cachePath.toString match {
1✔
95
          case path if CloudHelper.isCloudPath(path) => {
1✔
96
            CloudResources.downloadModelFromCloud(
×
97
              awsGateway,
98
              cachePath.toString,
×
99
              modelName,
100
              sourceS3URI)
101
          }
102
          case _ => {
103
            val destinationFile = new Path(cachePath.toString, resource.fileName)
1✔
104
            downloadAndUnzipFile(destinationFile, resource, s3FilePath)
1✔
105
          }
106
        }
107
      }
108
    }
109
  }
110

111
  def downloadAndUnzipFile(
112
      destinationFile: Path,
113
      resource: ResourceMetadata,
114
      s3FilePath: String): Option[String] = {
115

116
    val splitPath = destinationFile.toString.substring(0, destinationFile.toString.length - 4)
1✔
117
    if (!(fileSystem.exists(destinationFile) || fileSystem.exists(new Path(splitPath)))) {
1✔
118
      // 1. Create tmp file
119
      val tmpFileName = Files.createTempFile(resource.fileName, "").toString
1✔
120
      val tmpFile = new File(tmpFileName)
1✔
121

122
      // 2. Download content to tmp file
123
      awsGateway.getS3Object(bucket, s3FilePath, tmpFile)
1✔
124
      // 3. validate checksum
125
      if (!resource.checksum.equals(""))
1✔
126
        require(
×
127
          FileHelper.generateChecksum(tmpFileName).equals(resource.checksum),
×
128
          "Checksum validation failed!")
×
129

130
      // 4. Move tmp file to destination
131
      fileSystem.moveFromLocalFile(new Path(tmpFile.toString), destinationFile)
1✔
132

133
    }
134

135
    // 5. Unzip if needs
136
    if (resource.isZipped) {
1✔
137
      // if not already unzipped
UNCOV
138
      if (!fileSystem.exists(new Path(splitPath))) {
×
139
        val zis = new ZipInputStream(fileSystem.open(destinationFile))
1✔
140
        val buf = Array.ofDim[Byte](1024)
1✔
141
        var entry = zis.getNextEntry
1✔
142
        require(
1✔
143
          destinationFile.toString.substring(destinationFile.toString.length - 4) == ".zip",
1✔
144
          "Not a zip file.")
×
145

146
        while (entry != null) {
1✔
147
          if (!entry.isDirectory) {
1✔
148
            val entryName = new Path(splitPath, entry.getName)
1✔
149
            val outputStream = fileSystem.create(entryName)
1✔
150
            var bytesRead = zis.read(buf, 0, 1024)
1✔
151
            while (bytesRead > -1) {
1✔
152
              outputStream.write(buf, 0, bytesRead)
1✔
153
              bytesRead = zis.read(buf, 0, 1024)
1✔
154
            }
155
            outputStream.close()
1✔
156
          }
157
          zis.closeEntry()
1✔
158
          entry = zis.getNextEntry
1✔
159
        }
160
        zis.close()
1✔
161
        // delete the zip file
162
        fileSystem.delete(destinationFile, true)
1✔
163
      }
164
      Some(splitPath)
1✔
165
    } else {
166
      Some(destinationFile.getName)
×
167
    }
168
  }
169

170
  def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean): Option[String] = {
171
    // handle s3FilePath options:
172
    // 1--> s3://auxdata.johnsnowlabs.com/public/models/albert_base_sequence_classifier_ag_news_en_3.4.0_3.0_1639648298937.zip
173
    // 2--> public/models/albert_base_sequence_classifier_ag_news_en_3.4.0_3.0_1639648298937.zip
174

175
    val newS3FilePath = if (CloudHelper.isS3Path(s3FilePath)) {
×
176
      CloudHelper.parseS3URI(s3FilePath)._2
×
177
    } else s3FilePath
×
178

179
    val s3File = newS3FilePath.split("/").last
×
180

181
    val destinationFile = new Path(cachePath.toString + "/" + s3File)
×
182
    val splitPath = destinationFile.toString.substring(0, destinationFile.toString.length - 4)
×
183

184
    if (!(fileSystem.exists(destinationFile) || fileSystem.exists(new Path(splitPath)))) {
×
185
      // 1. Create tmp file
186
      val tmpFileName = Files.createTempFile(s3File, "").toString
×
187
      val tmpFile = new File(tmpFileName)
×
188

189
      val newStrfilePath: String = newS3FilePath
190
      val mybucket: String = bucket
×
191
      // 2. Download content to tmp file
192
      awsGateway.getS3Object(mybucket, newStrfilePath, tmpFile)
×
193

194
      // 4. Move tmp file to destination
195
      fileSystem.moveFromLocalFile(new Path(tmpFile.toString), destinationFile)
×
196
    }
197
    if (unzip) {
×
198
      if (!fileSystem.exists(new Path(splitPath))) {
×
199
        val zis = new ZipInputStream(fileSystem.open(destinationFile))
×
200
        val buf = Array.ofDim[Byte](1024)
×
201
        var entry = zis.getNextEntry
×
202
        require(
×
203
          destinationFile.toString.substring(destinationFile.toString.length - 4) == ".zip",
×
204
          "Not a zip file.")
×
205

206
        while (entry != null) {
×
207
          if (!entry.isDirectory) {
×
208
            val entryName = new Path(splitPath, entry.getName)
×
209
            val outputStream = fileSystem.create(entryName)
×
210
            var bytesRead = zis.read(buf, 0, 1024)
×
211
            while (bytesRead > -1) {
×
212
              outputStream.write(buf, 0, bytesRead)
×
213
              bytesRead = zis.read(buf, 0, 1024)
×
214
            }
215
            outputStream.close()
×
216
          }
217
          zis.closeEntry()
×
218
          entry = zis.getNextEntry
×
219
        }
220
        zis.close()
×
221
        // delete the zip file
222
        fileSystem.delete(destinationFile, true)
×
223
      }
224
    }
225
    Some(splitPath)
×
226

227
  }
228

229
  override def getDownloadSize(request: ResourceRequest): Option[Long] = {
230
    val link = resolveLink(request)
1✔
231
    link.flatMap { resource =>
1✔
232
      awsGateway.getS3DownloadSize(s3Path, request.folder, resource.fileName, bucket)
1✔
233
    }
234
  }
235

236
  override def clearCache(request: ResourceRequest): Unit = {
237
    val metadata = downloadMetadataIfNeed(request.folder)
×
238

239
    val resources = ResourceMetadata.resolveResource(metadata, request)
×
240
    for (resource <- resources) {
×
241
      val fileName = new Path(cachePath.toString, resource.fileName)
×
242
      if (fileSystem.exists(fileName))
×
243
        fileSystem.delete(fileName, true)
×
244

245
      if (resource.isZipped) {
×
246
        require(fileName.toString.substring(fileName.toString.length - 4) == ".zip")
×
247
        val unzipped = fileName.toString.substring(0, fileName.toString.length - 4)
×
248
        val unzippedFile = new Path(unzipped)
×
249
        if (fileSystem.exists(unzippedFile))
×
250
          fileSystem.delete(unzippedFile, true)
×
251
      }
252
    }
253
  }
254

255
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc