• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 18685790193

21 Oct 2025 01:39PM UTC coverage: 55.216%. First build
18685790193

Pull #14676

github

web-flow
Merge 427de3761 into b827818c7
Pull Request #14676: Spark NLP 6.2.0 Release

147 of 185 new or added lines in 7 files covered. (79.46%)

11924 of 21595 relevant lines covered (55.22%)

0.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.61
/src/main/scala/com/johnsnowlabs/reader/HTMLReader.scala
1
/*
2
 * Copyright 2017-2024 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16
package com.johnsnowlabs.reader
17

18
import com.johnsnowlabs.nlp.util.io.ResourceHelper
19
import com.johnsnowlabs.nlp.util.io.ResourceHelper.{isValidURL, validFile}
20
import com.johnsnowlabs.partition.util.PartitionHelper.datasetWithTextFile
21
import com.johnsnowlabs.reader.util.HTMLParser
22
import com.johnsnowlabs.reader.util.HTMLParser.tableElementToJson
23
import org.apache.spark.sql.DataFrame
24
import org.apache.spark.sql.functions.{col, udf}
25
import org.jsoup.Jsoup
26
import org.jsoup.nodes.{Document, Element, Node, TextNode}
27

28
import java.util.UUID
29
import scala.collection.JavaConverters._
30
import scala.collection.mutable
31
import scala.collection.mutable.ArrayBuffer
32

33
/** Class to parse and read HTML files.
34
  *
35
  * @param titleFontSize
36
  *   Minimum font size threshold in pixels used as part of heuristic rules to detect title
37
  *   elements based on formatting (e.g., bold, centered, capitalized). By default, it is set to
38
  *   16.
39
  * @param storeContent
40
  *   Whether to include the raw file content in the output DataFrame as a separate 'content'
41
  *   column, alongside the structured output. By default, it is set to false.
42
  * @param timeout
43
  *   Timeout value in seconds for reading remote HTML resources. Applied when fetching content
44
  *   from URLs. By default, it is set to 0.
45
  * @param headers
46
  *   sets the necessary headers for the URL request.
47
  *
48
  * Two types of input paths are supported for the reader,
49
  *
50
  * htmlPath: this is a path to a directory of HTML files or a path to an HTML file E.g.
51
  * "path/html/files"
52
  *
53
  * url: this is the URL or set of URLs of a website . E.g., "https://www.wikipedia.org"
54
  *
55
  * ==Example==
56
  * {{{
57
  * val path = "./html-files/fake-html.html"
58
  * val HTMLReader = new HTMLReader()
59
  * val htmlDF = HTMLReader.read(url)
60
  * }}}
61
  *
62
  * {{{
63
  * htmlDF.show()
64
  * +--------------------+--------------------+
65
  * |                path|                html|
66
  * +--------------------+--------------------+
67
  * |file:/content/htm...|[{Title, My First...|
68
  * +--------------------+--------------------+
69
  *
70
  * htmlDf.printSchema()
71
  * root
72
  *  |-- path: string (nullable = true)
73
  *  |-- html: array (nullable = true)
74
  *  |    |-- element: struct (containsNull = true)
75
  *  |    |    |-- elementType: string (nullable = true)
76
  *  |    |    |-- content: string (nullable = true)
77
  *  |    |    |-- metadata: map (nullable = true)
78
  *  |    |    |    |-- key: string
79
  *  |    |    |    |-- value: string (valueContainsNull = true)
80
  * }}}
81
  * For more examples please refer to this
82
  * [[https://github.com/JohnSnowLabs/spark-nlp/examples/python/reader/SparkNLP_HTML_Reader_Demo.ipynb notebook]].
83
  */
84

85
class HTMLReader(
86
    titleFontSize: Int = 16,
87
    storeContent: Boolean = false,
88
    timeout: Int = 0,
89
    includeTitleTag: Boolean = false,
90
    outputFormat: String = "plain-text",
91
    headers: Map[String, String] = Map.empty)
92
    extends Serializable {
93

94
  private lazy val spark = ResourceHelper.spark
95
  import spark.implicits._
96

97
  private var outputColumn = "html"
1✔
98

99
  def setOutputColumn(value: String): this.type = {
100
    require(value.nonEmpty, "Output column name cannot be empty.")
×
101
    outputColumn = value
×
102
    this
103
  }
104

105
  def getOutputColumn: String = outputColumn
1✔
106

107
  /** @param inputSource
108
    *   this is the link to the URL E.g. www.wikipedia.com
109
    *
110
    * @return
111
    *   Dataframe with parsed URL content.
112
    */
113

114
  def read(inputSource: String): DataFrame = {
115
    ResourceHelper match {
1✔
116
      case _ if validFile(inputSource) && !inputSource.startsWith("http") =>
1✔
117
        val htmlDf = datasetWithTextFile(spark, inputSource)
118
          .withColumn(outputColumn, parseHtmlUDF(col("content")))
1✔
119
        if (storeContent) htmlDf.select("path", "content", outputColumn)
1✔
120
        else htmlDf.select("path", outputColumn)
1✔
121
      case _ if isValidURL(inputSource) =>
1✔
122
        val htmlDf = spark
123
          .createDataset(Seq(inputSource))
1✔
124
          .toDF("url")
1✔
125
          .withColumn(outputColumn, parseURLUDF(col("url")))
1✔
126
        if (storeContent) htmlDf.select("url", "content", outputColumn)
×
127
        else htmlDf.select("url", outputColumn)
1✔
128
      case _ =>
129
        throw new IllegalArgumentException(s"Invalid inputSource: $inputSource")
×
130
    }
131
  }
132

133
  /** @param inputURLs
134
    *   this is a list of URLs E.g. [www.wikipedia.com, www.example.com]
135
    *
136
    * @return
137
    *   Dataframe with parsed URL content.
138
    */
139

140
  def read(inputURLs: Array[String]): DataFrame = {
141
    val spark = ResourceHelper.spark
142
    import spark.implicits._
143

144
    val validURLs = inputURLs.filter(url => ResourceHelper.isValidURL(url)).toSeq
×
145
    spark
146
      .createDataset(validURLs)
×
147
      .toDF("url")
×
148
      .withColumn(outputColumn, parseURLUDF(col("url")))
×
149
  }
150

151
  private val parseHtmlUDF = udf((html: String) => {
1✔
152
    val document = Jsoup.parse(html)
1✔
153
    startTraversalFromBody(document)
1✔
154
  })
155

156
  private val parseURLUDF = udf((url: String) => {
1✔
157
    val connection = Jsoup
158
      .connect(url)
159
      .headers(headers.asJava)
1✔
160
      .timeout(timeout * 1000)
1✔
161
    val document = connection.get()
1✔
162
    startTraversalFromBody(document)
1✔
163
  })
164

165
  private def startTraversalFromBody(document: Document): Array[HTMLElement] = {
166
    try {
1✔
167
      val body = document.body()
1✔
168
      val elements = extractElements(body)
1✔
169
      val docTitle = document.title().trim
1✔
170

171
      if (docTitle.nonEmpty && includeTitleTag) {
1✔
172
        val titleElem = HTMLElement(
1✔
173
          ElementType.TITLE,
1✔
174
          content = docTitle,
175
          metadata = mutable.Map.empty[String, String])
1✔
176
        Array(titleElem) ++ elements
1✔
177
      } else {
178
        elements
1✔
179
      }
180
    } catch {
181
      case e: Exception =>
182
        Array(
×
183
          HTMLElement(ElementType.ERROR, s"Could not parse HTML: ${e.getMessage}", mutable.Map()))
×
184
    }
185
  }
186

187
  def htmlToHTMLElement(html: String): Array[HTMLElement] = {
188
    val document = Jsoup.parse(html)
1✔
189
    startTraversalFromBody(document)
1✔
190
  }
191

192
  def urlToHTMLElement(url: String): Array[HTMLElement] = {
193
    val connection = Jsoup
194
      .connect(url)
195
      .headers(headers.asJava)
×
196
      .timeout(timeout * 1000)
×
197
    val document = connection.get()
×
198
    startTraversalFromBody(document)
×
199
  }
200

201
  private case class NodeMetadata(tagName: Option[String], hidden: Boolean, var visited: Boolean)
202

203
  private def extractElements(root: Node): Array[HTMLElement] = {
204
    var sentenceIndex = 0
1✔
205
    val elements = ArrayBuffer[HTMLElement]()
1✔
206
    val trackingNodes = mutable.Map[Node, NodeMetadata]()
1✔
207
    var pageNumber = 1
1✔
208

209
    // Track parent-child hierarchy
210
    var currentParentId: Option[String] = None
1✔
211

212
    def newUUID(): String = UUID.randomUUID().toString
1✔
213

214
    def isNodeHidden(node: Node): Boolean = {
215
      node match {
216
        case elem: Element =>
217
          val style = elem.attr("style").toLowerCase
1✔
218
          val isHiddenByStyle =
219
            style.contains("display:none") || style.contains("visibility:hidden")
1✔
220
          val isHiddenByAttribute = elem.hasAttr("hidden") || elem.attr("aria-hidden") == "true"
1✔
221
          isHiddenByStyle || isHiddenByAttribute
1✔
222
        case _ => false
1✔
223
      }
224
    }
225

226
    def collectTextFromNodes(nodes: List[Node]): String = {
227
      val textBuffer = ArrayBuffer[String]()
1✔
228

229
      def traverseAndCollect(node: Node): Unit = {
230
        val isHiddenNode = trackingNodes
231
          .getOrElseUpdate(
232
            node,
233
            NodeMetadata(tagName = getTagName(node), hidden = isNodeHidden(node), visited = true))
×
234
          .hidden
1✔
235
        if (!isHiddenNode) {
×
236
          node match {
1✔
237
            case textNode: TextNode =>
238
              trackingNodes(textNode).visited = true
1✔
239
              val text = textNode.text().trim
1✔
240
              if (text.nonEmpty) textBuffer += text
1✔
241

242
            case elem: Element =>
243
              trackingNodes(elem).visited = true
1✔
244
              val text = elem.ownText().trim
1✔
245
              if (text.nonEmpty) textBuffer += text
1✔
246
              elem.childNodes().asScala.foreach(traverseAndCollect)
1✔
247

248
            case _ => // Ignore other node types
1✔
249
          }
250
        }
251
      }
252

253
      nodes.foreach(traverseAndCollect)
1✔
254
      textBuffer.mkString(" ").replaceAll("\\s+", " ").trim
1✔
255
    }
256

257
    def traverse(node: Node, tagName: Option[String]): Unit = {
258
      trackingNodes.getOrElseUpdate(
1✔
259
        node,
260
        NodeMetadata(tagName = tagName, hidden = isNodeHidden(node), visited = false))
1✔
261

262
      node.childNodes().forEach { childNode =>
1✔
263
        trackingNodes.getOrElseUpdate(
1✔
264
          childNode,
265
          NodeMetadata(tagName = tagName, hidden = isNodeHidden(childNode), visited = false))
1✔
266
      }
267

268
      if (trackingNodes(node).hidden) return
1✔
269

270
      node match {
271
        case element: Element =>
272
          val visitedNode = trackingNodes(element).visited
1✔
273
          val pageMetadata: mutable.Map[String, String] =
274
            mutable.Map("pageNumber" -> pageNumber.toString)
1✔
275

276
          element.tagName() match {
1✔
277
            case "a" =>
278
              pageMetadata("sentence") = sentenceIndex.toString
1✔
279
              sentenceIndex += 1
1✔
280
              val href = element.attr("href").trim
1✔
281
              val linkText = element.text().trim
1✔
282
              if (href.nonEmpty && linkText.nonEmpty && !visitedNode) {
1✔
283
                trackingNodes(element).visited = true
1✔
284
                pageMetadata("element_id") = newUUID()
1✔
285
                currentParentId.foreach(pid => pageMetadata("parent_id") = pid)
1✔
286
                elements += HTMLElement(
1✔
287
                  ElementType.LINK,
1✔
288
                  content = s"[$linkText]($href)",
1✔
289
                  metadata = pageMetadata)
290
              }
291

292
            case "table" =>
293
              pageMetadata("sentence") = sentenceIndex.toString
1✔
294
              sentenceIndex += 1
1✔
295
              val tableContent = outputFormat match {
1✔
296
                case "plain-text" => extractNestedTableContent(element).trim
1✔
297
                case "html-table" =>
298
                  element
299
                    .outerHtml()
300
                    .replaceAll("\\n", "")
301
                    .replaceAll(">\\s+<", "><")
302
                    .replaceAll("^\\s+|\\s+$", "")
1✔
303
                case "json-table" => tableElementToJson(element)
1✔
NEW
304
                case _ => extractNestedTableContent(element).trim
×
305
              }
306
              if (tableContent.nonEmpty && !visitedNode) {
×
307
                trackingNodes(element).visited = true
1✔
308
                pageMetadata("element_id") = newUUID()
1✔
309
                currentParentId.foreach(pid => pageMetadata("parent_id") = pid)
1✔
310
                elements += HTMLElement(
1✔
311
                  ElementType.TABLE,
1✔
312
                  content = tableContent,
313
                  metadata = pageMetadata)
314
              }
315

316
            case "li" =>
317
              pageMetadata("sentence") = sentenceIndex.toString
1✔
318
              sentenceIndex += 1
1✔
319
              val itemText = element.text().trim
1✔
320
              if (itemText.nonEmpty && !visitedNode) {
1✔
321
                trackingNodes(element).visited = true
1✔
322
                pageMetadata("element_id") = newUUID()
1✔
323
                currentParentId.foreach(pid => pageMetadata("parent_id") = pid)
1✔
324
                elements += HTMLElement(
1✔
325
                  ElementType.LIST_ITEM,
1✔
326
                  content = itemText,
327
                  metadata = pageMetadata)
328
              }
329

330
            case "pre" =>
331
              val codeElem = element.getElementsByTag("code").first()
1✔
332
              val codeText =
333
                if (codeElem != null) codeElem.text().trim
1✔
334
                else element.text().trim
×
335
              if (codeText.nonEmpty && !visitedNode) {
1✔
336
                pageMetadata("sentence") = sentenceIndex.toString
1✔
337
                sentenceIndex += 1
1✔
338
                trackingNodes(element).visited = true
1✔
339
                pageMetadata("element_id") = newUUID()
1✔
340
                currentParentId.foreach(pid => pageMetadata("parent_id") = pid)
1✔
341
                elements += HTMLElement(
1✔
342
                  ElementType.UNCATEGORIZED_TEXT,
1✔
343
                  content = codeText,
344
                  metadata = pageMetadata)
345
              }
346

347
            case tag if isParagraphLikeElement(element) =>
1✔
348
              if (!visitedNode) {
1✔
349
                val classType = classifyParagraphElement(element)
1✔
350
                element.childNodes().asScala.foreach { childNode =>
1✔
351
                  val tagName = getTagName(childNode)
1✔
352
                  traverse(childNode, tagName)
1✔
353
                }
354

355
                classType match {
356
                  case ElementType.NARRATIVE_TEXT =>
357
                    val childNodes = element.childNodes().asScala.toList
1✔
358
                    val aggregatedText = collectTextFromNodes(childNodes)
1✔
359
                    if (aggregatedText.nonEmpty) {
1✔
360
                      pageMetadata("sentence") = sentenceIndex.toString
1✔
361
                      sentenceIndex += 1
1✔
362
                      trackingNodes(element).visited = true
1✔
363
                      pageMetadata("element_id") = newUUID()
1✔
364
                      currentParentId.foreach(pid => pageMetadata("parent_id") = pid)
1✔
365
                      elements += HTMLElement(
1✔
366
                        ElementType.NARRATIVE_TEXT,
1✔
367
                        content = aggregatedText,
368
                        metadata = pageMetadata)
369
                    }
370

371
                  case ElementType.TITLE =>
372
                    val titleText = element.text().trim
1✔
373
                    if (titleText.nonEmpty) {
1✔
374
                      pageMetadata("sentence") = sentenceIndex.toString
1✔
375
                      sentenceIndex += 1
1✔
376
                      trackingNodes(element).visited = true
1✔
377
                      val titleId = newUUID()
1✔
378
                      pageMetadata("element_id") = titleId
1✔
379
                      elements += HTMLElement(
1✔
380
                        ElementType.TITLE,
1✔
381
                        content = titleText,
382
                        metadata = pageMetadata)
383
                      currentParentId = Some(titleId)
1✔
384
                    }
385

386
                  case ElementType.UNCATEGORIZED_TEXT =>
387
                    val text = element.text().trim
×
388
                    if (text.nonEmpty) {
×
389
                      pageMetadata("sentence") = sentenceIndex.toString
×
390
                      sentenceIndex += 1
×
391
                      trackingNodes(element).visited = true
×
NEW
392
                      pageMetadata("element_id") = newUUID()
×
NEW
393
                      currentParentId.foreach(pid => pageMetadata("parent_id") = pid)
×
394
                      elements += HTMLElement(
×
395
                        ElementType.UNCATEGORIZED_TEXT,
×
396
                        content = text,
397
                        metadata = pageMetadata)
398
                    }
399
                }
400
              }
401

402
            case _ if isTitleElement(element) && !visitedNode =>
1✔
403
              trackingNodes(element).visited = true
1✔
404
              val titleText = element.text().trim
1✔
405
              if (titleText.nonEmpty) {
×
406
                pageMetadata("sentence") = sentenceIndex.toString
1✔
407
                sentenceIndex += 1
1✔
408
                val titleId = newUUID()
1✔
409
                pageMetadata("element_id") = titleId
1✔
410
                elements += HTMLElement(
1✔
411
                  ElementType.TITLE,
1✔
412
                  content = titleText,
413
                  metadata = pageMetadata)
414
                currentParentId = Some(titleId)
1✔
415
              }
416

417
            case "hr" =>
418
              if (element.attr("style").toLowerCase.contains("page-break")) {
1✔
419
                pageNumber = pageNumber + 1
1✔
420
              }
421

422
            case "img" =>
423
              pageMetadata("sentence") = sentenceIndex.toString
1✔
424
              sentenceIndex += 1
1✔
425
              val src = element.attr("src").trim
1✔
426
              val alt = element.attr("alt").trim
1✔
427
              if (src.nonEmpty && !visitedNode) {
×
428
                trackingNodes(element).visited = true
1✔
429
                val isBase64 = src.toLowerCase.contains("base64")
1✔
430
                val width = element.attr("width").trim
1✔
431
                val height = element.attr("height").trim
1✔
432

433
                val imgMetadata = mutable.Map[String, String]("alt" -> alt) ++ pageMetadata
1✔
434
                var contentValue = src
435
                if (isBase64) {
1✔
436
                  val commaIndex = src.indexOf(',')
1✔
437
                  if (commaIndex > 0) {
1✔
438
                    val header = src.substring(0, commaIndex)
1✔
439
                    val base64Payload = src.substring(commaIndex + 1)
1✔
440
                    imgMetadata("encoding") = header
1✔
441
                    contentValue = base64Payload
442
                  }
443
                }
444
                if (width.nonEmpty) imgMetadata("width") = width
1✔
445
                if (height.nonEmpty) imgMetadata("height") = height
1✔
446
                imgMetadata("element_id") = newUUID()
1✔
447
                currentParentId.foreach(pid => imgMetadata("parent_id") = pid)
1✔
448
                elements += HTMLElement(
1✔
449
                  ElementType.IMAGE,
1✔
450
                  content = contentValue,
451
                  metadata = imgMetadata)
452
              }
453

454
            case _ =>
455
              element.childNodes().asScala.foreach { childNode =>
1✔
456
                val tagName = getTagName(childNode)
1✔
457
                traverse(childNode, tagName)
1✔
458
              }
459
          }
460
        case _ => // Ignore other node types
1✔
461
      }
462
    }
463

464
    val tagName = getTagName(root)
1✔
465
    traverse(root, tagName)
1✔
466
    elements.toArray
1✔
467
  }
468

469
  private def isParagraphLikeElement(elem: Element): Boolean = {
470
    val tag = elem.tagName().toLowerCase
1✔
471
    val style = elem.attr("style").toLowerCase
1✔
472
    (tag == "p") ||
1✔
473
    (tag == "div" && (
1✔
474
      style.contains("font-size") ||
1✔
475
        style.contains("line-height") ||
1✔
476
        style.contains("margin") ||
1✔
477
        elem.getElementsByTag("b").size() > 0 ||
1✔
478
        elem.getElementsByTag("strong").size() > 0
1✔
479
    ))
480
  }
481

482
  private def getTagName(node: Node): Option[String] = {
483
    node match {
484
      case element: Element => Some(element.tagName())
1✔
485
      case _ => None
1✔
486
    }
487
  }
488

489
  private def classifyParagraphElement(element: Element): String = {
490
    if (isFormattedAsTitle(element)) {
1✔
491
      ElementType.TITLE
1✔
492
    } else if (isTextElement(element)) {
1✔
493
      ElementType.NARRATIVE_TEXT
1✔
494
    } else {
495
      ElementType.UNCATEGORIZED_TEXT
×
496
    }
497
  }
498

499
  private def isTitleElement(element: Element): Boolean = {
500
    val tag = element.tagName().toLowerCase
1✔
501
    val style = element.attr("style").toLowerCase
1✔
502
    val role = element.attr("role").toLowerCase
1✔
503
    HTMLParser.isTitleElement(tag, style, role, titleFontSize)
1✔
504
  }
505

506
  private def isTextElement(elem: Element): Boolean = {
507
    !isFormattedAsTitle(elem) &&
1✔
508
    (elem.attr("style").toLowerCase.contains("text") ||
1✔
509
      elem.tagName().toLowerCase == "p" ||
1✔
510
      (elem.tagName().toLowerCase == "div" && isParagraphLikeElement(elem)))
1✔
511
  }
512

513
  private def isFormattedAsTitle(elem: Element): Boolean = {
514
    val style = elem.attr("style").toLowerCase
1✔
515
    val hasBoldTag =
516
      elem.getElementsByTag("b").size() > 0 || elem.getElementsByTag("strong").size() > 0
1✔
517
    hasBoldTag || HTMLParser.isFormattedAsTitle(style, titleFontSize)
1✔
518
  }
519

520
  private def extractNestedTableContent(elem: Element): String = {
521
    val textBuffer = ArrayBuffer[String]()
1✔
522
    val processedElements = mutable.Set[Node]() // Set to track processed elements
1✔
523

524
    // Recursive function to collect text from the element and its children
525
    def collectText(node: Node): Unit = {
526
      node match {
527
        case childElem: Element =>
528
          if (!processedElements.contains(childElem)) {
×
529
            processedElements += childElem
1✔
530

531
            val directText = childElem.ownText().trim
1✔
532
            if (directText.nonEmpty) textBuffer += directText
1✔
533

534
            childElem.childNodes().asScala.foreach(collectText)
1✔
535
          }
536

537
        case _ => // Ignore other node types
1✔
538
      }
539
    }
540

541
    // Start the recursive text collection
542
    collectText(elem)
1✔
543
    textBuffer.mkString(" ")
1✔
544
  }
545

546
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc