• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

xd009642 / ndarray-vision / #59

pending completion
#59

push

web-flow
Added manual threshold algorithm (#56)

Co-authored-by: Christopher Field <chris.field@theiascientific.com>

15 of 15 new or added lines in 1 file covered. (100.0%)

774 of 1098 relevant lines covered (70.49%)

1.39 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

54.29
/src/processing/threshold.rs
1
use crate::core::PixelBound;
2
use crate::core::{ColourModel, Image, ImageBase};
3
use crate::processing::*;
4
use ndarray::{prelude::*, Data};
5
use ndarray_stats::histogram::{Bins, Edges, Grid};
6
use ndarray_stats::HistogramExt;
7
use ndarray_stats::QuantileExt;
8
use num_traits::cast::FromPrimitive;
9
use num_traits::cast::ToPrimitive;
10
use num_traits::{Num, NumAssignOps};
11
use std::marker::PhantomData;
12

13
/// Runs the Otsu thresholding algorithm on a type `T`.
14
pub trait ThresholdOtsuExt<T> {
15
    /// The Otsu thresholding output is a binary image.
16
    type Output;
17

18
    /// Run the Otsu threshold algorithm.
19
    ///
20
    /// Due to Otsu threshold algorithm specifying a greyscale image, all
21
    /// current implementations assume a single channel image; otherwise, an
22
    /// error is returned.
23
    ///
24
    /// # Errors
25
    ///
26
    /// Returns a `ChannelDimensionMismatch` error if more than one channel
27
    /// exists.
28
    fn threshold_otsu(&self) -> Result<Self::Output, Error>;
29
}
30

31
/// Runs the Mean thresholding algorithm on a type `T`.
32
pub trait ThresholdMeanExt<T> {
33
    /// The Mean thresholding output is a binary image.
34
    type Output;
35

36
    /// Run the Mean threshold algorithm.
37
    ///
38
    /// This assumes the image is a single channel image, i.e., a greyscale
39
    /// image; otherwise, an error is returned.
40
    ///
41
    /// # Errors
42
    ///
43
    /// Returns a `ChannelDimensionMismatch` error if more than one channel
44
    /// exists.
45
    fn threshold_mean(&self) -> Result<Self::Output, Error>;
46
}
47

48
/// Applies an upper and lower limit threshold on a type `T`.
49
pub trait ThresholdApplyExt<T> {
50
    /// The output is a binary image.
51
    type Output;
52

53
    /// Apply the threshold with the given limits.
54
    ///
55
    /// An image is segmented into background and foreground
56
    /// elements, where any pixel value within the limits are considered
57
    /// foreground elements and any pixels with a value outside the limits are
58
    /// considered part of the background. The upper and lower limits are
59
    /// inclusive.
60
    ///
61
    /// If only a lower limit threshold is to be applied, the `f64::INFINITY`
62
    /// value can be used for the upper limit.
63
    ///
64
    /// # Errors
65
    ///
66
    /// The current implementation assumes a single channel image, i.e.,
67
    /// greyscale image. Thus, if more than one channel is present, then
68
    /// a `ChannelDimensionMismatch` error occurs.
69
    ///
70
    /// An `InvalidParameter` error occurs if the `lower` limit is greater than
71
    /// the `upper` limit.
72
    fn threshold_apply(&self, lower: f64, upper: f64) -> Result<Self::Output, Error>;
73
}
74

75
impl<T, U, C> ThresholdOtsuExt<T> for ImageBase<U, C>
76
where
77
    U: Data<Elem = T>,
78
    Image<U, C>: Clone,
79
    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
80
    C: ColourModel,
81
{
82
    type Output = Image<bool, C>;
83

84
    fn threshold_otsu(&self) -> Result<Self::Output, Error> {
×
85
        let data = self.data.threshold_otsu()?;
×
86
        Ok(Self::Output {
×
87
            data,
×
88
            model: PhantomData,
×
89
        })
90
    }
91
}
92

93
impl<T, U> ThresholdOtsuExt<T> for ArrayBase<U, Ix3>
94
where
95
    U: Data<Elem = T>,
96
    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
97
{
98
    type Output = Array3<bool>;
99

100
    fn threshold_otsu(&self) -> Result<Self::Output, Error> {
×
101
        if self.shape()[2] > 1 {
×
102
            Err(Error::ChannelDimensionMismatch)
×
103
        } else {
104
            let value = calculate_threshold_otsu(self)?;
×
105
            self.threshold_apply(value, f64::INFINITY)
×
106
        }
107
    }
108
}
109

110
/// Calculates Otsu's threshold.
111
///
112
/// Works per channel, but currently assumes greyscale.
113
///
114
/// See the Errors section for the `ThresholdOtsuExt` trait if the number of
115
/// channels is greater than one (1), i.e., single channel; otherwise, we would
116
/// need to output all three threshold values.
117
///
118
/// TODO: Add optional nbins
119
fn calculate_threshold_otsu<T, U>(mat: &ArrayBase<U, Ix3>) -> Result<f64, Error>
2✔
120
where
121
    U: Data<Elem = T>,
122
    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
123
{
124
    let mut threshold = 0.0;
2✔
125
    let n_bins = 255;
2✔
126
    for c in mat.axis_iter(Axis(2)) {
6✔
127
        let scale_factor = (n_bins) as f64 / (c.max().unwrap().to_f64().unwrap());
2✔
128
        let edges_vec: Vec<u8> = (0..n_bins).collect();
2✔
129
        let grid = Grid::from(vec![Bins::new(Edges::from(edges_vec))]);
6✔
130

131
        // get the histogram
132
        let flat = Array::from_iter(c.iter()).insert_axis(Axis(1));
4✔
133
        let flat2 = flat.mapv(|x| ((*x).to_f64().unwrap() * scale_factor).to_u8().unwrap());
6✔
134
        let hist = flat2.histogram(grid);
2✔
135
        // Straight out of wikipedia:
136
        let counts = hist.counts();
2✔
137
        let total = counts.sum().to_f64().unwrap();
4✔
138
        let counts = Array::from_iter(counts.iter());
2✔
139
        // NOTE: Could use the cdf generation for skimage-esque implementation
140
        // which entails a cumulative sum of the standard histogram
141
        let mut sum_b = 0.0;
2✔
142
        let mut weight_b = 0.0;
2✔
143
        let mut maximum = 0.0;
2✔
144
        let mut level = 0.0;
2✔
145
        let mut sum_intensity = 0.0;
2✔
146
        for (index, count) in counts.indexed_iter() {
8✔
147
            sum_intensity += (index as f64) * (*count).to_f64().unwrap();
2✔
148
        }
149
        for (index, count) in counts.indexed_iter() {
6✔
150
            weight_b = weight_b + count.to_f64().unwrap();
2✔
151
            sum_b = sum_b + (index as f64) * count.to_f64().unwrap();
2✔
152
            let weight_f = total - weight_b;
2✔
153
            if (weight_b > 0.0) && (weight_f > 0.0) {
2✔
154
                let mean_f = (sum_intensity - sum_b) / weight_f;
2✔
155
                let val = weight_b
6✔
156
                    * weight_f
×
157
                    * ((sum_b / weight_b) - mean_f)
2✔
158
                    * ((sum_b / weight_b) - mean_f);
2✔
159
                if val > maximum {
4✔
160
                    level = 1.0 + (index as f64);
2✔
161
                    maximum = val;
2✔
162
                }
163
            }
164
        }
165
        threshold = level as f64 / scale_factor;
2✔
166
    }
167
    Ok(threshold)
2✔
168
}
169

170
impl<T, U, C> ThresholdMeanExt<T> for ImageBase<U, C>
171
where
172
    U: Data<Elem = T>,
173
    Image<U, C>: Clone,
174
    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
175
    C: ColourModel,
176
{
177
    type Output = Image<bool, C>;
178

179
    fn threshold_mean(&self) -> Result<Self::Output, Error> {
×
180
        let data = self.data.threshold_mean()?;
×
181
        Ok(Self::Output {
×
182
            data,
×
183
            model: PhantomData,
×
184
        })
185
    }
186
}
187

188
impl<T, U> ThresholdMeanExt<T> for ArrayBase<U, Ix3>
189
where
190
    U: Data<Elem = T>,
191
    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
192
{
193
    type Output = Array3<bool>;
194

195
    fn threshold_mean(&self) -> Result<Self::Output, Error> {
×
196
        if self.shape()[2] > 1 {
×
197
            Err(Error::ChannelDimensionMismatch)
×
198
        } else {
199
            let value = calculate_threshold_mean(self)?;
×
200
            self.threshold_apply(value, f64::INFINITY)
×
201
        }
202
    }
203
}
204

205
fn calculate_threshold_mean<T, U>(array: &ArrayBase<U, Ix3>) -> Result<f64, Error>
2✔
206
where
207
    U: Data<Elem = T>,
208
    T: Copy + Clone + Num + NumAssignOps + ToPrimitive + FromPrimitive,
209
{
210
    Ok(array.sum().to_f64().unwrap() / array.len() as f64)
2✔
211
}
212

213
impl<T, U, C> ThresholdApplyExt<T> for ImageBase<U, C>
214
where
215
    U: Data<Elem = T>,
216
    Image<U, C>: Clone,
217
    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
218
    C: ColourModel,
219
{
220
    type Output = Image<bool, C>;
221

222
    fn threshold_apply(&self, lower: f64, upper: f64) -> Result<Self::Output, Error> {
×
223
        let data = self.data.threshold_apply(lower, upper)?;
×
224
        Ok(Self::Output {
×
225
            data,
×
226
            model: PhantomData,
×
227
        })
228
    }
229
}
230

231
impl<T, U> ThresholdApplyExt<T> for ArrayBase<U, Ix3>
232
where
233
    U: Data<Elem = T>,
234
    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
235
{
236
    type Output = Array3<bool>;
237

238
    fn threshold_apply(&self, lower: f64, upper: f64) -> Result<Self::Output, Error> {
×
239
        if self.shape()[2] > 1 {
×
240
            Err(Error::ChannelDimensionMismatch)
×
241
        } else if lower > upper {
×
242
            Err(Error::InvalidParameter)
×
243
        } else {
244
            Ok(apply_threshold(self, lower, upper))
×
245
        }
246
    }
247
}
248

249
fn apply_threshold<T, U>(data: &ArrayBase<U, Ix3>, lower: f64, upper: f64) -> Array3<bool>
1✔
250
where
251
    U: Data<Elem = T>,
252
    T: Copy + Clone + Num + NumAssignOps + ToPrimitive + FromPrimitive,
253
{
254
    data.mapv(|x| x.to_f64().unwrap() >= lower && x.to_f64().unwrap() <= upper)
3✔
255
}
256

257
#[cfg(test)]
258
mod tests {
259
    use super::*;
260
    use assert_approx_eq::assert_approx_eq;
261
    use ndarray::arr3;
262
    use noisy_float::types::n64;
263

264
    #[test]
265
    fn threshold_apply_threshold() {
266
        let data = arr3(&[
267
            [[0.2], [0.4], [0.0]],
268
            [[0.7], [0.5], [0.8]],
269
            [[0.1], [0.6], [0.0]],
270
        ]);
271

272
        let expected = arr3(&[
273
            [[false], [false], [false]],
274
            [[true], [true], [true]],
275
            [[false], [true], [false]],
276
        ]);
277

278
        let result = apply_threshold(&data, 0.5, f64::INFINITY);
279

280
        assert_eq!(result, expected);
281
    }
282

283
    #[test]
284
    fn threshold_apply_threshold_range() {
285
        let data = arr3(&[
286
            [[0.2], [0.4], [0.0]],
287
            [[0.7], [0.5], [0.8]],
288
            [[0.1], [0.6], [0.0]],
289
        ]);
290
        let expected = arr3(&[
291
            [[false], [true], [false]],
292
            [[true], [true], [false]],
293
            [[false], [true], [false]],
294
        ]);
295

296
        let result = apply_threshold(&data, 0.25, 0.75);
297

298
        assert_eq!(result, expected);
299
    }
300

301
    #[test]
302
    fn threshold_calculate_threshold_otsu_ints() {
303
        let data = arr3(&[[[2], [4], [0]], [[7], [5], [8]], [[1], [6], [0]]]);
304
        let result = calculate_threshold_otsu(&data).unwrap();
305
        println!("Done {}", result);
306

307
        // Calculated using Python's skimage.filters.threshold_otsu
308
        // on int input array. Float array returns 2.0156...
309
        let expected = 2.0;
310

311
        assert_approx_eq!(result, expected, 5e-1);
312
    }
313

314
    #[test]
315
    fn threshold_calculate_threshold_otsu_floats() {
316
        let data = arr3(&[
317
            [[n64(2.0)], [n64(4.0)], [n64(0.0)]],
318
            [[n64(7.0)], [n64(5.0)], [n64(8.0)]],
319
            [[n64(1.0)], [n64(6.0)], [n64(0.0)]],
320
        ]);
321

322
        let result = calculate_threshold_otsu(&data).unwrap();
323

324
        // Calculated using Python's skimage.filters.threshold_otsu
325
        // on int input array. Float array returns 2.0156...
326
        let expected = 2.0156;
327

328
        assert_approx_eq!(result, expected, 5e-1);
329
    }
330

331
    #[test]
332
    fn threshold_calculate_threshold_mean_ints() {
333
        let data = arr3(&[[[4], [4], [4]], [[5], [5], [5]], [[6], [6], [6]]]);
334

335
        let result = calculate_threshold_mean(&data).unwrap();
336
        let expected = 5.0;
337

338
        assert_approx_eq!(result, expected, 1e-16);
339
    }
340

341
    #[test]
342
    fn threshold_calculate_threshold_mean_floats() {
343
        let data = arr3(&[
344
            [[4.0], [4.0], [4.0]],
345
            [[5.0], [5.0], [5.0]],
346
            [[6.0], [6.0], [6.0]],
347
        ]);
348

349
        let result = calculate_threshold_mean(&data).unwrap();
350
        let expected = 5.0;
351

352
        assert_approx_eq!(result, expected, 1e-16);
353
    }
354
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc