• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 12294726112

12 Dec 2024 10:42AM UTC coverage: 90.472% (+0.1%) from 90.333%
12294726112

Pull #8617

github

web-flow
Merge 0ecc8177b into 04fc187bc
Pull Request #8617: !feat: unify NLTKDocumentSplitter and DocumentSplitter

8100 of 8953 relevant lines covered (90.47%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.05
haystack/components/preprocessors/sentence_tokenizer.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import re
1✔
6
from pathlib import Path
1✔
7
from typing import Any, Dict, List, Literal, Tuple
1✔
8

9
from haystack import logging
1✔
10
from haystack.lazy_imports import LazyImport
1✔
11

12
with LazyImport("Run 'pip install nltk'") as nltk_imports:
1✔
13
    import nltk
1✔
14

15
logger = logging.getLogger(__name__)
1✔
16

17
Language = Literal[
1✔
18
    "ru", "sl", "es", "sv", "tr", "cs", "da", "nl", "en", "et", "fi", "fr", "de", "el", "it", "no", "pl", "pt", "ml"
19
]
20

21
ISO639_TO_NLTK = {
1✔
22
    "ru": "russian",
23
    "sl": "slovene",
24
    "es": "spanish",
25
    "sv": "swedish",
26
    "tr": "turkish",
27
    "cs": "czech",
28
    "da": "danish",
29
    "nl": "dutch",
30
    "en": "english",
31
    "et": "estonian",
32
    "fi": "finnish",
33
    "fr": "french",
34
    "de": "german",
35
    "el": "greek",
36
    "it": "italian",
37
    "no": "norwegian",
38
    "pl": "polish",
39
    "pt": "portuguese",
40
    "ml": "malayalam",
41
}
42

43
QUOTE_SPANS_RE = re.compile(r"\W(\"+|\'+).*?\1")
1✔
44

45
if nltk_imports.is_successful():
1✔
46

47
    def load_sentence_tokenizer(
1✔
48
        language: Language, keep_white_spaces: bool = False
49
    ) -> nltk.tokenize.punkt.PunktSentenceTokenizer:
50
        """
51
        Utility function to load the nltk sentence tokenizer.
52

53
        :param language: The language for the tokenizer.
54
        :param keep_white_spaces: If True, the tokenizer will keep white spaces between sentences.
55
        :returns: nltk sentence tokenizer.
56
        """
57
        try:
1✔
58
            nltk.data.find("tokenizers/punkt_tab")
1✔
59
        except LookupError:
1✔
60
            try:
1✔
61
                nltk.download("punkt_tab")
1✔
62
            except FileExistsError as error:
×
63
                logger.debug("NLTK punkt tokenizer seems to be already downloaded. Error message: {error}", error=error)
×
64

65
        language_name = ISO639_TO_NLTK.get(language)
1✔
66

67
        if language_name is not None:
1✔
68
            sentence_tokenizer = nltk.data.load(f"tokenizers/punkt_tab/{language_name}.pickle")
1✔
69
        else:
70
            logger.warning(
×
71
                "PreProcessor couldn't find the default sentence tokenizer model for {language}. "
72
                " Using English instead. You may train your own model and use the 'tokenizer_model_folder' parameter.",
73
                language=language,
74
            )
75
            sentence_tokenizer = nltk.data.load("tokenizers/punkt_tab/english.pickle")
×
76

77
        if keep_white_spaces:
1✔
78
            sentence_tokenizer._lang_vars = CustomPunktLanguageVars()
1✔
79

80
        return sentence_tokenizer
1✔
81

82
    class CustomPunktLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
1✔
83
        # The following adjustment of PunktSentenceTokenizer is inspired by:
84
        # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
85
        # It is needed for preserving whitespace while splitting text into sentences.
86
        _period_context_fmt = r"""
1✔
87
                %(SentEndChars)s             # a potential sentence ending
88
                \s*                          # match potential whitespace [ \t\n\x0B\f\r]
89
                (?=(?P<after_tok>
90
                    %(NonWord)s              # either other punctuation
91
                    |
92
                    (?P<next_tok>\S+)        # or some other token - original version: \s+(?P<next_tok>\S+)
93
                ))"""
94

95
        def period_context_re(self) -> re.Pattern:
1✔
96
            """
97
            Compiles and returns a regular expression to find contexts including possible sentence boundaries.
98

99
            :returns: A compiled regular expression pattern.
100
            """
101
            try:
1✔
102
                return self._re_period_context  # type: ignore
1✔
103
            except:  # noqa: E722
1✔
104
                self._re_period_context = re.compile(
1✔
105
                    self._period_context_fmt
106
                    % {
107
                        "NonWord": self._re_non_word_chars,
108
                        # SentEndChars might be followed by closing brackets, so we match them here.
109
                        "SentEndChars": self._re_sent_end_chars + r"[\)\]}]*",
110
                    },
111
                    re.UNICODE | re.VERBOSE,
112
                )
113
                return self._re_period_context
1✔
114

115

116
class SentenceSplitter:  # pylint: disable=too-few-public-methods
1✔
117
    """
118
    SentenceSplitter splits a text into sentences using the nltk sentence tokenizer
119
    """
120

121
    def __init__(
1✔
122
        self,
123
        language: Language = "en",
124
        use_split_rules: bool = True,
125
        extend_abbreviations: bool = True,
126
        keep_white_spaces: bool = False,
127
    ) -> None:
128
        """
129
        Initializes the SentenceSplitter with the specified language, split rules, and abbreviation handling.
130

131
        :param language: The language for the tokenizer. Default is "en".
132
        :param use_split_rules: If True, the additional split rules are used. If False, the rules are not used.
133
        :param extend_abbreviations: If True, the abbreviations used by NLTK's PunktTokenizer are extended by a list
134
            of curated abbreviations if available. If False, the default abbreviations are used.
135
            Currently supported languages are: en, de.
136
        :param keep_white_spaces: If True, the tokenizer will keep white spaces between sentences.
137
        """
138
        self.language = language
1✔
139
        self.sentence_tokenizer = load_sentence_tokenizer(language, keep_white_spaces=keep_white_spaces)
1✔
140
        self.use_split_rules = use_split_rules
1✔
141
        if extend_abbreviations:
1✔
142
            abbreviations = SentenceSplitter._read_abbreviations(language)
1✔
143
            self.sentence_tokenizer._params.abbrev_types.update(abbreviations)
1✔
144
        self.keep_white_spaces = keep_white_spaces
1✔
145

146
    def split_sentences(self, text: str) -> List[Dict[str, Any]]:
1✔
147
        """
148
        Splits a text into sentences including references to original char positions for each split.
149

150
        :param text: The text to split.
151
        :returns: list of sentences with positions.
152
        """
153
        sentence_spans = list(self.sentence_tokenizer.span_tokenize(text))
1✔
154
        if self.use_split_rules:
1✔
155
            sentence_spans = SentenceSplitter._apply_split_rules(text, sentence_spans)
1✔
156

157
        sentences = [{"sentence": text[start:end], "start": start, "end": end} for start, end in sentence_spans]
1✔
158
        return sentences
1✔
159

160
    @staticmethod
1✔
161
    def _apply_split_rules(text: str, sentence_spans: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
1✔
162
        """
163
        Applies additional split rules to the sentence spans.
164

165
        :param text: The text to split.
166
        :param sentence_spans: The list of sentence spans to split.
167
        :returns: The list of sentence spans after applying the split rules.
168
        """
169
        new_sentence_spans = []
1✔
170
        quote_spans = [match.span() for match in QUOTE_SPANS_RE.finditer(text)]
1✔
171
        while sentence_spans:
1✔
172
            span = sentence_spans.pop(0)
1✔
173
            next_span = sentence_spans[0] if len(sentence_spans) > 0 else None
1✔
174
            while next_span and SentenceSplitter._needs_join(text, span, next_span, quote_spans):
1✔
175
                sentence_spans.pop(0)
1✔
176
                span = (span[0], next_span[1])
1✔
177
                next_span = sentence_spans[0] if len(sentence_spans) > 0 else None
1✔
178
            start, end = span
1✔
179
            new_sentence_spans.append((start, end))
1✔
180
        return new_sentence_spans
1✔
181

182
    @staticmethod
1✔
183
    def _needs_join(
1✔
184
        text: str, span: Tuple[int, int], next_span: Tuple[int, int], quote_spans: List[Tuple[int, int]]
185
    ) -> bool:
186
        """
187
        Checks if the spans need to be joined as parts of one sentence.
188

189
        This method determines whether two adjacent sentence spans should be joined back together as a single sentence.
190
        It's used to prevent incorrect sentence splitting in specific cases like quotations, numbered lists,
191
        and parenthetical expressions.
192

193
        :param text: The text containing the spans.
194
        :param span: Tuple of (start, end) positions for the current sentence span.
195
        :param next_span: Tuple of (start, end) positions for the next sentence span.
196
        :param quote_spans: All quoted spans within text.
197
        :returns:
198
            True if the spans needs to be joined.
199
        """
200
        start, end = span
1✔
201
        next_start, next_end = next_span
1✔
202

203
        # sentence. sentence"\nsentence -> no split (end << quote_end)
204
        # sentence.", sentence -> no split (end < quote_end)
205
        # sentence?", sentence -> no split (end < quote_end)
206
        if any(quote_start < end < quote_end for quote_start, quote_end in quote_spans):
1✔
207
            # sentence boundary is inside a quote
208
            return True
1✔
209

210
        # sentence." sentence -> split (end == quote_end)
211
        # sentence?" sentence -> no split (end == quote_end)
212
        if any(quote_start < end == quote_end and text[quote_end - 2] == "?" for quote_start, quote_end in quote_spans):
1✔
213
            # question is cited
214
            return True
×
215

216
        if re.search(r"(^|\n)\s*\d{1,2}\.$", text[start:end]) is not None:
1✔
217
            # sentence ends with a numeration
218
            return True
1✔
219

220
        # next sentence starts with a bracket or we return False
221
        return re.search(r"^\s*[\(\[]", text[next_start:next_end]) is not None
1✔
222

223
    @staticmethod
1✔
224
    def _read_abbreviations(lang: Language) -> List[str]:
1✔
225
        """
226
        Reads the abbreviations for a given language from the abbreviations file.
227

228
        :param lang: The language to read the abbreviations for.
229
        :returns: List of abbreviations.
230
        """
231
        abbreviations_file = Path(__file__).parent.parent / f"data/abbreviations/{lang}.txt"
1✔
232
        if not abbreviations_file.exists():
1✔
233
            logger.warning("No abbreviations file found for {language}. Using default abbreviations.", language=lang)
1✔
234
            return []
1✔
235

236
        abbreviations = abbreviations_file.read_text().split("\n")
1✔
237
        return abbreviations
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc