• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PyThaiNLP / pythainlp / 11626163864

01 Nov 2024 07:49AM UTC coverage: 14.17% (+14.2%) from 0.0%
11626163864

Pull #952

github

web-flow
Merge 8f2551bc9 into 89ea62ebc
Pull Request #952: Specify a limited test suite

44 of 80 new or added lines in 48 files covered. (55.0%)

1048 of 7396 relevant lines covered (14.17%)

0.14 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/pythainlp/transliterate/thai2rom.py
1
# -*- coding: utf-8 -*-
2
# SPDX-FileCopyrightText: 2016-2024 PyThaiNLP Project
3
# SPDX-License-Identifier: Apache-2.0
4
"""
5
Romanization of Thai words based on machine-learnt engine ("thai2rom")
6
"""
7
import random
×
8

9
import torch
×
10
import torch.nn.functional as F
×
NEW
11
from torch import nn
×
12

13
from pythainlp.corpus import get_corpus_path
×
14

15
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
×
16

17
_MODEL_NAME = "thai2rom-pytorch-attn"
×
18

19

20
class ThaiTransliterator:
×
21
    def __init__(self):
×
22
        """
23
        Transliteration of Thai words.
24

25
        Now supports Thai to Latin (romanization)
26
        """
27
        # get the model, download it if it's not available locally
28
        self.__model_filename = get_corpus_path(_MODEL_NAME)
×
29

30
        loader = torch.load(self.__model_filename, map_location=device)
×
31

32
        INPUT_DIM, E_EMB_DIM, E_HID_DIM, E_DROPOUT = loader["encoder_params"]
×
33
        OUTPUT_DIM, D_EMB_DIM, D_HID_DIM, D_DROPOUT = loader["decoder_params"]
×
34

35
        self._maxlength = 100
×
36

37
        self._char_to_ix = loader["char_to_ix"]
×
38
        self._ix_to_char = loader["ix_to_char"]
×
39
        self._target_char_to_ix = loader["target_char_to_ix"]
×
40
        self._ix_to_target_char = loader["ix_to_target_char"]
×
41

42
        # encoder/ decoder
43
        # Restore the model and construct the encoder and decoder.
44
        self._encoder = Encoder(INPUT_DIM, E_EMB_DIM, E_HID_DIM, E_DROPOUT)
×
45

46
        self._decoder = AttentionDecoder(
×
47
            OUTPUT_DIM, D_EMB_DIM, D_HID_DIM, D_DROPOUT
48
        )
49

50
        self._network = Seq2Seq(
×
51
            self._encoder,
52
            self._decoder,
53
            self._target_char_to_ix["<start>"],
54
            self._target_char_to_ix["<end>"],
55
            self._maxlength,
56
        ).to(device)
57

58
        self._network.load_state_dict(loader["model_state_dict"])
×
59
        self._network.eval()
×
60

61
    def _prepare_sequence_in(self, text: str):
×
62
        """
63
        Prepare input sequence for PyTorch
64
        """
65
        idxs = []
×
66
        for ch in text:
×
67
            if ch in self._char_to_ix:
×
68
                idxs.append(self._char_to_ix[ch])
×
69
            else:
70
                idxs.append(self._char_to_ix["<UNK>"])
×
71
        idxs.append(self._char_to_ix["<end>"])
×
72
        tensor = torch.tensor(idxs, dtype=torch.long)
×
73
        return tensor.to(device)
×
74

75
    def romanize(self, text: str) -> str:
×
76
        """
77
        :param str text: Thai text to be romanized
78
        :return: English (more or less) text that spells out how the Thai text
79
                 should be pronounced.
80
        """
81
        input_tensor = self._prepare_sequence_in(text).view(1, -1)
×
82
        input_length = torch.Tensor([len(text) + 1]).int()
×
83
        target_tensor_logits = self._network(
×
84
            input_tensor, input_length, None, 0
85
        )
86

87
        # Seq2seq model returns <END> as the first token,
88
        # As a result, target_tensor_logits.size() is torch.Size([0])
89
        if target_tensor_logits.size(0) == 0:
×
90
            target = ["<PAD>"]
×
91
        else:
92
            target_tensor = (
×
93
                torch.argmax(target_tensor_logits.squeeze(1), 1)
94
                .cpu()
95
                .detach()
96
                .numpy()
97
            )
98
            target = [self._ix_to_target_char[t] for t in target_tensor]
×
99

100
        return "".join(target)
×
101

102

103
class Encoder(nn.Module):
×
104
    def __init__(
×
105
        self, vocabulary_size, embedding_size, hidden_size, dropout=0.5
106
    ):
107
        """Constructor"""
108
        super().__init__()
×
109
        self.hidden_size = hidden_size
×
110
        self.character_embedding = nn.Embedding(
×
111
            vocabulary_size, embedding_size
112
        )
113
        self.rnn = nn.LSTM(
×
114
            input_size=embedding_size,
115
            hidden_size=hidden_size // 2,
116
            bidirectional=True,
117
            batch_first=True,
118
        )
119

120
        self.dropout = nn.Dropout(dropout)
×
121

122
    def forward(self, sequences, sequences_lengths):
×
123

124
        # sequences: (batch_size, sequence_length=MAX_LENGTH)
125
        # sequences_lengths: (batch_size)
126

127
        batch_size = sequences.size(0)
×
128
        self.hidden = self.init_hidden(batch_size)
×
129

130
        sequences_lengths = torch.flip(
×
131
            torch.sort(sequences_lengths).values, dims=(0,)
132
        )
133
        index_sorted = torch.sort(-1 * sequences_lengths).indices
×
134
        index_unsort = torch.sort(index_sorted).indices  # to unsorted sequence
×
135
        sequences = sequences.index_select(0, index_sorted.to(device))
×
136

137
        sequences = self.character_embedding(sequences)
×
138
        sequences = self.dropout(sequences)
×
139

140
        sequences_packed = nn.utils.rnn.pack_padded_sequence(
×
141
            sequences, sequences_lengths.clone(), batch_first=True
142
        )
143

144
        sequences_output, self.hidden = self.rnn(sequences_packed, self.hidden)
×
145

146
        sequences_output, _ = nn.utils.rnn.pad_packed_sequence(
×
147
            sequences_output, batch_first=True
148
        )
149

150
        sequences_output = sequences_output.index_select(
×
151
            0, index_unsort.clone().detach()
152
        )
153
        return sequences_output, self.hidden
×
154

155
    def init_hidden(self, batch_size):
×
156
        h_0 = torch.zeros(
×
157
            [2, batch_size, self.hidden_size // 2], requires_grad=True
158
        ).to(device)
159
        c_0 = torch.zeros(
×
160
            [2, batch_size, self.hidden_size // 2], requires_grad=True
161
        ).to(device)
162

163
        return (h_0, c_0)
×
164

165

166
class Attn(nn.Module):
×
167
    def __init__(self, method, hidden_size):
×
168
        super().__init__()
×
169

170
        self.method = method
×
171
        self.hidden_size = hidden_size
×
172

173
        if self.method == "general":
×
174
            self.attn = nn.Linear(self.hidden_size, hidden_size)
×
175

176
        elif self.method == "concat":
×
177
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
×
178
            self.other = nn.Parameter(torch.FloatTensor(1, hidden_size))
×
179

180
    def forward(self, hidden, encoder_outputs, mask):
×
181
        # Calculate energies for each encoder output
182
        if self.method == "dot":
×
183
            attn_energies = torch.bmm(
×
184
                encoder_outputs, hidden.transpose(1, 2)
185
            ).squeeze(2)
186
        elif self.method == "general":
×
187
            attn_energies = self.attn(
×
188
                encoder_outputs.view(-1, encoder_outputs.size(-1))
189
            )  # (batch_size * sequence_len,  hidden_size)
190
            attn_energies = torch.bmm(
×
191
                attn_energies.view(*encoder_outputs.size()),
192
                hidden.transpose(1, 2),
193
            ).squeeze(
194
                2
195
            )  # (batch_size,  sequence_len)
196
        elif self.method == "concat":
×
197
            attn_energies = self.attn(
×
198
                torch.cat(
199
                    (hidden.expand(*encoder_outputs.size()), encoder_outputs),
200
                    2,
201
                )
202
            )  # (batch_size, sequence_len,  hidden_size)
203
            attn_energies = torch.bmm(
×
204
                attn_energies,
205
                self.other.unsqueeze(0).expand(*hidden.size()).transpose(1, 2),
206
            ).squeeze(2)
207

208
        attn_energies = attn_energies.masked_fill(mask == 0, -1e10)
×
209

210
        # Normalize energies to weights in range 0 to 1
211
        return F.softmax(attn_energies, 1)
×
212

213

214
class AttentionDecoder(nn.Module):
×
215
    def __init__(
×
216
        self, vocabulary_size, embedding_size, hidden_size, dropout=0.5
217
    ):
218
        """Constructor"""
219
        super().__init__()
×
220
        self.vocabulary_size = vocabulary_size
×
221
        self.hidden_size = hidden_size
×
222
        self.character_embedding = nn.Embedding(
×
223
            vocabulary_size, embedding_size
224
        )
225
        self.rnn = nn.LSTM(
×
226
            input_size=embedding_size + self.hidden_size,
227
            hidden_size=hidden_size,
228
            bidirectional=False,
229
            batch_first=True,
230
        )
231

232
        self.attn = Attn(method="general", hidden_size=self.hidden_size)
×
233
        self.linear = nn.Linear(hidden_size, vocabulary_size)
×
234

235
        self.dropout = nn.Dropout(dropout)
×
236

237
    def forward(self, input_character, last_hidden, encoder_outputs, mask):
×
238
        """ "Defines the forward computation of the decoder"""
239

240
        # input_character: (batch_size, 1)
241
        # last_hidden: (batch_size, hidden_dim)
242
        # encoder_outputs: (batch_size, sequence_len, hidden_dim)
243
        # mask: (batch_size, sequence_len)
244

245
        hidden = last_hidden.permute(1, 0, 2)
×
246
        attn_weights = self.attn(hidden, encoder_outputs, mask)
×
247

248
        context_vector = attn_weights.unsqueeze(1).bmm(encoder_outputs)
×
249
        context_vector = torch.sum(context_vector, dim=1)
×
250
        context_vector = context_vector.unsqueeze(1)
×
251

252
        embedded = self.character_embedding(input_character)
×
253
        embedded = self.dropout(embedded)
×
254

255
        rnn_input = torch.cat((context_vector, embedded), -1)
×
256

257
        output, hidden = self.rnn(rnn_input)
×
258
        output = output.view(-1, output.size(2))
×
259

260
        x = self.linear(output)
×
261

262
        return x, hidden[0], attn_weights
×
263

264

265
class Seq2Seq(nn.Module):
×
266
    def __init__(
×
267
        self,
268
        encoder,
269
        decoder,
270
        target_start_token,
271
        target_end_token,
272
        max_length,
273
    ):
274
        super().__init__()
×
275

276
        self.encoder = encoder
×
277
        self.decoder = decoder
×
278
        self.pad_idx = 0
×
279
        self.target_start_token = target_start_token
×
280
        self.target_end_token = target_end_token
×
281
        self.max_length = max_length
×
282

283
        assert encoder.hidden_size == decoder.hidden_size
×
284

285
    def create_mask(self, source_seq):
×
286
        mask = source_seq != self.pad_idx
×
287
        return mask
×
288

289
    def forward(
×
290
        self, source_seq, source_seq_len, target_seq, teacher_forcing_ratio=0.5
291
    ):
292

293
        # source_seq: (batch_size, MAX_LENGTH)
294
        # source_seq_len: (batch_size, 1)
295
        # target_seq: (batch_size, MAX_LENGTH)
296

297
        batch_size = source_seq.size(0)
×
298
        start_token = self.target_start_token
×
299
        end_token = self.target_end_token
×
300
        max_len = self.max_length
×
301
        target_vocab_size = self.decoder.vocabulary_size
×
302

303
        outputs = torch.zeros(max_len, batch_size, target_vocab_size).to(
×
304
            device
305
        )
306

307
        if target_seq is None:
×
308
            assert teacher_forcing_ratio == 0, "Must be zero during inference"
×
309
            inference = True
×
310
        else:
311
            inference = False
×
312

313
        encoder_outputs, encoder_hidden = self.encoder(
×
314
            source_seq, source_seq_len
315
        )
316

317
        decoder_input = (
×
318
            torch.tensor([[start_token] * batch_size])
319
            .view(batch_size, 1)
320
            .to(device)
321
        )
322

323
        encoder_hidden_h_t = torch.cat(
×
324
            [encoder_hidden[0][0], encoder_hidden[0][1]], dim=1
325
        ).unsqueeze(dim=0)
326
        decoder_hidden = encoder_hidden_h_t
×
327

328
        max_source_len = encoder_outputs.size(1)
×
329
        mask = self.create_mask(source_seq[:, 0:max_source_len])
×
330

331
        for di in range(max_len):
×
332
            decoder_output, decoder_hidden, _ = self.decoder(
×
333
                decoder_input, decoder_hidden, encoder_outputs, mask
334
            )
335

336
            _, topi = decoder_output.topk(1)
×
337
            outputs[di] = decoder_output.to(device)
×
338

339
            teacher_force = random.random() < teacher_forcing_ratio
×
340

341
            decoder_input = (
×
342
                target_seq[:, di].reshape(batch_size, 1)
343
                if teacher_force
344
                else topi.detach()
345
            )
346

347
            decoder_input = topi.detach()
×
348

349
            if inference and decoder_input == end_token:
×
350
                return outputs[:di]
×
351

352
        return outputs
×
353

354

355
_THAI_TO_ROM = ThaiTransliterator()
×
356

357

358
def romanize(text: str) -> str:
×
359
    return _THAI_TO_ROM.romanize(text)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc