• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4413880521

pending completion
4413880521

Pull #13648

github

GitHub
Merge 210854e6c into a9f10588b
Pull Request #13648: release/432-release-candidate

80 of 80 new or added lines in 5 files covered. (100.0%)

8591 of 12937 relevant lines covered (66.41%)

0.67 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

33.33
/src/main/scala/com/johnsnowlabs/client/aws/AWSGateway.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.client.aws
18

19
import com.amazonaws.auth.{AWSCredentials, AWSStaticCredentialsProvider}
20
import com.amazonaws.services.pi.model.InvalidArgumentException
21
import com.amazonaws.services.s3.model.{
22
  GetObjectRequest,
23
  ObjectMetadata,
24
  PutObjectResult,
25
  S3Object,
26
  S3ObjectSummary
27
}
28
import com.amazonaws.services.s3.transfer.{Transfer, TransferManagerBuilder}
29
import com.amazonaws.services.s3.{AmazonS3, AmazonS3ClientBuilder}
30
import com.amazonaws.{AmazonClientException, AmazonServiceException, ClientConfiguration}
31
import com.johnsnowlabs.client.CredentialParams
32
import com.johnsnowlabs.nlp.pretrained.ResourceMetadata
33
import com.johnsnowlabs.nlp.util.io.ResourceHelper
34
import com.johnsnowlabs.util.{ConfigHelper, ConfigLoader}
35
import org.apache.hadoop.fs.{FileSystem, Path}
36
import org.slf4j.{Logger, LoggerFactory}
37

38
import scala.jdk.CollectionConverters._
39
import java.io.File
40
import scala.util.control.NonFatal
41

42
class AWSGateway(
43
    accessKeyId: String = ConfigLoader.getConfigStringValue(ConfigHelper.awsExternalAccessKeyId),
44
    secretAccessKey: String =
45
      ConfigLoader.getConfigStringValue(ConfigHelper.awsExternalSecretAccessKey),
46
    sessionToken: String =
47
      ConfigLoader.getConfigStringValue(ConfigHelper.awsExternalSessionToken),
48
    awsProfile: String = ConfigLoader.getConfigStringValue(ConfigHelper.awsExternalProfileName),
49
    region: String = ConfigLoader.getConfigStringValue(ConfigHelper.awsExternalRegion),
50
    credentialsType: String = "private")
51
    extends AutoCloseable {
52

53
  protected val logger: Logger = LoggerFactory.getLogger(this.getClass.toString)
1✔
54

55
  lazy val client: AmazonS3 = {
56
    if (region.isEmpty || region == null) {
57
      throw new InvalidArgumentException(
58
        "Region argument is mandatory to create Amazon S3 client.")
59
    }
60
    var credentialParams =
61
      CredentialParams(accessKeyId, secretAccessKey, sessionToken, awsProfile, region)
62
    if (credentialsType == "public" || credentialsType == "community") {
63
      credentialParams = CredentialParams("anonymous", "", "", "", region)
64
    }
65
    val awsCredentials = new AWSTokenCredentials
66
    val credentials: Option[AWSCredentials] = awsCredentials.buildCredentials(credentialParams)
67

68
    getAmazonS3Client(credentials)
69
  }
70

71
  private def getAmazonS3Client(credentials: Option[AWSCredentials]): AmazonS3 = {
72
    val config = new ClientConfiguration()
1✔
73
    val timeout = ConfigLoader.getConfigIntValue(ConfigHelper.s3SocketTimeout)
1✔
74
    config.setSocketTimeout(timeout)
1✔
75

76
    val s3Client = {
77
      if (credentials.isDefined) {
1✔
78
        AmazonS3ClientBuilder
79
          .standard()
80
          .withCredentials(new AWSStaticCredentialsProvider(credentials.get))
81
          .withClientConfiguration(config)
1✔
82
      } else {
×
83
        val warning_message =
84
          "Unable to build AWS credential via AWSGateway chain, some parameter is missing or" +
×
85
            " malformed. S3 integration may not work well."
86
        logger.warn(warning_message)
×
87
        AmazonS3ClientBuilder
88
          .standard()
89
          .withClientConfiguration(config)
×
90
      }
91
    }
92

93
    s3Client.withRegion(region).build()
1✔
94
  }
95

96
  def getMetadata(s3Path: String, folder: String, bucket: String): List[ResourceMetadata] = {
97
    val metaFile = getS3File(s3Path, folder, "metadata.json")
1✔
98
    val obj = this.client.getObject(bucket, metaFile)
1✔
99
    val metadata = ResourceMetadata.readResources(obj.getObjectContent)
1✔
100
    metadata
101
  }
102

103
  def getS3File(parts: String*): String = {
104
    parts
105
      .map(part => part.stripSuffix("/"))
106
      .filter(part => part.nonEmpty)
107
      .mkString("/")
1✔
108
  }
109

110
  def doesS3ObjectExist(bucket: String, s3FilePath: String): Boolean = {
111
    try {
1✔
112
      client.getObjectMetadata(bucket, s3FilePath)
1✔
113
      true
1✔
114
    } catch {
115
      case exception: AmazonServiceException =>
116
        if (exception.getStatusCode == 404) false else throw exception
×
117
      case NonFatal(unexpectedException) =>
118
        val methodName = Thread.currentThread.getStackTrace()(1).getMethodName
×
119
        throw new Exception(
×
120
          s"Unexpected error in ${this.getClass.getName}.$methodName: $unexpectedException")
121
    }
122
  }
123

124
  def doesS3FolderExist(bucket: String, s3FilePath: String): Boolean = {
125
    try {
×
126
      val listObjects = client.listObjectsV2(bucket, s3FilePath)
×
127
      listObjects.getObjectSummaries.size() > 0
×
128
    } catch {
129
      case exception: AmazonServiceException =>
130
        if (exception.getStatusCode == 404) false else throw exception
×
131
      case NonFatal(unexpectedException) =>
132
        val methodName = Thread.currentThread.getStackTrace()(1).getMethodName
×
133
        throw new Exception(
×
134
          s"Unexpected error in ${this.getClass.getName}.$methodName: $unexpectedException")
135
    }
136

137
  }
138

139
  def getS3Object(bucket: String, s3FilePath: String, tmpFile: File): ObjectMetadata = {
140
    val req = new GetObjectRequest(bucket, s3FilePath)
1✔
141
    client.getObject(req, tmpFile)
1✔
142
  }
143

144
  def getS3Object(bucket: String, s3FilePath: String): S3Object = {
145
    val s3Object = client.getObject(bucket, s3FilePath)
×
146
    s3Object
147
  }
148

149
  def getS3DownloadSize(
150
      s3Path: String,
151
      folder: String,
152
      fileName: String,
153
      bucket: String): Option[Long] = {
154
    try {
1✔
155
      val s3FilePath = getS3File(s3Path, folder, fileName)
1✔
156
      val meta = client.getObjectMetadata(bucket, s3FilePath)
1✔
157
      Some(meta.getContentLength)
1✔
158
    } catch {
159
      case exception: AmazonServiceException =>
160
        if (exception.getStatusCode == 404) None else throw exception
×
161
      case NonFatal(unexpectedException) =>
162
        val methodName = Thread.currentThread.getStackTrace()(1).getMethodName
×
163
        throw new Exception(
×
164
          s"Unexpected error in ${this.getClass.getName}.$methodName: $unexpectedException")
165
    }
166
  }
167

168
  def copyFileToS3(
169
      bucket: String,
170
      s3FilePath: String,
171
      sourceFilePath: String): PutObjectResult = {
172
    val sourceFile = new File("file://" + sourceFilePath)
×
173
    client.putObject(bucket, s3FilePath, sourceFile)
×
174
  }
175

176
  def copyInputStreamToS3(
177
      bucket: String,
178
      s3FilePath: String,
179
      sourceFilePath: String): PutObjectResult = {
180
    val fileSystem = FileSystem.get(ResourceHelper.spark.sparkContext.hadoopConfiguration)
×
181
    val inputStream = fileSystem.open(new Path(sourceFilePath))
×
182
    client.putObject(bucket, s3FilePath, inputStream, new ObjectMetadata())
×
183
  }
184

185
  def downloadFilesFromDirectory(
186
      bucketName: String,
187
      keyPrefix: String,
188
      directoryPath: File): Unit = {
189
    val transferManager = TransferManagerBuilder
190
      .standard()
191
      .withS3Client(client)
192
      .build()
×
193
    try {
×
194
      val multipleFileDownload =
195
        transferManager.downloadDirectory(bucketName, keyPrefix, directoryPath)
×
196
      println(multipleFileDownload.getDescription)
×
197
      waitForCompletion(multipleFileDownload)
×
198
    } catch {
199
      case e: AmazonServiceException =>
200
        throw new AmazonServiceException(
×
201
          "Amazon service error when downloading files from S3 directory: " + e.getMessage)
202
    }
203
    transferManager.shutdownNow()
×
204
  }
205

206
  private def waitForCompletion(transfer: Transfer): Unit = {
207
    try transfer.waitForCompletion()
×
208
    catch {
209
      case e: AmazonServiceException =>
210
        throw new AmazonServiceException("Amazon service error: " + e.getMessage)
×
211
      case e: AmazonClientException =>
212
        throw new AmazonClientException("Amazon client error: " + e.getMessage)
×
213
      case e: InterruptedException =>
214
        throw new InterruptedException("Transfer interrupted: " + e.getMessage)
×
215
    }
216
  }
217

218
  def listS3Files(bucket: String, s3Path: String): Array[S3ObjectSummary] = {
219
    try {
×
220
      val listObjects = client.listObjectsV2(bucket, s3Path)
×
221
      listObjects.getObjectSummaries.asScala.toArray
×
222
    } catch {
223
      case e: AmazonServiceException =>
224
        throw new AmazonServiceException("Amazon service error: " + e.getMessage)
×
225
      case NonFatal(unexpectedException) =>
226
        val methodName = Thread.currentThread.getStackTrace()(1).getMethodName
×
227
        throw new Exception(
×
228
          s"Unexpected error in ${this.getClass.getName}.$methodName: $unexpectedException")
229
    }
230
  }
231

232
  override def close(): Unit = {
233
    client.shutdown()
×
234
  }
235

236
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc