• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PyThaiNLP / pythainlp / 11625814262

01 Nov 2024 07:14AM UTC coverage: 20.782% (+20.8%) from 0.0%
11625814262

Pull #952

github

web-flow
Merge c8385dcae into 515fe7ced
Pull Request #952: Specify a limited test suite

45 of 80 new or added lines in 48 files covered. (56.25%)

1537 of 7396 relevant lines covered (20.78%)

0.21 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

3.6
/pythainlp/wangchanberta/core.py
1
# -*- coding: utf-8 -*-
2
# SPDX-FileCopyrightText: 2016-2024 PyThaiNLP Project
3
# SPDX-License-Identifier: Apache-2.0
4
import re
1✔
5
import warnings
1✔
6
from typing import List, Tuple, Union
1✔
7

8
from transformers import (
1✔
9
    CamembertTokenizer,
10
    pipeline,
11
)
12

13
from pythainlp.tokenize import word_tokenize
×
14

15
_model_name = "wangchanberta-base-att-spm-uncased"
×
16
_tokenizer = CamembertTokenizer.from_pretrained(
×
17
    f"airesearch/{_model_name}", revision="main"
18
)
19
if _model_name == "wangchanberta-base-att-spm-uncased":
×
20
    _tokenizer.additional_special_tokens = ["<s>NOTUSED", "</s>NOTUSED", "<_>"]
×
21

22

23
class ThaiNameTagger:
×
24
    def __init__(
×
25
        self, dataset_name: str = "thainer", grouped_entities: bool = True
26
    ):
27
        """
28
        This function tags named entities in text in IOB format.
29

30
        Powered by wangchanberta from VISTEC-depa\
31
             AI Research Institute of Thailand
32

33
        :param str dataset_name:
34
            * *thainer* - ThaiNER dataset
35
        :param bool grouped_entities: grouped entities
36
        """
37
        self.dataset_name = dataset_name
×
38
        self.grouped_entities = grouped_entities
×
39
        self.classify_tokens = pipeline(
×
40
            task="ner",
41
            tokenizer=_tokenizer,
42
            model=f"airesearch/{_model_name}",
43
            revision=f"finetuned@{self.dataset_name}-ner",
44
            ignore_labels=[],
45
            grouped_entities=self.grouped_entities,
46
        )
47

48
    def _IOB(self, tag):
×
49
        if tag != "O":
×
50
            return "B-" + tag
×
51
        return "O"
×
52

53
    def _clear_tag(self, tag):
×
54
        return tag.replace("B-", "").replace("I-", "")
×
55

56
    def get_ner(
×
57
        self, text: str, pos: bool = False, tag: bool = False
58
    ) -> Union[List[Tuple[str, str]], str]:
59
        """
60
        This function tags named entities in text in IOB format.
61
        Powered by wangchanberta from VISTEC-depa\
62
             AI Research Institute of Thailand
63

64
        :param str text: text in Thai to be tagged
65
        :param bool tag: output HTML-like tags.
66
        :return: a list of tuples associated with tokenized word groups,\
67
            NER tags, and output HTML-like tags (if the parameter `tag` is \
68
            specified as `True`). \
69
            Otherwise, return a list of tuples associated with tokenized \
70
            words and NER tags
71
        :rtype: Union[list[tuple[str, str]]], str
72
        """
73
        if pos:
×
74
            warnings.warn(
×
75
                "This model doesn't support output of POS tags and it doesn't output the POS tags."
76
            )
77
        text = re.sub(" ", "<_>", text)
×
78
        self.json_ner = self.classify_tokens(text)
×
79
        self.output = ""
×
80
        if self.grouped_entities and self.dataset_name == "thainer":
×
81
            self.sent_ner = [
×
82
                (
83
                    i["word"].replace("<_>", " ").replace("▁", ""),
84
                    self._IOB(i["entity_group"]),
85
                )
86
                for i in self.json_ner
87
            ]
88
        elif self.dataset_name == "thainer":
×
89
            self.sent_ner = [
×
90
                (i["word"].replace("<_>", " ").replace("▁", ""), i["entity"])
91
                for i in self.json_ner
92
                if i["word"] != "▁"
93
            ]
94
        else:
95
            self.sent_ner = [
×
96
                (
97
                    i["word"].replace("<_>", " ").replace("▁", ""),
98
                    i["entity"].replace("_", "-").replace("E-", "I-"),
99
                )
100
                for i in self.json_ner
101
            ]
102
        if self.sent_ner[0][0] == "" and len(self.sent_ner) > 1:
×
103
            self.sent_ner = self.sent_ner[1:]
×
104
        for idx, (word, ner) in enumerate(self.sent_ner):
×
105
            if idx > 0 and ner.startswith("B-"):
×
106
                if self._clear_tag(ner) == self._clear_tag(
×
107
                    self.sent_ner[idx - 1][1]
108
                ):
109
                    self.sent_ner[idx] = (word, ner.replace("B-", "I-"))
×
110
        if tag:
×
111
            temp = ""
×
112
            sent = ""
×
113
            for idx, (word, ner) in enumerate(self.sent_ner):
×
114
                if ner.startswith("B-") and temp != "":
×
115
                    sent += "</" + temp + ">"
×
116
                    temp = ner[2:]
×
117
                    sent += "<" + temp + ">"
×
118
                elif ner.startswith("B-"):
×
119
                    temp = ner[2:]
×
120
                    sent += "<" + temp + ">"
×
121
                elif ner == "O" and temp != "":
×
122
                    sent += "</" + temp + ">"
×
123
                    temp = ""
×
124
                sent += word
×
125

126
                if idx == len(self.sent_ner) - 1 and temp != "":
×
127
                    sent += "</" + temp + ">"
×
128

129
            return sent
×
130
        else:
131
            return self.sent_ner
×
132

133

134
class NamedEntityRecognition:
×
135
    def __init__(
×
136
        self, model: str = "pythainlp/thainer-corpus-v2-base-model"
137
    ) -> None:
138
        """
139
        This function tags named entities in text in IOB format.
140

141
        Powered by wangchanberta from VISTEC-depa\
142
             AI Research Institute of Thailand
143
        :param str model: The model that use wangchanberta pretrained.
144
        """
NEW
145
        from transformers import AutoModelForTokenClassification, AutoTokenizer
×
146

147
        self.tokenizer = AutoTokenizer.from_pretrained(model)
×
148
        self.model = AutoModelForTokenClassification.from_pretrained(model)
×
149

150
    def _fix_span_error(self, words, ner):
×
151
        _ner = []
×
152
        _ner = ner
×
153
        _new_tag = []
×
154
        for i, j in zip(words, _ner):
×
155
            i = self.tokenizer.decode(i)
×
156
            if i.isspace() and j.startswith("B-"):
×
157
                j = "O"
×
158
            if i in ("", "<s>", "</s>"):
×
159
                continue
×
160
            if i == "<_>":
×
161
                i = " "
×
162
            _new_tag.append((i, j))
×
163
        return _new_tag
×
164

165
    def get_ner(
×
166
        self, text: str, pos: bool = False, tag: bool = False
167
    ) -> Union[List[Tuple[str, str]], str]:
168
        """
169
        This function tags named entities in text in IOB format.
170
        Powered by wangchanberta from VISTEC-depa\
171
             AI Research Institute of Thailand
172

173
        :param str text: text in Thai to be tagged
174
        :param bool tag: output HTML-like tags.
175
        :return: a list of tuples associated with tokenized word groups, NER tags, \
176
                 and output HTML-like tags (if the parameter `tag` is \
177
                 specified as `True`). \
178
                 Otherwise, return a list of tuples associated with tokenized \
179
                 words and NER tags
180
        :rtype: Union[list[tuple[str, str]]], str
181
        """
182
        import torch
×
183

184
        if pos:
×
185
            warnings.warn(
×
186
                "This model doesn't support output postag and It doesn't output the postag."
187
            )
188
        words_token = word_tokenize(text.replace(" ", "<_>"))
×
189
        inputs = self.tokenizer(
×
190
            words_token, is_split_into_words=True, return_tensors="pt"
191
        )
192
        ids = inputs["input_ids"]
×
193
        mask = inputs["attention_mask"]
×
194
        # forward pass
195
        outputs = self.model(ids, attention_mask=mask)
×
196
        logits = outputs[0]
×
197
        predictions = torch.argmax(logits, dim=2)
×
198
        predicted_token_class = [
×
199
            self.model.config.id2label[t.item()] for t in predictions[0]
200
        ]
201
        ner_tag = self._fix_span_error(
×
202
            inputs["input_ids"][0], predicted_token_class
203
        )
204
        if tag:
×
205
            temp = ""
×
206
            sent = ""
×
207
            for idx, (word, ner) in enumerate(ner_tag):
×
208
                if ner.startswith("B-") and temp != "":
×
209
                    sent += "</" + temp + ">"
×
210
                    temp = ner[2:]
×
211
                    sent += "<" + temp + ">"
×
212
                elif ner.startswith("B-"):
×
213
                    temp = ner[2:]
×
214
                    sent += "<" + temp + ">"
×
215
                elif ner == "O" and temp != "":
×
216
                    sent += "</" + temp + ">"
×
217
                    temp = ""
×
218
                sent += word
×
219

220
                if idx == len(ner_tag) - 1 and temp != "":
×
221
                    sent += "</" + temp + ">"
×
222

223
            return sent
×
224
        return ner_tag
×
225

226

227
def segment(text: str) -> List[str]:
×
228
    """
229
    Subword tokenize. SentencePiece from wangchanberta model.
230

231
    :param str text: text to be tokenized
232
    :return: list of subwords
233
    :rtype: list[str]
234
    """
235
    if not text or not isinstance(text, str):
×
236
        return []
×
237

238
    return _tokenizer.tokenize(text)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc