• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymorphy2-fork / DAWG-Python / 13370339198

17 Feb 2025 12:33PM UTC coverage: 89.367% (+2.0%) from 87.367%
13370339198

Pull #41

github

web-flow
Merge a881cb3b0 into 44e13ac5b
Pull Request #41: Add type annotations

155 of 188 branches covered (82.45%)

Branch coverage included in aggregate %.

82 of 90 new or added lines in 3 files covered. (91.11%)

2 existing lines in 1 file now uncovered.

635 of 696 relevant lines covered (91.24%)

5.47 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.82
/dawg_python/dawgs.py
1
from __future__ import annotations
6✔
2

3
import struct
6✔
4
from binascii import a2b_base64
6✔
5
from typing import TYPE_CHECKING
6✔
6

7
from . import wrapper
6✔
8

9
if TYPE_CHECKING:
6!
NEW
10
    from pathlib import Path
×
NEW
11
    from typing import Any, Iterator, Mapping
×
12

NEW
13
    from typing_extensions import Self, TypeAlias
×
14

NEW
15
    Replaces: TypeAlias = Mapping[str, str | list[str]]
×
NEW
16
    CompiledReplaces: TypeAlias = Mapping[str, list[tuple[bytes, str]]]
×
17

18

19
class DAWG:
6✔
20
    """
21
    Base DAWG wrapper.
22
    """
23

24
    dct: wrapper.Dictionary | None
6✔
25

26
    def __init__(self) -> None:
6✔
27
        self.dct = None
6✔
28

29
    def __contains__(self, key: str | bytes) -> bool:
6✔
30
        if not isinstance(key, bytes):
6✔
31
            key = key.encode("utf8")
6✔
32
        return self.dct.contains(key)
6✔
33

34
    def load(self, path: str | Path) -> Self:
6✔
35
        """
36
        Loads DAWG from a file.
37
        """
38
        self.dct = wrapper.Dictionary.load(path)
6✔
39
        return self
6✔
40

41
    def _has_value(self, index: int) -> bool:
6✔
42
        return self.dct.has_value(index)
6✔
43

44
    def _similar_keys(self, current_prefix: str, key: str, index: int, replace_chars: CompiledReplaces) -> list[str]:
6✔
45
        res = []
6✔
46
        start_pos = len(current_prefix)
6✔
47
        end_pos = len(key)
6✔
48
        word_pos = start_pos
6✔
49

50
        while word_pos < end_pos:
6✔
51
            b_step = key[word_pos].encode("utf8")
6✔
52

53
            if b_step in replace_chars:
6✔
54
                for b_replace_char, u_replace_char in replace_chars[b_step]:
6✔
55
                    next_index = index
6✔
56

57
                    next_index = self.dct.follow_bytes(b_replace_char, next_index)
6✔
58

59
                    if next_index:
6✔
60
                        prefix = current_prefix + key[start_pos:word_pos] + u_replace_char
6✔
61
                        extra_keys = self._similar_keys(prefix, key, next_index, replace_chars)
6✔
62
                        res += extra_keys
6✔
63

64
            index = self.dct.follow_bytes(b_step, index)
6✔
65
            if index is None:
6✔
66
                break
6✔
67
            word_pos += 1
6✔
68

69
        else:
70
            if self._has_value(index):
6!
71
                found_key = current_prefix + key[start_pos:]
6✔
72
                res.insert(0, found_key)
6✔
73

74
        return res
6✔
75

76
    def similar_keys(self, key: str, replaces: CompiledReplaces) -> list[str]:
6✔
77
        """
78
        Returns all variants of ``key`` in this DAWG according to
79
        ``replaces``.
80

81
        ``replaces`` is an object obtained from
82
        ``DAWG.compile_replaces(mapping)`` where mapping is a dict
83
        that maps single-char unicode strings to (one or more) single-char
84
        unicode strings.
85

86
        This may be useful e.g. for handling single-character umlauts.
87
        """
88
        return self._similar_keys("", key, self.dct.ROOT, replaces)
6✔
89

90
    @staticmethod
6✔
91
    def compile_replaces(replaces: Replaces) -> CompiledReplaces:
6✔
92
        for k, v in replaces.items():
6✔
93
            if len(k) != 1:
6!
94
                msg = "Keys must be single-char unicode strings."
×
95
                raise ValueError(msg)
×
96
            if isinstance(v, str) and len(v) != 1:
6!
97
                msg = "Values must be single-char unicode strings or non-empty lists of such."
×
98
                raise ValueError(msg)
×
99
            if isinstance(v, list) and (any(len(v_entry) != 1 for v_entry in v) or len(v) < 1):
6!
100
                msg = "Values must be single-char unicode strings or non-empty lists of such."
×
101
                raise ValueError(msg)
×
102

103
        return {k.encode("utf8"): [(v_entry.encode("utf8"), v_entry) for v_entry in v] for k, v in replaces.items()}
6✔
104

105
    def prefixes(self, key: str | bytes) -> list[str]:
6✔
106
        """
107
        Returns a list with keys of this DAWG that are prefixes of the ``key``.
108
        """
109
        res = []
6✔
110
        index = self.dct.ROOT
6✔
111
        if not isinstance(key, bytes):
6!
112
            key = key.encode("utf8")
6✔
113

114
        pos = 1
6✔
115

116
        for ch in key:
6✔
117
            index = self.dct.follow_char(ch, index)
6✔
118
            if not index:
6✔
119
                break
6✔
120

121
            if self._has_value(index):
6✔
122
                res.append(key[:pos].decode("utf8"))
6✔
123
            pos += 1
6✔
124

125
        return res
6✔
126

127

128
class CompletionDAWG(DAWG):
6✔
129
    """
130
    DAWG with key completion support.
131
    """
132

133
    guide: wrapper.Guide | None
6✔
134

135
    def __init__(self) -> None:
6✔
136
        super().__init__()
6✔
137
        self.guide = None
6✔
138

139
    def keys(self, prefix: str = "") -> list[str]:
6✔
140
        return list(self.iterkeys(prefix))
6✔
141

142
    def iterkeys(self, prefix: str = "") -> Iterator[str]:
6✔
143
        b_prefix = prefix.encode("utf8")
6✔
144
        index = self.dct.follow_bytes(b_prefix, self.dct.ROOT)
6✔
145
        if index is None:
6✔
146
            return
6✔
147

148
        completer = wrapper.Completer(self.dct, self.guide)
6✔
149
        completer.start(index, b_prefix)
6✔
150

151
        while completer.next():
6✔
152
            yield completer.key.decode("utf8")
6✔
153

154
    def load(self, path: str | Path) -> Self:
6✔
155
        """
156
        Loads DAWG from a file.
157
        """
158
        self.dct = wrapper.Dictionary()
6✔
159
        self.guide = wrapper.Guide()
6✔
160

161
        with open(path, "rb") as f:
6✔
162
            self.dct.read(f)
6✔
163
            self.guide.read(f)
6✔
164

165
        return self
6✔
166

167

168
PAYLOAD_SEPARATOR = b"\x01"
6✔
169
MAX_VALUE_SIZE = 32768
6✔
170

171

172
class BytesDAWG(CompletionDAWG):
6✔
173
    """
174
    DAWG that is able to transparently store extra binary payload in keys;
175
    there may be several payloads for the same key.
176

177
    In other words, this class implements read-only DAWG-based
178
    {unicode -> list of bytes objects} mapping.
179
    """
180

181
    def __init__(self, payload_separator: bytes = PAYLOAD_SEPARATOR) -> None:
6✔
182
        super().__init__()
6✔
183
        self._payload_separator = payload_separator
6✔
184

185
    def __contains__(self, key: str | bytes) -> bool:
6✔
186
        if not isinstance(key, bytes):
6!
187
            key = key.encode("utf8")
6✔
188
        return bool(self._follow_key(key))
6✔
189

190
    def __getitem__(self, key: str | bytes) -> list[bytes]:
6✔
191
        res = self.get(key)
6✔
192
        if res is None:
6✔
193
            raise KeyError(key)
6✔
194
        return res
6✔
195

196
    def get(self, key: str | bytes, default: list[bytes] | None = None) -> list[bytes] | None:
6✔
197
        """
198
        Returns a list of payloads (as byte objects) for a given key
199
        or ``default`` if the key is not found.
200
        """
201
        if not isinstance(key, bytes):
6!
202
            key = key.encode("utf8")
6✔
203

204
        return self.b_get_value(key) or default
6✔
205

206
    def _follow_key(self, b_key: bytes) -> int | None:
6✔
207
        index = self.dct.follow_bytes(b_key, self.dct.ROOT)
6✔
208
        if not index:
6✔
209
            return None
6✔
210

211
        index = self.dct.follow_bytes(self._payload_separator, index)
6✔
212
        if not index:
6✔
213
            return None
6✔
214

215
        return index
6✔
216

217
    def _value_for_index(self, index: int) -> list[bytes]:
6✔
218
        res = []
6✔
219

220
        completer = wrapper.Completer(self.dct, self.guide)
6✔
221

222
        completer.start(index)
6✔
223
        while completer.next():
6✔
224
            b64_data = completer.key
6✔
225
            res.append(a2b_base64(b64_data))
6✔
226

227
        return res
6✔
228

229
    def b_get_value(self, b_key: bytes) -> list[bytes]:
6✔
230
        index = self._follow_key(b_key)
6✔
231
        if not index:
6✔
232
            return []
6✔
233
        return self._value_for_index(index)
6✔
234

235
    def keys(self, prefix: str | bytes = "") -> list[str]:
6✔
236
        return list(self.iterkeys(prefix))
6✔
237

238
    def iterkeys(self, prefix: str | bytes = "") -> Iterator[bytes]:
6✔
239
        if not isinstance(prefix, bytes):
6!
240
            prefix = prefix.encode("utf8")
6✔
241

242
        index = self.dct.ROOT
6✔
243

244
        if prefix:
6✔
245
            index = self.dct.follow_bytes(prefix, index)
6✔
246
            if not index:
6✔
247
                return
6✔
248

249
        completer = wrapper.Completer(self.dct, self.guide)
6✔
250
        completer.start(index, prefix)
6✔
251

252
        while completer.next():
6✔
253
            payload_idx = completer.key.index(self._payload_separator)
6✔
254
            u_key = completer.key[:payload_idx].decode("utf8")
6✔
255
            yield u_key
6✔
256

257
    def items(self, prefix: str | bytes = "") -> list[tuple[str, bytes]]:
6✔
258
        return list(self.iteritems(prefix))
6✔
259

260
    def iteritems(self, prefix: str | bytes = "") -> Iterator[tuple[str, bytes]]:
6✔
261
        if not isinstance(prefix, bytes):
6!
262
            prefix = prefix.encode("utf8")
6✔
263

264
        index = self.dct.ROOT
6✔
265
        if prefix:
6✔
266
            index = self.dct.follow_bytes(prefix, index)
6✔
267
            if not index:
6✔
268
                return
6✔
269

270
        completer = wrapper.Completer(self.dct, self.guide)
6✔
271
        completer.start(index, prefix)
6✔
272

273
        while completer.next():
6✔
274
            key, value = completer.key.split(self._payload_separator)
6✔
275
            item = (key.decode("utf8"), a2b_base64(value))
6✔
276
            yield item
6✔
277

278
    def _has_value(self, index: int) -> int | None:
6✔
279
        return self.dct.follow_bytes(PAYLOAD_SEPARATOR, index)
6✔
280

281
    def _similar_items(
6✔
282
        self,
283
        current_prefix: str,
284
        key: str,
285
        index: int,
286
        replace_chars: CompiledReplaces,
287
    ) -> list[tuple[str, bytes]]:
288
        res = []
6✔
289
        start_pos = len(current_prefix)
6✔
290
        end_pos = len(key)
6✔
291
        word_pos = start_pos
6✔
292

293
        while word_pos < end_pos:
6✔
294
            b_step = key[word_pos].encode("utf8")
6✔
295

296
            if b_step in replace_chars:
6✔
297
                for b_replace_char, u_replace_char in replace_chars[b_step]:
6✔
298
                    next_index = index
6✔
299

300
                    next_index = self.dct.follow_bytes(b_replace_char, next_index)
6✔
301

302
                    if next_index:
6✔
303
                        prefix = current_prefix + key[start_pos:word_pos] + u_replace_char
6✔
304
                        extra_items = self._similar_items(prefix, key, next_index, replace_chars)
6✔
305
                        res += extra_items
6✔
306

307
            index = self.dct.follow_bytes(b_step, index)
6✔
308
            if not index:
6✔
309
                break
6✔
310
            word_pos += 1
6✔
311

312
        else:
313
            index = self.dct.follow_bytes(self._payload_separator, index)
6✔
314
            if index:
6!
315
                found_key = current_prefix + key[start_pos:]
6✔
316
                value = self._value_for_index(index)
6✔
317
                res.insert(0, (found_key, value))
6✔
318

319
        return res
6✔
320

321
    def similar_items(self, key: str, replaces: CompiledReplaces) -> list[tuple[str, bytes]]:
6✔
322
        """
323
        Returns a list of (key, value) tuples for all variants of ``key``
324
        in this DAWG according to ``replaces``.
325

326
        ``replaces`` is an object obtained from
327
        ``DAWG.compile_replaces(mapping)`` where mapping is a dict
328
        that maps single-char unicode strings to (one or more) single-char
329
        unicode strings.
330
        """
331
        return self._similar_items("", key, self.dct.ROOT, replaces)
6✔
332

333
    def _similar_item_values(
6✔
334
        self,
335
        start_pos: int,
336
        key: str,
337
        index: int,
338
        replace_chars: CompiledReplaces,
339
    ) -> list[bytes]:
340
        res = []
6✔
341
        end_pos = len(key)
6✔
342
        word_pos = start_pos
6✔
343

344
        while word_pos < end_pos:
6✔
345
            b_step = key[word_pos].encode("utf8")
6✔
346

347
            if b_step in replace_chars:
6✔
348
                for b_replace_char, _u_replace_char in replace_chars[b_step]:
6✔
349
                    next_index = index
6✔
350

351
                    next_index = self.dct.follow_bytes(b_replace_char, next_index)
6✔
352

353
                    if next_index:
6✔
354
                        extra_items = self._similar_item_values(word_pos + 1, key, next_index, replace_chars)
6✔
355
                        res += extra_items
6✔
356

357
            index = self.dct.follow_bytes(b_step, index)
6✔
358
            if not index:
6✔
359
                break
6✔
360
            word_pos += 1
6✔
361

362
        else:
363
            index = self.dct.follow_bytes(self._payload_separator, index)
6✔
364
            if index:
6!
365
                value = self._value_for_index(index)
6✔
366
                res.insert(0, value)
6✔
367

368
        return res
6✔
369

370
    def similar_item_values(self, key: str, replaces: CompiledReplaces) -> list[bytes]:
6✔
371
        """
372
        Returns a list of values for all variants of the ``key``
373
        in this DAWG according to ``replaces``.
374

375
        ``replaces`` is an object obtained from
376
        ``DAWG.compile_replaces(mapping)`` where mapping is a dict
377
        that maps single-char unicode strings to (one or more) single-char
378
        unicode strings.
379
        """
380
        return self._similar_item_values(0, key, self.dct.ROOT, replaces)
6✔
381

382

383
class RecordDAWG(BytesDAWG):
6✔
384
    def __init__(self, fmt: str | bytes, payload_separator: bytes = PAYLOAD_SEPARATOR) -> None:
6✔
385
        super().__init__(payload_separator)
6✔
386
        self._struct = struct.Struct(fmt)
6✔
387
        self.fmt = fmt
6✔
388

389
    def _value_for_index(self, index: int) -> list[tuple[Any, ...]]:
6✔
390
        value = super()._value_for_index(index)
6✔
391
        return [self._struct.unpack(val) for val in value]
6✔
392

393
    def items(self, prefix: str | bytes = "") -> list[tuple[str, tuple[Any, ...]]]:
6✔
394
        return list(self.iteritems(prefix))
6✔
395

396
    def iteritems(self, prefix: str | bytes = "") -> Iterator[tuple[str, tuple[Any, ...]]]:
6✔
397
        res = super().iteritems(prefix)
6✔
398
        return ((key, self._struct.unpack(val)) for (key, val) in res)
6✔
399

400

401
LOOKUP_ERROR = -1
6✔
402

403

404
class IntDAWG(DAWG):
6✔
405
    """
406
    Dict-like class based on DAWG.
407
    It can store integer values for unicode keys.
408
    """
409

410
    def __getitem__(self, key: str | bytes) -> int | None:
6✔
411
        res = self.get(key, LOOKUP_ERROR)
6✔
412
        if res == LOOKUP_ERROR:
6✔
413
            raise KeyError(key)
6✔
414
        return res
6✔
415

416
    def get(self, key: str | bytes, default: int | None = None) -> int | None:
6✔
417
        """
418
        Return value for the given key or ``default`` if the key is not found.
419
        """
420
        if not isinstance(key, bytes):
6!
421
            key = key.encode("utf8")
6✔
422
        res = self.b_get_value(key)
6✔
423
        if res == LOOKUP_ERROR:
6✔
424
            return default
6✔
425
        return res
6✔
426

427
    def b_get_value(self, key: bytes) -> int:
6✔
428
        return self.dct.find(key)
6✔
429

430

431
class IntCompletionDAWG(CompletionDAWG, IntDAWG):
6✔
432
    """
433
    Dict-like class based on DAWG.
434
    It can store integer values for unicode keys and support key completion.
435
    """
436

437
    def items(self, prefix: str | bytes = "") -> list[tuple[str, int]]:
6✔
438
        return list(self.iteritems(prefix))
6✔
439

440
    def iteritems(self, prefix: str | bytes = "") -> Iterator[tuple[str, int]]:
6✔
441
        if not isinstance(prefix, bytes):
6!
442
            prefix = prefix.encode("utf8")
6✔
443
        index = self.dct.ROOT
6✔
444

445
        if prefix:
6!
446
            index = self.dct.follow_bytes(prefix, index)
×
447
            if not index:
×
448
                return
×
449

450
        completer = wrapper.Completer(self.dct, self.guide)
6✔
451
        completer.start(index, prefix)
6✔
452

453
        while completer.next():
6✔
454
            yield completer.key.decode("utf8"), completer.value()
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc