• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zbraniecki / icu4x / 6815798908

09 Nov 2023 05:17PM UTC coverage: 72.607% (-2.4%) from 75.01%
6815798908

push

github

web-flow
Implement `Any/BufferProvider` for some smart pointers (#4255)

Allows storing them as a `Box<dyn Any/BufferProvider>` without using a
wrapper type that implements the trait.

44281 of 60987 relevant lines covered (72.61%)

201375.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

40.65
/components/segmenter/src/provider/lstm.rs
1
// This file is part of ICU4X. For terms of use, please see the file
×
2
// called LICENSE at the top level of the ICU4X source tree
3
// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4

5
//! Data provider struct definitions for the lstm
6

7
// Provider structs must be stable
8
#![allow(clippy::exhaustive_structs, clippy::exhaustive_enums)]
9

10
use icu_provider::prelude::*;
11
use zerovec::{ule::UnvalidatedStr, ZeroMap, ZeroVec};
12

13
// We do this instead of const generics because ZeroFrom and Yokeable derives, as well as serde
14
// don't support them
15
macro_rules! lstm_matrix {
16
    ($name:ident, $generic:literal) => {
17
        /// The struct that stores a LSTM's matrix.
18
        ///
19
        /// <div class="stab unstable">
20
        /// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways,
21
        /// including in SemVer minor releases. While the serde representation of data structs is guaranteed
22
        /// to be stable, their Rust representation might not be. Use with caution.
23
        /// </div>
24
        #[derive(PartialEq, Debug, Clone, zerofrom::ZeroFrom, yoke::Yokeable)]
72✔
25
        #[cfg_attr(feature = "datagen", derive(serde::Serialize))]
×
26
        pub struct $name<'data> {
27
            // Invariant: dims.product() == data.len()
28
            #[allow(missing_docs)]
29
            pub(crate) dims: [u16; $generic],
36✔
30
            #[allow(missing_docs)]
31
            pub(crate) data: ZeroVec<'data, f32>,
36✔
32
        }
33

34
        impl<'data> $name<'data> {
35
            #[cfg(any(feature = "serde", feature = "datagen"))]
36
            /// Creates a LstmMatrix with the given dimensions. Fails if the dimensions don't match the data.
37
            pub fn from_parts(
9✔
38
                dims: [u16; $generic],
39
                data: ZeroVec<'data, f32>,
40
            ) -> Result<Self, DataError> {
41
                if dims.iter().map(|&i| i as usize).product::<usize>() != data.len() {
31✔
42
                    Err(DataError::custom("Dimension mismatch"))
×
43
                } else {
44
                    Ok(Self { dims, data })
9✔
45
                }
46
            }
9✔
47

48
            #[doc(hidden)] // databake
49
            pub const fn from_parts_unchecked(
×
50
                dims: [u16; $generic],
51
                data: ZeroVec<'data, f32>,
52
            ) -> Self {
53
                Self { dims, data }
×
54
            }
×
55
        }
56

57
        #[cfg(feature = "serde")]
58
        impl<'de: 'data, 'data> serde::Deserialize<'de> for $name<'data> {
59
            fn deserialize<S>(deserializer: S) -> Result<Self, S::Error>
×
60
            where
61
                S: serde::de::Deserializer<'de>,
62
            {
63
                #[derive(serde::Deserialize)]
×
64
                struct Raw<'data> {
65
                    dims: [u16; $generic],
66
                    #[serde(borrow)]
67
                    data: ZeroVec<'data, f32>,
68
                }
69

70
                let raw = Raw::deserialize(deserializer)?;
×
71

72
                use serde::de::Error;
73
                Self::from_parts(raw.dims, raw.data)
×
74
                    .map_err(|_| S::Error::custom("Dimension mismatch"))
×
75
            }
×
76
        }
77

78
        #[cfg(feature = "datagen")]
79
        impl databake::Bake for $name<'_> {
80
            fn bake(&self, env: &databake::CrateEnv) -> databake::TokenStream {
×
81
                let dims = self.dims.bake(env);
×
82
                let data = self.data.bake(env);
×
83
                databake::quote! {
×
84
                    icu_segmenter::provider::$name::from_parts_unchecked(#dims, #data)
85
                }
86
            }
×
87
        }
88
    };
89
}
90

91
lstm_matrix!(LstmMatrix1, 1);
×
92
lstm_matrix!(LstmMatrix2, 2);
×
93
lstm_matrix!(LstmMatrix3, 3);
×
94

95
#[derive(PartialEq, Debug, Clone, Copy)]
52✔
96
#[cfg_attr(
97
    feature = "datagen",
98
    derive(serde::Serialize,databake::Bake),
×
99
    databake(path = icu_segmenter::provider),
100
)]
101
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
×
102
/// The type of LSTM model
103
///
104
/// <div class="stab unstable">
105
/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways,
106
/// including in SemVer minor releases. While the serde representation of data structs is guaranteed
107
/// to be stable, their Rust representation might not be. Use with caution.
108
/// </div>
109
pub enum ModelType {
×
110
    /// A model working on code points
111
    Codepoints,
112
    /// A model working on grapheme clusters
113
    GraphemeClusters,
114
}
115

116
/// The struct that stores a LSTM model.
117
///
118
/// <div class="stab unstable">
119
/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways,
120
/// including in SemVer minor releases. While the serde representation of data structs is guaranteed
121
/// to be stable, their Rust representation might not be. Use with caution.
122
/// </div>
123
#[derive(PartialEq, Debug, Clone, yoke::Yokeable, zerofrom::ZeroFrom)]
8✔
124
#[cfg_attr(feature = "datagen", derive(serde::Serialize))]
×
125
#[yoke(prove_covariance_manually)]
126
pub struct LstmDataFloat32<'data> {
127
    /// Type of the model
128
    pub(crate) model: ModelType,
4✔
129
    /// The grapheme cluster dictionary used to train the model
130
    pub(crate) dic: ZeroMap<'data, UnvalidatedStr, u16>,
4✔
131
    /// The embedding layer. Shape (dic.len + 1, e)
132
    pub(crate) embedding: LstmMatrix2<'data>,
4✔
133
    /// The forward layer's first matrix. Shape (h, 4, e)
134
    pub(crate) fw_w: LstmMatrix3<'data>,
4✔
135
    /// The forward layer's second matrix. Shape (h, 4, h)
136
    pub(crate) fw_u: LstmMatrix3<'data>,
4✔
137
    /// The forward layer's bias. Shape (h, 4)
138
    pub(crate) fw_b: LstmMatrix2<'data>,
4✔
139
    /// The backward layer's first matrix. Shape (h, 4, e)
140
    pub(crate) bw_w: LstmMatrix3<'data>,
4✔
141
    /// The backward layer's second matrix. Shape (h, 4, h)
142
    pub(crate) bw_u: LstmMatrix3<'data>,
4✔
143
    /// The backward layer's bias. Shape (h, 4)
144
    pub(crate) bw_b: LstmMatrix2<'data>,
4✔
145
    /// The output layer's weights. Shape (2, 4, h)
146
    pub(crate) time_w: LstmMatrix3<'data>,
4✔
147
    /// The output layer's bias. Shape (4)
148
    pub(crate) time_b: LstmMatrix1<'data>,
4✔
149
}
150

151
impl<'data> LstmDataFloat32<'data> {
152
    #[doc(hidden)] // databake
153
    #[allow(clippy::too_many_arguments)] // constructor
154
    pub const fn from_parts_unchecked(
×
155
        model: ModelType,
156
        dic: ZeroMap<'data, UnvalidatedStr, u16>,
157
        embedding: LstmMatrix2<'data>,
158
        fw_w: LstmMatrix3<'data>,
159
        fw_u: LstmMatrix3<'data>,
160
        fw_b: LstmMatrix2<'data>,
161
        bw_w: LstmMatrix3<'data>,
162
        bw_u: LstmMatrix3<'data>,
163
        bw_b: LstmMatrix2<'data>,
164
        time_w: LstmMatrix3<'data>,
165
        time_b: LstmMatrix1<'data>,
166
    ) -> Self {
167
        Self {
×
168
            model,
169
            dic,
×
170
            embedding,
×
171
            fw_w,
×
172
            fw_u,
×
173
            fw_b,
×
174
            bw_w,
×
175
            bw_u,
×
176
            bw_b,
×
177
            time_w,
×
178
            time_b,
×
179
        }
180
    }
×
181

182
    #[cfg(any(feature = "serde", feature = "datagen"))]
183
    /// Creates a LstmDataFloat32 with the given data. Fails if the matrix dimensions are inconsisent.
184
    #[allow(clippy::too_many_arguments)] // constructor
185
    pub fn try_from_parts(
1✔
186
        model: ModelType,
187
        dic: ZeroMap<'data, UnvalidatedStr, u16>,
188
        embedding: LstmMatrix2<'data>,
189
        fw_w: LstmMatrix3<'data>,
190
        fw_u: LstmMatrix3<'data>,
191
        fw_b: LstmMatrix2<'data>,
192
        bw_w: LstmMatrix3<'data>,
193
        bw_u: LstmMatrix3<'data>,
194
        bw_b: LstmMatrix2<'data>,
195
        time_w: LstmMatrix3<'data>,
196
        time_b: LstmMatrix1<'data>,
197
    ) -> Result<Self, DataError> {
198
        let dic_len = u16::try_from(dic.len())
1✔
199
            .map_err(|_| DataError::custom("Dictionary does not fit in u16"))?;
×
200

201
        let num_classes = embedding.dims[0];
1✔
202
        let embedd_dim = embedding.dims[1];
1✔
203
        let hunits = fw_u.dims[2];
1✔
204
        if num_classes - 1 != dic_len
9✔
205
            || fw_w.dims != [4, hunits, embedd_dim]
1✔
206
            || fw_u.dims != [4, hunits, hunits]
1✔
207
            || fw_b.dims != [4, hunits]
1✔
208
            || bw_w.dims != [4, hunits, embedd_dim]
1✔
209
            || bw_u.dims != [4, hunits, hunits]
1✔
210
            || bw_b.dims != [4, hunits]
1✔
211
            || time_w.dims != [2, 4, hunits]
1✔
212
            || time_b.dims != [4]
1✔
213
        {
214
            return Err(DataError::custom("LSTM dimension mismatch"));
×
215
        }
216

217
        #[cfg(debug_assertions)]
218
        if !dic.iter_copied_values().all(|(_, g)| g < dic_len) {
350✔
219
            return Err(DataError::custom("Invalid cluster id"));
×
220
        }
221

222
        Ok(Self {
1✔
223
            model,
224
            dic,
1✔
225
            embedding,
1✔
226
            fw_w,
1✔
227
            fw_u,
1✔
228
            fw_b,
1✔
229
            bw_w,
1✔
230
            bw_u,
1✔
231
            bw_b,
1✔
232
            time_w,
1✔
233
            time_b,
1✔
234
        })
235
    }
1✔
236
}
237

238
#[cfg(feature = "serde")]
239
impl<'de: 'data, 'data> serde::Deserialize<'de> for LstmDataFloat32<'data> {
240
    fn deserialize<S>(deserializer: S) -> Result<Self, S::Error>
×
241
    where
242
        S: serde::de::Deserializer<'de>,
243
    {
244
        #[derive(serde::Deserialize)]
×
245
        struct Raw<'data> {
246
            model: ModelType,
247
            #[cfg_attr(feature = "serde", serde(borrow))]
248
            dic: ZeroMap<'data, UnvalidatedStr, u16>,
249
            #[cfg_attr(feature = "serde", serde(borrow))]
250
            embedding: LstmMatrix2<'data>,
251
            #[cfg_attr(feature = "serde", serde(borrow))]
252
            fw_w: LstmMatrix3<'data>,
253
            #[cfg_attr(feature = "serde", serde(borrow))]
254
            fw_u: LstmMatrix3<'data>,
255
            #[cfg_attr(feature = "serde", serde(borrow))]
256
            fw_b: LstmMatrix2<'data>,
257
            #[cfg_attr(feature = "serde", serde(borrow))]
258
            bw_w: LstmMatrix3<'data>,
259
            #[cfg_attr(feature = "serde", serde(borrow))]
260
            bw_u: LstmMatrix3<'data>,
261
            #[cfg_attr(feature = "serde", serde(borrow))]
262
            bw_b: LstmMatrix2<'data>,
263
            #[cfg_attr(feature = "serde", serde(borrow))]
264
            time_w: LstmMatrix3<'data>,
265
            #[cfg_attr(feature = "serde", serde(borrow))]
266
            time_b: LstmMatrix1<'data>,
267
        }
268

269
        let raw = Raw::deserialize(deserializer)?;
×
270

271
        use serde::de::Error;
272
        Self::try_from_parts(
×
273
            raw.model,
×
274
            raw.dic,
×
275
            raw.embedding,
×
276
            raw.fw_w,
×
277
            raw.fw_u,
×
278
            raw.fw_b,
×
279
            raw.bw_w,
×
280
            raw.bw_u,
×
281
            raw.bw_b,
×
282
            raw.time_w,
×
283
            raw.time_b,
×
284
        )
285
        .map_err(|_| S::Error::custom("Invalid dimensions"))
×
286
    }
×
287
}
288

289
#[cfg(feature = "datagen")]
290
impl databake::Bake for LstmDataFloat32<'_> {
291
    fn bake(&self, env: &databake::CrateEnv) -> databake::TokenStream {
×
292
        let model = self.model.bake(env);
×
293
        let dic = self.dic.bake(env);
×
294
        let embedding = self.embedding.bake(env);
×
295
        let fw_w = self.fw_w.bake(env);
×
296
        let fw_u = self.fw_u.bake(env);
×
297
        let fw_b = self.fw_b.bake(env);
×
298
        let bw_w = self.bw_w.bake(env);
×
299
        let bw_u = self.bw_u.bake(env);
×
300
        let bw_b = self.bw_b.bake(env);
×
301
        let time_w = self.time_w.bake(env);
×
302
        let time_b = self.time_b.bake(env);
×
303
        databake::quote! {
×
304
            icu_segmenter::provider::LstmDataFloat32::from_parts_unchecked(
305
                #model,
306
                #dic,
307
                #embedding,
308
                #fw_w,
309
                #fw_u,
310
                #fw_b,
311
                #bw_w,
312
                #bw_u,
313
                #bw_b,
314
                #time_w,
315
                #time_b,
316
            )
317
        }
318
    }
×
319
}
320

321
/// The data to power the LSTM segmentation model.
322
///
323
/// This data enum is extensible: more backends may be added in the future.
324
/// Old data can be used with newer code but not vice versa.
325
///
326
/// Examples of possible future extensions:
327
///
328
/// 1. Variant to store data in 16 instead of 32 bits
329
/// 2. Minor changes to the LSTM model, such as different forward/backward matrix sizes
330
///
331
/// <div class="stab unstable">
332
/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways,
333
/// including in SemVer minor releases. While the serde representation of data structs is guaranteed
334
/// to be stable, their Rust representation might not be. Use with caution.
335
/// </div>
336
#[icu_provider::data_struct(LstmForWordLineAutoV1Marker = "segmenter/lstm/wl_auto@1")]
56✔
337
#[derive(Debug, PartialEq, Clone)]
8✔
338
#[cfg_attr(
339
    feature = "datagen", 
340
    derive(serde::Serialize, databake::Bake),
×
341
    databake(path = icu_segmenter::provider),
342
)]
343
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
×
344
#[yoke(prove_covariance_manually)]
345
#[non_exhaustive]
346
pub enum LstmDataV1<'data> {
347
    /// The data as matrices of zerovec f32 values.
348
    Float32(#[cfg_attr(feature = "serde", serde(borrow))] LstmDataFloat32<'data>),
4✔
349
    // new variants should go BELOW existing ones
350
    // Serde serializes based on variant name and index in the enum
351
    // https://docs.rs/serde/latest/serde/trait.Serializer.html#tymethod.serialize_unit_variant
352
}
353

354
pub(crate) struct LstmDataV1Marker;
355

356
impl DataMarker for LstmDataV1Marker {
357
    type Yokeable = LstmDataV1<'static>;
358
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc