• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4947838414

pending completion
4947838414

Pull #13796

github

GitHub
Merge 30bdeef19 into ef7906c5e
Pull Request #13796: Add unzip param to downloadModelDirectly in ResourceDownloader

39 of 39 new or added lines in 2 files covered. (100.0%)

8632 of 13111 relevant lines covered (65.84%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

44.99
/src/main/scala/com/johnsnowlabs/nlp/util/io/ResourceHelper.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.util.io
18

19
import com.amazonaws.AmazonServiceException
20
import com.johnsnowlabs.client.aws.AWSGateway
21
import com.johnsnowlabs.nlp.annotators.Tokenizer
22
import com.johnsnowlabs.nlp.annotators.common.{TaggedSentence, TaggedWord}
23
import com.johnsnowlabs.nlp.pretrained.ResourceDownloader
24
import com.johnsnowlabs.nlp.util.io.ReadAs._
25
import com.johnsnowlabs.nlp.{DocumentAssembler, Finisher}
26
import com.johnsnowlabs.util.ConfigHelper
27
import org.apache.commons.io.{FileUtils, IOUtils}
28
import org.apache.hadoop.fs.{FileSystem, Path}
29
import org.apache.spark.ml.{Pipeline, PipelineModel}
30
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
31

32
import java.io._
33
import java.net.{URI, URL, URLDecoder}
34
import java.nio.file
35
import java.nio.file.{Files, Paths}
36
import java.util.jar.JarFile
37
import scala.collection.mutable.{ArrayBuffer, Map => MMap}
38
import scala.io.BufferedSource
39
import scala.util.{Failure, Success, Try}
40

41
/** Helper one-place for IO management. Streams, source and external input should be handled from
42
  * here
43
  */
44
object ResourceHelper {
45

46
  def getActiveSparkSession: SparkSession =
47
    SparkSession.getActiveSession.getOrElse(
1✔
48
      SparkSession
49
        .builder()
50
        .appName("SparkNLP Default Session")
51
        .master("local[*]")
52
        .config("spark.driver.memory", "22G")
53
        .config("spark.driver.maxResultSize", "0")
54
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
55
        .config("spark.kryoserializer.buffer.max", "1000m")
56
        .getOrCreate())
×
57

58
  def getSparkSessionWithS3(
59
      awsAccessKeyId: String,
60
      awsSecretAccessKey: String,
61
      hadoopAwsVersion: String = ConfigHelper.hadoopAwsVersion,
62
      AwsJavaSdkVersion: String = ConfigHelper.awsJavaSdkVersion,
63
      region: String = "us-east-1",
64
      s3Impl: String = "org.apache.hadoop.fs.s3a.S3AFileSystem",
65
      pathStyleAccess: Boolean = true,
66
      credentialsProvider: String = "TemporaryAWSCredentialsProvider",
67
      awsSessionToken: Option[String] = None): SparkSession = {
68

69
    require(
×
70
      SparkSession.getActiveSession.isEmpty,
×
71
      "Spark session already running, can't apply new configuration for S3.")
×
72

73
    val sparkSession = SparkSession
74
      .builder()
75
      .appName("SparkNLP Session with S3 Support")
×
76
      .master("local[*]")
×
77
      .config("spark.driver.memory", "22G")
×
78
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
×
79
      .config("spark.kryoserializer.buffer.max", "1000M")
×
80
      .config("spark.driver.maxResultSize", "0")
×
81
      .config("spark.hadoop.fs.s3a.access.key", awsAccessKeyId)
×
82
      .config("spark.hadoop.fs.s3a.secret.key", awsSecretAccessKey)
×
83
      .config(ConfigHelper.awsExternalRegion, region)
×
84
      .config(
85
        "spark.hadoop.fs.s3a.aws.credentials.provider",
×
86
        s"org.apache.hadoop.fs.s3a.$credentialsProvider")
×
87
      .config("spark.hadoop.fs.s3a.impl", s3Impl)
×
88
      .config(
89
        "spark.jars.packages",
×
90
        "org.apache.hadoop:hadoop-aws:" + hadoopAwsVersion + ",com.amazonaws:aws-java-sdk:" + AwsJavaSdkVersion)
×
91
      .config("spark.hadoop.fs.s3a.path.style.access", pathStyleAccess.toString)
×
92

93
    if (credentialsProvider == "TemporaryAWSCredentialsProvider") {
×
94
      require(
×
95
        awsSessionToken.isDefined,
×
96
        "AWS Session token needs to be provided for TemporaryAWSCredentialsProvider.")
×
97
      sparkSession.config("spark.hadoop.fs.s3a.session.token", awsSessionToken.get)
×
98
    }
99

100
    sparkSession.getOrCreate()
×
101
  }
102

103
  lazy val spark: SparkSession = getActiveSparkSession
104

105
  /** Structure for a SourceStream coming from compiled content */
106
  case class SourceStream(resource: String) {
107

108
    var fileSystem: Option[FileSystem] = None
1✔
109
    private val (pathExists: Boolean, path: Option[Path]) = OutputHelper.doesPathExists(resource)
1✔
110
    if (!pathExists) {
1✔
111
      throw new FileNotFoundException(s"file or folder: $resource not found")
1✔
112
    } else {
113
      fileSystem = Some(OutputHelper.getFileSystem(resource))
1✔
114
    }
115

116
    val pipe: Seq[InputStream] = getPipe(fileSystem.get)
1✔
117
    private val openBuffers: Seq[BufferedSource] = pipe.map(pp => {
1✔
118
      new BufferedSource(pp)("UTF-8")
1✔
119
    })
120
    val content: Seq[Iterator[String]] = openBuffers.map(c => c.getLines())
1✔
121

122
    private def getPipe(fileSystem: FileSystem): Seq[InputStream] = {
123
      if (fileSystem.getScheme == "s3a") {
1✔
124
        val awsGateway = new AWSGateway()
×
125
        val (bucket, s3Path) = parseS3URI(path.get.toString)
×
126
        val inputStreams = awsGateway.listS3Files(bucket, s3Path).map { summary =>
×
127
          val s3Object = awsGateway.getS3Object(bucket, summary.getKey)
×
128
          s3Object.getObjectContent
×
129
        }
130
        inputStreams
×
131
      } else {
1✔
132
        val files = fileSystem.listFiles(path.get, true)
1✔
133
        val buffer = ArrayBuffer.empty[InputStream]
1✔
134
        while (files.hasNext) buffer.append(fileSystem.open(files.next().getPath))
1✔
135
        buffer
136
      }
137
    }
138

139
    /** Copies the resource into a local temporary folder and returns the folders URI.
140
      *
141
      * @param prefix
142
      *   Prefix for the temporary folder.
143
      * @return
144
      *   URI of the created temporary folder with the resource
145
      */
146
    def copyToLocal(prefix: String = "sparknlp_tmp_"): URI = {
147
      if (fileSystem.get.getScheme == "file")
×
148
        return URI.create(resource)
1✔
149

150
      val destination: file.Path = Files.createTempDirectory(prefix)
×
151

152
      val destinationUri = fileSystem.get.getScheme match {
×
153
        case "hdfs" =>
154
          fileSystem.get.copyToLocalFile(false, path.get, new Path(destination.toUri), true)
×
155
          if (fileSystem.get.getFileStatus(path.get).isDirectory)
×
156
            Paths.get(destination.toString, path.get.getName).toUri
×
157
          else destination.toUri
×
158
        case "dbfs" =>
159
          val dbfsPath = path.get.toString.replace("dbfs:/", "/dbfs/")
×
160
          val sourceFile = new File(dbfsPath)
×
161
          val targetFile = new File(destination.toString)
×
162
          if (sourceFile.isFile) FileUtils.copyFileToDirectory(sourceFile, targetFile)
×
163
          else FileUtils.copyDirectory(sourceFile, targetFile)
×
164
          targetFile.toURI
×
165
        case _ =>
166
          val files = fileSystem.get.listFiles(path.get, false)
×
167
          while (files.hasNext) {
×
168
            fileSystem.get.copyFromLocalFile(files.next.getPath, new Path(destination.toUri))
×
169
          }
170
          destination.toUri
×
171
      }
172

173
      destinationUri
174
    }
175

176
    def close(): Unit = {
177
      openBuffers.foreach(_.close())
1✔
178
      pipe.foreach(_.close)
1✔
179
    }
180
  }
181

182
  private def fixTarget(path: String): String = {
183
    val toSearch =
184
      s"^.*target\\${File.separator}.*scala-.*\\${File.separator}.*classes\\${File.separator}"
1✔
185
    if (path.matches(toSearch + ".*")) {
1✔
186
      path.replaceFirst(toSearch, "")
1✔
187
    } else {
188
      path
×
189
    }
190
  }
191

192
  /** Copies the remote resource to a local temporary folder and returns its absolute path.
193
    *
194
    * Currently, file:/, s3:/, hdfs:/ and dbfs:/ are supported.
195
    *
196
    * If the file is already on the local file system just the absolute path will be returned
197
    * instead.
198
    * @param path
199
    *   Path to the resource
200
    * @return
201
    *   Absolute path to the temporary or local folder of the resource
202
    */
203
  def copyToLocal(path: String): String = try {
1✔
204
    val localUri =
205
      if (path.startsWith("s3:/") || path.startsWith("s3a:/")) { // Download directly from S3
1✔
206
        ResourceDownloader.downloadS3Directory(path)
×
207
      } else { // Use Source Stream
1✔
208
        val pathWithProtocol: String =
209
          if (URI.create(path).getScheme == null) new File(path).toURI.toURL.toString else path
1✔
210
        val resource = SourceStream(pathWithProtocol)
1✔
211
        resource.copyToLocal()
1✔
212
      }
213

214
    new File(localUri).getAbsolutePath // Platform independent path
1✔
215
  } catch {
216
    case awsE: AmazonServiceException =>
217
      println("Error while retrieving folder from S3. Make sure you have set the right " +
×
218
        "access keys with proper permissions in your configuration. For an example please see " +
219
        "https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/english/dl-ner/mfa_ner_graphs_s3.ipynb")
220
      throw awsE
×
221
    case e: Exception =>
222
      val copyToLocalErrorMessage: String =
223
        "Please make sure the provided path exists and is accessible while keeping in mind only file:/, hdfs:/, dbfs:/ and s3:/ protocols are supported at the moment."
×
224
      println(
×
225
        s"$e \n Therefore, could not create temporary local directory for provided path $path. $copyToLocalErrorMessage")
×
226
      throw e
×
227
  }
228

229
  /** NOT thread safe. Do not call from executors. */
230
  def getResourceStream(path: String): InputStream = {
231
    if (new File(path).exists())
1✔
232
      new FileInputStream(new File(path))
1✔
233
    else {
234
      Option(getClass.getResourceAsStream(path))
1✔
235
        .getOrElse {
1✔
236
          Option(getClass.getClassLoader.getResourceAsStream(path))
1✔
237
            .getOrElse(throw new IllegalArgumentException(f"Wrong resource path $path"))
1✔
238
        }
239
    }
240
  }
241

242
  def getResourceFile(path: String): URL = {
243
    var dirURL = getClass.getResource(path)
1✔
244

245
    if (dirURL == null)
1✔
246
      dirURL = getClass.getClassLoader.getResource(path)
1✔
247

248
    dirURL
249
  }
250

251
  def listResourceDirectory(path: String): Seq[String] = {
252
    val dirURL = getResourceFile(path)
1✔
253

254
    if (dirURL != null && dirURL.getProtocol.equals("file") && new File(dirURL.toURI).exists()) {
1✔
255
      /* A file path: easy enough */
256
      return new File(dirURL.toURI).listFiles.sorted.map(_.getPath).map(fixTarget)
1✔
257
    } else if (dirURL == null) {
×
258
      /* path not in resources and not in disk */
259
      throw new FileNotFoundException(path)
1✔
260
    }
261

262
    if (dirURL.getProtocol.equals("jar")) {
×
263
      /* A JAR path */
264
      val jarPath =
265
        dirURL.getPath.substring(5, dirURL.getPath.indexOf("!")) // strip out only the JAR file
×
266
      val jar = new JarFile(URLDecoder.decode(jarPath, "UTF-8"))
×
267
      val entries = jar.entries()
×
268
      val result = new ArrayBuffer[String]()
×
269

270
      val pathToCheck = path
271
        .stripPrefix(File.separator.replaceAllLiterally("\\", "/"))
×
272
        .stripSuffix(File.separator) +
×
273
        File.separator.replaceAllLiterally("\\", "/")
×
274

275
      while (entries.hasMoreElements) {
×
276
        val name = entries.nextElement().getName.stripPrefix(File.separator)
×
277
        if (name.startsWith(pathToCheck)) { // filter according to the path
×
278
          var entry = name.substring(pathToCheck.length())
×
279
          val checkSubdir = entry.indexOf("/")
×
280
          if (checkSubdir >= 0) {
×
281
            // if it is a subdirectory, we just return the directory name
282
            entry = entry.substring(0, checkSubdir)
×
283
          }
284
          if (entry.nonEmpty) {
×
285
            result.append(pathToCheck + entry)
×
286
          }
287
        }
288
      }
289
      return result.distinct.sorted
×
290
    }
291

292
    throw new UnsupportedOperationException(s"Cannot list files for URL $dirURL")
×
293
  }
294

295
  /** General purpose key value parser from source Currently read only text files
296
    *
297
    * @return
298
    */
299
  def parseKeyValueText(er: ExternalResource): Map[String, String] = {
300
    er.readAs match {
1✔
301
      case TEXT =>
302
        val sourceStream = SourceStream(er.path)
1✔
303
        val res = sourceStream.content
304
          .flatMap(c =>
1✔
305
            c.map(line => {
1✔
306
              val kv = line.split(er.options("delimiter"))
1✔
307
              (kv.head.trim, kv.last.trim)
1✔
308
            }))
309
          .toMap
1✔
310
        sourceStream.close()
1✔
311
        res
312
      case SPARK =>
313
        import spark.implicits._
314
        val dataset = spark.read
315
          .options(er.options)
316
          .format(er.options("format"))
317
          .options(er.options)
318
          .option("delimiter", er.options("delimiter"))
319
          .load(er.path)
320
          .toDF("key", "value")
×
321
        val keyValueStore = MMap.empty[String, String]
×
322
        dataset.as[(String, String)].foreach { kv =>
×
323
          keyValueStore(kv._1) = kv._2
×
324
        }
325
        keyValueStore.toMap
×
326
      case _ =>
327
        throw new Exception("Unsupported readAs")
×
328
    }
329
  }
330

331
  def parseKeyListValues(externalResource: ExternalResource): Map[String, List[String]] = {
332
    externalResource.readAs match {
1✔
333
      case TEXT =>
334
        val sourceStream = SourceStream(externalResource.path)
1✔
335
        val keyValueStore = MMap.empty[String, List[String]]
1✔
336
        sourceStream.content.foreach(content =>
1✔
337
          content.foreach { line =>
1✔
338
            {
339
              val keyValues = line.split(externalResource.options("delimiter"))
1✔
340
              val key = keyValues.head
1✔
341
              val value = keyValues.drop(1).toList
1✔
342
              val storedValue = keyValueStore.get(key)
1✔
343
              if (storedValue.isDefined && !storedValue.contains(value)) {
1✔
344
                keyValueStore.update(key, storedValue.get ++ value)
×
345
              } else keyValueStore(key) = value
1✔
346
            }
347
          })
348
        sourceStream.close()
1✔
349
        keyValueStore.toMap
1✔
350
    }
351
  }
352

353
  def parseKeyArrayValues(externalResource: ExternalResource): Map[String, Array[Float]] = {
354
    externalResource.readAs match {
×
355
      case TEXT =>
356
        val sourceStream = SourceStream(externalResource.path)
×
357
        val keyValueStore = MMap.empty[String, Array[Float]]
×
358
        sourceStream.content.foreach(content =>
×
359
          content.foreach { line =>
×
360
            {
361
              val keyValues = line.split(externalResource.options("delimiter"))
×
362
              val key = keyValues.head
×
363
              val value = keyValues.drop(1).map(x => x.toFloat)
×
364
              if (value.length > 1) {
×
365
                keyValueStore(key) = value
×
366
              }
367
            }
368
          })
369
        sourceStream.close()
×
370
        keyValueStore.toMap
×
371
    }
372
  }
373

374
  /** General purpose line parser from source Currently read only text files
375
    *
376
    * @return
377
    */
378
  def parseLines(er: ExternalResource): Array[String] = {
379
    er.readAs match {
1✔
380
      case TEXT =>
381
        val sourceStream = SourceStream(er.path)
1✔
382
        val res = sourceStream.content.flatten.toArray
1✔
383
        sourceStream.close()
1✔
384
        res
385
      case SPARK =>
386
        import spark.implicits._
387
        spark.read
388
          .options(er.options)
389
          .format(er.options("format"))
390
          .load(er.path)
391
          .as[String]
392
          .collect
×
393
      case _ =>
394
        throw new Exception("Unsupported readAs")
×
395
    }
396
  }
397

398
  /** General purpose line parser from source Currently read only text files
399
    *
400
    * @return
401
    */
402
  def parseLinesIterator(er: ExternalResource): Seq[Iterator[String]] = {
403
    er.readAs match {
1✔
404
      case TEXT =>
405
        val sourceStream = SourceStream(er.path)
1✔
406
        sourceStream.content
1✔
407
      case _ =>
408
        throw new Exception("Unsupported readAs")
×
409
    }
410
  }
411

412
  /** General purpose tuple parser from source Currently read only text files
413
    *
414
    * @return
415
    */
416
  def parseTupleText(er: ExternalResource): Array[(String, String)] = {
417
    er.readAs match {
1✔
418
      case TEXT =>
419
        val sourceStream = SourceStream(er.path)
1✔
420
        val res = sourceStream.content
421
          .flatMap(c =>
1✔
422
            c.filter(_.nonEmpty)
1✔
423
              .map(line => {
1✔
424
                val kv = line.split(er.options("delimiter")).map(_.trim)
1✔
425
                (kv.head, kv.last)
1✔
426
              }))
427
          .toArray
1✔
428
        sourceStream.close()
1✔
429
        res
430
      case SPARK =>
431
        import spark.implicits._
432
        val dataset = spark.read.options(er.options).format(er.options("format")).load(er.path)
×
433
        val lineStore = spark.sparkContext.collectionAccumulator[String]
×
434
        dataset.as[String].foreach(l => lineStore.add(l))
×
435
        val result = lineStore.value.toArray.map(line => {
×
436
          val kv = line.toString.split(er.options("delimiter")).map(_.trim)
×
437
          (kv.head, kv.last)
×
438
        })
439
        lineStore.reset()
×
440
        result
441
      case _ =>
442
        throw new Exception("Unsupported readAs")
×
443
    }
444
  }
445

446
  /** General purpose tuple parser from source Currently read only text files
447
    *
448
    * @return
449
    */
450
  def parseTupleSentences(er: ExternalResource): Array[TaggedSentence] = {
451
    er.readAs match {
1✔
452
      case TEXT =>
453
        val sourceStream = SourceStream(er.path)
1✔
454
        val result = sourceStream.content
455
          .flatMap(c =>
1✔
456
            c.filter(_.nonEmpty)
1✔
457
              .map(line => {
1✔
458
                line
459
                  .split("\\s+")
1✔
460
                  .filter(kv => {
1✔
461
                    val s = kv.split(er.options("delimiter").head)
1✔
462
                    s.length == 2 && s(0).nonEmpty && s(1).nonEmpty
1✔
463
                  })
464
                  .map(kv => {
1✔
465
                    val p = kv.split(er.options("delimiter").head)
1✔
466
                    TaggedWord(p(0), p(1))
1✔
467
                  })
468
              }))
469
          .toArray
1✔
470
        sourceStream.close()
1✔
471
        result.map(TaggedSentence(_))
1✔
472
      case SPARK =>
473
        import spark.implicits._
474
        val dataset = spark.read.options(er.options).format(er.options("format")).load(er.path)
×
475
        val result = dataset
476
          .as[String]
477
          .filter(_.nonEmpty)
478
          .map(line => {
479
            line
480
              .split("\\s+")
481
              .filter(kv => {
482
                val s = kv.split(er.options("delimiter").head)
483
                s.length == 2 && s(0).nonEmpty && s(1).nonEmpty
484
              })
485
              .map(kv => {
486
                val p = kv.split(er.options("delimiter").head)
487
                TaggedWord(p(0), p(1))
488
              })
489
          })
490
          .collect
×
491
        result.map(TaggedSentence(_))
×
492
      case _ =>
493
        throw new Exception("Unsupported readAs")
×
494
    }
495
  }
496

497
  def parseTupleSentencesDS(er: ExternalResource): Dataset[TaggedSentence] = {
498
    er.readAs match {
1✔
499
      case SPARK =>
500
        import spark.implicits._
501
        val dataset = spark.read.options(er.options).format(er.options("format")).load(er.path)
1✔
502
        val result = dataset
503
          .as[String]
1✔
504
          .filter(_.nonEmpty)
1✔
505
          .map(line => {
1✔
506
            line
507
              .split("\\s+")
1✔
508
              .filter(kv => {
1✔
509
                val s = kv.split(er.options("delimiter").head)
1✔
510
                s.length == 2 && s(0).nonEmpty && s(1).nonEmpty
1✔
511
              })
512
              .map(kv => {
1✔
513
                val p = kv.split(er.options("delimiter").head)
1✔
514
                TaggedWord(p(0), p(1))
1✔
515
              })
516
          })
517
        result.map(TaggedSentence(_))
1✔
518
      case _ =>
519
        throw new Exception(
×
520
          "Unsupported readAs. If you're training POS with large dataset, consider PerceptronApproachDistributed")
521
    }
522
  }
523

524
  /** For multiple values per keys, this optimizer flattens all values for keys to have constant
525
    * access
526
    */
527
  def flattenRevertValuesAsKeys(er: ExternalResource): Map[String, String] = {
528
    er.readAs match {
1✔
529
      case TEXT =>
530
        val m: MMap[String, String] = MMap()
1✔
531
        val sourceStream = SourceStream(er.path)
1✔
532
        sourceStream.content.foreach(c =>
1✔
533
          c.foreach(line => {
1✔
534
            val kv = line.split(er.options("keyDelimiter")).map(_.trim)
1✔
535
            if (kv.length > 1) {
×
536
              val key = kv(0)
1✔
537
              val values = kv(1).split(er.options("valueDelimiter")).map(_.trim)
1✔
538
              values.foreach(m(_) = key)
1✔
539
            }
540
          }))
541
        sourceStream.close()
1✔
542
        m.toMap
1✔
543
      case SPARK =>
544
        import spark.implicits._
545
        val dataset = spark.read.options(er.options).format(er.options("format")).load(er.path)
×
546
        val valueAsKeys = MMap.empty[String, String]
×
547
        dataset
548
          .as[String]
×
549
          .foreach(line => {
×
550
            val kv = line.split(er.options("keyDelimiter")).map(_.trim)
×
551
            if (kv.length > 1) {
×
552
              val key = kv(0)
×
553
              val values = kv(1).split(er.options("valueDelimiter")).map(_.trim)
×
554
              values.foreach(v => valueAsKeys(v) = key)
×
555
            }
556
          })
557
        valueAsKeys.toMap
×
558
      case _ =>
559
        throw new Exception("Unsupported readAs")
×
560
    }
561
  }
562

563
  /** General purpose read saved Parquet Currently read only Parquet format
564
    *
565
    * @return
566
    */
567
  def readSparkDataFrame(er: ExternalResource): DataFrame = {
568
    er.readAs match {
×
569
      case SPARK =>
570
        val dataset = spark.read.options(er.options).format(er.options("format")).load(er.path)
×
571
        dataset
572
      case _ =>
573
        throw new Exception("Unsupported readAs - only accepts SPARK")
×
574
    }
575
  }
576

577
  def getWordCount(
578
      externalResource: ExternalResource,
579
      wordCount: MMap[String, Long] = MMap.empty[String, Long].withDefaultValue(0),
580
      pipeline: Option[PipelineModel] = None): MMap[String, Long] = {
581
    externalResource.readAs match {
1✔
582
      case TEXT =>
583
        val sourceStream = SourceStream(externalResource.path)
1✔
584
        val regex = externalResource.options("tokenPattern").r
1✔
585
        sourceStream.content.foreach(c =>
1✔
586
          c.foreach { line =>
1✔
587
            {
588
              val words: List[String] = regex.findAllMatchIn(line).map(_.matched).toList
1✔
589
              words.foreach(w =>
1✔
590
                // Creates a Map of frequency words: word -> frequency based on ExternalResource
591
                wordCount(w) += 1)
1✔
592
            }
593
          })
594
        sourceStream.close()
1✔
595
        if (wordCount.isEmpty)
1✔
596
          throw new FileNotFoundException(
×
597
            "Word count dictionary for spell checker does not exist or is empty")
598
        wordCount
599
      case SPARK =>
600
        import spark.implicits._
601
        val dataset = spark.read
602
          .options(externalResource.options)
×
603
          .format(externalResource.options("format"))
×
604
          .load(externalResource.path)
×
605
        val transformation = {
606
          if (pipeline.isDefined) {
×
607
            pipeline.get.transform(dataset)
×
608
          } else {
×
609
            val documentAssembler = new DocumentAssembler()
610
              .setInputCol("value")
×
611
            val tokenizer = new Tokenizer()
612
              .setInputCols("document")
×
613
              .setOutputCol("token")
×
614
              .setTargetPattern(externalResource.options("tokenPattern"))
×
615
            val finisher = new Finisher()
616
              .setInputCols("token")
617
              .setOutputCols("finished")
618
              .setAnnotationSplitSymbol("--")
×
619
            new Pipeline()
620
              .setStages(Array(documentAssembler, tokenizer, finisher))
621
              .fit(dataset)
622
              .transform(dataset)
×
623
          }
624
        }
625
        val wordCount = MMap.empty[String, Long].withDefaultValue(0)
×
626
        transformation
627
          .select("finished")
×
628
          .as[String]
×
629
          .foreach(text =>
×
630
            text
631
              .split("--")
×
632
              .foreach(t => {
×
633
                wordCount(t) += 1
×
634
              }))
635
        wordCount
636
      case _ => throw new IllegalArgumentException("format not available for word count")
×
637
    }
638
  }
639

640
  def getFilesContentBuffer(externalResource: ExternalResource): Seq[Iterator[String]] = {
641
    externalResource.readAs match {
1✔
642
      case TEXT =>
643
        SourceStream(externalResource.path).content
1✔
644
      case _ =>
645
        throw new Exception("Unsupported readAs")
1✔
646
    }
647
  }
648

649
  def listLocalFiles(path: String): List[File] = {
650

651
    val fileSystem = OutputHelper.getFileSystem
1✔
652

653
    val filesPath = fileSystem.getScheme match {
1✔
654
      case "hdfs" =>
655
        if (path.startsWith("file:")) {
×
656
          Option(new File(path.replace("file:", "")).listFiles())
×
657
        } else {
×
658
          val filesIterator = fileSystem.listFiles(new Path(path), false)
×
659
          val files: ArrayBuffer[File] = ArrayBuffer()
×
660

661
          while (filesIterator.hasNext) {
×
662
            val file = new File(filesIterator.next().getPath.toString)
×
663
            files.append(file)
×
664
          }
665

666
          Option(files.toArray)
×
667
        }
668
      case "dbfs" if path.startsWith("dbfs:") =>
×
669
        Option(new File(path.replace("dbfs:", "/dbfs/")).listFiles())
×
670
      case _ => Option(new File(path).listFiles())
1✔
671
    }
672

673
    val files = filesPath.getOrElse(throw new FileNotFoundException(s"folder: $path not found"))
×
674
    files.toList
1✔
675
  }
676

677
  def getFileFromPath(pathToFile: String): File = {
678
    val fileSystem = OutputHelper.getFileSystem
1✔
679
    val filePath = fileSystem.getScheme match {
1✔
680
      case "hdfs" =>
681
        if (pathToFile.startsWith("file:")) {
×
682
          new File(pathToFile.replace("file:", ""))
×
683
        } else new File(pathToFile)
×
684
      case "dbfs" if pathToFile.startsWith("dbfs:") =>
×
685
        new File(pathToFile.replace("dbfs:", "/dbfs/"))
×
686
      case _ => new File(pathToFile)
1✔
687
    }
688

689
    filePath
690
  }
691

692
  def validFile(path: String): Boolean = {
693

694
    if (path.isEmpty) return false
1✔
695

696
    var isValid = validLocalFile(path) match {
1✔
697
      case Success(value) => value
698
      case Failure(_) => false
×
699
    }
700

701
    if (!isValid) {
1✔
702
      validHadoopFile(path) match {
1✔
703
        case Success(value) => isValid = value
704
        case Failure(_) => isValid = false
1✔
705
      }
706
    }
707

708
    if (!isValid) {
1✔
709
      validDbfsFile(path) match {
1✔
710
        case Success(value) => isValid = value
711
        case Failure(_) => isValid = false
×
712
      }
713
    }
714

715
    isValid
716
  }
717

718
  private def validLocalFile(path: String): Try[Boolean] = Try {
1✔
719
    Files.exists(Paths.get(path))
1✔
720
  }
721

722
  private def validHadoopFile(path: String): Try[Boolean] = Try {
1✔
723
    val hadoopPath = new Path(path)
1✔
724
    val fileSystem = OutputHelper.getFileSystem
1✔
725
    fileSystem.exists(hadoopPath)
1✔
726
  }
727

728
  private def validDbfsFile(path: String): Try[Boolean] = Try {
1✔
729
    getFileFromPath(path).exists()
1✔
730
  }
731

732
  def moveFile(sourceFile: String, destinationFile: String): Unit = {
733

734
    val sourceFileSystem = OutputHelper.getFileSystem(sourceFile)
×
735

736
    if (destinationFile.startsWith("s3:")) {
×
737
      val s3Bucket = destinationFile.replace("s3://", "").split("/").head
×
738
      val s3Path = "s3:/" + destinationFile.substring(s"s3://$s3Bucket".length)
×
739

740
      if (sourceFileSystem.getScheme.equals("dbfs") || sourceFileSystem.getScheme.equals(
×
741
          "hdfs")) {
×
742
        val inputStream = getResourceStream(sourceFile)
×
743

744
        val destinationFile = sourceFile.split("/").last
×
745
        val tmpPath =
746
          if (sourceFileSystem.getScheme.equals("dbfs")) new Path("dbfs:/tmp")
×
747
          else new Path("hdfs:/tmp")
×
748
        if (!sourceFileSystem.exists(tmpPath)) sourceFileSystem.mkdirs(tmpPath)
×
749
        val sourceFilePath = tmpPath + "/" + destinationFile
×
750
        val outputStream = sourceFileSystem.create(new Path(sourceFilePath))
×
751

752
        val inputBytes = IOUtils.toByteArray(inputStream)
×
753
        outputStream.write(inputBytes)
×
754
        outputStream.close()
×
755

756
        OutputHelper.storeFileInS3(sourceFilePath, s3Bucket, s3Path)
×
757
      }
758

759
    } else {
760

761
      if (!sourceFileSystem.getScheme.equals("dbfs")) {
×
762
        val source = new Path(s"file:///$sourceFile")
×
763
        val destination = new Path(destinationFile)
×
764
        sourceFileSystem.copyFromLocalFile(source, destination)
×
765
      }
766
    }
767

768
  }
769

770
  def parseS3URI(s3URI: String): (String, String) = {
771
    val prefix = if (s3URI.startsWith("s3:")) "s3://" else "s3a://"
1✔
772
    val bucketName = s3URI.substring(prefix.length).split("/").head
1✔
773
    val key = s3URI.substring((prefix + bucketName).length + 1)
1✔
774

775
    require(bucketName.nonEmpty, "S3 bucket name is empty!")
1✔
776

777
    (bucketName, key)
1✔
778
  }
779

780
  def parseGCPStorageURI(gcpStorageURI: String): (String, String) = {
781
    val prefix = "gs://"
×
782
    val bucketName = gcpStorageURI.substring(prefix.length).split("/").head
×
783
    val storagePath = gcpStorageURI.substring((prefix + bucketName).length + 1)
×
784

785
    require(bucketName.nonEmpty, "GCP Storage bucket name is empty!")
×
786

787
    (bucketName, storagePath)
×
788
  }
789

790
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc