• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In
Build has been canceled!

wildjames / predictive_coding_rs / 26030244354

18 May 2026 11:19AM UTC coverage: 89.597% (+1.0%) from 88.629%
26030244354

Pull #18

github

web-flow
Merge 84aa8ae92 into ab068c7f4
Pull Request #18: Implement the remaining GPU kernels

255 of 275 new or added lines in 12 files covered. (92.73%)

6 existing lines in 2 files now uncovered.

2067 of 2307 relevant lines covered (89.6%)

14.3 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.81
/src/benchmark.rs
1
//! This program is used to benchmark the speed of various training configurations and model architectures. It takes a training config file as input, and times its processes, creating a series of files detailing the results.
2

3
use std::{collections::BTreeMap, path::Path, time::Instant};
4

5
use predictive_coding::{
6
    error::{PredictiveCodingError, Result},
7
    model::{PredictiveCodingModelConfig, save_model_config},
8
    training::{
9
        StepProfile, TrainConfig, TrainingHandler, save_training_config, setup_training_run_handler,
10
    },
11
    utils::{logging, timestamp},
12
};
13

14
use clap::Parser;
15
use tracing::info;
16

17
#[cfg(test)]
18
#[path = "test_utils.rs"]
19
mod test_utils;
20

21
/// This program is used to benchmark the speed of various training configurations and model architectures. It takes a training config file as input, and times its processes, creating a series of files detailing the results.
22
#[derive(Parser)]
23
struct BenchArgs {
24
    /// The model configuration to benchmark.
25
    // #[arg(default_value_t = String::from("benchmark_data/benchmark_minibatch_config.json"))]
26
    #[arg(default_value_t = String::from("benchmark_data/benchmark_gpu_singlethread_config.json"))]
27
    config: String,
28

29
    /// Optional artifact output prefix. Defaults to `benchmark_data/<timestamp>/benchmark`.
30
    #[arg(long, default_value_t = format!("benchmark_data/benchmark_{}/bench_", timestamp()))]
31
    output_prefix: String,
32
}
33

34
fn current_git_commit_hash_with_command(command_name: &str, args: &[&str]) -> Result<String> {
6✔
35
    let command = if args.is_empty() {
6✔
36
        command_name.to_string()
1✔
37
    } else {
38
        format!("{} {}", command_name, args.join(" "))
5✔
39
    };
40
    let output = std::process::Command::new(command_name)
6✔
41
        .args(args)
6✔
42
        .output()
6✔
43
        .map_err(|source| PredictiveCodingError::command_io(command.clone(), source))?;
6✔
44

45
    if !output.status.success() {
5✔
46
        return Err(PredictiveCodingError::command_failed(
1✔
47
            command,
1✔
48
            output.status,
1✔
49
            String::from_utf8_lossy(&output.stderr),
1✔
50
        ));
1✔
51
    }
4✔
52

53
    Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
4✔
54
}
6✔
55

56
fn current_git_commit_hash() -> Result<String> {
3✔
57
    current_git_commit_hash_with_command("git", &["rev-parse", "HEAD"])
3✔
58
}
3✔
59

60
fn run_benchmark(args: BenchArgs) -> Result<()> {
3✔
61
    // Detect if this binary was compiled in release mode or not
62
    let release_mode: bool = !cfg!(debug_assertions);
3✔
63
    #[cfg(debug_assertions)]
64
    info!(
3✔
65
        "Running benchmark in debug mode. For more accurate benchmarking, compile with --release"
66
    );
67
    let mut handler: Box<dyn TrainingHandler> =
3✔
68
        setup_training_run_handler(args.config, args.output_prefix.clone())?;
3✔
69

70
    let step_data: Vec<BenchmarkStepData> = run_benchmark_training_loop(
3✔
71
        handler.as_mut(),
3✔
72
        &format!("{}_{}", &args.output_prefix, "bench_run.csv"),
3✔
73
    )?;
×
74

75
    // Compute per-phase summary
76
    let phase_summary = compute_phase_summary(&step_data);
3✔
77

78
    // Print summary to console
79
    info!("--- Benchmark Phase Summary ---");
3✔
80
    let total_wall: f32 = step_data.iter().map(|s| s.total_ms).sum();
3✔
81
    info!(
3✔
82
        "Total wall time: {:.1} ms over {} steps ({:.1} ms/step avg)",
83
        total_wall,
84
        step_data.len(),
2✔
85
        if step_data.is_empty() {
2✔
NEW
86
            0.0
×
87
        } else {
88
            total_wall / step_data.len() as f32
2✔
89
        }
90
    );
91
    for (name, s) in &phase_summary {
12✔
92
        info!(
12✔
93
            "  {:<25} mean={:>8.2}ms  min={:>8.2}ms  max={:>8.2}ms  total={:>10.1}ms  ({:.1}%)",
94
            name, s.mean_ms, s.min_ms, s.max_ms, s.total_ms, s.pct_of_total
95
        );
96
    }
97

98
    // Write the training params to "{output_prefix}/params.json"
99
    let current_commit_hash_str: String = current_git_commit_hash()?;
3✔
100

101
    let result = BenchmarkResult {
3✔
102
        step_data,
3✔
103
        phase_summary,
3✔
104
        git_commit_hash: current_commit_hash_str,
3✔
105
        run_timestamp: chrono::Utc::now().to_rfc3339(),
3✔
106
        release_mode,
3✔
107
    };
3✔
108

109
    // Write the benchmarking parameters, training parameters, and model configuration to files for posterity.
110
    // This is JSON, so probably a bit harder to read than the CSV file, but it does create a single file with all the relevant information for each benchmark run, which is nice.
111
    let result_path: String = format!("{}_{}", args.output_prefix, "result.json");
3✔
112
    let result_file = std::fs::File::create(&result_path).map_err(|source| {
3✔
113
        PredictiveCodingError::io("create benchmark result", &result_path, source)
×
114
    })?;
×
115
    serde_json::to_writer_pretty(result_file, &result)
3✔
116
        .map_err(|source| PredictiveCodingError::json_serialize(&result_path, source))?;
3✔
117

118
    Ok(())
3✔
119
}
3✔
120

121
fn main() -> Result<()> {
2✔
122
    logging::setup_tracing(false);
2✔
123
    info!("Starting benchmark run");
2✔
124
    run_benchmark(BenchArgs::parse())
2✔
125
}
2✔
126

127
#[derive(serde::Serialize)]
128
struct BenchmarkResult {
129
    step_data: Vec<BenchmarkStepData>,
130
    phase_summary: BTreeMap<String, PhaseSummary>,
131
    git_commit_hash: String,
132
    run_timestamp: String,
133
    release_mode: bool,
134
}
135

136
#[derive(serde::Serialize)]
137
struct BenchmarkStepData {
138
    step: u32,
139
    total_ms: f32,
140
    phases: BTreeMap<String, f32>,
141
}
142

143
#[derive(serde::Serialize)]
144
struct PhaseSummary {
145
    mean_ms: f32,
146
    min_ms: f32,
147
    max_ms: f32,
148
    total_ms: f32,
149
    pct_of_total: f32,
150
}
151

152
fn run_benchmark_training_loop(
4✔
153
    handler: &mut dyn TrainingHandler,
4✔
154
    bench_run_outfile: &str,
4✔
155
) -> Result<Vec<BenchmarkStepData>> {
4✔
156
    let mut benchmark_data: Vec<BenchmarkStepData> = Vec::new();
4✔
157

158
    handler.pre_training_hook()?;
4✔
159

160
    // Time each training step and write to a csv file
161
    if let Some(parent) = Path::new(bench_run_outfile)
4✔
162
        .parent()
4✔
163
        .filter(|path| !path.as_os_str().is_empty())
4✔
164
    {
165
        std::fs::create_dir_all(parent).map_err(|source| {
4✔
166
            PredictiveCodingError::io("create benchmark output directory", parent, source)
×
167
        })?;
×
168
    }
×
169
    let mut wtr = csv::Writer::from_path(bench_run_outfile).map_err(|source| {
4✔
170
        PredictiveCodingError::csv("create benchmark writer", bench_run_outfile, source)
×
171
    })?;
×
172

173
    let training_config: &TrainConfig = handler.get_config();
4✔
174
    let training_steps: u32 = training_config.training_steps;
4✔
175

176
    // Write the config and training params to a file
177
    let model_config: PredictiveCodingModelConfig = handler.model_config();
4✔
178
    save_model_config(
4✔
179
        &model_config,
4✔
180
        &format!("{}_model_config.json", &handler.get_file_output_prefix()),
4✔
181
    )?;
×
182
    save_training_config(
4✔
183
        handler.get_config(),
4✔
184
        &format!("{}_training_config.json", &handler.get_file_output_prefix()),
4✔
185
    )?;
×
186

187
    // We discover phase names from the first step's profile, then use a
188
    // consistent column order for the CSV.
189
    let mut phase_names: Vec<String> = Vec::new();
4✔
190
    let mut header_written = false;
4✔
191

192
    for step in 0..training_steps {
8✔
193
        let start_time: Instant = Instant::now();
8✔
194
        handler.pre_step_hook(step)?;
8✔
195
        let profile: StepProfile = handler.profiled_train_step(step)?;
8✔
196
        handler.post_step_hook(step)?;
8✔
197
        let wall_time = start_time.elapsed();
8✔
198

199
        let wall_time_ms: f32 = wall_time.as_secs_f32() * 1000.0;
8✔
200

201
        // On the first step, discover phase names and write the CSV header
202
        if !header_written {
8✔
203
            phase_names = profile
4✔
204
                .phases
4✔
205
                .iter()
4✔
206
                .map(|(name, _)| name.clone())
12✔
207
                .collect();
4✔
208
            let mut header: Vec<String> = vec!["step".into(), "total_ms".into()];
4✔
209
            for name in &phase_names {
12✔
210
                header.push(format!("{}_ms", name));
12✔
211
            }
12✔
212
            wtr.write_record(&header).map_err(|source| {
4✔
NEW
213
                PredictiveCodingError::csv("write benchmark header", bench_run_outfile, source)
×
UNCOV
214
            })?;
×
215
            header_written = true;
4✔
216
        }
4✔
217

218
        // Build phase timing map
219
        let mut phases: BTreeMap<String, f32> = BTreeMap::new();
8✔
220
        for (name, dur) in &profile.phases {
24✔
221
            phases.insert(name.clone(), dur.as_secs_f32() * 1000.0);
24✔
222
        }
24✔
223

224
        // Write CSV row
225
        let mut row: Vec<String> = vec![step.to_string(), format!("{:.3}", wall_time_ms)];
8✔
226
        for name in &phase_names {
24✔
227
            row.push(format!("{:.3}", phases.get(name).copied().unwrap_or(0.0)));
24✔
228
        }
24✔
229
        wtr.write_record(&row).map_err(|source| {
8✔
NEW
230
            PredictiveCodingError::csv("append benchmark row", bench_run_outfile, source)
×
NEW
231
        })?;
×
232
        wtr.flush().map_err(|source| {
8✔
233
            PredictiveCodingError::io("flush benchmark CSV", bench_run_outfile, source)
×
234
        })?;
×
235

236
        // Log with phase breakdown
237
        let phase_str: String = profile
8✔
238
            .phases
8✔
239
            .iter()
8✔
240
            .map(|(name, dur)| format!("{}={:.1}ms", name, dur.as_secs_f32() * 1000.0))
24✔
241
            .collect::<Vec<_>>()
8✔
242
            .join("  ");
8✔
243
        info!(
8✔
244
            "Step {}: total {:.1} ms  [{}]",
245
            step, wall_time_ms, phase_str
246
        );
247

248
        benchmark_data.push(BenchmarkStepData {
8✔
249
            step,
8✔
250
            total_ms: wall_time_ms,
8✔
251
            phases,
8✔
252
        });
8✔
253
    }
254

255
    handler.post_training_hook()?;
4✔
256

257
    Ok(benchmark_data)
4✔
258
}
4✔
259

260
/// Compute per-phase summary statistics from the step data.
261
fn compute_phase_summary(step_data: &[BenchmarkStepData]) -> BTreeMap<String, PhaseSummary> {
3✔
262
    if step_data.is_empty() {
3✔
NEW
263
        return BTreeMap::new();
×
264
    }
3✔
265

266
    // Collect all phase names
267
    let mut all_phases: BTreeMap<String, Vec<f32>> = BTreeMap::new();
3✔
268
    let mut total_wall_ms: f32 = 0.0;
3✔
269

270
    for step in step_data {
6✔
271
        total_wall_ms += step.total_ms;
6✔
272
        for (name, &ms) in &step.phases {
24✔
273
            all_phases.entry(name.clone()).or_default().push(ms);
24✔
274
        }
24✔
275
    }
276

277
    let mut summary = BTreeMap::new();
3✔
278
    for (name, times) in &all_phases {
12✔
279
        let total: f32 = times.iter().sum();
12✔
280
        let mean = total / times.len() as f32;
12✔
281
        let min = times.iter().copied().fold(f32::INFINITY, f32::min);
12✔
282
        let max = times.iter().copied().fold(f32::NEG_INFINITY, f32::max);
12✔
283
        let pct = if total_wall_ms > 0.0 {
12✔
284
            (total / total_wall_ms) * 100.0
12✔
285
        } else {
NEW
286
            0.0
×
287
        };
288
        summary.insert(
12✔
289
            name.clone(),
12✔
290
            PhaseSummary {
12✔
291
                mean_ms: mean,
12✔
292
                min_ms: min,
12✔
293
                max_ms: max,
12✔
294
                total_ms: total,
12✔
295
                pct_of_total: pct,
12✔
296
            },
12✔
297
        );
298
    }
299

300
    summary
3✔
301
}
3✔
302

303
#[cfg(test)]
304
mod tests {
305
    use super::test_utils::{RecordingTrainingHandler, TempDir, single_thread_train_config};
306
    use super::*;
307

308
    use std::{fs, path::PathBuf};
309

310
    #[test]
311
    fn current_git_commit_hash_helper_handles_success_and_failure_cases() {
1✔
312
        let success: String =
1✔
313
            current_git_commit_hash_with_command("sh", &["-c", "printf 'abc123\\n'"]).unwrap();
1✔
314
        assert_eq!(success, "abc123");
1✔
315

316
        let failure =
1✔
317
            current_git_commit_hash_with_command("sh", &["-c", "printf 'boom\\n' >&2; exit 2"]);
1✔
318
        assert!(matches!(
1✔
319
          failure,
1✔
320
          Err(PredictiveCodingError::CommandFailed { stderr, .. }) if stderr.contains("boom")
1✔
321
        ));
322

323
        let spawn_failure: std::result::Result<String, PredictiveCodingError> =
1✔
324
            current_git_commit_hash_with_command("/definitely/missing/git", &[]);
1✔
325
        assert!(matches!(
1✔
326
            spawn_failure,
1✔
327
            Err(PredictiveCodingError::CommandIo { .. })
328
        ));
329
    }
1✔
330

331
    #[test]
332
    fn benchmark_loop_writes_csv_and_artifacts() {
1✔
333
        let temp_dir: TempDir = TempDir::new("benchmark_loop");
1✔
334
        let output_prefix: String = temp_dir.join("nested/bench").display().to_string();
1✔
335
        let bench_csv: PathBuf = temp_dir.join("nested/bench_run.csv");
1✔
336
        let mut handler = RecordingTrainingHandler::new(
1✔
337
            single_thread_train_config(2, 0, 0),
1✔
338
            output_prefix.clone(),
1✔
339
        );
340

341
        let step_data: Vec<BenchmarkStepData> =
1✔
342
            run_benchmark_training_loop(&mut handler, bench_csv.to_str().unwrap()).unwrap();
1✔
343

344
        assert_eq!(handler.steps, vec![0, 1]);
1✔
345
        assert_eq!(step_data.len(), 2);
1✔
346
        assert!(bench_csv.exists());
1✔
347
        assert!(Path::new(&format!("{}_model_config.json", output_prefix)).exists());
1✔
348
        assert!(Path::new(&format!("{}_training_config.json", output_prefix)).exists());
1✔
349
        assert!(Path::new(&format!("{}_final_model.json", output_prefix)).exists());
1✔
350

351
        let csv_output = fs::read_to_string(bench_csv).unwrap();
1✔
352
        assert!(csv_output.contains("step,total_ms"));
1✔
353
        assert!(csv_output.contains("\n0,"));
1✔
354
        assert!(csv_output.contains("\n1,"));
1✔
355
    }
1✔
356

357
    #[test]
358
    fn recording_handler_exposes_dataset_fixture() {
1✔
359
        let handler = RecordingTrainingHandler::new(
1✔
360
            single_thread_train_config(2, 0, 0),
1✔
361
            String::from("unused/bench"),
1✔
362
        );
363
        let data = handler.get_data();
1✔
364

365
        assert_eq!(data.get_dataset_size(), 1);
1✔
366
        assert_eq!(data.get_input_size(), 4);
1✔
367
        assert_eq!(data.get_output_size(), 10);
1✔
368
        assert_eq!(data.get_random_input(), data.get_input(0));
1✔
369

370
        let (input, output) = data.get_random_input_and_output();
1✔
371
        assert_eq!(input, data.get_input(0));
1✔
372
        assert_eq!(output, data.get_output(0));
1✔
373
    }
1✔
374

375
    #[test]
376
    fn run_benchmark_writes_result_payload_in_process() {
1✔
377
        let temp_dir: TempDir = TempDir::new("benchmark_run_main_path");
1✔
378
        let output_prefix: String = temp_dir.join("end_to_end/bench").display().to_string();
1✔
379

380
        run_benchmark(BenchArgs {
1✔
381
            config: String::from("test_data/bench_single_thread_config.json"),
1✔
382
            output_prefix: output_prefix.clone(),
1✔
383
        })
1✔
384
        .unwrap();
1✔
385

386
        let result_path: String = format!("{}_result.json", output_prefix);
1✔
387
        let result_json: serde_json::Value =
1✔
388
            serde_json::from_str(&fs::read_to_string(result_path).unwrap()).unwrap();
1✔
389
        assert_eq!(result_json["step_data"].as_array().unwrap().len(), 2);
1✔
390
        assert!(!result_json["git_commit_hash"].as_str().unwrap().is_empty());
1✔
391
    }
1✔
392
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc