• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4947838414

pending completion
4947838414

Pull #13796

github

GitHub
Merge 30bdeef19 into ef7906c5e
Pull Request #13796: Add unzip param to downloadModelDirectly in ResourceDownloader

39 of 39 new or added lines in 2 files covered. (100.0%)

8632 of 13111 relevant lines covered (65.84%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

50.0
/src/main/scala/com/johnsnowlabs/storage/HasStorage.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.storage
18

19
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
20
import com.johnsnowlabs.nlp.HasCaseSensitiveProperties
21
import com.johnsnowlabs.nlp.annotators.param.ExternalResourceParam
22
import com.johnsnowlabs.nlp.pretrained.ResourceDownloader
23
import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs}
24
import com.johnsnowlabs.storage.Database.Name
25
import com.johnsnowlabs.util.{ConfigHelper, ConfigLoader, FileHelper}
26
import org.apache.hadoop.fs.{FileSystem, Path}
27
import org.apache.spark.SparkContext
28
import org.apache.spark.sql.{Dataset, SparkSession}
29

30
import java.nio.file.{Files, Paths, StandardCopyOption}
31
import java.util.UUID
32

33
trait HasStorage extends HasStorageRef with HasStorageOptions with HasCaseSensitiveProperties {
34

35
  protected val databases: Array[Database.Name]
36

37
  /** Path to the external resource.
38
    * @group param
39
    */
40
  val storagePath = new ExternalResourceParam(this, "storagePath", "path to file")
1✔
41

42
  /** @group setParam */
43
  def setStoragePath(path: String, readAs: String): this.type =
44
    set(storagePath, new ExternalResource(path, readAs, Map.empty[String, String]))
1✔
45

46
  /** @group setParam */
47
  def setStoragePath(path: String, readAs: ReadAs.Value): this.type =
48
    setStoragePath(path, readAs.toString)
1✔
49

50
  /** @group getParam */
51
  def getStoragePath: Option[ExternalResource] = get(storagePath)
1✔
52

53
  protected val missingRefMsg: String = s"Please set storageRef param in $this."
1✔
54

55
  protected def index(
56
      fitDataset: Dataset[_],
57
      storageSourcePath: Option[String],
58
      readAs: Option[ReadAs.Value],
59
      writers: Map[Database.Name, StorageWriter[_]],
60
      readOptions: Option[Map[String, String]] = None): Unit
61

62
  protected def createWriter(database: Name, connection: RocksDBConnection): StorageWriter[_]
63

64
  private def indexDatabases(
65
      databases: Array[Database.Name],
66
      resource: Option[ExternalResource],
67
      localFiles: Array[String],
68
      fitDataset: Dataset[_],
69
      spark: SparkContext): Unit = {
70

71
    require(
1✔
72
      databases.length == localFiles.length,
1✔
73
      "Storage temp locations must be equal to the amount of databases")
×
74

75
    lazy val connections = databases
76
      .zip(localFiles)
77
      .map { case (database, localFile) => (database, RocksDBConnection.getOrCreate(localFile)) }
78

79
    val writers = connections
80
      .map { case (db, conn) =>
1✔
81
        (db, createWriter(db, conn))
1✔
82
      }
83
      .toMap[Database.Name, StorageWriter[_]]
1✔
84

85
    val storageSourcePath = resource.map(r => importIfS3(r.path, spark).toUri.toString)
1✔
86
    if (resource.isDefined && new Path(resource.get.path)
1✔
87
        .getFileSystem(spark.hadoopConfiguration)
88
        .getScheme != "file") {
×
89
      val uri = new java.net.URI(storageSourcePath.get.replaceAllLiterally("\\", "/"))
×
90
      val fs = FileSystem.get(uri, spark.hadoopConfiguration)
×
91

92
      /** ToDo: What if the file is too large to copy to local? Index directly from hadoop? */
93
      val tmpFile = Files.createTempFile("sparknlp_", ".str").toAbsolutePath.toString
×
94
      fs.copyToLocalFile(new Path(storageSourcePath.get), new Path(tmpFile))
×
95
      index(fitDataset, Some(tmpFile), resource.map(_.readAs), writers, resource.map(_.options))
×
96
      FileHelper.delete(tmpFile)
×
97
    } else {
98
      index(
1✔
99
        fitDataset,
100
        storageSourcePath,
101
        resource.map(_.readAs),
1✔
102
        writers,
103
        resource.map(_.options))
1✔
104
    }
105

106
    writers.values.foreach(_.close())
1✔
107
    connections.map(_._2).foreach(_.close())
1✔
108
  }
109

110
  private def preload(
111
      fitDataset: Dataset[_],
112
      resource: Option[ExternalResource],
113
      spark: SparkSession,
114
      databases: Array[Database.Name]): Unit = {
115

116
    val sparkContext = spark.sparkContext
1✔
117

118
    val tmpLocalDestinations = {
119
      databases.map(_ =>
1✔
120
        Files
121
          .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_idx")
122
          .toAbsolutePath
123
          .toString)
1✔
124
    }
125

126
    indexDatabases(databases, resource, tmpLocalDestinations, fitDataset, sparkContext)
1✔
127

128
    val locators =
129
      databases.map(database => StorageLocator(database.toString, $(storageRef), spark))
1✔
130

131
    tmpLocalDestinations.zip(locators).foreach { case (tmpLocalDestination, locator) =>
1✔
132
      /** tmpFiles indexed must be explicitly set to be local files */
133
      val uri =
134
        "file://" + new java.net.URI(tmpLocalDestination.replaceAllLiterally("\\", "/")).getPath
1✔
135
      StorageHelper.sendToCluster(
1✔
136
        new Path(uri),
1✔
137
        locator.clusterFilePath,
1✔
138
        locator.clusterFileName,
1✔
139
        locator.destinationScheme,
1✔
140
        sparkContext)
141
    }
142

143
    // 3. Create Spark Embeddings
144
    locators.foreach(locator => RocksDBConnection.getOrCreate(locator.clusterFileName))
1✔
145
  }
146

147
  private def importIfS3(path: String, spark: SparkContext): Path = {
148
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
1✔
149
    var src = new Path(path)
1✔
150
    // if the path contains s3a download to local cache if not present
151
    if (uri.getScheme != null) {
1✔
152
      if (uri.getScheme.equals("s3a")) {
×
153
        var accessKeyId = ConfigLoader.getConfigStringValue(ConfigHelper.accessKeyId)
×
154
        var secretAccessKey = ConfigLoader.getConfigStringValue(ConfigHelper.secretAccessKey)
×
155

156
        if (accessKeyId == "" || secretAccessKey == "") {
×
157
          val defaultCredentials = new DefaultAWSCredentialsProviderChain().getCredentials
×
158
          accessKeyId = defaultCredentials.getAWSAccessKeyId
×
159
          secretAccessKey = defaultCredentials.getAWSSecretKey
×
160
        }
161
        var old_key = ""
×
162
        var old_secret = ""
×
163
        if (spark.hadoopConfiguration.get("fs.s3a.access.key") != null) {
×
164
          old_key = spark.hadoopConfiguration.get("fs.s3a.access.key")
×
165
          old_secret = spark.hadoopConfiguration.get("fs.s3a.secret.key")
×
166
        }
167
        try {
×
168
          val dst = new Path(ResourceDownloader.cacheFolder, src.getName)
×
169
          if (!Files.exists(Paths.get(dst.toUri.getPath))) {
×
170
            // download s3 resource locally using config keys
171
            spark.hadoopConfiguration.set("fs.s3a.access.key", accessKeyId)
×
172
            spark.hadoopConfiguration.set("fs.s3a.secret.key", secretAccessKey)
×
173
            val s3fs = FileSystem.get(uri, spark.hadoopConfiguration)
×
174

175
            val dst_tmp = new Path(ResourceDownloader.cacheFolder, src.getName + "_tmp")
×
176

177
            s3fs.copyToLocalFile(src, dst_tmp)
×
178
            // rename to original file
179
            Files.move(
×
180
              Paths.get(dst_tmp.toUri.getRawPath),
×
181
              Paths.get(dst.toUri.getRawPath),
×
182
              StandardCopyOption.REPLACE_EXISTING)
×
183

184
          }
185
          src = new Path(dst.toUri.getPath)
×
186
        } finally {
187
          // reset the keys
188
          if (!old_key.equals("")) {
×
189
            spark.hadoopConfiguration.set("fs.s3a.access.key", old_key)
×
190
            spark.hadoopConfiguration.set("fs.s3a.secret.key", old_secret)
×
191
          }
192
        }
193

194
      }
195
    }
196
    src
197
  }
198

199
  private var preloaded = false
1✔
200

201
  def indexStorage(fitDataset: Dataset[_], resource: Option[ExternalResource]): Unit = {
202
    if (!preloaded) {
1✔
203
      preloaded = true
1✔
204
      require(isDefined(storageRef), missingRefMsg)
×
205
      preload(fitDataset, resource, fitDataset.sparkSession, databases)
1✔
206
    }
207
  }
208

209
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc