• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nielstron / quantulum3 / 946

pending completion
946

cron

travis-ci-com

nielstron
Merge branch 'dev'

467 of 467 new or added lines in 14 files covered. (100.0%)

1812 of 1847 relevant lines covered (98.11%)

4.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.45
/quantulum3/classifier.py
1
# -*- coding: utf-8 -*-
2
"""
5✔
3
:mod:`Quantulum` classifier functions.
4
"""
5

6
import json
5✔
7
import logging
5✔
8
import multiprocessing
5✔
9
import os
5✔
10
import warnings
5✔
11

12
import pkg_resources
5✔
13

14
from . import language, load
5✔
15
from .load import cached
5✔
16

17
# Semi-dependencies
18
try:
5✔
19
    import joblib
5✔
20
    from sklearn.feature_extraction.text import TfidfVectorizer
5✔
21
    from sklearn.linear_model import SGDClassifier
5✔
22

23
    USE_CLF = True
5✔
24
except ImportError:
5✔
25
    SGDClassifier, TfidfVectorizer = None, None
5✔
26
    USE_CLF = False
5✔
27

28
    warnings.warn(
5✔
29
        "Classifier dependencies not installed. Run pip install quantulum3[classifier] "
30
        "to install them. The classifer helps to dissambiguate units."
31
    )
32

33
try:
5✔
34
    import wikipedia
5✔
35
except ImportError:
5✔
36
    wikipedia = None
5✔
37

38

39
_LOGGER = logging.getLogger(__name__)
5✔
40

41

42
@cached
5✔
43
def _get_classifier(lang="en_US"):
5✔
44
    return language.get("classifier", lang)
5✔
45

46

47
###############################################################################
48
def ambiguous_units(lang="en_US"):  # pragma: no cover
49
    """
50
    Determine ambiguous units
51
    :return: list ( tuple( key, list (Unit) ) )
52
    """
53
    ambiguous = [
54
        i for i in list(load.units(lang).surfaces_all.items()) if len(i[1]) > 1
55
    ]
56
    ambiguous += [i for i in list(load.units(lang).symbols.items()) if len(i[1]) > 1]
57
    ambiguous += [i for i in list(load.entities(lang).derived.items()) if len(i[1]) > 1]
58
    return ambiguous
59

60

61
###############################################################################
62
def download_wiki(store=True, lang="en_US"):  # pragma: no cover
63
    """
64
    Download WikiPedia pages of ambiguous units.
65
    @:param store (bool) store wikipedia data in wiki.json file
66
    """
67
    if not wikipedia:
68
        print("Cannot download wikipedia pages. Install package wikipedia first.")
69
        return
70

71
    wikipedia.set_lang(lang[:2])
72

73
    ambiguous = ambiguous_units()
74
    pages = set([(j.name, j.uri) for i in ambiguous for j in i[1]])
75

76
    print()
77
    objs = []
78
    for num, page in enumerate(pages):
79
        obj = {
80
            "_id": page[1],
81
            "url": "https://{}.wikipedia.org/wiki/{}".format(lang[:2], page[1]),
82
            "clean": page[1].replace("_", " "),
83
        }
84

85
        print("---> Downloading %s (%d of %d)" % (obj["clean"], num + 1, len(pages)))
86

87
        obj["text"] = wikipedia.page(obj["clean"], auto_suggest=False).content
88
        obj["unit"] = page[0]
89
        objs.append(obj)
90

91
    path = language.topdir(lang).joinpath("train/wiki.json")
92
    if store:
93
        with path.open("w") as wiki_file:
94
            json.dump(objs, wiki_file, indent=4, sort_keys=True)
95

96
    print("\n---> All done.\n")
97
    return objs
98

99

100
###############################################################################
101
def clean_text(text, lang="en_US"):
5✔
102
    """
103
    Clean text for TFIDF
104
    """
105
    return _get_classifier(lang).clean_text(text)
5✔
106

107

108
def _clean_text_lang(lang):
5✔
109
    return _get_classifier(lang).clean_text
5✔
110

111

112
###############################################################################
113
def train_classifier(
5✔
114
    parameters=None,
115
    ngram_range=(1, 1),
116
    store=True,
117
    lang="en_US",
118
    n_jobs=None,
119
    training_set=None,
120
    output_path=None,
121
):
122
    """
123
    Train the intent classifier
124
    TODO auto invoke if sklearn version is new or first install or sth
125
    @:param store (bool) store classifier in clf.joblib
126

127
    Parameters
128
    ----------
129
    parameters : dict
130
        Parameters for the SGDClassifier (see sklearn.linear_model.SGDClassifier)
131
    ngram_range : tuple
132
        Range of ngrams to use (see sklearn.feature_extraction.text.TfidfVectorizer)
133
    store : bool
134
        Save the classifier as a joblib file
135
    lang : str
136
        Language to use
137
    n_jobs : int
138
        Number of CPU jobs to use for training
139
    training_set : list
140
        Training data as a list of dicts with keys "text" and "unit". If None,
141
        the training set will be loaded from the training set file. See
142
        quantulum3._lang.en_US.train for examples.
143
    output_path : str
144
        Path to save the classifier to. If None, the classifier will be saved
145
        to the default location for the given language.
146
    """
147
    _LOGGER.info("Started training, parallelized with {} jobs".format(n_jobs))
5✔
148
    _LOGGER.info("Loading training set")
5✔
149
    if training_set is None:
5✔
150
        training_set = load.training_set(lang)
5✔
151

152
    target_names = list(frozenset([i["unit"] for i in training_set]))
5✔
153

154
    _LOGGER.info("Preparing training set")
5✔
155

156
    if n_jobs is None:
5✔
157
        try:
5✔
158
            # Retreive the number of cpus that can be used
159
            n_jobs = len(os.sched_getaffinity(0))
5✔
160
        except AttributeError:
×
161
            # n_jobs stays None such that Pool will try to
162
            # automatically set the number of processes appropriately
163
            pass
×
164
    with multiprocessing.Pool(processes=n_jobs) as p:
5✔
165
        train_data = p.map(_clean_text_lang(lang), [ex["text"] for ex in training_set])
5✔
166

167
    train_target = [target_names.index(example["unit"]) for example in training_set]
5✔
168

169
    tfidf_model = TfidfVectorizer(
5✔
170
        sublinear_tf=True,
171
        ngram_range=ngram_range,
172
        stop_words=_get_classifier(lang).stop_words(),
173
    )
174

175
    _LOGGER.info("Fit TFIDF Model")
5✔
176
    matrix = tfidf_model.fit_transform(train_data)
5✔
177

178
    if parameters is None:
5✔
179
        parameters = {
5✔
180
            "loss": "log",
181
            "penalty": "l2",
182
            "tol": 1e-3,
183
            "n_jobs": n_jobs,
184
            "alpha": 0.0001,
185
            "fit_intercept": True,
186
            "random_state": 0,
187
        }
188

189
    _LOGGER.info("Fit SGD Classifier")
5✔
190
    clf = SGDClassifier(**parameters).fit(matrix, train_target)
5✔
191
    obj = {
5✔
192
        "scikit-learn_version": pkg_resources.get_distribution("scikit-learn").version,
193
        "tfidf_model": tfidf_model,
194
        "clf": clf,
195
        "target_names": target_names,
196
    }
197
    if store:  # pragma: no cover
198
        if output_path is not None:
199
            path = output_path
200
        else:
201
            # legacy behavior
202
            path = language.topdir(lang).joinpath("clf.joblib")
203

204
        _LOGGER.info("Store classifier at {}".format(path))
205
        with open(path, "wb") as file:
206
            joblib.dump(obj, file)
207
    return obj
5✔
208

209

210
###############################################################################
211
class Classifier(object):
5✔
212
    def __init__(self, classifier_object=None, lang="en_US", classifier_path=None):
5✔
213
        """
214
        Load the intent classifier
215

216
        Parameters
217
        ----------
218
        obj : dict
219
            Classifier object as returned by train_classifier
220
        lang : str
221
            Language to use (ignored if a classifier object or path is given)
222
        classifier_path : str
223
            Path a joblib file containing the classifier. If None, the
224
            classifier will be loaded from the default location for the given
225
            language.
226
        """
227
        self.tfidf_model = None
5✔
228
        self.classifier = None
5✔
229
        self.target_names = None
5✔
230

231
        if not USE_CLF:
5✔
232
            return
×
233

234
        if not classifier_object:
5✔
235
            if classifier_path is None:
5✔
236
                classifier_path = language.topdir(lang).joinpath("clf.joblib")
5✔
237
            with open(classifier_path, "rb") as file:
5✔
238
                classifier_object = joblib.load(file)
5✔
239

240
        cur_scipy_version = pkg_resources.get_distribution("scikit-learn").version
5✔
241
        if cur_scipy_version != classifier_object.get(
242
            "scikit-learn_version"
243
        ):  # pragma: no cover
244
            _LOGGER.warning(
245
                "The classifier was built using a different scikit-learn "
246
                "version (={}, !={}). The disambiguation tool could behave "
247
                "unexpectedly. Consider running classifier.train_classfier()".format(
248
                    classifier_object.get("scikit-learn_version"), cur_scipy_version
249
                )
250
            )
251

252
        self.tfidf_model = classifier_object["tfidf_model"]
5✔
253
        self.classifier = classifier_object["clf"]
5✔
254
        self.target_names = classifier_object["target_names"]
5✔
255

256

257
@cached
5✔
258
def classifier(lang: str = "en_US", classifier_path: str = None) -> Classifier:
5✔
259
    """
260
    Cached classifier object
261
    :param lang: language
262
    :param classifier_path: path to a joblib file containing the classifier
263
    :return: Classifier object
264
    """
265
    return Classifier(lang=lang, classifier_path=classifier_path)
5✔
266

267

268
###############################################################################
269
def disambiguate_entity(key, text, lang="en_US", classifier_path=None):
5✔
270
    """
271
    Resolve ambiguity between entities with same dimensionality.
272
    """
273

274
    entities_ = load.entities(lang)
5✔
275

276
    new_ent = next(iter(entities_.derived[key]))
5✔
277
    if len(entities_.derived[key]) > 1:
5✔
278
        classifier_: Classifier = classifier(lang, classifier_path)
5✔
279

280
        transformed = classifier_.tfidf_model.transform([clean_text(text, lang)])
5✔
281
        scores = classifier_.classifier.predict_proba(transformed).tolist()[0]
5✔
282
        scores = zip(scores, classifier_.target_names)
5✔
283

284
        # Filter for possible names
285
        names = [i.name for i in entities_.derived[key]]
5✔
286
        scores = [i for i in scores if i[1] in names]
5✔
287

288
        # Sort by rank
289
        scores = sorted(scores, key=lambda x: x[0], reverse=True)
5✔
290
        try:
5✔
291
            new_ent = entities_.names[scores[0][1]]
5✔
292
        except IndexError:
×
293
            _LOGGER.debug('\tAmbiguity not resolved for "%s"', str(key))
×
294

295
    return new_ent
5✔
296

297

298
###############################################################################
299
def disambiguate_unit(unit, text, lang="en_US", classifier_path=None):
5✔
300
    """
301
    Resolve ambiguity between units with same names, symbols or abbreviations.
302
    """
303

304
    units_ = load.units(lang)
5✔
305

306
    new_unit = (
5✔
307
        units_.symbols.get(unit)
308
        or units_.surfaces.get(unit)
309
        or units_.surfaces_lower.get(unit.lower())
310
        or units_.symbols_lower.get(unit.lower())
311
    )
312
    if not new_unit:
5✔
313
        return units_.names.get("unk")
5✔
314

315
    if len(new_unit) > 1:
5✔
316
        classifier_: Classifier = classifier(lang, classifier_path)
5✔
317

318
        transformed = classifier_.tfidf_model.transform([clean_text(text, lang)])
5✔
319
        scores = classifier_.classifier.predict_proba(transformed).tolist()[0]
5✔
320
        scores = zip(scores, classifier_.target_names)
5✔
321

322
        # Filter for possible names
323
        names = [i.name for i in new_unit]
5✔
324
        scores = [i for i in scores if i[1] in names]
5✔
325

326
        # Sort by rank
327
        scores = sorted(scores, key=lambda x: x[0], reverse=True)
5✔
328
        try:
5✔
329
            final = units_.names[scores[0][1]]
5✔
330
            _LOGGER.debug('\tAmbiguity resolved for "%s" (%s)' % (unit, scores))
5✔
331
        except IndexError:
5✔
332
            _LOGGER.debug('\tAmbiguity not resolved for "%s"' % unit)
5✔
333
            final = next(iter(new_unit))
5✔
334
    else:
335
        final = next(iter(new_unit))
5✔
336

337
    return final
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc