• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jzombie / rust-triplets / 23177332371

17 Mar 2026 03:41AM UTC coverage: 94.165% (-0.5%) from 94.685%
23177332371

push

github

jzombie
Prepare for 0.5.0-alpha

18899 of 20070 relevant lines covered (94.17%)

2323.16 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.24
/src/example_apps.rs
1
use std::collections::HashMap;
2
use std::error::Error;
3
use std::path::PathBuf;
4
use std::sync::Arc;
5
use std::sync::Once;
6

7
use cache_manager::CacheRoot;
8
use clap::{Parser, ValueEnum, error::ErrorKind};
9

10
use crate::config::{ChunkingStrategy, SamplerConfig, TripletRecipe};
11
use crate::constants::cache::{MULTI_SOURCE_DEMO_GROUP, MULTI_SOURCE_DEMO_STORE_FILENAME};
12
use crate::data::ChunkView;
13
use crate::heuristics::{
14
    CapacityTotals, EFFECTIVE_NEGATIVES_PER_ANCHOR, EFFECTIVE_POSITIVES_PER_ANCHOR,
15
    estimate_source_split_capacity_from_counts, format_replay_factor, format_u128_with_commas,
16
    resolve_text_recipes_for_source, split_counts_for_total,
17
};
18
use crate::metrics::source_skew;
19
use crate::sampler::chunk_weight;
20
use crate::source::DataSource;
21
use crate::splits::{FileSplitStore, SplitLabel, SplitRatios, SplitStore};
22
use crate::{
23
    RecordChunk, SampleBatch, Sampler, SamplerError, SourceId, TextBatch, TextRecipe, TripletBatch,
24
    TripletSampler,
25
};
26

27
type DynSource = Box<dyn DataSource + 'static>;
28

29
fn managed_demo_split_store_path() -> Result<PathBuf, String> {
×
30
    let cache_root = CacheRoot::from_discovery()
×
31
        .map_err(|err| format!("failed discovering managed cache root: {err}"))?;
×
32
    let group = PathBuf::from(MULTI_SOURCE_DEMO_GROUP);
×
33
    let dir = cache_root.ensure_group(&group).map_err(|err| {
×
34
        format!(
×
35
            "failed creating managed demo cache group '{}': {err}",
36
            group.display()
×
37
        )
38
    })?;
×
39
    Ok(dir.join(MULTI_SOURCE_DEMO_STORE_FILENAME))
×
40
}
×
41

42
fn init_example_tracing() {
16✔
43
    static INIT: Once = Once::new();
44
    INIT.call_once(|| {
16✔
45
        let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
1✔
46
            .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("triplets=info"));
1✔
47
        let _ = tracing_subscriber::fmt()
1✔
48
            .with_env_filter(env_filter)
1✔
49
            .try_init();
1✔
50
    });
1✔
51
}
16✔
52

53
#[derive(Debug, Clone, Copy, ValueEnum)]
54
/// CLI split selector mapped onto `SplitLabel`.
55
enum SplitArg {
56
    Train,
57
    Validation,
58
    Test,
59
}
60

61
impl From<SplitArg> for SplitLabel {
62
    fn from(value: SplitArg) -> Self {
6✔
63
        match value {
6✔
64
            SplitArg::Train => SplitLabel::Train,
1✔
65
            SplitArg::Validation => SplitLabel::Validation,
4✔
66
            SplitArg::Test => SplitLabel::Test,
1✔
67
        }
68
    }
6✔
69
}
70

71
#[derive(Debug, Parser)]
72
#[command(
73
    name = "estimate_capacity",
74
    disable_help_subcommand = true,
75
    about = "Metadata-only capacity estimation",
76
    long_about = "Estimate record, pair, triplet, and text-sample capacity using source-reported counts only (no data refresh).",
77
    after_help = "Source roots are optional and resolved in order by explicit arg, environment variables, then project defaults."
78
)]
79
/// CLI arguments for metadata-only capacity estimation.
80
struct EstimateCapacityCli {
81
    #[arg(
82
        long,
83
        default_value_t = 99,
84
        help = "Deterministic seed used for split allocation"
85
    )]
86
    seed: u64,
87
    #[arg(
88
        long = "split-ratios",
89
        value_name = "TRAIN,VALIDATION,TEST",
90
        value_parser = parse_split_ratios_arg,
91
        default_value = "0.8,0.1,0.1",
92
        help = "Comma-separated split ratios that must sum to 1.0"
93
    )]
94
    split: SplitRatios,
95
    #[arg(
96
        long = "source-root",
97
        value_name = "PATH",
98
        help = "Optional source root override, repeat as needed in source order"
99
    )]
100
    source_roots: Vec<String>,
101
}
102

103
#[derive(Debug, Parser)]
104
#[command(
105
    name = "multi_source_demo",
106
    disable_help_subcommand = true,
107
    about = "Run sampled batches from multiple sources",
108
    long_about = "Sample triplet, pair, or text batches from multiple sources and persist split/epoch state.",
109
    after_help = "Source roots are optional and resolved in order by explicit arg, environment variables, then project defaults."
110
)]
111
/// CLI for `multi_source_demo`.
112
///
113
/// Common usage:
114
/// - Use managed cache-group default path (no flag)
115
/// - Set an explicit file path: `--split-store-path /tmp/split_store.bin`
116
/// - Repeat `--source-root <PATH>` to override source roots in order
117
struct MultiSourceDemoCli {
118
    #[arg(
119
        long = "text-recipes",
120
        help = "Emit a text batch instead of a triplet batch"
121
    )]
122
    show_text_samples: bool,
123
    #[arg(
124
        long = "pair-batch",
125
        help = "Emit a pair batch instead of a triplet batch"
126
    )]
127
    show_pair_samples: bool,
128
    #[arg(
129
        long = "list-text-recipes",
130
        help = "Print registered text recipes and exit"
131
    )]
132
    list_text_recipes: bool,
133
    #[arg(
134
        long = "batch-size",
135
        default_value_t = 4,
136
        value_parser = parse_positive_usize,
137
        help = "Batch size used for sampling"
138
    )]
139
    batch_size: usize,
140
    #[arg(long, help = "Optional deterministic seed override")]
141
    seed: Option<u64>,
142
    #[arg(long, value_enum, help = "Target split to sample from")]
143
    split: Option<SplitArg>,
144
    #[arg(
145
        long = "source-root",
146
        value_name = "PATH",
147
        help = "Optional source root override, repeat as needed in source order"
148
    )]
149
    source_roots: Vec<String>,
150
    #[arg(
151
        long = "split-store-path",
152
        value_name = "SPLIT_STORE_PATH",
153
        help = "Optional explicit path for persisted split/epoch state file"
154
    )]
155
    split_store_path: Option<PathBuf>,
156
}
157

158
#[derive(Debug, Clone)]
159
/// Source-level inventory used by capacity estimation output.
160
struct SourceInventory {
161
    source_id: String,
162
    reported_records: u128,
163
    triplet_recipes: Vec<TripletRecipe>,
164
}
165

166
/// Run the capacity-estimation CLI with injectable root resolution/source builders.
167
///
168
/// `build_sources` is construction-only; sampler configuration is applied
169
/// centrally by this function before any source calls.
170
pub fn run_estimate_capacity<R, Resolve, Build, I>(
4✔
171
    args_iter: I,
4✔
172
    resolve_roots: Resolve,
4✔
173
    build_sources: Build,
4✔
174
) -> Result<(), Box<dyn Error>>
4✔
175
where
4✔
176
    Resolve: FnOnce(Vec<String>) -> Result<R, Box<dyn Error>>,
4✔
177
    Build: FnOnce(&R) -> Vec<DynSource>,
4✔
178
    I: Iterator<Item = String>,
4✔
179
{
180
    init_example_tracing();
4✔
181

182
    let Some(cli) = parse_cli::<EstimateCapacityCli, _>(
4✔
183
        std::iter::once("estimate_capacity".to_string()).chain(args_iter),
4✔
184
    )?
×
185
    else {
186
        return Ok(());
×
187
    };
188

189
    let roots = resolve_roots(cli.source_roots)?;
4✔
190

191
    let config = SamplerConfig {
3✔
192
        seed: cli.seed,
3✔
193
        split: cli.split,
3✔
194
        ..SamplerConfig::default()
3✔
195
    };
3✔
196

197
    let sources = build_sources(&roots);
3✔
198

199
    let mut inventories = Vec::new();
3✔
200
    for source in &sources {
3✔
201
        let recipes = if config.recipes.is_empty() {
3✔
202
            source.default_triplet_recipes()
3✔
203
        } else {
204
            config.recipes.clone()
×
205
        };
206
        let reported_records = source.reported_record_count(&config).map_err(|err| {
3✔
207
            format!(
1✔
208
                "source '{}' failed to report exact record count: {err}",
209
                source.id()
1✔
210
            )
211
        })?;
1✔
212
        inventories.push(SourceInventory {
2✔
213
            source_id: source.id().to_string(),
2✔
214
            reported_records,
2✔
215
            triplet_recipes: recipes,
2✔
216
        });
2✔
217
    }
218

219
    let mut per_source_split_counts: HashMap<(String, SplitLabel), u128> = HashMap::new();
2✔
220
    let mut split_record_counts: HashMap<SplitLabel, u128> = HashMap::new();
2✔
221

222
    for source in &inventories {
2✔
223
        let counts = split_counts_for_total(source.reported_records, cli.split);
2✔
224
        for (label, count) in counts {
6✔
225
            per_source_split_counts.insert((source.source_id.clone(), label), count);
6✔
226
            *split_record_counts.entry(label).or_insert(0) += count;
6✔
227
        }
6✔
228
    }
229

230
    let mut totals_by_split: HashMap<SplitLabel, CapacityTotals> = HashMap::new();
2✔
231
    let mut totals_by_source_and_split: HashMap<(String, SplitLabel), CapacityTotals> =
2✔
232
        HashMap::new();
2✔
233

234
    for split_label in [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test] {
6✔
235
        let mut totals = CapacityTotals::default();
6✔
236

237
        for source in &inventories {
6✔
238
            let source_split_records = per_source_split_counts
6✔
239
                .get(&(source.source_id.clone(), split_label))
6✔
240
                .copied()
6✔
241
                .unwrap_or(0);
6✔
242

6✔
243
            let triplet_recipes = &source.triplet_recipes;
6✔
244
            let text_recipes = resolve_text_recipes_for_source(&config, triplet_recipes);
6✔
245

6✔
246
            let capacity = estimate_source_split_capacity_from_counts(
6✔
247
                source_split_records,
6✔
248
                triplet_recipes,
6✔
249
                &text_recipes,
6✔
250
            );
6✔
251

6✔
252
            totals_by_source_and_split.insert((source.source_id.clone(), split_label), capacity);
6✔
253

6✔
254
            totals.triplets += capacity.triplets;
6✔
255
            totals.effective_triplets += capacity.effective_triplets;
6✔
256
            totals.pairs += capacity.pairs;
6✔
257
            totals.text_samples += capacity.text_samples;
6✔
258
        }
6✔
259

260
        totals_by_split.insert(split_label, totals);
6✔
261
    }
262

263
    let min_nonzero_records_by_split: HashMap<SplitLabel, u128> =
2✔
264
        [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test]
2✔
265
            .into_iter()
2✔
266
            .map(|split_label| {
6✔
267
                let min_nonzero = inventories
6✔
268
                    .iter()
6✔
269
                    .filter_map(|source| {
6✔
270
                        per_source_split_counts
6✔
271
                            .get(&(source.source_id.clone(), split_label))
6✔
272
                            .copied()
6✔
273
                    })
6✔
274
                    .filter(|&records| records > 0)
6✔
275
                    .min()
6✔
276
                    .unwrap_or(0);
6✔
277
                (split_label, min_nonzero)
6✔
278
            })
6✔
279
            .collect();
2✔
280

281
    let min_nonzero_records_all_splits = inventories
2✔
282
        .iter()
2✔
283
        .map(|source| source.reported_records)
2✔
284
        .filter(|&records| records > 0)
2✔
285
        .min()
2✔
286
        .unwrap_or(0);
2✔
287

288
    println!("=== capacity estimate (length-only) ===");
2✔
289
    println!("mode: metadata-only (no source.refresh calls)");
2✔
290
    println!("classification: heuristic approximation (not exact)");
2✔
291
    println!("split seed: {}", cli.seed);
2✔
292
    println!(
2✔
293
        "split ratios: train={:.4}, validation={:.4}, test={:.4}",
294
        cli.split.train, cli.split.validation, cli.split.test
295
    );
296
    println!();
2✔
297

298
    println!("[SOURCES]");
2✔
299
    for source in &inventories {
2✔
300
        println!(
2✔
301
            "  {} => reported records: {}",
2✔
302
            source.source_id,
2✔
303
            format_u128_with_commas(source.reported_records)
2✔
304
        );
2✔
305
    }
2✔
306
    println!();
2✔
307

308
    println!("[PER SOURCE BREAKDOWN]");
2✔
309
    for source in &inventories {
2✔
310
        println!("  {}", source.source_id);
2✔
311
        let mut source_grand = CapacityTotals::default();
2✔
312
        let mut source_total_records = 0u128;
2✔
313
        for split_label in [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test] {
6✔
314
            let split_records = per_source_split_counts
6✔
315
                .get(&(source.source_id.clone(), split_label))
6✔
316
                .copied()
6✔
317
                .unwrap_or(0);
6✔
318
            source_total_records = source_total_records.saturating_add(split_records);
6✔
319
            let split_longest_records = inventories
6✔
320
                .iter()
6✔
321
                .map(|candidate| {
6✔
322
                    per_source_split_counts
6✔
323
                        .get(&(candidate.source_id.clone(), split_label))
6✔
324
                        .copied()
6✔
325
                        .unwrap_or(0)
6✔
326
                })
6✔
327
                .max()
6✔
328
                .unwrap_or(0);
6✔
329
            let totals = totals_by_source_and_split
6✔
330
                .get(&(source.source_id.clone(), split_label))
6✔
331
                .copied()
6✔
332
                .unwrap_or_default();
6✔
333
            source_grand.triplets += totals.triplets;
6✔
334
            source_grand.effective_triplets += totals.effective_triplets;
6✔
335
            source_grand.pairs += totals.pairs;
6✔
336
            source_grand.text_samples += totals.text_samples;
6✔
337
            println!("    [{:?}]", split_label);
6✔
338
            println!("      records: {}", format_u128_with_commas(split_records));
6✔
339
            println!(
6✔
340
                "      triplet combinations: {}",
341
                format_u128_with_commas(totals.triplets)
6✔
342
            );
343
            println!(
6✔
344
                "      effective sampled triplets (p={}, k={}): {}",
345
                EFFECTIVE_POSITIVES_PER_ANCHOR,
346
                EFFECTIVE_NEGATIVES_PER_ANCHOR,
347
                format_u128_with_commas(totals.effective_triplets)
6✔
348
            );
349
            println!(
6✔
350
                "      pair combinations:    {}",
351
                format_u128_with_commas(totals.pairs)
6✔
352
            );
353
            println!(
6✔
354
                "      text samples:         {}",
355
                format_u128_with_commas(totals.text_samples)
6✔
356
            );
357
            println!(
6✔
358
                "      replay factor vs longest source: {}",
359
                format_replay_factor(split_longest_records, split_records)
6✔
360
            );
361
            println!(
6✔
362
                "      suggested proportional-size batch weight (0-1): {:.4}",
363
                suggested_balancing_weight(split_longest_records, split_records)
6✔
364
            );
365
            let split_smallest_nonzero = min_nonzero_records_by_split
6✔
366
                .get(&split_label)
6✔
367
                .copied()
6✔
368
                .unwrap_or(0);
6✔
369
            println!(
6✔
370
                "      suggested small-source-boost batch weight (0-1): {:.4}",
371
                suggested_oversampling_weight(split_smallest_nonzero, split_records)
6✔
372
            );
373
            println!();
6✔
374
        }
375
        let longest_source_total = inventories
2✔
376
            .iter()
2✔
377
            .map(|candidate| candidate.reported_records)
2✔
378
            .max()
2✔
379
            .unwrap_or(0);
2✔
380
        println!("    [ALL SPLITS FOR SOURCE]");
2✔
381
        println!(
2✔
382
            "      triplet combinations: {}",
383
            format_u128_with_commas(source_grand.triplets)
2✔
384
        );
385
        println!(
2✔
386
            "      effective sampled triplets (p={}, k={}): {}",
387
            EFFECTIVE_POSITIVES_PER_ANCHOR,
388
            EFFECTIVE_NEGATIVES_PER_ANCHOR,
389
            format_u128_with_commas(source_grand.effective_triplets)
2✔
390
        );
391
        println!(
2✔
392
            "      pair combinations:    {}",
393
            format_u128_with_commas(source_grand.pairs)
2✔
394
        );
395
        println!(
2✔
396
            "      text samples:         {}",
397
            format_u128_with_commas(source_grand.text_samples)
2✔
398
        );
399
        println!(
2✔
400
            "      replay factor vs longest source: {}",
401
            format_replay_factor(longest_source_total, source_total_records)
2✔
402
        );
403
        println!(
2✔
404
            "      suggested proportional-size batch weight (0-1): {:.4}",
405
            suggested_balancing_weight(longest_source_total, source_total_records)
2✔
406
        );
407
        println!(
2✔
408
            "      suggested small-source-boost batch weight (0-1): {:.4}",
409
            suggested_oversampling_weight(min_nonzero_records_all_splits, source_total_records)
2✔
410
        );
411
        println!();
2✔
412
    }
413

414
    let mut grand = CapacityTotals::default();
2✔
415
    for split_label in [SplitLabel::Train, SplitLabel::Validation, SplitLabel::Test] {
6✔
416
        let record_count = split_record_counts.get(&split_label).copied().unwrap_or(0);
6✔
417
        let totals = totals_by_split
6✔
418
            .get(&split_label)
6✔
419
            .copied()
6✔
420
            .unwrap_or_default();
6✔
421

6✔
422
        grand.triplets += totals.triplets;
6✔
423
        grand.effective_triplets += totals.effective_triplets;
6✔
424
        grand.pairs += totals.pairs;
6✔
425
        grand.text_samples += totals.text_samples;
6✔
426

6✔
427
        println!("[{:?}]", split_label);
6✔
428
        println!("  records: {}", format_u128_with_commas(record_count));
6✔
429
        println!(
6✔
430
            "  triplet combinations: {}",
6✔
431
            format_u128_with_commas(totals.triplets)
6✔
432
        );
6✔
433
        println!(
6✔
434
            "  effective sampled triplets (p={}, k={}): {}",
6✔
435
            EFFECTIVE_POSITIVES_PER_ANCHOR,
6✔
436
            EFFECTIVE_NEGATIVES_PER_ANCHOR,
6✔
437
            format_u128_with_commas(totals.effective_triplets)
6✔
438
        );
6✔
439
        println!(
6✔
440
            "  pair combinations:    {}",
6✔
441
            format_u128_with_commas(totals.pairs)
6✔
442
        );
6✔
443
        println!(
6✔
444
            "  text samples:         {}",
6✔
445
            format_u128_with_commas(totals.text_samples)
6✔
446
        );
6✔
447
        println!();
6✔
448
    }
6✔
449

450
    println!("[ALL SPLITS TOTAL]");
2✔
451
    println!(
2✔
452
        "  triplet combinations: {}",
453
        format_u128_with_commas(grand.triplets)
2✔
454
    );
455
    println!(
2✔
456
        "  effective sampled triplets (p={}, k={}): {}",
457
        EFFECTIVE_POSITIVES_PER_ANCHOR,
458
        EFFECTIVE_NEGATIVES_PER_ANCHOR,
459
        format_u128_with_commas(grand.effective_triplets)
2✔
460
    );
461
    println!(
2✔
462
        "  pair combinations:    {}",
463
        format_u128_with_commas(grand.pairs)
2✔
464
    );
465
    println!(
2✔
466
        "  text samples:         {}",
467
        format_u128_with_commas(grand.text_samples)
2✔
468
    );
469
    println!();
2✔
470
    println!(
2✔
471
        "Note: counts are heuristic, length-based estimates from source-reported totals and recipe structure. They are approximate, not exact, and assume anchor-positive pairs=records (one positive per anchor by default), negatives=source_records_in_split-1 (anchor excluded as its own negative), and at most one chunk/window realization per sample. In real-world chunked sampling, practical combinations are often higher, so treat this as a floor-like baseline."
472
    );
473
    println!();
2✔
474
    println!(
2✔
475
        "Effective sampled triplets apply a bounded training assumption: effective_triplets = records * p * k per triplet recipe, with defaults p={} positives per anchor and k={} negatives per anchor.",
476
        EFFECTIVE_POSITIVES_PER_ANCHOR, EFFECTIVE_NEGATIVES_PER_ANCHOR
477
    );
478
    println!();
2✔
479
    println!(
2✔
480
        "Oversample loops are not inferred from this static report. To measure true oversampling (how many times sampling loops through the combination space), use observed sampled draw counts from an actual run."
481
    );
482
    println!();
2✔
483
    println!(
2✔
484
        "Suggested proportional-size batch weight (0-1) is source/max_source by record count: 1.0 for the largest source in scope, smaller values for smaller sources."
485
    );
486
    println!();
2✔
487
    println!(
2✔
488
        "Suggested small-source-boost batch weight (0-1) is min_nonzero_source/source by record count: 1.0 for the smallest non-zero source in scope, smaller values for larger sources."
489
    );
490
    println!();
2✔
491
    println!(
2✔
492
        "When passed to next_*_batch_with_weights, higher weight means that source is sampled more often relative to lower-weight sources."
493
    );
494

495
    Ok(())
2✔
496
}
4✔
497

498
/// Run the multi-source demo CLI with injectable root resolution/source builders.
499
///
500
/// `build_sources` is construction-only. Source sampler configuration is owned
501
/// by sampler registration (`TripletSampler::register_source`).
502
pub fn run_multi_source_demo<R, Resolve, Build, I>(
12✔
503
    args_iter: I,
12✔
504
    resolve_roots: Resolve,
12✔
505
    build_sources: Build,
12✔
506
) -> Result<(), Box<dyn Error>>
12✔
507
where
12✔
508
    Resolve: FnOnce(Vec<String>) -> Result<R, Box<dyn Error>>,
12✔
509
    Build: FnOnce(&R) -> Vec<DynSource>,
12✔
510
    I: Iterator<Item = String>,
12✔
511
{
512
    init_example_tracing();
12✔
513

514
    let Some(cli) = parse_cli::<MultiSourceDemoCli, _>(
12✔
515
        std::iter::once("multi_source_demo".to_string()).chain(args_iter),
12✔
516
    )?
×
517
    else {
518
        return Ok(());
×
519
    };
520

521
    let roots = resolve_roots(cli.source_roots)?;
12✔
522

523
    let mut config = SamplerConfig::default();
11✔
524
    config.seed = cli.seed.unwrap_or(config.seed);
11✔
525
    config.batch_size = cli.batch_size;
11✔
526
    config.chunking = Default::default();
11✔
527
    let selected_split = cli.split.map(Into::into).unwrap_or(SplitLabel::Train);
11✔
528
    config.split = SplitRatios::default();
11✔
529
    config.allowed_splits = vec![selected_split];
11✔
530
    let chunking = config.chunking.clone();
11✔
531

532
    let split_store_path = if let Some(path) = cli.split_store_path {
11✔
533
        path
11✔
534
    } else {
535
        managed_demo_split_store_path().map_err(|err| {
×
536
            Box::<dyn Error>::from(format!("failed to resolve demo split-store path: {err}"))
×
537
        })?
×
538
    };
539

540
    println!(
11✔
541
        "Persisting split assignments and epoch state to {}",
542
        split_store_path.display()
11✔
543
    );
544
    let sources = build_sources(&roots);
11✔
545
    let split_store = Arc::new(FileSplitStore::open(&split_store_path, config.split, 99)?);
11✔
546
    let sampler = TripletSampler::new(config, split_store.clone());
11✔
547
    for source in sources {
11✔
548
        sampler.register_source(source);
11✔
549
    }
11✔
550

551
    if cli.show_pair_samples {
11✔
552
        match sampler.next_pair_batch(selected_split) {
3✔
553
            Ok(pair_batch) => {
×
554
                if pair_batch.pairs.is_empty() {
×
555
                    println!("Pair sampling produced no results.");
×
556
                } else {
×
557
                    print_pair_batch(&chunking, &pair_batch, split_store.as_ref());
×
558
                }
×
559
                sampler.save_sampler_state(None)?;
×
560
            }
561
            Err(SamplerError::Exhausted(name)) => {
3✔
562
                eprintln!(
3✔
563
                    "Pair sampler exhausted recipe '{}'. Ensure both positive and negative examples exist.",
3✔
564
                    name
3✔
565
                );
3✔
566
            }
3✔
567
            Err(err) => return Err(err.into()),
×
568
        }
569
    } else if cli.show_text_samples {
8✔
570
        match sampler.next_text_batch(selected_split) {
3✔
571
            Ok(text_batch) => {
1✔
572
                if text_batch.samples.is_empty() {
1✔
573
                    println!(
×
574
                        "Text sampling produced no results. Ensure each source has eligible sections."
×
575
                    );
×
576
                } else {
1✔
577
                    print_text_batch(&chunking, &text_batch, split_store.as_ref());
1✔
578
                }
1✔
579
                sampler.save_sampler_state(None)?;
1✔
580
            }
581
            Err(SamplerError::Exhausted(name)) => {
2✔
582
                eprintln!(
2✔
583
                    "Text sampler exhausted selector '{}'. Ensure matching sections exist.",
2✔
584
                    name
2✔
585
                );
2✔
586
            }
2✔
587
            Err(err) => return Err(err.into()),
×
588
        }
589
    } else if cli.list_text_recipes {
5✔
590
        let recipes = sampler.text_recipes();
2✔
591
        if recipes.is_empty() {
2✔
592
            println!(
1✔
593
                "No text recipes registered. Ensure your sources expose triplet selectors or configure text_recipes explicitly."
1✔
594
            );
1✔
595
        } else {
1✔
596
            print_text_recipes(&recipes);
1✔
597
        }
1✔
598
    } else {
599
        match sampler.next_triplet_batch(selected_split) {
3✔
600
            Ok(triplet_batch) => {
×
601
                if triplet_batch.triplets.is_empty() {
×
602
                    println!(
×
603
                        "Triplet sampling produced no results. Ensure multiple records per source exist."
×
604
                    );
×
605
                } else {
×
606
                    print_triplet_batch(&chunking, &triplet_batch, split_store.as_ref());
×
607
                }
×
608
                sampler.save_sampler_state(None)?;
×
609
            }
610
            Err(SamplerError::Exhausted(name)) => {
3✔
611
                eprintln!(
3✔
612
                    "Triplet sampler exhausted recipe '{}'. Ensure both positive and negative examples exist.",
3✔
613
                    name
3✔
614
                );
3✔
615
            }
3✔
616
            Err(err) => return Err(err.into()),
×
617
        }
618
    }
619

620
    Ok(())
11✔
621
}
12✔
622

623
fn parse_positive_usize(raw: &str) -> Result<usize, String> {
17✔
624
    let parsed = raw.parse::<usize>().map_err(|_| {
17✔
625
        format!(
1✔
626
            "Could not parse --batch-size value '{}' as a positive integer",
627
            raw
628
        )
629
    })?;
1✔
630
    if parsed == 0 {
16✔
631
        return Err("--batch-size must be greater than zero".to_string());
2✔
632
    }
14✔
633
    Ok(parsed)
14✔
634
}
17✔
635

636
fn suggested_balancing_weight(max_baseline: u128, source_baseline: u128) -> f32 {
13✔
637
    if max_baseline == 0 || source_baseline == 0 {
13✔
638
        return 0.0;
4✔
639
    }
9✔
640
    (source_baseline as f64 / max_baseline as f64).clamp(0.0, 1.0) as f32
9✔
641
}
13✔
642

643
fn suggested_oversampling_weight(min_nonzero_baseline: u128, source_baseline: u128) -> f32 {
13✔
644
    if min_nonzero_baseline == 0 || source_baseline == 0 {
13✔
645
        return 0.0;
4✔
646
    }
9✔
647
    (min_nonzero_baseline as f64 / source_baseline as f64).clamp(0.0, 1.0) as f32
9✔
648
}
13✔
649

650
fn parse_cli<T, I>(args: I) -> Result<Option<T>, Box<dyn Error>>
22✔
651
where
22✔
652
    T: Parser,
22✔
653
    I: IntoIterator,
22✔
654
    I::Item: Into<std::ffi::OsString> + Clone,
22✔
655
{
656
    match T::try_parse_from(args) {
22✔
657
        Ok(cli) => Ok(Some(cli)),
17✔
658
        Err(err) => match err.kind() {
5✔
659
            ErrorKind::DisplayHelp | ErrorKind::DisplayVersion => {
660
                err.print()?;
3✔
661
                Ok(None)
3✔
662
            }
663
            _ => Err(err.into()),
2✔
664
        },
665
    }
666
}
22✔
667

668
fn parse_split_ratios_arg(raw: &str) -> Result<SplitRatios, String> {
11✔
669
    let parts: Vec<&str> = raw.split(',').collect();
11✔
670
    if parts.len() != 3 {
11✔
671
        return Err("--split-ratios expects exactly 3 comma-separated values".to_string());
1✔
672
    }
10✔
673
    let train = parts[0]
10✔
674
        .trim()
10✔
675
        .parse::<f32>()
10✔
676
        .map_err(|_| format!("invalid train ratio '{}': must be a float", parts[0].trim()))?;
10✔
677
    let validation = parts[1].trim().parse::<f32>().map_err(|_| {
9✔
678
        format!(
1✔
679
            "invalid validation ratio '{}': must be a float",
680
            parts[1].trim()
1✔
681
        )
682
    })?;
1✔
683
    let test = parts[2]
8✔
684
        .trim()
8✔
685
        .parse::<f32>()
8✔
686
        .map_err(|_| format!("invalid test ratio '{}': must be a float", parts[2].trim()))?;
8✔
687
    let ratios = SplitRatios {
7✔
688
        train,
7✔
689
        validation,
7✔
690
        test,
7✔
691
    };
7✔
692
    let sum = ratios.train + ratios.validation + ratios.test;
7✔
693
    if (sum - 1.0).abs() > 1e-5 {
7✔
694
        return Err(format!(
1✔
695
            "split ratios must sum to 1.0, got {:.6} (train={}, validation={}, test={})",
1✔
696
            sum, ratios.train, ratios.validation, ratios.test
1✔
697
        ));
1✔
698
    }
6✔
699
    if ratios.train < 0.0 || ratios.validation < 0.0 || ratios.test < 0.0 {
6✔
700
        return Err("split ratios must be non-negative".to_string());
1✔
701
    }
5✔
702
    Ok(ratios)
5✔
703
}
11✔
704

705
fn print_triplet_batch(
1✔
706
    strategy: &ChunkingStrategy,
1✔
707
    batch: &TripletBatch,
1✔
708
    split_store: &impl SplitStore,
1✔
709
) {
1✔
710
    println!("=== triplet batch ===");
1✔
711
    for (idx, triplet) in batch.triplets.iter().enumerate() {
1✔
712
        println!("--- triplet #{} ---", idx);
1✔
713
        println!("recipe       : {}", triplet.recipe);
1✔
714
        println!("sample_weight: {:.4}", triplet.weight);
1✔
715
        if let Some(instr) = &triplet.instruction {
1✔
716
            println!("instruction shown to model:\n{}\n", instr);
1✔
717
        }
1✔
718
        print_chunk_block("ANCHOR", &triplet.anchor, strategy, split_store);
1✔
719
        print_chunk_block("POSITIVE", &triplet.positive, strategy, split_store);
1✔
720
        print_chunk_block("NEGATIVE", &triplet.negative, strategy, split_store);
1✔
721
    }
722
    print_source_summary(
1✔
723
        "triplet anchors",
1✔
724
        batch
1✔
725
            .triplets
1✔
726
            .iter()
1✔
727
            .map(|triplet| triplet.anchor.record_id.as_str()),
1✔
728
    );
729
    print_recipe_context_by_source(
1✔
730
        "triplet recipes by source",
1✔
731
        batch
1✔
732
            .triplets
1✔
733
            .iter()
1✔
734
            .map(|triplet| (triplet.anchor.record_id.as_str(), triplet.recipe.as_str())),
1✔
735
    );
736
}
1✔
737

738
fn print_text_batch(strategy: &ChunkingStrategy, batch: &TextBatch, split_store: &impl SplitStore) {
2✔
739
    println!("=== text batch ===");
2✔
740
    for (idx, sample) in batch.samples.iter().enumerate() {
5✔
741
        println!("--- sample #{} ---", idx);
5✔
742
        println!("recipe       : {}", sample.recipe);
5✔
743
        println!("sample_weight: {:.4}", sample.weight);
5✔
744
        if let Some(instr) = &sample.instruction {
5✔
745
            println!("instruction shown to model:\n{}\n", instr);
1✔
746
        }
4✔
747
        print_chunk_block("TEXT", &sample.chunk, strategy, split_store);
5✔
748
    }
749
    print_source_summary(
2✔
750
        "text samples",
2✔
751
        batch
2✔
752
            .samples
2✔
753
            .iter()
2✔
754
            .map(|sample| sample.chunk.record_id.as_str()),
5✔
755
    );
756
    print_recipe_context_by_source(
2✔
757
        "text recipes by source",
2✔
758
        batch
2✔
759
            .samples
2✔
760
            .iter()
2✔
761
            .map(|sample| (sample.chunk.record_id.as_str(), sample.recipe.as_str())),
5✔
762
    );
763
}
2✔
764

765
fn print_pair_batch(
1✔
766
    strategy: &ChunkingStrategy,
1✔
767
    batch: &SampleBatch,
1✔
768
    split_store: &impl SplitStore,
1✔
769
) {
1✔
770
    println!("=== pair batch ===");
1✔
771
    for (idx, pair) in batch.pairs.iter().enumerate() {
1✔
772
        println!("--- pair #{} ---", idx);
1✔
773
        println!("recipe       : {}", pair.recipe);
1✔
774
        println!("label        : {:?}", pair.label);
1✔
775
        if let Some(reason) = &pair.reason {
1✔
776
            println!("reason       : {}", reason);
1✔
777
        }
1✔
778
        print_chunk_block("ANCHOR", &pair.anchor, strategy, split_store);
1✔
779
        print_chunk_block("OTHER", &pair.positive, strategy, split_store);
1✔
780
    }
781
    print_source_summary(
1✔
782
        "pair anchors",
1✔
783
        batch
1✔
784
            .pairs
1✔
785
            .iter()
1✔
786
            .map(|pair| pair.anchor.record_id.as_str()),
1✔
787
    );
788
    print_recipe_context_by_source(
1✔
789
        "pair recipes by source",
1✔
790
        batch
1✔
791
            .pairs
1✔
792
            .iter()
1✔
793
            .map(|pair| (pair.anchor.record_id.as_str(), pair.recipe.as_str())),
1✔
794
    );
795
}
1✔
796

797
fn print_text_recipes(recipes: &[TextRecipe]) {
2✔
798
    println!("=== available text recipes ===");
2✔
799
    for recipe in recipes {
4✔
800
        println!(
4✔
801
            "- {} (weight: {:.3}) selector={:?}",
802
            recipe.name, recipe.weight, recipe.selector
803
        );
804
        if let Some(instr) = &recipe.instruction {
4✔
805
            println!("  instruction: {}", instr);
1✔
806
        }
3✔
807
    }
808
}
2✔
809

810
trait ChunkDebug {
811
    fn view_name(&self) -> String;
812
}
813

814
impl ChunkDebug for RecordChunk {
815
    fn view_name(&self) -> String {
10✔
816
        match &self.view {
10✔
817
            ChunkView::Window {
818
                index,
8✔
819
                span,
8✔
820
                overlap,
8✔
821
                start_ratio,
8✔
822
            } => format!(
8✔
823
                "window#index={} span={} overlap={} start_ratio={:.3} tokens={}",
824
                index, span, overlap, start_ratio, self.tokens_estimate
825
            ),
826
            ChunkView::SummaryFallback { strategy, .. } => {
2✔
827
                format!("summary:{} tokens={}", strategy, self.tokens_estimate)
2✔
828
            }
829
        }
830
    }
10✔
831
}
832

833
fn print_chunk_block(
10✔
834
    title: &str,
10✔
835
    chunk: &RecordChunk,
10✔
836
    strategy: &ChunkingStrategy,
10✔
837
    split_store: &impl SplitStore,
10✔
838
) {
10✔
839
    let chunk_weight = chunk_weight(strategy, chunk);
10✔
840
    let split = split_store
10✔
841
        .label_for(&chunk.record_id)
10✔
842
        .map(|label| format!("{:?}", label))
10✔
843
        .unwrap_or_else(|| "Unknown".to_string());
10✔
844
    println!("--- {} ---", title);
10✔
845
    println!("split        : {}", split);
10✔
846
    println!("view         : {}", chunk.view_name());
10✔
847
    println!("chunk_weight : {:.4}", chunk_weight);
10✔
848
    println!("record_id    : {}", chunk.record_id);
10✔
849
    println!("section_idx  : {}", chunk.section_idx);
10✔
850
    println!("token_est    : {}", chunk.tokens_estimate);
10✔
851
    println!("model_input (exact text sent to the model):");
10✔
852
    println!(
10✔
853
        "<<< BEGIN MODEL TEXT >>>\n{}\n<<< END MODEL TEXT >>>\n",
854
        chunk.text
855
    );
856
}
10✔
857

858
fn print_source_summary<'a, I>(label: &str, ids: I)
4✔
859
where
4✔
860
    I: Iterator<Item = &'a str>,
4✔
861
{
862
    let mut counts: HashMap<SourceId, usize> = HashMap::new();
4✔
863
    for id in ids {
7✔
864
        let source = extract_source(id);
7✔
865
        *counts.entry(source).or_insert(0) += 1;
7✔
866
    }
7✔
867
    if counts.is_empty() {
4✔
868
        return;
×
869
    }
4✔
870
    let skew = source_skew(&counts);
4✔
871
    let mut entries: Vec<(String, usize)> = counts.into_iter().collect();
4✔
872
    entries.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
4✔
873
    println!("--- {} by source ---", label);
4✔
874
    if let Some(skew) = skew {
4✔
875
        for entry in &skew.per_source {
4✔
876
            println!(
4✔
877
                "{}: count={} share={:.2}",
4✔
878
                entry.source, entry.count, entry.share
4✔
879
            );
4✔
880
        }
4✔
881
        println!(
4✔
882
            "skew: sources={} total={} min={} max={} mean={:.2} ratio={:.2}",
883
            skew.sources, skew.total, skew.min, skew.max, skew.mean, skew.ratio
884
        );
885
    } else {
886
        for (source, count) in &entries {
×
887
            println!("{source}: count={count}");
×
888
        }
×
889
    }
890
}
4✔
891

892
fn print_recipe_context_by_source<'a, I>(label: &str, entries: I)
4✔
893
where
4✔
894
    I: Iterator<Item = (&'a str, &'a str)>,
4✔
895
{
896
    let mut counts: HashMap<SourceId, HashMap<String, usize>> = HashMap::new();
4✔
897
    for (record_id, recipe) in entries {
7✔
898
        let source = extract_source(record_id);
7✔
899
        let entry = counts
7✔
900
            .entry(source)
7✔
901
            .or_default()
7✔
902
            .entry(recipe.to_string())
7✔
903
            .or_insert(0);
7✔
904
        *entry += 1;
7✔
905
    }
7✔
906
    if counts.is_empty() {
4✔
907
        return;
×
908
    }
4✔
909
    let mut sources: Vec<(SourceId, HashMap<String, usize>)> = counts.into_iter().collect();
4✔
910
    sources.sort_by(|a, b| a.0.cmp(&b.0));
4✔
911
    println!("--- {} ---", label);
4✔
912
    for (source, recipes) in sources {
4✔
913
        println!("{source}");
4✔
914
        let mut entries: Vec<(String, usize)> = recipes.into_iter().collect();
4✔
915
        entries.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
4✔
916
        for (recipe, count) in entries {
5✔
917
            println!("  - {recipe}={count}");
5✔
918
        }
5✔
919
    }
920
}
4✔
921

922
fn extract_source(record_id: &str) -> SourceId {
16✔
923
    record_id
16✔
924
        .split_once("::")
16✔
925
        .map(|(source, _)| source.to_string())
16✔
926
        .unwrap_or_else(|| "unknown".to_string())
16✔
927
}
16✔
928

929
#[cfg(test)]
930
mod tests {
931
    use super::*;
932
    use crate::DataRecord;
933
    use crate::DeterministicSplitStore;
934
    use crate::data::{QualityScore, RecordSection, SectionRole};
935
    use crate::source::{SourceCursor, SourceSnapshot};
936
    use chrono::Utc;
937
    use tempfile::tempdir;
938

939
    /// Minimal in-memory `DataSource` test double for example app tests.
940
    struct TestSource {
941
        id: String,
942
        count: Option<u128>,
943
        recipes: Vec<TripletRecipe>,
944
    }
945

946
    impl DataSource for TestSource {
947
        fn id(&self) -> &str {
100✔
948
            &self.id
100✔
949
        }
100✔
950

951
        fn refresh(
30✔
952
            &self,
30✔
953
            _config: &SamplerConfig,
30✔
954
            _cursor: Option<&SourceCursor>,
30✔
955
            _limit: Option<usize>,
30✔
956
        ) -> Result<SourceSnapshot, SamplerError> {
30✔
957
            Ok(SourceSnapshot {
30✔
958
                records: Vec::new(),
30✔
959
                cursor: SourceCursor {
30✔
960
                    last_seen: Utc::now(),
30✔
961
                    revision: 0,
30✔
962
                },
30✔
963
            })
30✔
964
        }
30✔
965

966
        fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
2✔
967
            self.count.ok_or_else(|| SamplerError::SourceInconsistent {
2✔
968
                source_id: self.id.clone(),
1✔
969
                details: "test source has no configured exact count".to_string(),
1✔
970
            })
1✔
971
        }
2✔
972

973
        fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
10✔
974
            self.recipes.clone()
10✔
975
        }
10✔
976
    }
977

978
    struct ConfigRequiredSource {
979
        id: String,
980
        expected_seed: u64,
981
    }
982

983
    impl DataSource for ConfigRequiredSource {
984
        fn id(&self) -> &str {
1✔
985
            &self.id
1✔
986
        }
1✔
987

988
        fn refresh(
1✔
989
            &self,
1✔
990
            _config: &SamplerConfig,
1✔
991
            _cursor: Option<&SourceCursor>,
1✔
992
            _limit: Option<usize>,
1✔
993
        ) -> Result<SourceSnapshot, SamplerError> {
1✔
994
            Ok(SourceSnapshot {
1✔
995
                records: Vec::new(),
1✔
996
                cursor: SourceCursor {
1✔
997
                    last_seen: Utc::now(),
1✔
998
                    revision: 0,
1✔
999
                },
1✔
1000
            })
1✔
1001
        }
1✔
1002

1003
        fn reported_record_count(&self, config: &SamplerConfig) -> Result<u128, SamplerError> {
2✔
1004
            if config.seed == self.expected_seed {
2✔
1005
                Ok(1)
1✔
1006
            } else {
1007
                Err(SamplerError::SourceInconsistent {
1✔
1008
                    source_id: self.id.clone(),
1✔
1009
                    details: format!(
1✔
1010
                        "expected sampler seed {} but got {}",
1✔
1011
                        self.expected_seed, config.seed
1✔
1012
                    ),
1✔
1013
                })
1✔
1014
            }
1015
        }
2✔
1016

1017
        fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
2✔
1018
            Vec::new()
2✔
1019
        }
2✔
1020
    }
1021

1022
    fn default_recipe(name: &str) -> TripletRecipe {
9✔
1023
        TripletRecipe {
9✔
1024
            name: name.to_string().into(),
9✔
1025
            anchor: crate::config::Selector::Role(SectionRole::Anchor),
9✔
1026
            positive_selector: crate::config::Selector::Role(SectionRole::Context),
9✔
1027
            negative_selector: crate::config::Selector::Role(SectionRole::Context),
9✔
1028
            negative_strategy: crate::config::NegativeStrategy::WrongArticle,
9✔
1029
            weight: 1.0,
9✔
1030
            instruction: None,
9✔
1031
        }
9✔
1032
    }
9✔
1033

1034
    #[test]
1035
    fn parse_helpers_validate_inputs() {
1✔
1036
        assert_eq!(parse_positive_usize("2").unwrap(), 2);
1✔
1037
        assert!(parse_positive_usize("0").is_err());
1✔
1038
        assert!(parse_positive_usize("abc").is_err());
1✔
1039

1040
        let split = parse_split_ratios_arg("0.8,0.1,0.1").unwrap();
1✔
1041
        assert!((split.train - 0.8).abs() < 1e-6);
1✔
1042
        assert!(parse_split_ratios_arg("0.8,0.1").is_err());
1✔
1043
        assert!(parse_split_ratios_arg("1.0,0.0,0.1").is_err());
1✔
1044
        assert!(parse_split_ratios_arg("-0.1,0.6,0.5").is_err());
1✔
1045
    }
1✔
1046

1047
    #[test]
1048
    fn suggested_balancing_weight_is_longest_normalized_and_bounded() {
1✔
1049
        assert!((suggested_balancing_weight(100, 100) - 1.0).abs() < 1e-6);
1✔
1050
        assert!((suggested_balancing_weight(400, 100) - 0.25).abs() < 1e-6);
1✔
1051
        assert!((suggested_balancing_weight(400, 400) - 1.0).abs() < 1e-6);
1✔
1052
        assert_eq!(suggested_balancing_weight(0, 100), 0.0);
1✔
1053
        assert_eq!(suggested_balancing_weight(100, 0), 0.0);
1✔
1054
    }
1✔
1055

1056
    #[test]
1057
    fn suggested_oversampling_weight_is_inverse_in_unit_interval() {
1✔
1058
        assert!((suggested_oversampling_weight(100, 100) - 1.0).abs() < 1e-6);
1✔
1059
        assert!((suggested_oversampling_weight(100, 400) - 0.25).abs() < 1e-6);
1✔
1060
        assert!((suggested_oversampling_weight(100, 1000) - 0.1).abs() < 1e-6);
1✔
1061
        assert_eq!(suggested_oversampling_weight(0, 100), 0.0);
1✔
1062
        assert_eq!(suggested_oversampling_weight(100, 0), 0.0);
1✔
1063
    }
1✔
1064

1065
    #[test]
1066
    fn parse_cli_handles_help_and_invalid_args() {
1✔
1067
        let help = parse_cli::<EstimateCapacityCli, _>(["estimate_capacity", "--help"]).unwrap();
1✔
1068
        assert!(help.is_none());
1✔
1069

1070
        let err = parse_cli::<EstimateCapacityCli, _>(["estimate_capacity", "--unknown"]);
1✔
1071
        assert!(err.is_err());
1✔
1072
    }
1✔
1073

1074
    #[test]
1075
    fn run_estimate_capacity_succeeds_with_reported_counts() {
1✔
1076
        let result = run_estimate_capacity(
1✔
1077
            std::iter::empty::<String>(),
1✔
1078
            |roots| {
1✔
1079
                assert!(roots.is_empty());
1✔
1080
                Ok(())
1✔
1081
            },
1✔
1082
            |_| {
1✔
1083
                vec![Box::new(TestSource {
1✔
1084
                    id: "source_a".into(),
1✔
1085
                    count: Some(12),
1✔
1086
                    recipes: vec![default_recipe("r1")],
1✔
1087
                }) as DynSource]
1✔
1088
            },
1✔
1089
        );
1090

1091
        assert!(result.is_ok());
1✔
1092
    }
1✔
1093

1094
    #[test]
1095
    fn run_estimate_capacity_errors_when_source_count_missing() {
1✔
1096
        let result = run_estimate_capacity(
1✔
1097
            std::iter::empty::<String>(),
1✔
1098
            |_| Ok(()),
1✔
1099
            |_| {
1✔
1100
                vec![Box::new(TestSource {
1✔
1101
                    id: "source_missing".into(),
1✔
1102
                    count: None,
1✔
1103
                    recipes: vec![default_recipe("r1")],
1✔
1104
                }) as DynSource]
1✔
1105
            },
1✔
1106
        );
1107

1108
        let err = result.unwrap_err().to_string();
1✔
1109
        assert!(err.contains("failed to report exact record count"));
1✔
1110
    }
1✔
1111

1112
    #[test]
1113
    fn run_estimate_capacity_propagates_root_resolution_error() {
1✔
1114
        let result = run_estimate_capacity(
1✔
1115
            std::iter::empty::<String>(),
1✔
1116
            |_| Err("root resolution failed".into()),
1✔
1117
            |_: &()| Vec::<DynSource>::new(),
×
1118
        );
1119

1120
        let err = result.unwrap_err().to_string();
1✔
1121
        assert!(err.contains("root resolution failed"));
1✔
1122
    }
1✔
1123

1124
    #[test]
1125
    fn run_estimate_capacity_configures_sources_centrally_before_counting() {
1✔
1126
        let result = run_estimate_capacity(
1✔
1127
            std::iter::empty::<String>(),
1✔
1128
            |_| Ok(()),
1✔
1129
            |_| {
1✔
1130
                vec![Box::new(ConfigRequiredSource {
1✔
1131
                    id: "requires_config".into(),
1✔
1132
                    expected_seed: 99,
1✔
1133
                }) as DynSource]
1✔
1134
            },
1✔
1135
        );
1136

1137
        assert!(result.is_ok());
1✔
1138
    }
1✔
1139

1140
    #[test]
1141
    fn config_required_source_refresh_and_seed_mismatch_are_exercised() {
1✔
1142
        let source = ConfigRequiredSource {
1✔
1143
            id: "cfg-source".to_string(),
1✔
1144
            expected_seed: 42,
1✔
1145
        };
1✔
1146

1147
        let refreshed = source
1✔
1148
            .refresh(&SamplerConfig::default(), None, None)
1✔
1149
            .unwrap();
1✔
1150
        assert!(refreshed.records.is_empty());
1✔
1151

1152
        let mismatched = source.reported_record_count(&SamplerConfig {
1✔
1153
            seed: 7,
1✔
1154
            ..SamplerConfig::default()
1✔
1155
        });
1✔
1156
        assert!(matches!(
1✔
1157
            mismatched,
1✔
1158
            Err(SamplerError::SourceInconsistent { .. })
1159
        ));
1160

1161
        assert!(source.default_triplet_recipes().is_empty());
1✔
1162
    }
1✔
1163

1164
    #[test]
1165
    fn run_multi_source_demo_exhausted_paths_return_ok() {
1✔
1166
        struct OneRecordSource;
1167

1168
        impl DataSource for OneRecordSource {
1169
            fn id(&self) -> &str {
37✔
1170
                "one_record"
37✔
1171
            }
37✔
1172

1173
            fn refresh(
11✔
1174
                &self,
11✔
1175
                _config: &SamplerConfig,
11✔
1176
                _cursor: Option<&SourceCursor>,
11✔
1177
                _limit: Option<usize>,
11✔
1178
            ) -> Result<SourceSnapshot, SamplerError> {
11✔
1179
                let now = Utc::now();
11✔
1180
                Ok(SourceSnapshot {
11✔
1181
                    records: vec![DataRecord {
11✔
1182
                        id: "one_record::r1".to_string(),
11✔
1183
                        source: "one_record".to_string(),
11✔
1184
                        created_at: now,
11✔
1185
                        updated_at: now,
11✔
1186
                        quality: QualityScore { trust: 1.0 },
11✔
1187
                        taxonomy: Vec::new(),
11✔
1188
                        sections: vec![
11✔
1189
                            RecordSection {
11✔
1190
                                role: SectionRole::Anchor,
11✔
1191
                                heading: Some("title".to_string()),
11✔
1192
                                text: "anchor".to_string(),
11✔
1193
                                sentences: vec!["anchor".to_string()],
11✔
1194
                            },
11✔
1195
                            RecordSection {
11✔
1196
                                role: SectionRole::Context,
11✔
1197
                                heading: Some("body".to_string()),
11✔
1198
                                text: "context".to_string(),
11✔
1199
                                sentences: vec!["context".to_string()],
11✔
1200
                            },
11✔
1201
                        ],
11✔
1202
                        meta_prefix: None,
11✔
1203
                    }],
11✔
1204
                    cursor: SourceCursor {
11✔
1205
                        last_seen: now,
11✔
1206
                        revision: 0,
11✔
1207
                    },
11✔
1208
                })
11✔
1209
            }
11✔
1210

1211
            fn reported_record_count(&self, _config: &SamplerConfig) -> Result<u128, SamplerError> {
×
1212
                Ok(1)
×
1213
            }
×
1214

1215
            fn default_triplet_recipes(&self) -> Vec<TripletRecipe> {
3✔
1216
                vec![default_recipe("single_record_recipe")]
3✔
1217
            }
3✔
1218
        }
1219

1220
        for mode in ["--pair-batch", "--text-recipes", ""] {
3✔
1221
            let dir = tempdir().unwrap();
3✔
1222
            let split_store_path = dir.path().join("split_store.bin");
3✔
1223
            let mut args = vec![
3✔
1224
                "--split-store-path".to_string(),
3✔
1225
                split_store_path.to_string_lossy().to_string(),
3✔
1226
            ];
1227
            if !mode.is_empty() {
3✔
1228
                args.push(mode.to_string());
2✔
1229
            }
2✔
1230

1231
            let result = run_multi_source_demo(
3✔
1232
                args.into_iter(),
3✔
1233
                |_| Ok(()),
3✔
1234
                |_| vec![Box::new(OneRecordSource) as DynSource],
3✔
1235
            );
1236
            assert!(result.is_ok());
3✔
1237
        }
1238
    }
1✔
1239

1240
    #[test]
1241
    fn parse_multi_source_cli_handles_help_and_batch_size_validation() {
1✔
1242
        let help = parse_cli::<MultiSourceDemoCli, _>(["multi_source_demo", "--help"]).unwrap();
1✔
1243
        assert!(help.is_none());
1✔
1244

1245
        let err = parse_cli::<MultiSourceDemoCli, _>(["multi_source_demo", "--batch-size", "0"]);
1✔
1246
        assert!(err.is_err());
1✔
1247

1248
        let parsed = parse_cli::<MultiSourceDemoCli, _>(["multi_source_demo"]);
1✔
1249
        assert!(parsed.is_ok());
1✔
1250
    }
1✔
1251

1252
    #[test]
1253
    fn parse_cli_handles_display_version_path() {
1✔
1254
        #[derive(Debug, Parser)]
1255
        #[command(name = "version_test", version = "1.0.0")]
1256
        struct VersionCli {}
1257

1258
        let parsed = parse_cli::<VersionCli, _>(["version_test", "--version"]).unwrap();
1✔
1259
        assert!(parsed.is_none());
1✔
1260
    }
1✔
1261

1262
    #[test]
1263
    fn run_multi_source_demo_list_text_recipes_path_succeeds() {
1✔
1264
        let dir = tempdir().unwrap();
1✔
1265
        let split_store_path = dir.path().join("recipes_split_store.bin");
1✔
1266
        let mut args = vec![
1✔
1267
            "--list-text-recipes".to_string(),
1✔
1268
            "--split-store-path".to_string(),
1✔
1269
            split_store_path.to_string_lossy().to_string(),
1✔
1270
        ];
1271
        let result = run_multi_source_demo(
1✔
1272
            args.drain(..),
1✔
1273
            |_| Ok(()),
1✔
1274
            |_| {
1✔
1275
                vec![Box::new(TestSource {
1✔
1276
                    id: "source_for_recipes".into(),
1✔
1277
                    count: Some(10),
1✔
1278
                    recipes: vec![default_recipe("recipe_a")],
1✔
1279
                }) as DynSource]
1✔
1280
            },
1✔
1281
        );
1282

1283
        assert!(result.is_ok());
1✔
1284
    }
1✔
1285

1286
    #[test]
1287
    fn run_multi_source_demo_list_text_recipes_uses_explicit_split_store_path() {
1✔
1288
        let dir = tempdir().unwrap();
1✔
1289
        let split_store_path = dir.path().join("custom_split_store.bin");
1✔
1290
        let args = vec![
1✔
1291
            "--list-text-recipes".to_string(),
1✔
1292
            "--split-store-path".to_string(),
1✔
1293
            split_store_path.to_string_lossy().to_string(),
1✔
1294
        ];
1295

1296
        let result = run_multi_source_demo(
1✔
1297
            args.into_iter(),
1✔
1298
            |_| Ok(()),
1✔
1299
            |_| {
1✔
1300
                vec![Box::new(TestSource {
1✔
1301
                    id: "source_without_text_recipes".into(),
1✔
1302
                    count: Some(1),
1✔
1303
                    recipes: Vec::new(),
1✔
1304
                }) as DynSource]
1✔
1305
            },
1✔
1306
        );
1307

1308
        assert!(result.is_ok());
1✔
1309
    }
1✔
1310

1311
    #[test]
1312
    fn run_multi_source_demo_sampling_modes_handle_empty_sources() {
1✔
1313
        for mode in [
3✔
1314
            vec!["--pair-batch".to_string()],
1✔
1315
            vec!["--text-recipes".to_string()],
1✔
1316
            vec![],
1✔
1317
        ] {
1✔
1318
            let dir = tempdir().unwrap();
3✔
1319
            let split_store_path = dir.path().join("empty_sources_split_store.bin");
3✔
1320
            let mut args = mode;
3✔
1321
            args.push("--split-store-path".to_string());
3✔
1322
            args.push(split_store_path.to_string_lossy().to_string());
3✔
1323
            args.push("--split".to_string());
3✔
1324
            args.push("validation".to_string());
3✔
1325

1326
            let result = run_multi_source_demo(
3✔
1327
                args.into_iter(),
3✔
1328
                |_| Ok(()),
3✔
1329
                |_| {
3✔
1330
                    vec![Box::new(TestSource {
3✔
1331
                        id: "source_empty".into(),
3✔
1332
                        count: Some(0),
3✔
1333
                        recipes: vec![default_recipe("recipe_empty")],
3✔
1334
                    }) as DynSource]
3✔
1335
                },
3✔
1336
            );
1337

1338
            assert!(result.is_ok());
3✔
1339
        }
1340
    }
1✔
1341

1342
    #[test]
1343
    fn run_multi_source_demo_propagates_root_resolution_error() {
1✔
1344
        let dir = tempdir().unwrap();
1✔
1345
        let split_store_path = dir.path().join("root_resolution_error_store.bin");
1✔
1346
        let result = run_multi_source_demo(
1✔
1347
            [
1✔
1348
                "--split-store-path".to_string(),
1✔
1349
                split_store_path.to_string_lossy().to_string(),
1✔
1350
            ]
1✔
1351
            .into_iter(),
1✔
1352
            |_| Err("demo root resolution failed".into()),
1✔
1353
            |_: &()| Vec::<DynSource>::new(),
×
1354
        );
1355

1356
        let err = result.unwrap_err().to_string();
1✔
1357
        assert!(err.contains("demo root resolution failed"));
1✔
1358
    }
1✔
1359

1360
    #[test]
1361
    fn print_helpers_and_extract_source_cover_paths() {
1✔
1362
        let split = SplitRatios::default();
1✔
1363
        let store = DeterministicSplitStore::new(split, 42).unwrap();
1✔
1364
        let strategy = ChunkingStrategy::default();
1✔
1365

1366
        let anchor = RecordChunk {
1✔
1367
            record_id: "source_a::rec1".to_string(),
1✔
1368
            section_idx: 0,
1✔
1369
            view: ChunkView::Window {
1✔
1370
                index: 1,
1✔
1371
                overlap: 2,
1✔
1372
                span: 12,
1✔
1373
                start_ratio: 0.25,
1✔
1374
            },
1✔
1375
            text: "anchor text".to_string(),
1✔
1376
            tokens_estimate: 8,
1✔
1377
            quality: crate::data::QualityScore { trust: 0.9 },
1✔
1378
        };
1✔
1379
        let positive = RecordChunk {
1✔
1380
            record_id: "source_a::rec2".to_string(),
1✔
1381
            section_idx: 1,
1✔
1382
            view: ChunkView::SummaryFallback {
1✔
1383
                strategy: "summary".to_string(),
1✔
1384
                weight: 0.7,
1✔
1385
            },
1✔
1386
            text: "positive text".to_string(),
1✔
1387
            tokens_estimate: 6,
1✔
1388
            quality: crate::data::QualityScore { trust: 0.8 },
1✔
1389
        };
1✔
1390
        let negative = RecordChunk {
1✔
1391
            record_id: "source_b::rec3".to_string(),
1✔
1392
            section_idx: 2,
1✔
1393
            view: ChunkView::Window {
1✔
1394
                index: 0,
1✔
1395
                overlap: 0,
1✔
1396
                span: 16,
1✔
1397
                start_ratio: 0.0,
1✔
1398
            },
1✔
1399
            text: "negative text".to_string(),
1✔
1400
            tokens_estimate: 7,
1✔
1401
            quality: crate::data::QualityScore { trust: 0.5 },
1✔
1402
        };
1✔
1403

1404
        let triplet_batch = TripletBatch {
1✔
1405
            triplets: vec![crate::SampleTriplet {
1✔
1406
                recipe: "triplet_recipe".to_string(),
1✔
1407
                anchor: anchor.clone(),
1✔
1408
                positive: positive.clone(),
1✔
1409
                negative: negative.clone(),
1✔
1410
                weight: 1.0,
1✔
1411
                instruction: Some("triplet instruction".to_string()),
1✔
1412
            }],
1✔
1413
        };
1✔
1414
        print_triplet_batch(&strategy, &triplet_batch, &store);
1✔
1415

1416
        let pair_batch = SampleBatch {
1✔
1417
            pairs: vec![crate::SamplePair {
1✔
1418
                recipe: "pair_recipe".to_string(),
1✔
1419
                anchor: anchor.clone(),
1✔
1420
                positive: positive.clone(),
1✔
1421
                weight: 1.0,
1✔
1422
                instruction: None,
1✔
1423
                label: crate::PairLabel::Positive,
1✔
1424
                reason: Some("same topic".to_string()),
1✔
1425
            }],
1✔
1426
        };
1✔
1427
        print_pair_batch(&strategy, &pair_batch, &store);
1✔
1428

1429
        let text_batch = TextBatch {
1✔
1430
            samples: vec![crate::TextSample {
1✔
1431
                recipe: "text_recipe".to_string(),
1✔
1432
                chunk: negative,
1✔
1433
                weight: 0.8,
1✔
1434
                instruction: Some("text instruction".to_string()),
1✔
1435
            }],
1✔
1436
        };
1✔
1437
        print_text_batch(&strategy, &text_batch, &store);
1✔
1438

1439
        let recipes = vec![TextRecipe {
1✔
1440
            name: "recipe_name".into(),
1✔
1441
            selector: crate::config::Selector::Role(SectionRole::Context),
1✔
1442
            instruction: Some("instruction".into()),
1✔
1443
            weight: 1.0,
1✔
1444
        }];
1✔
1445
        print_text_recipes(&recipes);
1✔
1446

1447
        assert_eq!(extract_source("source_a::record"), "source_a");
1✔
1448
        assert_eq!(extract_source("record-without-delimiter"), "unknown");
1✔
1449
    }
1✔
1450

1451
    #[test]
1452
    fn split_arg_conversion_and_version_parse_paths_are_covered() {
1✔
1453
        assert!(matches!(
1✔
1454
            SplitLabel::from(SplitArg::Train),
1✔
1455
            SplitLabel::Train
1456
        ));
1457
        assert!(matches!(
1✔
1458
            SplitLabel::from(SplitArg::Validation),
1✔
1459
            SplitLabel::Validation
1460
        ));
1461
        assert!(matches!(SplitLabel::from(SplitArg::Test), SplitLabel::Test));
1✔
1462
    }
1✔
1463

1464
    #[test]
1465
    fn parse_split_ratios_reports_per_field_parse_errors() {
1✔
1466
        assert!(
1✔
1467
            parse_split_ratios_arg("x,0.1,0.9")
1✔
1468
                .unwrap_err()
1✔
1469
                .contains("invalid train ratio")
1✔
1470
        );
1471
        assert!(
1✔
1472
            parse_split_ratios_arg("0.1,y,0.8")
1✔
1473
                .unwrap_err()
1✔
1474
                .contains("invalid validation ratio")
1✔
1475
        );
1476
        assert!(
1✔
1477
            parse_split_ratios_arg("0.1,0.2,z")
1✔
1478
                .unwrap_err()
1✔
1479
                .contains("invalid test ratio")
1✔
1480
        );
1481
    }
1✔
1482

1483
    #[test]
1484
    fn run_multi_source_demo_exhausted_paths_are_handled() {
1✔
1485
        for mode in [
3✔
1486
            vec!["--pair-batch".to_string()],
1✔
1487
            vec!["--text-recipes".to_string()],
1✔
1488
            Vec::new(),
1✔
1489
        ] {
1✔
1490
            let dir = tempdir().unwrap();
3✔
1491
            let split_store_path = dir.path().join("exhausted_split_store.bin");
3✔
1492
            let mut args = mode;
3✔
1493
            args.push("--split-store-path".to_string());
3✔
1494
            args.push(split_store_path.to_string_lossy().to_string());
3✔
1495

1496
            let result = run_multi_source_demo(
3✔
1497
                args.into_iter(),
3✔
1498
                |_| Ok(()),
3✔
1499
                |_| {
3✔
1500
                    vec![Box::new(TestSource {
3✔
1501
                        id: "source_without_recipes".into(),
3✔
1502
                        count: Some(1),
3✔
1503
                        recipes: Vec::new(),
3✔
1504
                    }) as DynSource]
3✔
1505
                },
3✔
1506
            );
1507

1508
            assert!(result.is_ok());
3✔
1509
        }
1510
    }
1✔
1511
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc