• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PyThaiNLP / pythainlp / 11625814262

01 Nov 2024 07:14AM UTC coverage: 20.782% (+20.8%) from 0.0%
11625814262

Pull #952

github

web-flow
Merge c8385dcae into 515fe7ced
Pull Request #952: Specify a limited test suite

45 of 80 new or added lines in 48 files covered. (56.25%)

1537 of 7396 relevant lines covered (20.78%)

0.21 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/pythainlp/transliterate/thai2rom_onnx.py
1
# -*- coding: utf-8 -*-
2
# SPDX-FileCopyrightText: 2016-2024 PyThaiNLP Project
3
# SPDX-License-Identifier: Apache-2.0
4
"""
5
Romanization of Thai words based on machine-learnt engine in ONNX runtime ("thai2rom")
6
"""
7
import json
×
8

9
import numpy as np
×
10
from onnxruntime import InferenceSession
×
11

NEW
12
from pythainlp.corpus import get_corpus_path
×
13

14
_MODEL_ENCODER_NAME = "thai2rom_encoder_onnx"
×
15
_MODEL_DECODER_NAME = "thai2rom_decoder_onnx"
×
16
_MODEL_CONFIG_NAME = "thai2rom_config_onnx"
×
17

18

19
class ThaiTransliterator_ONNX:
×
20
    def __init__(self):
×
21
        """
22
        Transliteration of Thai words.
23

24
        Now supports Thai to Latin (romanization)
25
        """
26
        # get the model, download it if it's not available locally
27
        self.__encoder_filename = get_corpus_path(_MODEL_ENCODER_NAME)
×
28
        self.__decoder_filename = get_corpus_path(_MODEL_DECODER_NAME)
×
29
        self.__config_filename = get_corpus_path(_MODEL_CONFIG_NAME)
×
30

31
        # loader = torch.load(self.__model_filename, map_location=device)
32
        with open(str(self.__config_filename)) as f:
×
33
            loader = json.load(f)
×
34

35
        OUTPUT_DIM = loader["output_dim"]
×
36

37
        self._maxlength = 100
×
38

39
        self._char_to_ix = loader["char_to_ix"]
×
40
        self._ix_to_char = loader["ix_to_char"]
×
41
        self._target_char_to_ix = loader["target_char_to_ix"]
×
42
        self._ix_to_target_char = loader["ix_to_target_char"]
×
43

44
        # encoder/ decoder
45
        # Load encoder decoder onnx models.
46
        self._encoder = InferenceSession(self.__encoder_filename)
×
47

48
        self._decoder = InferenceSession(self.__decoder_filename)
×
49

50
        self._network = Seq2Seq_ONNX(
×
51
            self._encoder,
52
            self._decoder,
53
            self._target_char_to_ix["<start>"],
54
            self._target_char_to_ix["<end>"],
55
            self._maxlength,
56
            target_vocab_size=OUTPUT_DIM,
57
        )
58

59
    def _prepare_sequence_in(self, text: str):
×
60
        """
61
        Prepare input sequence for ONNX
62
        """
63
        idxs = []
×
64
        for ch in text:
×
65
            if ch in self._char_to_ix:
×
66
                idxs.append(self._char_to_ix[ch])
×
67
            else:
68
                idxs.append(self._char_to_ix["<UNK>"])
×
69
        idxs.append(self._char_to_ix["<end>"])
×
70
        return np.array(idxs)
×
71

72
    def romanize(self, text: str) -> str:
×
73
        """
74
        :param str text: Thai text to be romanized
75
        :return: English (more or less) text that spells out how the Thai text
76
                 should be pronounced.
77
        """
78
        input_tensor = self._prepare_sequence_in(text).reshape(1, -1)
×
79
        input_length = [len(text) + 1]
×
80
        target_tensor_logits = self._network.run(input_tensor, input_length)
×
81

82
        # Seq2seq model returns <END> as the first token,
83
        # As a result, target_tensor_logits.size() is torch.Size([0])
84
        if target_tensor_logits.shape[0] == 0:
×
85
            target = ["<PAD>"]
×
86
        else:
87
            target_tensor = np.argmax(target_tensor_logits.squeeze(1), 1)
×
88
            target = [self._ix_to_target_char[str(t)] for t in target_tensor]
×
89

90
        return "".join(target)
×
91

92

93
class Seq2Seq_ONNX:
×
94
    def __init__(
×
95
        self,
96
        encoder,
97
        decoder,
98
        target_start_token,
99
        target_end_token,
100
        max_length,
101
        target_vocab_size,
102
    ):
103
        super().__init__()
×
104

105
        self.encoder = encoder
×
106
        self.decoder = decoder
×
107
        self.pad_idx = 0
×
108
        self.target_start_token = target_start_token
×
109
        self.target_end_token = target_end_token
×
110
        self.max_length = max_length
×
111

112
        self.target_vocab_size = target_vocab_size
×
113

114
    def create_mask(self, source_seq):
×
115
        mask = source_seq != self.pad_idx
×
116
        return mask
×
117

118
    def run(self, source_seq, source_seq_len):
×
119
        # source_seq: (batch_size, MAX_LENGTH)
120
        # source_seq_len: (batch_size, 1)
121
        # target_seq: (batch_size, MAX_LENGTH)
122

123
        batch_size = source_seq.shape[0]
×
124
        start_token = self.target_start_token
×
125
        end_token = self.target_end_token
×
126
        max_len = self.max_length
×
127
        # target_vocab_size = self.decoder.vocabulary_size
128

129
        outputs = np.zeros((max_len, batch_size, self.target_vocab_size))
×
130

131
        expected_encoder_outputs = list(
×
132
            map(lambda output: output.name, self.encoder.get_outputs())
133
        )
134
        encoder_outputs, encoder_hidden, _ = self.encoder.run(
×
135
            input_feed={
136
                "input_tensor": source_seq,
137
                "input_lengths": source_seq_len,
138
            },
139
            output_names=expected_encoder_outputs,
140
        )
141

142
        decoder_input = np.array([[start_token] * batch_size]).reshape(
×
143
            batch_size, 1
144
        )
145
        encoder_hidden_h_t = np.expand_dims(
×
146
            np.concatenate(
147
                # [encoder_hidden_1, encoder_hidden_2], dim=1
148
                (encoder_hidden[0], encoder_hidden[1]),
149
                axis=1,
150
            ),
151
            axis=0,
152
        )
153
        decoder_hidden = encoder_hidden_h_t
×
154

155
        max_source_len = encoder_outputs.shape[1]
×
156
        mask = self.create_mask(source_seq[:, 0:max_source_len])
×
157

158
        for di in range(max_len):
×
159
            decoder_output, decoder_hidden = self.decoder.run(
×
160
                input_feed={
161
                    "decoder_input": decoder_input.astype("int32"),
162
                    "decoder_hidden_1": decoder_hidden,
163
                    "encoder_outputs": encoder_outputs,
164
                    "mask": mask.tolist(),
165
                },
166
                output_names=[
167
                    self.decoder.get_outputs()[0].name,
168
                    self.decoder.get_outputs()[1].name,
169
                ],
170
            )
171

172
            topi = np.argmax(decoder_output, axis=1)
×
173
            outputs[di] = decoder_output
×
174

175
            decoder_input = np.array([topi])
×
176

177
            if decoder_input == end_token:
×
178
                return outputs[:di]
×
179

180
        return outputs
×
181

182

183
_THAI_TO_ROM_ONNX = ThaiTransliterator_ONNX()
×
184

185

186
def romanize(text: str) -> str:
×
187
    return _THAI_TO_ROM_ONNX.romanize(text)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc