• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jzombie / rust-triplets / 22312649065

23 Feb 2026 03:25PM UTC coverage: 87.827%. First build
22312649065

Pull #1

github

web-flow
Merge 6900fe67b into 4027cbc09
Pull Request #1: Add more tests and badges; update deps

173 of 187 new or added lines in 5 files covered. (92.51%)

8492 of 9669 relevant lines covered (87.83%)

4063.13 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

50.41
/src/example_apps.rs
1
use std::collections::HashMap;
2
use std::error::Error;
3
use std::path::PathBuf;
4
use std::sync::Arc;
5

6
use clap::{Parser, ValueEnum, error::ErrorKind};
7

8
use crate::config::{ChunkingStrategy, SamplerConfig, TripletRecipe};
9
use crate::data::ChunkView;
10
use crate::heuristics::{
11
    CapacityTotals, EFFECTIVE_NEGATIVES_PER_ANCHOR, EFFECTIVE_POSITIVES_PER_ANCHOR,
12
    estimate_source_split_capacity_from_counts, format_replay_factor, format_u128_with_commas,
13
    resolve_text_recipes_for_source, split_counts_for_total,
14
};
15
use crate::metrics::source_skew;
16
use crate::sampler::chunk_weight;
17
use crate::source::DataSource;
18
use crate::splits::{FileSplitStore, SplitLabel, SplitRatios, SplitStore};
19
use crate::{
20
    PairSampler, RecordChunk, SampleBatch, Sampler, SamplerError, SourceId, TextBatch, TextRecipe,
21
    TripletBatch,
22
};
23

24
type DynSource = Box<dyn DataSource + 'static>;
25

26
#[derive(Debug, Clone, Copy, ValueEnum)]
27
enum SplitArg {
28
    Train,
29
    Validation,
30
    Test,
31
}
32

33
impl From<SplitArg> for SplitLabel {
34
    fn from(value: SplitArg) -> Self {
×
35
        match value {
×
36
            SplitArg::Train => SplitLabel::Train,
×
37
            SplitArg::Validation => SplitLabel::Validation,
×
38
            SplitArg::Test => SplitLabel::Test,
×
39
        }
40
    }
×
41
}
42

43
#[derive(Debug, Parser)]
44
#[command(
45
    name = "estimate_capacity",
46
    disable_help_subcommand = true,
47
    about = "Metadata-only capacity estimation",
48
    long_about = "Estimate record, pair, triplet, and text-sample capacity using source-reported counts only (no data refresh).",
49
    after_help = "Source roots are optional and resolved in order by explicit arg, environment variables, then project defaults."
50
)]
51
struct EstimateCapacityCli {
52
    #[arg(
53
        long,
54
        default_value_t = 99,
55
        help = "Deterministic seed used for split allocation"
56
    )]
57
    seed: u64,
58
    #[arg(
59
        long = "split-ratios",
60
        value_name = "TRAIN,VALIDATION,TEST",
61
        value_parser = parse_split_ratios_arg,
62
        default_value = "0.8,0.1,0.1",
63
        help = "Comma-separated split ratios that must sum to 1.0"
64
    )]
65
    split: SplitRatios,
66
    #[arg(
67
        long = "source-root",
68
        value_name = "PATH",
69
        help = "Optional source root override, repeat as needed in source order"
70
    )]
71
    source_roots: Vec<String>,
72
}
73

74
#[derive(Debug, Parser)]
75
#[command(
76
    name = "multi_source_demo",
77
    disable_help_subcommand = true,
78
    about = "Run sampled batches from multiple sources",
79
    long_about = "Sample triplet, pair, or text batches from multiple sources and persist split/epoch state.",
80
    after_help = "Source roots are optional and resolved in order by explicit arg, environment variables, then project defaults."
81
)]
82
/// CLI for `multi_source_demo`.
83
///
84
/// Common usage:
85
/// - Keep default persistence file location: `.sampler_store/split_store.bin`
86
/// - Set an explicit file path: `--split-store-path /tmp/split_store.bin`
87
/// - Set a custom directory and keep default filename: `--split-store-dir /tmp/sampler_store`
88
/// - Repeat `--source-root <PATH>` to override source roots in order
89
struct MultiSourceDemoCli {
90
    #[arg(
91
        long = "text-recipes",
92
        help = "Emit a text batch instead of a triplet batch"
93
    )]
94
    show_text_samples: bool,
95
    #[arg(
96
        long = "pair-batch",
97
        help = "Emit a pair batch instead of a triplet batch"
98
    )]
99
    show_pair_samples: bool,
100
    #[arg(
101
        long = "list-text-recipes",
102
        help = "Print registered text recipes and exit"
103
    )]
104
    list_text_recipes: bool,
105
    #[arg(
106
        long = "batch-size",
107
        default_value_t = 4,
108
        value_parser = parse_positive_usize,
109
        help = "Batch size used for sampling"
110
    )]
111
    batch_size: usize,
112
    #[arg(long, help = "Optional deterministic seed override")]
113
    seed: Option<u64>,
114
    #[arg(long, value_enum, help = "Target split to sample from")]
115
    split: Option<SplitArg>,
116
    #[arg(
117
        long = "source-root",
118
        value_name = "PATH",
119
        help = "Optional source root override, repeat as needed in source order"
120
    )]
121
    source_roots: Vec<String>,
122
    #[arg(
123
        long = "split-store-path",
124
        value_name = "SPLIT_STORE_PATH",
125
        help = "Optional path for persisted split/epoch state file"
126
    )]
127
    split_store_path: Option<PathBuf>,
128
    #[arg(
129
        long = "split-store-dir",
130
        value_name = "DIR",
131
        conflicts_with = "split_store_path",
132
        help = "Optional directory for persisted split/epoch state file (uses split_store.bin filename)"
133
    )]
134
    split_store_dir: Option<PathBuf>,
135
}
136

137
#[derive(Debug, Clone)]
138
struct SourceInventory {
139
    source_id: String,
140
    reported_records: u128,
141
    triplet_recipes: Vec<TripletRecipe>,
142
}
143

144
pub fn run_estimate_capacity<R, Resolve, Build, I>(
2✔
145
    args_iter: I,
2✔
146
    resolve_roots: Resolve,
2✔
147
    build_sources: Build,
2✔
148
) -> Result<(), Box<dyn Error>>
2✔
149
where
2✔
150
    Resolve: FnOnce(Vec<String>) -> Result<R, Box<dyn Error>>,
2✔
151
    Build: FnOnce(&R) -> Vec<DynSource>,
2✔
152
    I: Iterator<Item = String>,
2✔
153
{
154
    let Some(cli) = parse_cli::<EstimateCapacityCli, _>(
2✔
155
        std::iter::once("estimate_capacity".to_string()).chain(args_iter),
2✔
156
    )?
×
157
    else {
158
        return Ok(());
×
159
    };
160

161
    let roots = resolve_roots(cli.source_roots)?;
2✔
162

163
    let config = SamplerConfig {
2✔
164
        seed: cli.seed,
2✔
165
        split: cli.split,
2✔
166
        ..SamplerConfig::default()
2✔
167
    };
2✔
168

169
    let sources = build_sources(&roots);
2✔
170

171
    let mut inventories = Vec::new();
2✔
172
    for source in &sources {
2✔
173
        let recipes = if config.recipes.is_empty() {
2✔
174
            source.default_triplet_recipes()
2✔
175
        } else {
176
            config.recipes.clone()
×
177
        };
178
        let reported_records = source.reported_record_count().ok_or_else(|| {
2✔
179
            format!(
1✔
180
                "source '{}' did not report a record count; metadata-only capacity estimation requires DataSource::reported_record_count",
181
                source.id()
1✔
182
            )
183
        })?;
1✔
184
        inventories.push(SourceInventory {
1✔
185
            source_id: source.id().to_string(),
1✔
186
            reported_records,
1✔
187
            triplet_recipes: recipes,
1✔
188
        });
1✔
189
    }
190

191
    let mut per_source_split_counts: HashMap<(String, SplitLabel), u128> = HashMap::new();
1✔
192
    let mut split_record_counts: HashMap<SplitLabel, u128> = HashMap::new();
1✔
193

194
    for source in &inventories {
1✔
195
        let counts = split_counts_for_total(source.reported_records, cli.split);
1✔
196
        for (label, count) in counts {
3✔
197
            per_source_split_counts.insert((source.source_id.clone(), label), count);
3✔
198
            *split_record_counts.entry(label).or_insert(0) += count;
3✔
199
        }
3✔
200
    }
201

202
    let mut totals_by_split: HashMap<SplitLabel, CapacityTotals> = HashMap::new();
1✔
203
    let mut totals_by_source_and_split: HashMap<(String, SplitLabel), CapacityTotals> =
1✔
204
        HashMap::new();
1✔
205

206
    for split_label in [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test] {
3✔
207
        let mut totals = CapacityTotals::default();
3✔
208

209
        for source in &inventories {
3✔
210
            let source_split_records = per_source_split_counts
3✔
211
                .get(&(source.source_id.clone(), split_label))
3✔
212
                .copied()
3✔
213
                .unwrap_or(0);
3✔
214

3✔
215
            let triplet_recipes = &source.triplet_recipes;
3✔
216
            let text_recipes = resolve_text_recipes_for_source(&config, triplet_recipes);
3✔
217

3✔
218
            let capacity = estimate_source_split_capacity_from_counts(
3✔
219
                source_split_records,
3✔
220
                triplet_recipes,
3✔
221
                &text_recipes,
3✔
222
            );
3✔
223

3✔
224
            totals_by_source_and_split.insert((source.source_id.clone(), split_label), capacity);
3✔
225

3✔
226
            totals.triplets += capacity.triplets;
3✔
227
            totals.effective_triplets += capacity.effective_triplets;
3✔
228
            totals.pairs += capacity.pairs;
3✔
229
            totals.text_samples += capacity.text_samples;
3✔
230
        }
3✔
231

232
        totals_by_split.insert(split_label, totals);
3✔
233
    }
234

235
    println!("=== capacity estimate (length-only) ===");
1✔
236
    println!("mode: metadata-only (no source.refresh calls)");
1✔
237
    println!("classification: heuristic approximation (not exact)");
1✔
238
    println!("split seed: {}", cli.seed);
1✔
239
    println!(
1✔
240
        "split ratios: train={:.4}, validation={:.4}, test={:.4}",
241
        cli.split.train, cli.split.validation, cli.split.test
242
    );
243
    println!();
1✔
244

245
    println!("[SOURCES]");
1✔
246
    for source in &inventories {
1✔
247
        println!(
1✔
248
            "  {} => reported records: {}",
1✔
249
            source.source_id,
1✔
250
            format_u128_with_commas(source.reported_records)
1✔
251
        );
1✔
252
    }
1✔
253
    println!();
1✔
254

255
    println!("[PER SOURCE BREAKDOWN]");
1✔
256
    for source in &inventories {
1✔
257
        println!("  {}", source.source_id);
1✔
258
        let mut source_grand = CapacityTotals::default();
1✔
259
        let mut source_total_records = 0u128;
1✔
260
        for split_label in [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test] {
3✔
261
            let split_records = per_source_split_counts
3✔
262
                .get(&(source.source_id.clone(), split_label))
3✔
263
                .copied()
3✔
264
                .unwrap_or(0);
3✔
265
            source_total_records = source_total_records.saturating_add(split_records);
3✔
266
            let split_longest_records = inventories
3✔
267
                .iter()
3✔
268
                .map(|candidate| {
3✔
269
                    per_source_split_counts
3✔
270
                        .get(&(candidate.source_id.clone(), split_label))
3✔
271
                        .copied()
3✔
272
                        .unwrap_or(0)
3✔
273
                })
3✔
274
                .max()
3✔
275
                .unwrap_or(0);
3✔
276
            let totals = totals_by_source_and_split
3✔
277
                .get(&(source.source_id.clone(), split_label))
3✔
278
                .copied()
3✔
279
                .unwrap_or_default();
3✔
280
            source_grand.triplets += totals.triplets;
3✔
281
            source_grand.effective_triplets += totals.effective_triplets;
3✔
282
            source_grand.pairs += totals.pairs;
3✔
283
            source_grand.text_samples += totals.text_samples;
3✔
284
            println!("    [{:?}]", split_label);
3✔
285
            println!("      records: {}", format_u128_with_commas(split_records));
3✔
286
            println!(
3✔
287
                "      triplet combinations: {}",
288
                format_u128_with_commas(totals.triplets)
3✔
289
            );
290
            println!(
3✔
291
                "      effective sampled triplets (p={}, k={}): {}",
292
                EFFECTIVE_POSITIVES_PER_ANCHOR,
293
                EFFECTIVE_NEGATIVES_PER_ANCHOR,
294
                format_u128_with_commas(totals.effective_triplets)
3✔
295
            );
296
            println!(
3✔
297
                "      pair combinations:    {}",
298
                format_u128_with_commas(totals.pairs)
3✔
299
            );
300
            println!(
3✔
301
                "      text samples:         {}",
302
                format_u128_with_commas(totals.text_samples)
3✔
303
            );
304
            println!(
3✔
305
                "      replay factor vs longest source: {}",
306
                format_replay_factor(split_longest_records, split_records)
3✔
307
            );
308
        }
309
        let longest_source_total = inventories
1✔
310
            .iter()
1✔
311
            .map(|candidate| candidate.reported_records)
1✔
312
            .max()
1✔
313
            .unwrap_or(0);
1✔
314
        println!("    [ALL SPLITS FOR SOURCE]");
1✔
315
        println!(
1✔
316
            "      triplet combinations: {}",
317
            format_u128_with_commas(source_grand.triplets)
1✔
318
        );
319
        println!(
1✔
320
            "      effective sampled triplets (p={}, k={}): {}",
321
            EFFECTIVE_POSITIVES_PER_ANCHOR,
322
            EFFECTIVE_NEGATIVES_PER_ANCHOR,
323
            format_u128_with_commas(source_grand.effective_triplets)
1✔
324
        );
325
        println!(
1✔
326
            "      pair combinations:    {}",
327
            format_u128_with_commas(source_grand.pairs)
1✔
328
        );
329
        println!(
1✔
330
            "      text samples:         {}",
331
            format_u128_with_commas(source_grand.text_samples)
1✔
332
        );
333
        println!(
1✔
334
            "      replay factor vs longest source: {}",
335
            format_replay_factor(longest_source_total, source_total_records)
1✔
336
        );
337
        println!();
1✔
338
    }
339

340
    let mut grand = CapacityTotals::default();
1✔
341
    for split_label in [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test] {
3✔
342
        let record_count = split_record_counts.get(&split_label).copied().unwrap_or(0);
3✔
343
        let totals = totals_by_split
3✔
344
            .get(&split_label)
3✔
345
            .copied()
3✔
346
            .unwrap_or_default();
3✔
347

3✔
348
        grand.triplets += totals.triplets;
3✔
349
        grand.effective_triplets += totals.effective_triplets;
3✔
350
        grand.pairs += totals.pairs;
3✔
351
        grand.text_samples += totals.text_samples;
3✔
352

3✔
353
        println!("[{:?}]", split_label);
3✔
354
        println!("  records: {}", format_u128_with_commas(record_count));
3✔
355
        println!(
3✔
356
            "  triplet combinations: {}",
3✔
357
            format_u128_with_commas(totals.triplets)
3✔
358
        );
3✔
359
        println!(
3✔
360
            "  effective sampled triplets (p={}, k={}): {}",
3✔
361
            EFFECTIVE_POSITIVES_PER_ANCHOR,
3✔
362
            EFFECTIVE_NEGATIVES_PER_ANCHOR,
3✔
363
            format_u128_with_commas(totals.effective_triplets)
3✔
364
        );
3✔
365
        println!(
3✔
366
            "  pair combinations:    {}",
3✔
367
            format_u128_with_commas(totals.pairs)
3✔
368
        );
3✔
369
        println!(
3✔
370
            "  text samples:         {}",
3✔
371
            format_u128_with_commas(totals.text_samples)
3✔
372
        );
3✔
373
        println!();
3✔
374
    }
3✔
375

376
    println!("[ALL SPLITS TOTAL]");
1✔
377
    println!(
1✔
378
        "  triplet combinations: {}",
379
        format_u128_with_commas(grand.triplets)
1✔
380
    );
381
    println!(
1✔
382
        "  effective sampled triplets (p={}, k={}): {}",
383
        EFFECTIVE_POSITIVES_PER_ANCHOR,
384
        EFFECTIVE_NEGATIVES_PER_ANCHOR,
385
        format_u128_with_commas(grand.effective_triplets)
1✔
386
    );
387
    println!(
1✔
388
        "  pair combinations:    {}",
389
        format_u128_with_commas(grand.pairs)
1✔
390
    );
391
    println!(
1✔
392
        "  text samples:         {}",
393
        format_u128_with_commas(grand.text_samples)
1✔
394
    );
395
    println!();
1✔
396
    println!(
1✔
397
        "Note: counts are heuristic, length-based estimates from source-reported totals and recipe structure. They are approximate, not exact, and assume anchor-positive pairs=records (one positive per anchor by default), negatives=source_records_in_split-1 (anchor excluded as its own negative), and at most one chunk/window realization per sample. In real-world chunked sampling, practical combinations are often higher, so treat this as a floor-like baseline."
398
    );
399
    println!(
1✔
400
        "Effective sampled triplets apply a bounded training assumption: effective_triplets = records * p * k per triplet recipe, with defaults p={} positives per anchor and k={} negatives per anchor.",
401
        EFFECTIVE_POSITIVES_PER_ANCHOR, EFFECTIVE_NEGATIVES_PER_ANCHOR
402
    );
403
    println!(
1✔
404
        "Oversample loops are not inferred from this static report. To measure true oversampling (how many times sampling loops through the combination space), use observed sampled draw counts from an actual run."
405
    );
406

407
    Ok(())
1✔
408
}
2✔
409

410
pub fn run_multi_source_demo<R, Resolve, Build, I>(
×
411
    args_iter: I,
×
412
    resolve_roots: Resolve,
×
413
    build_sources: Build,
×
414
) -> Result<(), Box<dyn Error>>
×
415
where
×
416
    Resolve: FnOnce(Vec<String>) -> Result<R, Box<dyn Error>>,
×
417
    Build: FnOnce(&R) -> Vec<DynSource>,
×
418
    I: Iterator<Item = String>,
×
419
{
420
    let _ = tracing_subscriber::fmt()
×
421
        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
×
422
        .try_init();
×
423

424
    let Some(cli) = parse_cli::<MultiSourceDemoCli, _>(
×
425
        std::iter::once("multi_source_demo".to_string()).chain(args_iter),
×
426
    )?
×
427
    else {
428
        return Ok(());
×
429
    };
430

431
    let roots = resolve_roots(cli.source_roots)?;
×
432

433
    let mut config = SamplerConfig::default();
×
434
    config.seed = cli.seed.unwrap_or(config.seed);
×
435
    config.batch_size = cli.batch_size;
×
436
    config.chunking = Default::default();
×
437
    let selected_split = cli.split.map(Into::into).unwrap_or(SplitLabel::Train);
×
438
    config.split = SplitRatios::default();
×
439
    config.allowed_splits = vec![selected_split];
×
440
    let chunking = config.chunking.clone();
×
441

442
    let split_store_path = if let Some(path) = cli.split_store_path {
×
443
        path
×
444
    } else if let Some(dir) = cli.split_store_dir {
×
445
        FileSplitStore::default_path_in_dir(dir)
×
446
    } else {
447
        FileSplitStore::default_path()
×
448
    };
449

450
    println!(
×
451
        "Persisting split assignments and epoch state to {}",
452
        split_store_path.display()
×
453
    );
454
    let split_store = Arc::new(FileSplitStore::open(&split_store_path, config.split, 99)?);
×
455
    let sampler = PairSampler::new(config, split_store.clone());
×
456
    for source in build_sources(&roots) {
×
457
        sampler.register_source(source);
×
458
    }
×
459

460
    if cli.show_pair_samples {
×
461
        match sampler.next_pair_batch(selected_split) {
×
462
            Ok(pair_batch) => {
×
463
                if pair_batch.pairs.is_empty() {
×
464
                    println!("Pair sampling produced no results.");
×
465
                } else {
×
466
                    print_pair_batch(&chunking, &pair_batch, split_store.as_ref());
×
467
                }
×
468
                sampler.persist_state()?;
×
469
            }
470
            Err(SamplerError::Exhausted(name)) => {
×
471
                eprintln!(
×
472
                    "Pair sampler exhausted recipe '{}'. Ensure both positive and negative examples exist.",
×
473
                    name
×
474
                );
×
475
            }
×
476
            Err(err) => return Err(err.into()),
×
477
        }
478
    } else if cli.show_text_samples {
×
479
        match sampler.next_text_batch(selected_split) {
×
480
            Ok(text_batch) => {
×
481
                if text_batch.samples.is_empty() {
×
482
                    println!(
×
483
                        "Text sampling produced no results. Ensure each source has eligible sections."
×
484
                    );
×
485
                } else {
×
486
                    print_text_batch(&chunking, &text_batch, split_store.as_ref());
×
487
                }
×
488
                sampler.persist_state()?;
×
489
            }
490
            Err(SamplerError::Exhausted(name)) => {
×
491
                eprintln!(
×
492
                    "Text sampler exhausted selector '{}'. Ensure matching sections exist.",
×
493
                    name
×
494
                );
×
495
            }
×
496
            Err(err) => return Err(err.into()),
×
497
        }
498
    } else if cli.list_text_recipes {
×
499
        let recipes = sampler.text_recipes();
×
500
        if recipes.is_empty() {
×
501
            println!(
×
502
                "No text recipes registered. Ensure your sources expose triplet selectors or configure text_recipes explicitly."
×
503
            );
×
504
        } else {
×
505
            print_text_recipes(&recipes);
×
506
        }
×
507
    } else {
508
        match sampler.next_triplet_batch(selected_split) {
×
509
            Ok(triplet_batch) => {
×
510
                if triplet_batch.triplets.is_empty() {
×
511
                    println!(
×
512
                        "Triplet sampling produced no results. Ensure multiple records per source exist."
×
513
                    );
×
514
                } else {
×
515
                    print_triplet_batch(&chunking, &triplet_batch, split_store.as_ref());
×
516
                }
×
517
                sampler.persist_state()?;
×
518
            }
519
            Err(SamplerError::Exhausted(name)) => {
×
520
                eprintln!(
×
521
                    "Triplet sampler exhausted recipe '{}'. Ensure both positive and negative examples exist.",
×
522
                    name
×
523
                );
×
524
            }
×
525
            Err(err) => return Err(err.into()),
×
526
        }
527
    }
528

529
    Ok(())
×
530
}
×
531

532
fn parse_positive_usize(raw: &str) -> Result<usize, String> {
3✔
533
    let parsed = raw.parse::<usize>().map_err(|_| {
3✔
534
        format!(
1✔
535
            "Could not parse --batch-size value '{}' as a positive integer",
536
            raw
537
        )
538
    })?;
1✔
539
    if parsed == 0 {
2✔
540
        return Err("--batch-size must be greater than zero".to_string());
1✔
541
    }
1✔
542
    Ok(parsed)
1✔
543
}
3✔
544

545
fn parse_cli<T, I>(args: I) -> Result<Option<T>, Box<dyn Error>>
4✔
546
where
4✔
547
    T: Parser,
4✔
548
    I: IntoIterator,
4✔
549
    I::Item: Into<std::ffi::OsString> + Clone,
4✔
550
{
551
    match T::try_parse_from(args) {
4✔
552
        Ok(cli) => Ok(Some(cli)),
2✔
553
        Err(err) => match err.kind() {
2✔
554
            ErrorKind::DisplayHelp | ErrorKind::DisplayVersion => {
555
                err.print()?;
1✔
556
                Ok(None)
1✔
557
            }
558
            _ => Err(err.into()),
1✔
559
        },
560
    }
561
}
4✔
562

563
fn parse_split_ratios_arg(raw: &str) -> Result<SplitRatios, String> {
6✔
564
    let parts: Vec<&str> = raw.split(',').collect();
6✔
565
    if parts.len() != 3 {
6✔
566
        return Err("--split-ratios expects exactly 3 comma-separated values".to_string());
1✔
567
    }
5✔
568
    let train = parts[0]
5✔
569
        .trim()
5✔
570
        .parse::<f32>()
5✔
571
        .map_err(|_| format!("invalid train ratio '{}': must be a float", parts[0].trim()))?;
5✔
572
    let validation = parts[1].trim().parse::<f32>().map_err(|_| {
5✔
573
        format!(
×
574
            "invalid validation ratio '{}': must be a float",
575
            parts[1].trim()
×
576
        )
577
    })?;
×
578
    let test = parts[2]
5✔
579
        .trim()
5✔
580
        .parse::<f32>()
5✔
581
        .map_err(|_| format!("invalid test ratio '{}': must be a float", parts[2].trim()))?;
5✔
582
    let ratios = SplitRatios {
5✔
583
        train,
5✔
584
        validation,
5✔
585
        test,
5✔
586
    };
5✔
587
    let sum = ratios.train + ratios.validation + ratios.test;
5✔
588
    if (sum - 1.0).abs() > 1e-5 {
5✔
589
        return Err(format!(
1✔
590
            "split ratios must sum to 1.0, got {:.6} (train={}, validation={}, test={})",
1✔
591
            sum, ratios.train, ratios.validation, ratios.test
1✔
592
        ));
1✔
593
    }
4✔
594
    if ratios.train < 0.0 || ratios.validation < 0.0 || ratios.test < 0.0 {
4✔
595
        return Err("split ratios must be non-negative".to_string());
1✔
596
    }
3✔
597
    Ok(ratios)
3✔
598
}
6✔
599

600
fn print_triplet_batch(
×
601
    strategy: &ChunkingStrategy,
×
602
    batch: &TripletBatch,
×
603
    split_store: &impl SplitStore,
×
604
) {
×
605
    println!("=== triplet batch ===");
×
606
    for (idx, triplet) in batch.triplets.iter().enumerate() {
×
607
        println!("--- triplet #{} ---", idx);
×
608
        println!("recipe       : {}", triplet.recipe);
×
609
        println!("sample_weight: {:.4}", triplet.weight);
×
610
        if let Some(instr) = &triplet.instruction {
×
611
            println!("instruction shown to model:\n{}\n", instr);
×
612
        }
×
613
        print_chunk_block("ANCHOR", &triplet.anchor, strategy, split_store);
×
614
        print_chunk_block("POSITIVE", &triplet.positive, strategy, split_store);
×
615
        print_chunk_block("NEGATIVE", &triplet.negative, strategy, split_store);
×
616
    }
617
    print_source_summary(
×
618
        "triplet anchors",
×
619
        batch
×
620
            .triplets
×
621
            .iter()
×
622
            .map(|triplet| triplet.anchor.record_id.as_str()),
×
623
    );
624
    print_recipe_summary_by_source(
×
625
        "triplet recipes by source",
×
626
        batch
×
627
            .triplets
×
628
            .iter()
×
629
            .map(|triplet| (triplet.anchor.record_id.as_str(), triplet.recipe.as_str())),
×
630
    );
631
}
×
632

633
fn print_text_batch(strategy: &ChunkingStrategy, batch: &TextBatch, split_store: &impl SplitStore) {
×
634
    println!("=== text batch ===");
×
635
    for (idx, sample) in batch.samples.iter().enumerate() {
×
636
        println!("--- sample #{} ---", idx);
×
637
        println!("recipe       : {}", sample.recipe);
×
638
        println!("sample_weight: {:.4}", sample.weight);
×
639
        if let Some(instr) = &sample.instruction {
×
640
            println!("instruction shown to model:\n{}\n", instr);
×
641
        }
×
642
        print_chunk_block("TEXT", &sample.chunk, strategy, split_store);
×
643
    }
644
    print_source_summary(
×
645
        "text samples",
×
646
        batch
×
647
            .samples
×
648
            .iter()
×
649
            .map(|sample| sample.chunk.record_id.as_str()),
×
650
    );
651
    print_recipe_summary_by_source(
×
652
        "text recipes by source",
×
653
        batch
×
654
            .samples
×
655
            .iter()
×
656
            .map(|sample| (sample.chunk.record_id.as_str(), sample.recipe.as_str())),
×
657
    );
658
}
×
659

660
fn print_pair_batch(
×
661
    strategy: &ChunkingStrategy,
×
662
    batch: &SampleBatch,
×
663
    split_store: &impl SplitStore,
×
664
) {
×
665
    println!("=== pair batch ===");
×
666
    for (idx, pair) in batch.pairs.iter().enumerate() {
×
667
        println!("--- pair #{} ---", idx);
×
668
        println!("recipe       : {}", pair.recipe);
×
669
        println!("label        : {:?}", pair.label);
×
670
        if let Some(reason) = &pair.reason {
×
671
            println!("reason       : {}", reason);
×
672
        }
×
673
        print_chunk_block("ANCHOR", &pair.anchor, strategy, split_store);
×
674
        print_chunk_block("OTHER", &pair.positive, strategy, split_store);
×
675
    }
676
    print_source_summary(
×
677
        "pair anchors",
×
678
        batch
×
679
            .pairs
×
680
            .iter()
×
681
            .map(|pair| pair.anchor.record_id.as_str()),
×
682
    );
683
    print_recipe_summary_by_source(
×
684
        "pair recipes by source",
×
685
        batch
×
686
            .pairs
×
687
            .iter()
×
688
            .map(|pair| (pair.anchor.record_id.as_str(), pair.recipe.as_str())),
×
689
    );
690
}
×
691

692
fn print_text_recipes(recipes: &[TextRecipe]) {
×
693
    println!("=== available text recipes ===");
×
694
    for recipe in recipes {
×
695
        println!(
×
696
            "- {} (weight: {:.3}) selector={:?}",
697
            recipe.name, recipe.weight, recipe.selector
698
        );
699
        if let Some(instr) = &recipe.instruction {
×
700
            println!("  instruction: {}", instr);
×
701
        }
×
702
    }
703
}
×
704

705
trait ChunkDebug {
706
    fn view_name(&self) -> String;
707
}
708

709
impl ChunkDebug for RecordChunk {
710
    fn view_name(&self) -> String {
×
711
        match &self.view {
×
712
            ChunkView::Window {
713
                index,
×
714
                span,
×
715
                overlap,
×
716
                start_ratio,
×
717
            } => format!(
×
718
                "window#index={} span={} overlap={} start_ratio={:.3} tokens={}",
719
                index, span, overlap, start_ratio, self.tokens_estimate
720
            ),
721
            ChunkView::SummaryFallback { strategy, .. } => {
×
722
                format!("summary:{} tokens={}", strategy, self.tokens_estimate)
×
723
            }
724
        }
725
    }
×
726
}
727

728
fn print_chunk_block(
×
729
    title: &str,
×
730
    chunk: &RecordChunk,
×
731
    strategy: &ChunkingStrategy,
×
732
    split_store: &impl SplitStore,
×
733
) {
×
734
    let chunk_weight = chunk_weight(strategy, chunk);
×
735
    let split = split_store
×
736
        .label_for(&chunk.record_id)
×
737
        .map(|label| format!("{:?}", label))
×
738
        .unwrap_or_else(|| "Unknown".to_string());
×
739
    println!("--- {} ---", title);
×
740
    println!("split        : {}", split);
×
741
    println!("view         : {}", chunk.view_name());
×
742
    println!("chunk_weight : {:.4}", chunk_weight);
×
743
    println!("record_id    : {}", chunk.record_id);
×
744
    println!("section_idx  : {}", chunk.section_idx);
×
745
    println!("token_est    : {}", chunk.tokens_estimate);
×
746
    println!("model_input (exact text sent to the model):");
×
747
    println!(
×
748
        "<<< BEGIN MODEL TEXT >>>\n{}\n<<< END MODEL TEXT >>>\n",
749
        chunk.text
750
    );
751
}
×
752

753
fn print_source_summary<'a, I>(label: &str, ids: I)
×
754
where
×
755
    I: Iterator<Item = &'a str>,
×
756
{
757
    let mut counts: HashMap<SourceId, usize> = HashMap::new();
×
758
    for id in ids {
×
759
        let source = extract_source(id);
×
760
        *counts.entry(source).or_insert(0) += 1;
×
761
    }
×
762
    if counts.is_empty() {
×
763
        return;
×
764
    }
×
765
    let skew = source_skew(&counts);
×
766
    let mut entries: Vec<(String, usize)> = counts.into_iter().collect();
×
767
    entries.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
×
768
    println!("--- {} by source ---", label);
×
769
    if let Some(skew) = skew {
×
770
        for entry in &skew.per_source {
×
771
            println!(
×
772
                "{}: count={} share={:.2}",
×
773
                entry.source, entry.count, entry.share
×
774
            );
×
775
        }
×
776
        println!(
×
777
            "skew: sources={} total={} min={} max={} mean={:.2} ratio={:.2}",
778
            skew.sources, skew.total, skew.min, skew.max, skew.mean, skew.ratio
779
        );
780
    } else {
781
        for (source, count) in &entries {
×
782
            println!("{source}: count={count}");
×
783
        }
×
784
    }
785
}
×
786

787
fn print_recipe_summary_by_source<'a, I>(label: &str, entries: I)
×
788
where
×
789
    I: Iterator<Item = (&'a str, &'a str)>,
×
790
{
791
    let mut counts: HashMap<SourceId, HashMap<String, usize>> = HashMap::new();
×
792
    for (record_id, recipe) in entries {
×
793
        let source = extract_source(record_id);
×
794
        let entry = counts
×
795
            .entry(source)
×
796
            .or_default()
×
797
            .entry(recipe.to_string())
×
798
            .or_insert(0);
×
799
        *entry += 1;
×
800
    }
×
801
    if counts.is_empty() {
×
802
        return;
×
803
    }
×
804
    let mut sources: Vec<(SourceId, HashMap<String, usize>)> = counts.into_iter().collect();
×
805
    sources.sort_by(|a, b| a.0.cmp(&b.0));
×
806
    println!("--- {} ---", label);
×
807
    for (source, recipes) in sources {
×
808
        println!("{source}");
×
809
        let mut entries: Vec<(String, usize)> = recipes.into_iter().collect();
×
810
        entries.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
×
811
        for (recipe, count) in entries {
×
812
            println!("  - {recipe}={count}");
×
813
        }
×
814
    }
815
}
×
816

817
fn extract_source(record_id: &str) -> SourceId {
×
818
    record_id
×
819
        .split_once("::")
×
820
        .map(|(source, _)| source.to_string())
×
821
        .unwrap_or_else(|| "unknown".to_string())
×
822
}
×
823

824
#[cfg(test)]
825
mod tests {
826
    use super::*;
827
    use crate::data::SectionRole;
828
    use crate::source::{SourceCursor, SourceSnapshot};
829
    use chrono::Utc;
830

831
    struct TestSource {
832
        id: String,
833
        count: Option<u128>,
834
        recipes: Vec<TripletRecipe>,
835
    }
836

837
    impl DataSource for TestSource {
838
        fn id(&self) -> &str {
2✔
839
            &self.id
2✔
840
        }
2✔
841

NEW
842
        fn refresh(
×
NEW
843
            &self,
×
NEW
844
            _cursor: Option<&SourceCursor>,
×
NEW
845
            _limit: Option<usize>,
×
NEW
846
        ) -> Result<SourceSnapshot, SamplerError> {
×
NEW
847
            Ok(SourceSnapshot {
×
NEW
848
                records: Vec::new(),
×
NEW
849
                cursor: SourceCursor {
×
NEW
850
                    last_seen: Utc::now(),
×
NEW
851
                    revision: 0,
×
NEW
852
                },
×
NEW
853
            })
×
NEW
854
        }
×
855

856
        fn reported_record_count(&self) -> Option<u128> {
2✔
857
            self.count
2✔
858
        }
2✔
859

860
        fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
2✔
861
            self.recipes.clone()
2✔
862
        }
2✔
863
    }
864

865
    fn default_recipe(name: &str) -> TripletRecipe {
2✔
866
        TripletRecipe {
2✔
867
            name: name.to_string().into(),
2✔
868
            anchor: crate::config::Selector::Role(SectionRole::Anchor),
2✔
869
            positive_selector: crate::config::Selector::Role(SectionRole::Context),
2✔
870
            negative_selector: crate::config::Selector::Role(SectionRole::Context),
2✔
871
            negative_strategy: crate::config::NegativeStrategy::WrongArticle,
2✔
872
            weight: 1.0,
2✔
873
            instruction: None,
2✔
874
        }
2✔
875
    }
2✔
876

877
    #[test]
878
    fn parse_helpers_validate_inputs() {
1✔
879
        assert_eq!(parse_positive_usize("2").unwrap(), 2);
1✔
880
        assert!(parse_positive_usize("0").is_err());
1✔
881
        assert!(parse_positive_usize("abc").is_err());
1✔
882

883
        let split = parse_split_ratios_arg("0.8,0.1,0.1").unwrap();
1✔
884
        assert!((split.train - 0.8).abs() < 1e-6);
1✔
885
        assert!(parse_split_ratios_arg("0.8,0.1").is_err());
1✔
886
        assert!(parse_split_ratios_arg("1.0,0.0,0.1").is_err());
1✔
887
        assert!(parse_split_ratios_arg("-0.1,0.6,0.5").is_err());
1✔
888
    }
1✔
889

890
    #[test]
891
    fn parse_cli_handles_help_and_invalid_args() {
1✔
892
        let help = parse_cli::<EstimateCapacityCli, _>(["estimate_capacity", "--help"]).unwrap();
1✔
893
        assert!(help.is_none());
1✔
894

895
        let err = parse_cli::<EstimateCapacityCli, _>(["estimate_capacity", "--unknown"]);
1✔
896
        assert!(err.is_err());
1✔
897
    }
1✔
898

899
    #[test]
900
    fn run_estimate_capacity_succeeds_with_reported_counts() {
1✔
901
        let result = run_estimate_capacity(
1✔
902
            std::iter::empty::<String>(),
1✔
903
            |roots| {
1✔
904
                assert!(roots.is_empty());
1✔
905
                Ok(())
1✔
906
            },
1✔
907
            |_| {
1✔
908
                vec![Box::new(TestSource {
1✔
909
                    id: "source_a".into(),
1✔
910
                    count: Some(12),
1✔
911
                    recipes: vec![default_recipe("r1")],
1✔
912
                }) as DynSource]
1✔
913
            },
1✔
914
        );
915

916
        assert!(result.is_ok());
1✔
917
    }
1✔
918

919
    #[test]
920
    fn run_estimate_capacity_errors_when_source_count_missing() {
1✔
921
        let result = run_estimate_capacity(
1✔
922
            std::iter::empty::<String>(),
1✔
923
            |_| Ok(()),
1✔
924
            |_| {
1✔
925
                vec![Box::new(TestSource {
1✔
926
                    id: "source_missing".into(),
1✔
927
                    count: None,
1✔
928
                    recipes: vec![default_recipe("r1")],
1✔
929
                }) as DynSource]
1✔
930
            },
1✔
931
        );
932

933
        let err = result.unwrap_err().to_string();
1✔
934
        assert!(err.contains("did not report a record count"));
1✔
935
    }
1✔
936
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc