• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nielstron / quantulum3 / 946

pending completion
946

cron

travis-ci-com

nielstron
Merge branch 'dev'

467 of 467 new or added lines in 14 files covered. (100.0%)

1812 of 1847 relevant lines covered (98.11%)

4.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.66
/quantulum3/tests/test_classifier.py
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
"""
5✔
4
:mod:`Quantulum` tests.
5
"""
6

7
from __future__ import division
5✔
8

9
import json
5✔
10
import os
5✔
11
import sys
5✔
12
import unittest
5✔
13
import urllib.request
5✔
14
from pathlib import Path
5✔
15
from tempfile import TemporaryDirectory
5✔
16
from typing import Any
5✔
17
from unittest.mock import MagicMock, patch
5✔
18

19
import joblib
5✔
20
import wikipedia
5✔
21

22
from .. import classifier as clf
5✔
23
from .. import language, load
5✔
24
from .. import parser as p
5✔
25
from .test_setup import (
5✔
26
    add_type_equalities,
27
    get_classifier_path,
28
    get_entity_paths,
29
    get_unit_paths,
30
    load_error_tests,
31
    load_expand_tests,
32
    load_quantity_tests,
33
    multilang,
34
)
35

36
COLOR1 = "\033[94m%s\033[0m"
5✔
37
COLOR2 = "\033[91m%s\033[0m"
5✔
38
TOPDIR = os.path.dirname(__file__) or "."
5✔
39
TEST_DATA_DIR = Path(TOPDIR) / "data"
5✔
40

41

42
###############################################################################
43
class ClassifierBuild(unittest.TestCase):
5✔
44
    @multilang
5✔
45
    def test_training(self, lang="en_US"):
5✔
46
        """Test that classifier training works"""
47
        # Test that no errors are thrown during training
48
        # Also stores result, to be included in package
49
        self.assertIsNotNone(clf.train_classifier(store=True, lang=lang))
5✔
50

51

52
###############################################################################
53
# pylint: disable=no-self-use
54
class ClassifierTest(unittest.TestCase):
5✔
55
    """Test suite for the quantulum3 project."""
56

57
    def setUp(self):
5✔
58
        load.clear_caches()
5✔
59
        load.reset_quantities()
5✔
60
        add_type_equalities(self)
5✔
61

62
    def _test_parse_classifier(self, lang="en_US", classifier_path=None):
5✔
63
        clf.USE_CLF = True
5✔
64

65
        parse_kwargs = {"lang": lang, "classifier_path": classifier_path}
5✔
66

67
        all_tests = load_quantity_tests(False, lang=lang)
5✔
68
        for test in sorted(all_tests, key=lambda x: len(x["req"])):
5✔
69
            with self.subTest(input=test["req"]):
5✔
70
                quants = p.parse(test["req"], **parse_kwargs)
5✔
71

72
                self.assertEqual(
5✔
73
                    len(test["res"]),
74
                    len(quants),
75
                    msg="Differing amount of quantities parsed, expected {}, "
76
                    "got {}: {}, {}".format(
77
                        len(test["res"]), len(quants), test["res"], quants
78
                    ),
79
                )
80
                for index, quant in enumerate(quants):
5✔
81
                    self.assertEqual(quant, test["res"][index])
5✔
82

83
        classifier_tests = load_quantity_tests(True, lang)
5✔
84
        correct = 0
5✔
85
        total = len(classifier_tests)
5✔
86
        error = []
5✔
87
        for test in sorted(classifier_tests, key=lambda x: len(x["req"])):
5✔
88
            quants = p.parse(test["req"], **parse_kwargs)
5✔
89
            if quants == test["res"]:
5✔
90
                correct += 1
5✔
91
            else:
92
                error.append((test, quants))
5✔
93
        success_rate = correct / total
5✔
94
        print("Classifier success rate at {:.2f}%".format(success_rate * 100))
5✔
95
        self.assertGreaterEqual(
5✔
96
            success_rate,
97
            0.8,
98
            "Classifier success rate was at {}%, below 80%.\nFailure at\n{}".format(
99
                success_rate * 100,
100
                "\n".join("{}: {}".format(test[0]["req"], test[1]) for test in error),
101
            ),
102
        )
103

104
    @multilang
5✔
105
    def test_parse_classifier(self, lang="en_US"):
5✔
106
        """Test that parsing works with classifier usage"""
107
        self._test_parse_classifier(lang=lang)
5✔
108

109
    # @multilang
110
    # this was causing the test to fail, `en_US` got convereted to lowercase
111
    # and the path was not found
112
    def test_parse_classifier_custom_classifier(self):
5✔
113
        """Test parsing with a custom classifier model. Use the same model as
114
        the default one, but load it via the classifier_path argument, and ensure
115
        that the results are the same."""
116

117
        lang = "en_US"
5✔
118
        classifier_path = get_classifier_path(lang)
5✔
119
        self.assertTrue(
5✔
120
            classifier_path.exists(),
121
            f"Classifier path does not exist: {classifier_path}",
122
        )
123

124
        classifier = clf.classifier(
5✔
125
            lang=lang,
126
            classifier_path=classifier_path,
127
        )
128

129
        # call.args and call.kwargs have different behavior pre-3.8
130
        # not interested in working this out for 3.6/3.7 which are EOL or soon to be
131
        if sys.version_info >= (3, 8):  # pragma: no cover
132
            with patch(
133
                "quantulum3.classifier.classifier", return_value=classifier
134
            ) as mock_clf_classifier:
135
                self._test_parse_classifier(classifier_path=classifier_path)
136

137
                self.mock_assert_arg_in_all_calls(
138
                    mock_clf_classifier,
139
                    "classifier_path",
140
                    1,
141
                    classifier_path,
142
                )
143
        else:  # pragma: no cover
144
            self._test_parse_classifier(classifier_path=classifier_path)
145

146
    def test_parse_classifier_custom_units(self):
5✔
147
        """Test parsing with custom units. Use the same unit files as the default ones,
148
        but load them via the custom_units argument, and ensure that the results are the
149
        same."""
150

151
        lang = "en_US"
5✔
152
        load.load_custom_units(get_unit_paths(lang), use_additional_units=False)
5✔
153
        self.assertFalse(load.USE_GENERAL_UNITS)
5✔
154
        self.assertFalse(load.USE_LANGUAGE_UNITS)
5✔
155
        self.assertFalse(load.USE_ADDITIONAL_UNITS)
5✔
156
        self.assertTrue(load.USE_CUSTOM_UNITS)
5✔
157
        self._test_parse_classifier(lang=lang)
5✔
158

159
    def test_parse_classifier_custom_entities(self):
5✔
160
        """Test parsing with custom entities. Use the same entity files as the default ones,
161
        but load them via the custom_entities argument, and ensure that the results are the
162
        same."""
163

164
        lang = "en_US"
5✔
165
        load.load_custom_entities(get_entity_paths(lang), use_additional_entities=False)
5✔
166
        self.assertFalse(load.USE_GENERAL_ENTITIES)
5✔
167
        self.assertFalse(load.USE_LANGUAGE_ENTITIES)
5✔
168
        self.assertFalse(load.USE_ADDITIONAL_ENTITIES)
5✔
169
        self.assertTrue(load.USE_CUSTOM_ENTITIES)
5✔
170
        self._test_parse_classifier(lang=lang)
5✔
171

172
    @multilang
5✔
173
    def test_expand(self, lang="en_US"):
5✔
174
        """Test that parsing and expanding works correctly"""
175
        all_tests = load_expand_tests(lang=lang)
5✔
176
        for test in all_tests:
5✔
177
            with self.subTest(input=test["req"]):
5✔
178
                result = p.inline_parse_and_expand(test["req"], lang=lang)
5✔
179
                self.assertEqual(result, test["res"])
5✔
180

181
    @multilang
5✔
182
    def test_errors(self, lang="en_US"):
5✔
183
        """Test that no errors are thrown in edge cases"""
184
        all_tests = load_error_tests(lang=lang)
5✔
185
        for test in all_tests:
5✔
186
            with self.subTest(input=test):
5✔
187
                # pylint: disable=broad-except
188
                p.parse(test, lang=lang)
5✔
189

190
    @unittest.skip("Not necessary, as classifier is live built")
5✔
191
    @multilang
5✔
192
    def test_classifier_up_to_date(self, lang="en_US"):
5✔
193
        """
194
        Test that the classifier has been built with the latest version of
195
        scikit-learn
196
        """
197
        path = language.topdir(lang).joinpath("clf.joblib")
×
198
        with path.open("rb") as clf_file:
×
199
            obj = joblib.load(clf_file)
×
200
        clf_version = obj["scikit-learn_version"]
×
201
        with urllib.request.urlopen(
×
202
            "https://pypi.org/pypi/scikit-learn/json"
203
        ) as response:
204
            cur_version = json.loads(response.read().decode("utf-8"))["info"]["version"]
×
205
        self.assertEqual(
×
206
            clf_version,
207
            cur_version,
208
            "Classifier has been built with scikit-learn version {}, while the"
209
            " newest version is {}. Please update scikit-learn.".format(
210
                clf_version, cur_version
211
            ),
212
        )
213

214
    @multilang
5✔
215
    def test_training(self, lang="en_US"):
5✔
216
        """Test that classifier training works"""
217
        # Test that no errors are thrown during training
218
        obj = clf.train_classifier(store=False, lang=lang)
5✔
219
        # Test that the classifier works with the currently downloaded data
220
        load._CACHE_DICT[id(clf.classifier)] = {
5✔
221
            lang: clf.Classifier(classifier_object=obj, lang=lang)
222
        }
223
        self.test_parse_classifier(lang=lang)
5✔
224

225
    @patch("quantulum3.load.training_set")
5✔
226
    def test_training_custom_data(self, mock_load_training_set):
3✔
227
        """Test the classifier training works with custom data"""
228

229
        with (TEST_DATA_DIR / "train.json").open() as f:
5✔
230
            self.custom_training_data = json.load(f)
5✔
231

232
        clf.train_classifier(
5✔
233
            store=False,
234
            training_set=self.custom_training_data,
235
        )
236
        mock_load_training_set.assert_not_called()
5✔
237

238
    def test_training_custom_out_path(self):
5✔
239
        """Test the classifier training works with custom out path"""
240

241
        with TemporaryDirectory() as tmpdir:
5✔
242
            out_path = Path(tmpdir) / "clf.joblib"
5✔
243
            clf.train_classifier(
5✔
244
                output_path=out_path,
245
            )
246

247
            self.assertTrue(out_path.exists())
5✔
248

249
    @multilang(["en_US"])
5✔
250
    def test_wikipedia_pages(self, lang):
3✔
251
        wikipedia.set_lang(lang[:2])
5✔
252
        err = []
5✔
253
        units = dict(sorted(load.units(lang).names.items()))
5✔
254
        for unit in units.values():
5✔
255
            try:
5✔
256
                wikipedia.page(unit.uri.replace("_", " "), auto_suggest=False)
5✔
257
                pass
5✔
258
            except (
259
                wikipedia.PageError,
260
                wikipedia.DisambiguationError,
261
            ) as e:  # pragma: no cover
262
                err.append((unit, e))
263
        if err:  # pragma: no cover
264
            self.fail("Problematic pages:\n{}".format("\n".join(str(e) for e in err)))
265

266
    def mock_assert_arg_in_all_calls(
5✔
267
        self, mock: MagicMock, arg_name: str, arg_position: int, arg_value: Any
268
    ):
269
        """
270
        Checks that the given arg_name/arg_value is in every call to the given mock,
271
        either as a kwarg or as a positional argument.
272
        """
273
        trues = []
3✔
274

275
        self.assertGreater(
3✔
276
            len(mock.call_args_list),
277
            0,
278
            msg=f"Expected {arg_name}={arg_value} in all calls to {mock}, but there were no calls.",
279
        )
280

281
        for call in mock.call_args_list:
3✔
282
            try:
3✔
283
                if arg_name in call.kwargs:
3✔
284
                    if call.kwargs[arg_name] == arg_value:  # pragma: no cover
285
                        trues.append(call)
286
                elif arg_value == call.args[arg_position]:
3✔
287
                    trues.append(call)
3✔
288
            except IndexError:  # pragma: no cover
289
                pass
290

291
        self.assertEqual(
3✔
292
            len(trues),
293
            len(mock.call_args_list),
294
            msg=f"Expected {arg_name}={arg_value} in all calls to {mock}, but it was not in {len(mock.call_args_list) - len(trues)} calls.",
295
        )
296

297

298
###############################################################################
299
if __name__ == "__main__":  # pragma: no cover
300
    unittest.main()
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc