• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PyThaiNLP / pythainlp / 11626163864

01 Nov 2024 07:49AM UTC coverage: 14.17% (+14.2%) from 0.0%
11626163864

Pull #952

github

web-flow
Merge 8f2551bc9 into 89ea62ebc
Pull Request #952: Specify a limited test suite

44 of 80 new or added lines in 48 files covered. (55.0%)

1048 of 7396 relevant lines covered (14.17%)

0.14 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/pythainlp/spell/wanchanberta_thai_grammarly.py
1
# -*- coding: utf-8 -*-
2
# SPDX-FileCopyrightText: 2016-2024 PyThaiNLP Project
3
# SPDX-License-Identifier: Apache-2.0
4
"""
5
Two-stage Thai Misspelling Correction based on Pre-trained Language Models
6

7
:See Also:
8
    * Paper: \
9
        https://ieeexplore.ieee.org/abstract/document/10202006
10
    * GitHub: \
11
        https://github.com/bookpanda/Two-stage-Thai-Misspelling-Correction-Based-on-Pre-trained-Language-Models
12
"""
13
import torch
×
NEW
14
from transformers import (
×
15
    AutoModelForMaskedLM,
16
    AutoTokenizer,
17
    BertForTokenClassification,
18
)
19

20
use_cuda = torch.cuda.is_available()
×
21
device = torch.device("cuda" if use_cuda else "cpu")
×
22
tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
×
23

24
class BertModel(torch.nn.Module):
×
25
    def __init__(self):
×
26
        super().__init__()
×
27
        self.bert = BertForTokenClassification.from_pretrained('bookpanda/wangchanberta-base-att-spm-uncased-tagging')
×
28

29
    def forward(self, input_id, mask, label):
×
30
        output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)
×
31
        return output
×
32

33
tagging_model = BertModel()
×
34
if use_cuda:
×
35
    tagging_model = tagging_model.to(device=device)
×
36
ids_to_labels = {0: 'f', 1: 'i'}
×
37

38
def align_word_ids(texts):
×
39
    tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)
×
40
    word_ids = tokenized_inputs.word_ids()
×
41
    label_ids = []
×
42
    for word_idx in word_ids:
×
43

44
        if word_idx is None:
×
45
            label_ids.append(-100)
×
46
        else:
47
            try:
×
48
                label_ids.append(2)
×
49
            except:
×
50
                label_ids.append(-100)
×
51

52
    return label_ids
×
53

54
def evaluate_one_text(model, sentence):
×
55
    text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")
×
56
    mask = text['attention_mask'][0].unsqueeze(0).to(device)
×
57
    input_id = text['input_ids'][0].unsqueeze(0).to(device)
×
58
    label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)
×
59

60
    logits = tagging_model(input_id, mask, None)
×
61
    logits_clean = logits[0][label_ids != -100]
×
62

63
    predictions = logits_clean.argmax(dim=1).tolist()
×
64
    prediction_label = [ids_to_labels[i] for i in predictions]
×
65
    return prediction_label
×
66

67

68
mlm_model = AutoModelForMaskedLM.from_pretrained("bookpanda/wangchanberta-base-att-spm-uncased-masking")
×
69
if use_cuda:
×
70
    mlm_model = mlm_model.to(device=device)
×
71

72
def correct(text):
×
73
    ans = []
×
74
    i_f = evaluate_one_text(tagging_model, text)
×
75
    a = tokenizer(text)
×
76
    i_f_len = len(i_f)
×
77
    for j in range(i_f_len):
×
78
        if i_f[j] == 'i':
×
79
            ph = a['input_ids'][j+1]
×
80
            a['input_ids'][j+1] = 25004
×
81
            b = {'input_ids': torch.Tensor([a['input_ids']]).type(torch.int64).to(device), 'attention_mask': torch.Tensor([a['attention_mask']]).type(torch.int64).to(device)}
×
82
            token_logits = mlm_model(**b).logits
×
83
            mask_token_index = torch.where(b["input_ids"] == tokenizer.mask_token_id)[1]
×
84
            mask_token_logits = token_logits[0, mask_token_index, :]
×
85
            top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
×
86
            ans.append((j, top_5_tokens[0]))
×
87
            text = ''.join(tokenizer.convert_ids_to_tokens(a['input_ids']))
×
88
            a['input_ids'][j+1] = ph
×
89
    for x,y in ans:
×
90
        a['input_ids'][x+1] = y
×
91
    final_output = tokenizer.convert_ids_to_tokens(a['input_ids'])
×
92
    if "<s>" in final_output:
×
93
        final_output.remove("<s>")
×
94
    if "</s>" in final_output:
×
95
        final_output.remove("</s>")
×
96
    if "" in final_output:
×
97
        final_output.remove("")
×
98
    if final_output[0] == '▁':
×
99
        final_output.pop(0)
×
100
    final_output = ''.join(final_output)
×
101
    final_output = final_output.replace("▁", " ")
×
102
    final_output = final_output.replace("", "")
×
103
    return final_output
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc