• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PyThaiNLP / pythainlp / 5337431273

pending completion
5337431273

push

github

wannaphong
Add กาลพฤกษ์ to list words

3573 of 6329 relevant lines covered (56.45%)

0.56 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/pythainlp/parse/transformers_ud.py
1
# -*- coding: utf-8 -*-
2
"""
×
3
TransformersUD
4

5
Author: Prof. Koichi Yasuoka
6

7
This tagger is provided under the terms of the apache-2.0 License.
8

9
The source: https://huggingface.co/KoichiYasuoka/deberta-base-thai-ud-head
10

11
GitHub: https://github.com/KoichiYasuoka
12
"""
13
import os
×
14
from typing import List, Union
×
15
import numpy
×
16
import torch
×
17
import ufal.chu_liu_edmonds
×
18
from transformers import (
×
19
    AutoTokenizer,
20
    AutoModelForQuestionAnswering,
21
    AutoModelForTokenClassification,
22
    AutoConfig,
23
    TokenClassificationPipeline,
24
)
25
from transformers.utils import cached_file
×
26

27

28
class Parse:
×
29
    def __init__(
×
30
        self, model: str = "KoichiYasuoka/deberta-base-thai-ud-head"
31
    ) -> None:
32
        if model == None:
×
33
            model = "KoichiYasuoka/deberta-base-thai-ud-head"
×
34
        self.tokenizer = AutoTokenizer.from_pretrained(model)
×
35
        self.model = AutoModelForQuestionAnswering.from_pretrained(model)
×
36
        x = AutoModelForTokenClassification.from_pretrained
×
37
        if os.path.isdir(model):
×
38
            d, t = x(os.path.join(model, "deprel")), x(
×
39
                os.path.join(model, "tagger")
40
            )
41
        else:
42
            c = AutoConfig.from_pretrained(
×
43
                cached_file(model, "deprel/config.json")
44
            )
45
            d = x(cached_file(model, "deprel/pytorch_model.bin"), config=c)
×
46
            s = AutoConfig.from_pretrained(
×
47
                cached_file(model, "tagger/config.json")
48
            )
49
            t = x(cached_file(model, "tagger/pytorch_model.bin"), config=s)
×
50
        self.deprel = TokenClassificationPipeline(
×
51
            model=d, tokenizer=self.tokenizer, aggregation_strategy="simple"
52
        )
53
        self.tagger = TokenClassificationPipeline(
×
54
            model=t, tokenizer=self.tokenizer
55
        )
56

57
    def __call__(
×
58
        self, text: str, tag: str = "str"
59
    ) -> Union[List[List[str]], str]:
60
        w = [
×
61
            (t["start"], t["end"], t["entity_group"])
62
            for t in self.deprel(text)
63
        ]
64
        z, n = {
×
65
            t["start"]: t["entity"].split("|") for t in self.tagger(text)
66
        }, len(w)
67
        r, m = [text[s:e] for s, e, p in w], numpy.full(
×
68
            (n + 1, n + 1), numpy.nan
69
        )
70
        v, c = self.tokenizer(r, add_special_tokens=False)["input_ids"], []
×
71
        for i, t in enumerate(v):
×
72
            q = (
×
73
                [self.tokenizer.cls_token_id]
74
                + t
75
                + [self.tokenizer.sep_token_id]
76
            )
77
            c.append(
×
78
                [q]
79
                + v[0:i]
80
                + [[self.tokenizer.mask_token_id]]
81
                + v[i + 1 :]
82
                + [[q[-1]]]
83
            )
84
        b = [[len(sum(x[0 : j + 1], [])) for j in range(len(x))] for x in c]
×
85
        with torch.no_grad():
×
86
            d = self.model(
×
87
                input_ids=torch.tensor([sum(x, []) for x in c]),
88
                token_type_ids=torch.tensor(
89
                    [[0] * x[0] + [1] * (x[-1] - x[0]) for x in b]
90
                ),
91
            )
92
        s, e = d.start_logits.tolist(), d.end_logits.tolist()
×
93
        for i in range(n):
×
94
            for j in range(n):
×
95
                m[i + 1, 0 if i == j else j + 1] = (
×
96
                    s[i][b[i][j]] + e[i][b[i][j + 1] - 1]
97
                )
98
        h = ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0]
×
99
        if [0 for i in h if i == 0] != [0]:
×
100
            i = ([p for s, e, p in w] + ["root"]).index("root")
×
101
            j = i + 1 if i < n else numpy.nanargmax(m[:, 0])
×
102
            m[0:j, 0] = m[j + 1 :, 0] = numpy.nan
×
103
            h = ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0]
×
104
        u = ""
×
105
        if tag == "list":
×
106
            _tag_data = []
×
107
            for i, (s, e, p) in enumerate(w, 1):
×
108
                p = "root" if h[i] == 0 else "dep" if p == "root" else p
×
109
                _tag_data.append(
×
110
                    [
111
                        str(i),
112
                        r[i - 1],
113
                        "_",
114
                        z[s][0][2:],
115
                        "_",
116
                        "|".join(z[s][1:]),
117
                        str(h[i]),
118
                        p,
119
                        "_",
120
                        "_" if i < n and e < w[i][0] else "SpaceAfter=No",
121
                    ]
122
                )
123
            return _tag_data
×
124
        for i, (s, e, p) in enumerate(w, 1):
×
125
            p = "root" if h[i] == 0 else "dep" if p == "root" else p
×
126
            u += (
×
127
                "\t".join(
128
                    [
129
                        str(i),
130
                        r[i - 1],
131
                        "_",
132
                        z[s][0][2:],
133
                        "_",
134
                        "|".join(z[s][1:]),
135
                        str(h[i]),
136
                        p,
137
                        "_",
138
                        "_" if i < n and e < w[i][0] else "SpaceAfter=No",
139
                    ]
140
                )
141
                + "\n"
142
            )
143
        return u + "\n"
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc