• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zbraniecki / icu4x / 6815798908

09 Nov 2023 05:17PM CUT coverage: 72.607% (-2.4%) from 75.01%
6815798908

push

github

web-flow
Implement `Any/BufferProvider` for some smart pointers (#4255)

Allows storing them as a `Box<dyn Any/BufferProvider>` without using a
wrapper type that implements the trait.

44281 of 60987 relevant lines covered (72.61%)

201375.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.88
/components/segmenter/src/complex/lstm/mod.rs
1
// This file is part of ICU4X. For terms of use, please see the file
84✔
2
// called LICENSE at the top level of the ICU4X source tree
3
// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4

5
use crate::grapheme::GraphemeClusterSegmenter;
6
use crate::provider::*;
7
use alloc::vec::Vec;
8
use core::char::{decode_utf16, REPLACEMENT_CHARACTER};
9
use zerovec::{maps::ZeroMapBorrowed, ule::UnvalidatedStr};
10

11
mod matrix;
12
use matrix::*;
13

14
// A word break iterator using LSTM model. Input string have to be same language.
15

16
struct LstmSegmenterIterator<'s> {
17
    input: &'s str,
18
    pos_utf8: usize,
19
    bies: BiesIterator<'s>,
20
}
21

22
impl Iterator for LstmSegmenterIterator<'_> {
23
    type Item = usize;
24

25
    fn next(&mut self) -> Option<Self::Item> {
107✔
26
        #[allow(clippy::indexing_slicing)] // pos_utf8 in range
27
        loop {
107✔
28
            let is_e = self.bies.next()?;
347✔
29
            self.pos_utf8 += self.input[self.pos_utf8..].chars().next()?.len_utf8();
324✔
30
            if is_e || self.bies.len() == 0 {
324✔
31
                return Some(self.pos_utf8);
84✔
32
            }
33
        }
34
    }
107✔
35
}
36

37
struct LstmSegmenterIteratorUtf16<'s> {
38
    bies: BiesIterator<'s>,
39
    pos: usize,
40
}
41

42
impl Iterator for LstmSegmenterIteratorUtf16<'_> {
43
    type Item = usize;
44

45
    fn next(&mut self) -> Option<Self::Item> {
111✔
46
        loop {
111✔
47
            self.pos += 1;
364✔
48
            if self.bies.next()? || self.bies.len() == 0 {
364✔
49
                return Some(self.pos);
87✔
50
            }
51
        }
52
    }
111✔
53
}
54

55
pub(super) struct LstmSegmenter<'l> {
56
    dic: ZeroMapBorrowed<'l, UnvalidatedStr, u16>,
57
    embedding: MatrixZero<'l, 2>,
58
    fw_w: MatrixZero<'l, 3>,
59
    fw_u: MatrixZero<'l, 3>,
60
    fw_b: MatrixZero<'l, 2>,
61
    bw_w: MatrixZero<'l, 3>,
62
    bw_u: MatrixZero<'l, 3>,
63
    bw_b: MatrixZero<'l, 2>,
64
    timew_fw: MatrixZero<'l, 2>,
65
    timew_bw: MatrixZero<'l, 2>,
66
    time_b: MatrixZero<'l, 1>,
67
    grapheme: Option<&'l RuleBreakDataV1<'l>>,
68
}
69

70
impl<'l> LstmSegmenter<'l> {
71
    /// Returns `Err` if grapheme data is required but not present
72
    pub(super) fn new(lstm: &'l LstmDataV1<'l>, grapheme: &'l RuleBreakDataV1<'l>) -> Self {
48✔
73
        let LstmDataV1::Float32(lstm) = lstm;
48✔
74
        let time_w = MatrixZero::from(&lstm.time_w);
48✔
75
        #[allow(clippy::unwrap_used)] // shape (2, 4, hunits)
76
        let timew_fw = time_w.submatrix(0).unwrap();
48✔
77
        #[allow(clippy::unwrap_used)] // shape (2, 4, hunits)
78
        let timew_bw = time_w.submatrix(1).unwrap();
48✔
79
        Self {
48✔
80
            dic: lstm.dic.as_borrowed(),
48✔
81
            embedding: MatrixZero::from(&lstm.embedding),
48✔
82
            fw_w: MatrixZero::from(&lstm.fw_w),
48✔
83
            fw_u: MatrixZero::from(&lstm.fw_u),
48✔
84
            fw_b: MatrixZero::from(&lstm.fw_b),
48✔
85
            bw_w: MatrixZero::from(&lstm.bw_w),
48✔
86
            bw_u: MatrixZero::from(&lstm.bw_u),
48✔
87
            bw_b: MatrixZero::from(&lstm.bw_b),
48✔
88
            timew_fw,
48✔
89
            timew_bw,
48✔
90
            time_b: MatrixZero::from(&lstm.time_b),
48✔
91
            grapheme: (lstm.model == ModelType::GraphemeClusters).then_some(grapheme),
48✔
92
        }
93
    }
48✔
94

95
    /// Create an LSTM based break iterator for an `str` (a UTF-8 string).
96
    pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator<Item = usize> + 'l {
23✔
97
        self.segment_str_p(input)
23✔
98
    }
23✔
99

100
    // For unit testing as we cannot inspect the opaque type's bies
101
    fn segment_str_p(&'l self, input: &'l str) -> LstmSegmenterIterator<'l> {
53✔
102
        let input_seq = if let Some(grapheme) = self.grapheme {
53✔
103
            GraphemeClusterSegmenter::new_and_segment_str(input, grapheme)
1✔
104
                .collect::<Vec<usize>>()
105
                .windows(2)
106
                .map(|chunk| {
15✔
107
                    let range = if let [first, second, ..] = chunk {
14✔
108
                        *first..*second
14✔
109
                    } else {
110
                        unreachable!()
×
111
                    };
112
                    let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) {
14✔
113
                        grapheme_cluster
14✔
114
                    } else {
115
                        return self.dic.len() as u16;
×
116
                    };
117

118
                    self.dic
42✔
119
                        .get_copied(UnvalidatedStr::from_str(grapheme_cluster))
14✔
120
                        .unwrap_or_else(|| self.dic.len() as u16)
14✔
121
                })
14✔
122
                .collect()
123
        } else {
1✔
124
            input
104✔
125
                .chars()
126
                .map(|c| {
1,322✔
127
                    self.dic
3,810✔
128
                        .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4])))
1,270✔
129
                        .unwrap_or_else(|| self.dic.len() as u16)
1,270✔
130
                })
1,270✔
131
                .collect()
132
        };
133
        LstmSegmenterIterator {
53✔
134
            input,
135
            pos_utf8: 0,
136
            bies: BiesIterator::new(self, input_seq),
53✔
137
        }
138
    }
53✔
139

140
    /// Create an LSTM based break iterator for a UTF-16 string.
141
    pub(super) fn segment_utf16(&'l self, input: &[u16]) -> impl Iterator<Item = usize> + 'l {
24✔
142
        let input_seq = if let Some(grapheme) = self.grapheme {
24✔
143
            GraphemeClusterSegmenter::new_and_segment_utf16(input, grapheme)
1✔
144
                .collect::<Vec<usize>>()
145
                .windows(2)
146
                .map(|chunk| {
15✔
147
                    let range = if let [first, second, ..] = chunk {
14✔
148
                        *first..*second
14✔
149
                    } else {
150
                        unreachable!()
×
151
                    };
152
                    let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) {
14✔
153
                        grapheme_cluster
14✔
154
                    } else {
155
                        return self.dic.len() as u16;
×
156
                    };
157

158
                    self.dic
42✔
159
                        .get_copied_by(|key| {
122✔
160
                            key.as_bytes().iter().copied().cmp(
216✔
161
                                decode_utf16(grapheme_cluster.iter().copied()).flat_map(|c| {
216✔
162
                                    let mut buf = [0; 4];
108✔
163
                                    let len = c
108✔
164
                                        .unwrap_or(REPLACEMENT_CHARACTER)
165
                                        .encode_utf8(&mut buf)
166
                                        .len();
167
                                    buf.into_iter().take(len)
108✔
168
                                }),
108✔
169
                            )
170
                        })
108✔
171
                        .unwrap_or_else(|| self.dic.len() as u16)
14✔
172
                })
14✔
173
                .collect()
174
        } else {
1✔
175
            decode_utf16(input.iter().copied())
46✔
176
                .map(|c| c.unwrap_or(REPLACEMENT_CHARACTER))
326✔
177
                .map(|c| {
348✔
178
                    self.dic
975✔
179
                        .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4])))
325✔
180
                        .unwrap_or_else(|| self.dic.len() as u16)
325✔
181
                })
325✔
182
                .collect()
183
        };
184
        LstmSegmenterIteratorUtf16 {
24✔
185
            bies: BiesIterator::new(self, input_seq),
24✔
186
            pos: 0,
187
        }
188
    }
24✔
189
}
190

191
struct BiesIterator<'l> {
192
    segmenter: &'l LstmSegmenter<'l>,
193
    input_seq: core::iter::Enumerate<alloc::vec::IntoIter<u16>>,
194
    h_bw: MatrixOwned<2>,
195
    curr_fw: MatrixOwned<1>,
196
    c_fw: MatrixOwned<1>,
197
}
198

199
impl<'l> BiesIterator<'l> {
200
    // input_seq is a sequence of id numbers that represents grapheme clusters or code points in the input line. These ids are used later
201
    // in the embedding layer of the model.
202
    fn new(segmenter: &'l LstmSegmenter<'l>, input_seq: Vec<u16>) -> Self {
77✔
203
        let hunits = segmenter.fw_u.dim().1;
77✔
204

205
        // Backward LSTM
206
        let mut c_bw = MatrixOwned::<1>::new_zero([hunits]);
77✔
207
        let mut h_bw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]);
77✔
208
        for (i, &g_id) in input_seq.iter().enumerate().rev() {
1,701✔
209
            if i + 1 < input_seq.len() {
1,624✔
210
                h_bw.as_mut().copy_submatrix::<1>(i + 1, i);
1,547✔
211
            }
212
            #[allow(clippy::unwrap_used)]
213
            compute_hc(
1,624✔
214
                segmenter.embedding.submatrix::<1>(g_id as usize).unwrap(), /* shape (dict.len() + 1, hunit), g_id is at most dict.len() */
1,624✔
215
                h_bw.submatrix_mut(i).unwrap(), // shape (input_seq.len(), hunits)
1,624✔
216
                c_bw.as_mut(),
1,624✔
217
                segmenter.bw_w,
1,624✔
218
                segmenter.bw_u,
1,624✔
219
                segmenter.bw_b,
1,624✔
220
            );
221
        }
222

223
        Self {
77✔
224
            input_seq: input_seq.into_iter().enumerate(),
77✔
225
            h_bw,
77✔
226
            c_fw: MatrixOwned::<1>::new_zero([hunits]),
77✔
227
            curr_fw: MatrixOwned::<1>::new_zero([hunits]),
77✔
228
            segmenter,
229
        }
×
230
    }
77✔
231
}
232

233
impl ExactSizeIterator for BiesIterator<'_> {
234
    fn len(&self) -> usize {
502✔
235
        self.input_seq.len()
502✔
236
    }
502✔
237
}
238

239
impl Iterator for BiesIterator<'_> {
240
    type Item = bool;
241

242
    fn next(&mut self) -> Option<Self::Item> {
1,701✔
243
        let (i, g_id) = self.input_seq.next()?;
1,701✔
244

245
        #[allow(clippy::unwrap_used)]
246
        compute_hc(
1,624✔
247
            self.segmenter
3,248✔
248
                .embedding
249
                .submatrix::<1>(g_id as usize)
1,624✔
250
                .unwrap(), // shape (dict.len() + 1, hunit), g_id is at most dict.len()
251
            self.curr_fw.as_mut(),
1,624✔
252
            self.c_fw.as_mut(),
1,624✔
253
            self.segmenter.fw_w,
1,624✔
254
            self.segmenter.fw_u,
1,624✔
255
            self.segmenter.fw_b,
1,624✔
256
        );
257

258
        #[allow(clippy::unwrap_used)] // shape (input_seq.len(), hunits)
259
        let curr_bw = self.h_bw.submatrix::<1>(i).unwrap();
1,624✔
260
        let mut weights = [0.0; 4];
1,624✔
261
        let mut curr_est = MatrixBorrowedMut {
1,624✔
262
            data: &mut weights,
263
            dims: [4],
1,624✔
264
        };
265
        curr_est.add_dot_2d(self.curr_fw.as_borrowed(), self.segmenter.timew_fw);
1,624✔
266
        curr_est.add_dot_2d(curr_bw, self.segmenter.timew_bw);
1,624✔
267
        #[allow(clippy::unwrap_used)] // both shape (4)
268
        curr_est.add(self.segmenter.time_b).unwrap();
1,624✔
269
        // For correct BIES weight calculation we'd now have to apply softmax, however
270
        // we're only doing a naive argmax, so a monotonic function doesn't make a difference.
271

272
        Some(weights[2] > weights[0] && weights[2] > weights[1] && weights[2] > weights[3])
1,624✔
273
    }
1,701✔
274
}
275

276
/// `compute_hc1` implemens the evaluation of one LSTM layer.
277
fn compute_hc<'a>(
3,248✔
278
    x_t: MatrixZero<'a, 1>,
279
    mut h_tm1: MatrixBorrowedMut<'a, 1>,
280
    mut c_tm1: MatrixBorrowedMut<'a, 1>,
281
    w: MatrixZero<'a, 3>,
282
    u: MatrixZero<'a, 3>,
283
    b: MatrixZero<'a, 2>,
284
) {
285
    #[cfg(debug_assertions)]
286
    {
287
        let hunits = h_tm1.dim();
3,248✔
288
        let embedd_dim = x_t.dim();
3,248✔
289
        c_tm1.as_borrowed().debug_assert_dims([hunits]);
3,248✔
290
        w.debug_assert_dims([4, hunits, embedd_dim]);
3,248✔
291
        u.debug_assert_dims([4, hunits, hunits]);
3,248✔
292
        b.debug_assert_dims([4, hunits]);
3,248✔
293
    }
294

295
    let mut s_t = b.to_owned();
3,248✔
296

297
    s_t.as_mut().add_dot_3d_2(x_t, w);
3,248✔
298
    s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), u);
3,248✔
299

300
    #[allow(clippy::unwrap_used)] // first dimension is 4
301
    s_t.submatrix_mut::<1>(0).unwrap().sigmoid_transform();
3,248✔
302
    #[allow(clippy::unwrap_used)] // first dimension is 4
303
    s_t.submatrix_mut::<1>(1).unwrap().sigmoid_transform();
3,248✔
304
    #[allow(clippy::unwrap_used)] // first dimension is 4
305
    s_t.submatrix_mut::<1>(2).unwrap().tanh_transform();
3,248✔
306
    #[allow(clippy::unwrap_used)] // first dimension is 4
307
    s_t.submatrix_mut::<1>(3).unwrap().sigmoid_transform();
3,248✔
308

309
    #[allow(clippy::unwrap_used)] // first dimension is 4
310
    c_tm1.convolve(
3,248✔
311
        s_t.as_borrowed().submatrix(0).unwrap(),
3,248✔
312
        s_t.as_borrowed().submatrix(2).unwrap(),
3,248✔
313
        s_t.as_borrowed().submatrix(1).unwrap(),
3,248✔
314
    );
315

316
    #[allow(clippy::unwrap_used)] // first dimension is 4
317
    h_tm1.mul_tanh(s_t.as_borrowed().submatrix(3).unwrap(), c_tm1.as_borrowed());
3,248✔
318
}
3,248✔
319

320
#[cfg(test)]
321
mod tests {
322
    use super::*;
323
    use icu_locid::locale;
324
    use icu_provider::prelude::*;
325
    use serde::Deserialize;
326
    use std::fs::File;
327
    use std::io::BufReader;
328

329
    /// `TestCase` is a struct used to store a single test case.
330
    /// Each test case has two attributs: `unseg` which denots the unsegmented line, and `true_bies` which indicates the Bies
331
    /// sequence representing the true segmentation.
332
    #[derive(PartialEq, Debug, Deserialize)]
330✔
333
    struct TestCase {
334
        unseg: String,
×
335
        expected_bies: String,
×
336
        true_bies: String,
×
337
    }
338

339
    /// `TestTextData` is a struct to store a vector of `TestCase` that represents a test text.
340
    #[derive(PartialEq, Debug, Deserialize)]
5✔
341
    struct TestTextData {
342
        testcases: Vec<TestCase>,
×
343
    }
344

345
    #[derive(Debug)]
×
346
    struct TestText {
347
        data: TestTextData,
×
348
    }
349

350
    fn load_test_text(filename: &str) -> TestTextData {
1✔
351
        let file = File::open(filename).expect("File should be present");
1✔
352
        let reader = BufReader::new(file);
1✔
353
        serde_json::from_reader(reader).expect("JSON syntax error")
1✔
354
    }
1✔
355

356
    #[test]
357
    fn segment_file_by_lstm() {
2✔
358
        let lstm: DataPayload<LstmForWordLineAutoV1Marker> = crate::provider::Baked
1✔
359
            .load(DataRequest {
1✔
360
                locale: &locale!("th").into(),
1✔
361
                metadata: Default::default(),
1✔
362
            })
363
            .unwrap()
364
            .take_payload()
365
            .unwrap();
1✔
366
        let lstm = LstmSegmenter::new(
1✔
367
            lstm.get(),
1✔
368
            crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
369
        );
370

371
        // Importing the test data
372
        let test_text_data = load_test_text(&format!(
1✔
373
            "tests/testdata/test_text_{}.json",
374
            if lstm.grapheme.is_some() {
1✔
375
                "grapheme"
×
376
            } else {
377
                "codepoints"
1✔
378
            }
379
        ));
1✔
380
        let test_text = TestText {
61✔
381
            data: test_text_data,
61✔
382
        };
383

384
        // Testing
385
        for test_case in &test_text.data.testcases {
61✔
386
            let lstm_output = lstm
30✔
387
                .segment_str_p(&test_case.unseg)
30✔
388
                .bies
389
                .map(|is_e| if is_e { 'e' } else { '?' })
960✔
390
                .collect::<String>();
391
            println!("Test case      : {}", test_case.unseg);
30✔
392
            println!("Expected bies  : {}", test_case.expected_bies);
30✔
393
            println!("Estimated bies : {lstm_output}");
30✔
394
            println!("True bies      : {}", test_case.true_bies);
30✔
395
            println!("****************************************************");
30✔
396
            assert_eq!(
30✔
397
                test_case.expected_bies.replace(['b', 'i', 's'], "?"),
30✔
398
                lstm_output
399
            );
400
        }
30✔
401
    }
2✔
402
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc