• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nielstron / quantulum3 / 897

pending completion
897

push

travis-ci-com

web-flow
Merge pull request #218 from adhardydm/negative-ranges

Enhance parsing of ranges to better handle negative values

22 of 22 new or added lines in 2 files covered. (100.0%)

1513 of 1546 relevant lines covered (97.87%)

4.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.15
/quantulum3/classifier.py
1
# -*- coding: utf-8 -*-
2
"""
5✔
3
:mod:`Quantulum` classifier functions.
4
"""
5

6
import json
5✔
7
import logging
5✔
8
import multiprocessing
5✔
9
import os
5✔
10

11
import pkg_resources
5✔
12

13
from . import language, load
5✔
14
from .load import cached
5✔
15

16
# Semi-dependencies
17
try:
5✔
18
    import joblib
5✔
19
    from sklearn.feature_extraction.text import TfidfVectorizer
5✔
20
    from sklearn.linear_model import SGDClassifier
5✔
21

22
    USE_CLF = True
5✔
23
except ImportError:
5✔
24
    SGDClassifier, TfidfVectorizer = None, None
5✔
25
    USE_CLF = False
5✔
26

27
try:
5✔
28
    import wikipedia
5✔
29
except ImportError:
5✔
30
    wikipedia = None
5✔
31

32

33
_LOGGER = logging.getLogger(__name__)
5✔
34

35

36
@cached
5✔
37
def _get_classifier(lang="en_US"):
5✔
38
    return language.get("classifier", lang)
5✔
39

40

41
###############################################################################
42
def ambiguous_units(lang="en_US"):  # pragma: no cover
43
    """
44
    Determine ambiguous units
45
    :return: list ( tuple( key, list (Unit) ) )
46
    """
47
    ambiguous = [
48
        i for i in list(load.units(lang).surfaces_all.items()) if len(i[1]) > 1
49
    ]
50
    ambiguous += [i for i in list(load.units(lang).symbols.items()) if len(i[1]) > 1]
51
    ambiguous += [i for i in list(load.entities(lang).derived.items()) if len(i[1]) > 1]
52
    return ambiguous
53

54

55
###############################################################################
56
def download_wiki(store=True, lang="en_US"):  # pragma: no cover
57
    """
58
    Download WikiPedia pages of ambiguous units.
59
    @:param store (bool) store wikipedia data in wiki.json file
60
    """
61
    if not wikipedia:
62
        print("Cannot download wikipedia pages. Install package wikipedia first.")
63
        return
64

65
    wikipedia.set_lang(lang[:2])
66

67
    ambiguous = ambiguous_units()
68
    pages = set([(j.name, j.uri) for i in ambiguous for j in i[1]])
69

70
    print()
71
    objs = []
72
    for num, page in enumerate(pages):
73
        obj = {
74
            "_id": page[1],
75
            "url": "https://{}.wikipedia.org/wiki/{}".format(lang[:2], page[1]),
76
            "clean": page[1].replace("_", " "),
77
        }
78

79
        print("---> Downloading %s (%d of %d)" % (obj["clean"], num + 1, len(pages)))
80

81
        obj["text"] = wikipedia.page(obj["clean"], auto_suggest=False).content
82
        obj["unit"] = page[0]
83
        objs.append(obj)
84

85
    path = language.topdir(lang).joinpath("train/wiki.json")
86
    if store:
87
        with path.open("w") as wiki_file:
88
            json.dump(objs, wiki_file, indent=4, sort_keys=True)
89

90
    print("\n---> All done.\n")
91
    return objs
92

93

94
###############################################################################
95
def clean_text(text, lang="en_US"):
5✔
96
    """
97
    Clean text for TFIDF
98
    """
99
    return _get_classifier(lang).clean_text(text)
5✔
100

101

102
def _clean_text_lang(lang):
5✔
103
    return _get_classifier(lang).clean_text
5✔
104

105

106
###############################################################################
107
def train_classifier(
5✔
108
    parameters=None,
109
    ngram_range=(1, 1),
110
    store=True,
111
    lang="en_US",
112
    n_jobs=None,
113
    training_set=None,
114
    output_path=None,
115
):
116
    """
117
    Train the intent classifier
118
    TODO auto invoke if sklearn version is new or first install or sth
119
    @:param store (bool) store classifier in clf.joblib
120

121
    Parameters
122
    ----------
123
    parameters : dict
124
        Parameters for the SGDClassifier (see sklearn.linear_model.SGDClassifier)
125
    ngram_range : tuple
126
        Range of ngrams to use (see sklearn.feature_extraction.text.TfidfVectorizer)
127
    store : bool
128
        Save the classifier as a joblib file
129
    lang : str
130
        Language to use
131
    n_jobs : int
132
        Number of CPU jobs to use for training
133
    training_set : list
134
        Training data as a list of dicts with keys "text" and "unit". If None,
135
        the training set will be loaded from the training set file. See
136
        quantulum3._lang.en_US.train for examples.
137
    output_path : str
138
        Path to save the classifier to. If None, the classifier will be saved
139
        to the default location for the given language.
140
    """
141
    _LOGGER.info("Started training, parallelized with {} jobs".format(n_jobs))
5✔
142
    _LOGGER.info("Loading training set")
5✔
143
    if training_set is None:
5✔
144
        training_set = load.training_set(lang)
5✔
145

146
    target_names = list(frozenset([i["unit"] for i in training_set]))
5✔
147

148
    _LOGGER.info("Preparing training set")
5✔
149

150
    if n_jobs is None:
5✔
151
        try:
5✔
152
            # Retreive the number of cpus that can be used
153
            n_jobs = len(os.sched_getaffinity(0))
5✔
154
        except AttributeError:
×
155
            # n_jobs stays None such that Pool will try to
156
            # automatically set the number of processes appropriately
157
            pass
×
158
    with multiprocessing.Pool(processes=n_jobs) as p:
5✔
159
        train_data = p.map(_clean_text_lang(lang), [ex["text"] for ex in training_set])
5✔
160

161
    train_target = [target_names.index(example["unit"]) for example in training_set]
5✔
162

163
    tfidf_model = TfidfVectorizer(
5✔
164
        sublinear_tf=True,
165
        ngram_range=ngram_range,
166
        stop_words=_get_classifier(lang).stop_words(),
167
    )
168

169
    _LOGGER.info("Fit TFIDF Model")
5✔
170
    matrix = tfidf_model.fit_transform(train_data)
5✔
171

172
    if parameters is None:
5✔
173
        parameters = {
5✔
174
            "loss": "log",
175
            "penalty": "l2",
176
            "tol": 1e-3,
177
            "n_jobs": n_jobs,
178
            "alpha": 0.0001,
179
            "fit_intercept": True,
180
            "random_state": 0,
181
        }
182

183
    _LOGGER.info("Fit SGD Classifier")
5✔
184
    clf = SGDClassifier(**parameters).fit(matrix, train_target)
5✔
185
    obj = {
5✔
186
        "scikit-learn_version": pkg_resources.get_distribution("scikit-learn").version,
187
        "tfidf_model": tfidf_model,
188
        "clf": clf,
189
        "target_names": target_names,
190
    }
191
    if store:  # pragma: no cover
192
        if output_path is not None:
193
            path = output_path
194
        else:
195
            # legacy behavior
196
            path = language.topdir(lang).joinpath("clf.joblib")
197

198
        _LOGGER.info("Store classifier at {}".format(path))
199
        with path.open("wb") as file:
200
            joblib.dump(obj, file)
201
    return obj
5✔
202

203

204
###############################################################################
205
class Classifier(object):
5✔
206
    def __init__(self, obj=None, lang="en_US"):
5✔
207
        """
208
        Load the intent classifier
209
        """
210
        self.tfidf_model = None
5✔
211
        self.classifier = None
5✔
212
        self.target_names = None
5✔
213

214
        if not USE_CLF:
5✔
215
            return
×
216

217
        if not obj:
5✔
218
            path = language.topdir(lang).joinpath("clf.joblib")
5✔
219
            with path.open("rb") as file:
5✔
220
                obj = joblib.load(file)
5✔
221

222
        cur_scipy_version = pkg_resources.get_distribution("scikit-learn").version
5✔
223
        if cur_scipy_version != obj.get("scikit-learn_version"):  # pragma: no cover
224
            _LOGGER.warning(
225
                "The classifier was built using a different scikit-learn "
226
                "version (={}, !={}). The disambiguation tool could behave "
227
                "unexpectedly. Consider running classifier.train_classfier()".format(
228
                    obj.get("scikit-learn_version"), cur_scipy_version
229
                )
230
            )
231

232
        self.tfidf_model = obj["tfidf_model"]
5✔
233
        self.classifier = obj["clf"]
5✔
234
        self.target_names = obj["target_names"]
5✔
235

236

237
@cached
5✔
238
def classifier(lang="en_US"):
5✔
239
    """
240
    Cached classifier object
241
    :param lang:
242
    :return:
243
    """
244
    return Classifier(lang=lang)
5✔
245

246

247
###############################################################################
248
def disambiguate_entity(key, text, lang="en_US"):
5✔
249
    """
250
    Resolve ambiguity between entities with same dimensionality.
251
    """
252

253
    new_ent = next(iter(load.entities(lang).derived[key]))
5✔
254
    if len(load.entities(lang).derived[key]) > 1:
5✔
255
        transformed = classifier(lang).tfidf_model.transform([clean_text(text, lang)])
5✔
256
        scores = classifier(lang).classifier.predict_proba(transformed).tolist()[0]
5✔
257
        scores = zip(scores, classifier(lang).target_names)
5✔
258

259
        # Filter for possible names
260
        names = [i.name for i in load.entities(lang).derived[key]]
5✔
261
        scores = [i for i in scores if i[1] in names]
5✔
262

263
        # Sort by rank
264
        scores = sorted(scores, key=lambda x: x[0], reverse=True)
5✔
265
        try:
5✔
266
            new_ent = load.entities(lang).names[scores[0][1]]
5✔
267
        except IndexError:
×
268
            _LOGGER.debug('\tAmbiguity not resolved for "%s"', str(key))
×
269

270
    return new_ent
5✔
271

272

273
###############################################################################
274
def disambiguate_unit(unit, text, lang="en_US"):
5✔
275
    """
276
    Resolve ambiguity between units with same names, symbols or abbreviations.
277
    """
278

279
    new_unit = (
5✔
280
        load.units(lang).symbols.get(unit)
281
        or load.units(lang).surfaces.get(unit)
282
        or load.units(lang).surfaces_lower.get(unit.lower())
283
        or load.units(lang).symbols_lower.get(unit.lower())
284
    )
285
    if not new_unit:
5✔
286
        return load.units(lang).names.get("unk")
5✔
287

288
    if len(new_unit) > 1:
5✔
289
        transformed = classifier(lang).tfidf_model.transform([clean_text(text, lang)])
5✔
290
        scores = classifier(lang).classifier.predict_proba(transformed).tolist()[0]
5✔
291
        scores = zip(scores, classifier(lang).target_names)
5✔
292

293
        # Filter for possible names
294
        names = [i.name for i in new_unit]
5✔
295
        scores = [i for i in scores if i[1] in names]
5✔
296

297
        # Sort by rank
298
        scores = sorted(scores, key=lambda x: x[0], reverse=True)
5✔
299
        try:
5✔
300
            final = load.units(lang).names[scores[0][1]]
5✔
301
            _LOGGER.debug('\tAmbiguity resolved for "%s" (%s)' % (unit, scores))
5✔
302
        except IndexError:
5✔
303
            _LOGGER.debug('\tAmbiguity not resolved for "%s"' % unit)
5✔
304
            final = next(iter(new_unit))
5✔
305
    else:
306
        final = next(iter(new_unit))
5✔
307

308
    return final
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc