• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

divviup / divviup-api / 13021118675

28 Jan 2025 11:00PM UTC coverage: 55.794%. Remained the same
13021118675

push

github

web-flow
Shorten PR 1533 using std::Cow::Borrowed (#1537)

* Shorten PR 1533 using std::Cow::Borrowed

1533 got auto-merged too quickly!

* moo

0 of 2 new or added lines in 1 file covered. (0.0%)

1 existing line in 1 file now uncovered.

3881 of 6956 relevant lines covered (55.79%)

86.02 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.09
/src/entity/task/new_task.rs
1
use super::*;
2
use crate::{
3
    clients::aggregator_client::api_types::{AggregatorVdaf, QueryType},
4
    entity::{
5
        aggregator::{Feature, Role},
6
        Account, CollectorCredential, Protocol,
7
    },
8
    handler::Error,
9
};
10
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
11
use rand::Rng;
12
use sea_orm::{ColumnTrait, ConnectionTrait, EntityTrait, QueryFilter};
13
use sha2::{Digest, Sha256};
14
use std::borrow::Cow;
15
use validator::{ValidationErrors, ValidationErrorsKind};
16
use vdaf::{DpStrategy, DpStrategyKind, SumVec};
17

18
#[derive(Deserialize, Validate, Debug, Clone, Default)]
19
pub struct NewTask {
20
    #[validate(required, length(min = 1))]
21
    pub name: Option<String>,
22

23
    #[validate(required)]
24
    pub leader_aggregator_id: Option<String>,
25

26
    #[validate(required)]
27
    pub helper_aggregator_id: Option<String>,
28

29
    #[validate(required, nested)]
30
    pub vdaf: Option<Vdaf>,
31

32
    #[validate(required, range(min = 100))]
33
    pub min_batch_size: Option<u64>,
34

35
    #[validate(range(min = 0))]
36
    pub max_batch_size: Option<u64>,
37

38
    #[validate(range(min = 0))]
39
    pub batch_time_window_size_seconds: Option<u64>,
40

41
    #[validate(
42
        required,
43
        range(
44
            min = 60,
45
            max = 2592000,
46
            message = "must be between 1 minute and 4 weeks"
47
        )
48
    )]
49
    pub time_precision_seconds: Option<u64>,
50

51
    #[validate(required)]
52
    pub collector_credential_id: Option<String>,
53
}
54

55
async fn load_aggregator(
52✔
56
    account: &Account,
52✔
57
    id: Option<&str>,
52✔
58
    db: &impl ConnectionTrait,
52✔
59
) -> Result<Option<Aggregator>, Error> {
52✔
60
    let Some(id) = id.map(Uuid::parse_str).transpose()? else {
52✔
61
        return Ok(None);
6✔
62
    };
63

64
    let aggregator = Aggregators::find_by_id(id)
46✔
65
        .filter(AggregatorColumn::DeletedAt.is_null())
46✔
66
        .one(db)
46✔
67
        .await?;
46✔
68

69
    let Some(aggregator) = aggregator else {
46✔
70
        return Ok(None);
2✔
71
    };
72

73
    if aggregator.account_id.is_none() || aggregator.account_id == Some(account.id) {
44✔
74
        Ok(Some(aggregator))
44✔
75
    } else {
76
        Ok(None)
×
77
    }
78
}
52✔
79

80
const VDAF_BYTES: usize = 16;
81
fn generate_vdaf_verify_key_and_expected_task_id() -> (String, String) {
7✔
82
    let mut verify_key = [0; VDAF_BYTES];
7✔
83
    rand::thread_rng().fill(&mut verify_key);
7✔
84
    (
7✔
85
        URL_SAFE_NO_PAD.encode(verify_key),
7✔
86
        URL_SAFE_NO_PAD.encode(Sha256::digest(verify_key)),
7✔
87
    )
7✔
88
}
7✔
89

90
impl NewTask {
91
    fn validate_min_lte_max(&self, errors: &mut ValidationErrors) {
26✔
92
        let min = self.min_batch_size;
26✔
93
        let max = self.max_batch_size;
26✔
94
        if matches!((min, max), (Some(min), Some(max)) if min > max) {
26✔
95
            let error = ValidationError::new("min_greater_than_max");
2✔
96
            errors.add("min_batch_size", error.clone());
2✔
97
            errors.add("max_batch_size", error);
2✔
98
        }
24✔
99
    }
26✔
100

101
    fn validate_batch_time_window_size(&self, errors: &mut ValidationErrors) {
26✔
102
        let window = self.batch_time_window_size_seconds;
26✔
103
        if let Some(window) = window {
26✔
104
            if self.max_batch_size.is_none() {
5✔
105
                errors.add(
1✔
106
                    "batch_time_window_size_seconds",
1✔
107
                    ValidationError::new("missing-max-batch-size"),
1✔
108
                );
1✔
109
            }
4✔
110
            if let Some(precision) = self.time_precision_seconds {
5✔
111
                if window % precision != 0 {
5✔
112
                    errors.add(
2✔
113
                        "batch_time_window_size_seconds",
2✔
114
                        ValidationError::new("not-multiple-of-time-precision"),
2✔
115
                    );
2✔
116
                }
3✔
117
            }
×
118
        }
21✔
119
    }
26✔
120

121
    async fn load_collector_credential(
26✔
122
        &self,
26✔
123
        account: &Account,
26✔
124
        db: &impl ConnectionTrait,
26✔
125
    ) -> Option<CollectorCredential> {
26✔
126
        let id = Uuid::parse_str(self.collector_credential_id.as_deref()?).ok()?;
26✔
127
        CollectorCredentials::find_by_id(id)
12✔
128
            .filter(CollectorCredentialColumn::AccountId.eq(account.id))
12✔
129
            .one(db)
12✔
130
            .await
12✔
131
            .ok()
12✔
132
            .flatten()
12✔
133
    }
26✔
134

135
    async fn validate_collector_credential(
26✔
136
        &self,
26✔
137
        account: &Account,
26✔
138
        leader: Option<&Aggregator>,
26✔
139
        db: &impl ConnectionTrait,
26✔
140
        errors: &mut ValidationErrors,
26✔
141
    ) -> Option<CollectorCredential> {
26✔
142
        match self.load_collector_credential(account, db).await {
26✔
143
            None => {
144
                errors.add("collector_credential_id", ValidationError::new("required"));
14✔
145
                None
14✔
146
            }
147

148
            Some(collector_credential) => {
12✔
149
                let leader_needs_token_hash =
12✔
150
                    leader.is_some_and(|leader| leader.features.token_hash_enabled());
12✔
151

12✔
152
                if leader_needs_token_hash && collector_credential.token_hash.is_none() {
12✔
153
                    errors.add(
×
154
                        "collector_credential_id",
×
155
                        ValidationError::new("missing-token-hash"),
×
156
                    );
×
157
                    None
×
158
                } else {
159
                    Some(collector_credential)
12✔
160
                }
161
            }
162
        }
163
    }
26✔
164

165
    async fn validate_aggregators(
26✔
166
        &self,
26✔
167
        account: &Account,
26✔
168
        db: &impl ConnectionTrait,
26✔
169
        errors: &mut ValidationErrors,
26✔
170
    ) -> Option<(Aggregator, Aggregator, Protocol)> {
26✔
171
        let leader = load_aggregator(account, self.leader_aggregator_id.as_deref(), db)
26✔
172
            .await
26✔
173
            .ok()
26✔
174
            .flatten();
26✔
175
        if leader.is_none() {
26✔
176
            errors.add("leader_aggregator_id", ValidationError::new("required"));
4✔
177
        }
22✔
178

179
        let helper = load_aggregator(account, self.helper_aggregator_id.as_deref(), db)
26✔
180
            .await
26✔
181
            .ok()
26✔
182
            .flatten();
26✔
183
        if helper.is_none() {
26✔
184
            errors.add("helper_aggregator_id", ValidationError::new("required"));
4✔
185
        }
22✔
186

187
        let (Some(leader), Some(helper)) = (leader, helper) else {
26✔
188
            return None;
5✔
189
        };
190

191
        if leader == helper {
21✔
192
            errors.add("leader_aggregator_id", ValidationError::new("same"));
×
193
            errors.add("helper_aggregator_id", ValidationError::new("same"));
×
194
        }
21✔
195

196
        if !leader.is_first_party && !helper.is_first_party {
21✔
197
            errors.add(
×
198
                "leader_aggregator_id",
×
199
                ValidationError::new("no-first-party"),
×
200
            );
×
201
            errors.add(
×
202
                "helper_aggregator_id",
×
203
                ValidationError::new("no-first-party"),
×
204
            );
×
205
        }
21✔
206

207
        let resolved_protocol = if leader.protocol == helper.protocol {
21✔
208
            leader.protocol
21✔
209
        } else {
210
            errors.add("leader_aggregator_id", ValidationError::new("protocol"));
×
211
            errors.add("helper_aggregator_id", ValidationError::new("protocol"));
×
212
            return None;
×
213
        };
214

215
        if leader.role == Role::Helper {
21✔
216
            errors.add("leader_aggregator_id", ValidationError::new("role"))
1✔
217
        }
20✔
218

219
        if helper.role == Role::Leader {
21✔
220
            errors.add("helper_aggregator_id", ValidationError::new("role"))
1✔
221
        }
20✔
222

223
        if self.batch_time_window_size_seconds.is_some()
21✔
224
            && !leader.features.contains(&Feature::TimeBucketedFixedSize)
5✔
225
        {
226
            errors.add(
1✔
227
                "leader_aggregator_id",
1✔
228
                ValidationError::new("time-bucketed-fixed-size-unsupported"),
1✔
229
            )
1✔
230
        }
20✔
231

232
        let uses_pure_dp_discrete_laplace = match &self.vdaf {
21✔
233
            Some(Vdaf::SumVec(SumVec {
234
                dp_strategy:
235
                    DpStrategy {
236
                        dp_strategy: DpStrategyKind::PureDpDiscreteLaplace,
237
                        ..
238
                    },
239
                ..
240
            })) => true,
×
241
            Some(Vdaf::Histogram(histogram)) => matches!(
6✔
242
                histogram.dp_strategy().dp_strategy,
6✔
243
                DpStrategyKind::PureDpDiscreteLaplace
244
            ),
245
            _ => false,
15✔
246
        };
247
        if uses_pure_dp_discrete_laplace
21✔
248
            && !leader.features.contains(&Feature::PureDpDiscreteLaplace)
5✔
249
        {
1✔
250
            errors.add(
1✔
251
                "leader_aggregator_id",
1✔
252
                ValidationError::new("pure-dp-discrete-laplace-unsupported"),
1✔
253
            );
1✔
254
        }
20✔
255
        if uses_pure_dp_discrete_laplace
21✔
256
            && !helper.features.contains(&Feature::PureDpDiscreteLaplace)
5✔
257
        {
1✔
258
            errors.add(
1✔
259
                "helper_aggregator_id",
1✔
260
                ValidationError::new("pure-dp-discrete-laplace-unsupported"),
1✔
261
            );
1✔
262
        }
20✔
263

264
        if errors.is_empty() {
21✔
265
            Some((leader, helper, resolved_protocol))
7✔
266
        } else {
267
            None
14✔
268
        }
269
    }
26✔
270

271
    fn validate_vdaf_is_supported(
7✔
272
        &self,
7✔
273
        leader: &Aggregator,
7✔
274
        helper: &Aggregator,
7✔
275
        protocol: &Protocol,
7✔
276
        errors: &mut ValidationErrors,
7✔
277
    ) -> Option<AggregatorVdaf> {
7✔
278
        let vdaf = self.vdaf.as_ref()?;
7✔
279

280
        let name = vdaf.name();
7✔
281
        let aggregator_vdaf = match vdaf.representation_for_protocol(protocol) {
7✔
282
            Ok(vdaf) => vdaf,
7✔
283
            Err(e) => {
×
284
                let errors = errors
×
285
                    .errors_mut()
×
NEW
286
                    .entry(Cow::Borrowed("vdaf"))
×
287
                    .or_insert_with(|| {
×
288
                        ValidationErrorsKind::Struct(Box::new(ValidationErrors::new()))
×
289
                    });
×
290
                match errors {
×
291
                    ValidationErrorsKind::Struct(errors) => {
×
292
                        errors.errors_mut().extend(e.into_errors())
×
293
                    }
294
                    other => *other = ValidationErrorsKind::Struct(Box::new(e)),
×
295
                };
296
                return None;
×
297
            }
298
        };
299

300
        if !leader.vdafs.contains(&name) || !helper.vdafs.contains(&name) {
7✔
301
            let errors = errors
×
302
                .errors_mut()
×
NEW
303
                .entry(Cow::Borrowed("vdaf"))
×
304
                .or_insert_with(|| ValidationErrorsKind::Struct(Box::new(ValidationErrors::new())));
×
305
            match errors {
×
306
                ValidationErrorsKind::Struct(errors) => {
×
307
                    errors.add("type", ValidationError::new("not-supported"));
×
308
                }
×
309
                other => {
×
310
                    let mut e = ValidationErrors::new();
×
311
                    e.add("type", ValidationError::new("not-supported"));
×
312
                    *other = ValidationErrorsKind::Struct(Box::new(e));
×
313
                }
×
314
            };
315
        }
7✔
316

317
        Some(aggregator_vdaf)
7✔
318
    }
7✔
319

320
    fn populate_chunk_length(&mut self, protocol: &Protocol) {
7✔
321
        if let Some(vdaf) = &mut self.vdaf {
7✔
322
            vdaf.populate_chunk_length(protocol);
7✔
323
        }
7✔
324
    }
7✔
325

326
    fn validate_query_type_is_supported(
7✔
327
        &self,
7✔
328
        leader: &Aggregator,
7✔
329
        helper: &Aggregator,
7✔
330
        errors: &mut ValidationErrors,
7✔
331
    ) {
7✔
332
        let name = self.query_type().name();
7✔
333
        if !leader.query_types.contains(&name) || !helper.query_types.contains(&name) {
7✔
334
            errors.add("max_batch_size", ValidationError::new("not-supported"));
×
335
        }
7✔
336
    }
7✔
337

338
    pub async fn normalize_and_validate(
26✔
339
        &mut self,
26✔
340
        account: Account,
26✔
341
        db: &impl ConnectionTrait,
26✔
342
    ) -> Result<ProvisionableTask, ValidationErrors> {
26✔
343
        let mut errors = Validate::validate(self).err().unwrap_or_default();
26✔
344
        self.validate_min_lte_max(&mut errors);
26✔
345
        self.validate_batch_time_window_size(&mut errors);
26✔
346
        let aggregators = self.validate_aggregators(&account, db, &mut errors).await;
26✔
347
        let collector_credential = self
26✔
348
            .validate_collector_credential(
26✔
349
                &account,
26✔
350
                aggregators.as_ref().map(|(leader, ..)| leader),
26✔
351
                db,
26✔
352
                &mut errors,
26✔
353
            )
26✔
354
            .await;
26✔
355

356
        let aggregator_vdaf = if let Some((leader, helper, protocol)) = aggregators.as_ref() {
26✔
357
            self.validate_query_type_is_supported(leader, helper, &mut errors);
7✔
358
            self.populate_chunk_length(protocol);
7✔
359
            self.validate_vdaf_is_supported(leader, helper, protocol, &mut errors)
7✔
360
        } else {
361
            None
19✔
362
        };
363

364
        if errors.is_empty() {
26✔
365
            // Unwrap safety: All of these unwraps below have previously
366
            // been checked by the above validations. The fact that we
367
            // have to check them twice is a consequence of the
368
            // disharmonious combination of Validate and the fact that we
369
            // need to use options for all fields so serde doesn't bail on
370
            // the first error.
371
            let (leader_aggregator, helper_aggregator, protocol) = aggregators.unwrap();
7✔
372

7✔
373
            let (vdaf_verify_key, id) = generate_vdaf_verify_key_and_expected_task_id();
7✔
374

7✔
375
            Ok(ProvisionableTask {
7✔
376
                account,
7✔
377
                id,
7✔
378
                vdaf_verify_key,
7✔
379
                name: self.name.clone().unwrap(),
7✔
380
                leader_aggregator,
7✔
381
                helper_aggregator,
7✔
382
                vdaf: self.vdaf.clone().unwrap(),
7✔
383
                aggregator_vdaf: aggregator_vdaf.unwrap(),
7✔
384
                min_batch_size: self.min_batch_size.unwrap(),
7✔
385
                max_batch_size: self.max_batch_size,
7✔
386
                batch_time_window_size_seconds: self.batch_time_window_size_seconds,
7✔
387
                expiration: Some(OffsetDateTime::now_utc() + DEFAULT_EXPIRATION_DURATION),
7✔
388
                time_precision_seconds: self.time_precision_seconds.unwrap(),
7✔
389
                collector_credential: collector_credential.unwrap(),
7✔
390
                aggregator_auth_token: None,
7✔
391
                protocol,
7✔
392
            })
7✔
393
        } else {
394
            Err(errors)
19✔
395
        }
396
    }
26✔
397

398
    pub fn query_type(&self) -> QueryType {
7✔
399
        if let Some(max_batch_size) = self.max_batch_size {
7✔
400
            QueryType::FixedSize {
1✔
401
                max_batch_size,
1✔
402
                batch_time_window_size: self.batch_time_window_size_seconds,
1✔
403
            }
1✔
404
        } else {
405
            QueryType::TimeInterval
6✔
406
        }
407
    }
7✔
408
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc