• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PyThaiNLP / pythainlp / 11625814262

01 Nov 2024 07:14AM UTC coverage: 20.782% (+20.8%) from 0.0%
11625814262

Pull #952

github

web-flow
Merge c8385dcae into 515fe7ced
Pull Request #952: Specify a limited test suite

45 of 80 new or added lines in 48 files covered. (56.25%)

1537 of 7396 relevant lines covered (20.78%)

0.21 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/pythainlp/tag/wangchanberta_onnx.py
1
# -*- coding: utf-8 -*-
2
# SPDX-FileCopyrightText: 2016-2024 PyThaiNLP Project
3
# SPDX-License-Identifier: Apache-2.0
4
import json
×
NEW
5
from typing import List
×
6

7
import numpy as np
×
8

9
from pythainlp.corpus import get_path_folder_corpus
×
10

11

12
class WngchanBerta_ONNX:
×
13
    def __init__(
×
14
        self,
15
        model_name: str,
16
        model_version: str,
17
        file_onnx: str,
18
        providers: List[str] = ["CPUExecutionProvider"],
19
    ) -> None:
20
        import sentencepiece as spm
×
21
        from onnxruntime import (
×
22
            GraphOptimizationLevel,
23
            InferenceSession,
24
            SessionOptions,
25
        )
26

27
        self.model_name = model_name
×
28
        self.model_version = model_version
×
29
        self.options = SessionOptions()
×
30
        self.options.graph_optimization_level = (
×
31
            GraphOptimizationLevel.ORT_ENABLE_ALL
32
        )
33
        self.session = InferenceSession(
×
34
            get_path_folder_corpus(
35
                self.model_name, self.model_version, file_onnx
36
            ),
37
            sess_options=self.options,
38
            providers=providers,
39
        )
40
        self.session.disable_fallback()
×
41
        self.outputs_name = self.session.get_outputs()[0].name
×
42
        self.sp = spm.SentencePieceProcessor(
×
43
            model_file=get_path_folder_corpus(
44
                self.model_name, self.model_version, "sentencepiece.bpe.model"
45
            )
46
        )
47
        with open(
×
48
            get_path_folder_corpus(
49
                self.model_name, self.model_version, "config.json"
50
            ),
51
            encoding="utf-8-sig",
52
        ) as fh:
53
            self._json = json.load(fh)
×
54
            self.id2tag = self._json["id2label"]
×
55

56
    def build_tokenizer(self, sent):
×
57
        _t = [5] + [i + 4 for i in self.sp.encode(sent)] + [6]
×
58
        model_inputs = {}
×
59
        model_inputs["input_ids"] = np.array([_t], dtype=np.int64)
×
60
        model_inputs["attention_mask"] = np.array(
×
61
            [[1] * len(_t)], dtype=np.int64
62
        )
63
        return model_inputs
×
64

65
    def postprocess(self, logits_data):
×
66
        logits_t = logits_data[0]
×
67
        maxes = np.max(logits_t, axis=-1, keepdims=True)
×
68
        shifted_exp = np.exp(logits_t - maxes)
×
69
        scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
×
70
        return scores
×
71

72
    def clean_output(self, list_text):
×
73
        return list_text
×
74

75
    def totag(self, post, sent):
×
76
        tag = []
×
77
        _s = self.sp.EncodeAsPieces(sent)
×
78
        for i in range(len(_s)):
×
79
            tag.append(
×
80
                (
81
                    _s[i],
82
                    self.id2tag[
83
                        str(list(post[i + 1]).index(max(list(post[i + 1]))))
84
                    ],
85
                )
86
            )
87
        return tag
×
88

89
    def _config(self, list_ner):
×
90
        return list_ner
×
91

92
    def get_ner(self, text: str, tag: bool = False):
×
93
        self._s = self.build_tokenizer(text)
×
94
        logits = self.session.run(
×
95
            output_names=[self.outputs_name], input_feed=self._s
96
        )[0]
97
        _tag = self.clean_output(self.totag(self.postprocess(logits), text))
×
98
        if tag:
×
99
            _tag = self._config(_tag)
×
100
            temp = ""
×
101
            sent = ""
×
102
            for idx, (word, ner) in enumerate(_tag):
×
103
                if ner.startswith("B-") and temp != "":
×
104
                    sent += "</" + temp + ">"
×
105
                    temp = ner[2:]
×
106
                    sent += "<" + temp + ">"
×
107
                elif ner.startswith("B-"):
×
108
                    temp = ner[2:]
×
109
                    sent += "<" + temp + ">"
×
110
                elif ner == "O" and temp != "":
×
111
                    sent += "</" + temp + ">"
×
112
                    temp = ""
×
113
                sent += word
×
114

115
                if idx == len(_tag) - 1 and temp != "":
×
116
                    sent += "</" + temp + ">"
×
117

118
            return sent
×
119
        else:
120
            return _tag
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc