• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

wexlergroup / FreeBird.jl / 20478418717

24 Dec 2025 04:42AM UTC coverage: 83.19% (-5.7%) from 88.861%
20478418717

Pull #124

github

web-flow
Merge ee4128ff2 into f4fdc6647
Pull Request #124: MLIPs in FreeBird

18 of 130 new or added lines in 8 files covered. (13.85%)

26 existing lines in 1 file now uncovered.

1742 of 2094 relevant lines covered (83.19%)

61212.43 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.38
/src/SamplingSchemes/nested_sampling.jl
1
"""
2
    mutable struct NestedSamplingParameters <: SamplingParameters
3

4
The `NestedSamplingParameters` struct represents the parameters used in the nested sampling scheme.
5

6
# Fields
7
- `mc_steps::Int64`: The number of total Monte Carlo moves to perform. For a parallel MC routine, this number will be distributed among workers. 
8
If `mc_steps` is not divisible by the number of workers, the actual number of MC moves per worker will be `ceil(mc_steps / nworkers())`.
9
- `initial_step_size::Float64`: The initial step size, which is the fallback step size if MC routine fails to accept a move.
10
- `step_size::Float64`: The on-the-fly step size used in the sampling process.
11
- `step_size_lo::Float64`: The lower bound of the step size.
12
- `step_size_up::Float64`: The upper bound of the step size.
13
- `accept_range::Tuple{Float64, Float64}`: The range of acceptance rates for adjusting the step size.
14
e.g. (0.25, 0.75) means that the step size will decrease if the acceptance rate is below 0.25 and increase if it is above 0.75.
15
- `fail_count::Int64`: The number of failed MC moves in a row.
16
- `allowed_fail_count::Int64`: The maximum number of failed MC moves allowed before resetting the step size.
17
- `energy_perturbation::Float64`: The perturbation value used to adjust the energy of the walkers.
18
- `random_seed::Int64`: The seed for the random number generator.
19
"""
20
mutable struct NestedSamplingParameters <: SamplingParameters
21
    mc_steps::Int64
36✔
22
    initial_step_size::Float64
23
    step_size::Float64
24
    step_size_lo::Float64
25
    step_size_up::Float64
26
    accept_range::Tuple{Float64, Float64}
27
    fail_count::Int64
28
    allowed_fail_count::Int64
29
    energy_perturbation::Float64
30
    random_seed::Int64
31
end
32

33
function NestedSamplingParameters(;
25✔
34
            mc_steps::Int64=200,
35
            initial_step_size::Float64=0.01,
36
            step_size::Float64=0.1,
37
            step_size_lo::Float64=1e-6,
38
            step_size_up::Float64=1.0,
39
            accept_range::Tuple{Float64, Float64}=(0.25, 0.75),
40
            fail_count::Int64=0,
41
            allowed_fail_count::Int64=100,
42
            energy_perturbation::Float64=1e-12,
43
            random_seed::Int64=1234,
44
            )
45
    NestedSamplingParameters(mc_steps, initial_step_size, step_size, step_size_lo, step_size_up, accept_range, fail_count, allowed_fail_count, energy_perturbation, random_seed)  
20✔
46
end
47

48
"""
49
    LatticeNestedSamplingParameters(;
50
            mc_steps::Int64=100,
51
            energy_perturbation::Float64=1e-12,
52
            fail_count::Int64=0,
53
            allowed_fail_count::Int64=10,
54
            random_seed::Int64=1234,
55
            )
56
A convenience constructor for `NestedSamplingParameters` with default values suitable for lattice systems.
57
"""
58
function LatticeNestedSamplingParameters(;
20✔
59
            mc_steps::Int64=100,
60
            energy_perturbation::Float64=1e-12,
61
            fail_count::Int64=0,
62
            allowed_fail_count::Int64=10,
63
            random_seed::Int64=1234,
64
            )
65
    NestedSamplingParameters(mc_steps=mc_steps, fail_count=fail_count, allowed_fail_count=allowed_fail_count, energy_perturbation=energy_perturbation, random_seed=random_seed)
16✔
66
end
67

68

69
"""
70
    abstract type MCRoutine
71

72
An abstract type representing a Monte Carlo routine.
73

74
Currently, the following concrete types are supported:
75
- `MCRandomWalkMaxE`: A type for generating a new walker by performing a random walk for decorrelation on the
76
highest-energy walker.
77
- `MCRandomWalkClone`: A type for generating a new walker by cloning an existing walker and performing a random walk
78
for decorrelation.
79
- `MCNewSample`: A type for generating a new walker from a random configuration. Currently, it is intended to use 
80
this routine for lattice gas systems.
81
- `MCMixedMoves`: A type for generating a new walker by performing random walks and swapping atoms. Currently, it is
82
intended to use this routine for multi-component systems. The actual number of random walks and swaps to perform is
83
determined by the weights of the fields `walks_freq` and `swaps_freq`. See [`MCMixedMoves`](@ref).
84
- `MCRejectionSampling`: A type for generating a new walker by performing rejection sampling. Currently, it is intended
85
to use this routine for lattice gas systems.
86
- `MCDistributed`: A type for generating new walkers by performing random walks for decorrelation in parallel using Distributed.jl.
87
This routine supports multiple culling walkers and multiple decorrelation walkers. See [`MCDistributed`](@ref).
88
"""
89
abstract type MCRoutine end
90

91
"""
92
    abstract type MCRoutineParallel <: MCRoutine
93
(Internal) An abstract type representing a parallel Monte Carlo routine.
94
"""
95
abstract type MCRoutineParallel <: MCRoutine end
96

97
"""
98
    struct MCRandomWalkMaxE <: MCRoutine
99
A type for generating a new walker by performing a random walk for decorrelation on the highest-energy walker.
100
"""
101
struct MCRandomWalkMaxE <: MCRoutine 
102
    dims::Vector{Int64}
103
    function MCRandomWalkMaxE(dims::Vector{Int64}=[1, 2, 3])
65✔
104
        new(dims)
104✔
105
    end
106
end
107

108
"""
109
    struct MCRandomWalkClone <: MCRoutine
110
A type for generating a new walker by cloning an existing walker and performing a random walk for decorrelation.
111
"""
112
struct MCRandomWalkClone <: MCRoutine 
113
    dims::Vector{Int64}
114
    function MCRandomWalkClone(;dims::Vector{Int64}=[1, 2, 3])
45✔
115
        new(dims)
36✔
116
    end
117
end
118

119
"""
120
    struct MCRandomWalkCloneParallel <: MCRoutineParallel
121
A type for generating a new walker by cloning an existing walker and performing a random walk for decorrelation in parallel.
122
"""
123
struct MCRandomWalkCloneParallel <: MCRoutineParallel
124
    dims::Vector{Int64}
125
    function MCRandomWalkCloneParallel(;dims::Vector{Int64}=[1, 2, 3])
15✔
126
        new(dims)
12✔
127
    end
128
end
129

130
"""
131
    struct MCDistributed <: MCRoutineParallel
132
A type for generating new walkers by performing random walks for decorrelation in parallel using Distributed.jl.
133
# Fields
134
- `n_cull::Int64`: The number of lowest-energy walkers to cull (replace) in each iteration. The default is 1.
135
- `n_decorr::Int64`: The number of walkers to use for decorrelation (random walks). The default is `nworkers() - 1`.
136
- `dims::Vector{Int64}`: The dimensions along which to perform the random walks.
137
"""
138
struct MCDistributed <: MCRoutineParallel
139
    n_cull::Int64
140
    n_decorr::Int64
141
    dims::Vector{Int64}
142
    function MCDistributed(;n_cull::Int64=1, n_decorr::Int64=nworkers()-1, dims::Vector{Int64}=[1, 2, 3])
×
143
        if n_cull + n_decorr != nworkers()
×
144
            error("n_cull + n_decorr must be equal to the number of workers: $(nworkers())")
×
145
        end
146
        @info "Distributed nested sampling initiated: n_cull: $n_cull, n_decorr: $n_decorr, total workers: $(n_cull + n_decorr)"
×
147
        new(n_cull, n_decorr, dims)
×
148
    end
149
end
150

151
"""
152
    MCRandomWalkMaxEParallel <: MCRoutineParallel
153
A type for generating a new walker by performing a random walk for decorrelation on the highest-energy walker(s) in parallel.
154
"""
155
struct MCRandomWalkMaxEParallel <: MCRoutineParallel
156
    dims::Vector{Int64}
157
    function MCRandomWalkMaxEParallel(;dims::Vector{Int64}=[1, 2, 3])
15✔
158
        new(dims)
12✔
159
    end
160
end
161

162
"""
163
    struct MCNewSample <: MCRoutine
164
A type for generating a new walker from a random configuration. Currently, it is intended to use this routine for lattice gas systems.
165
"""
166
struct MCNewSample <: MCRoutine end
24✔
167

168
""" 
169
    struct MCMixedMoves <: MCRoutine
170
A type for generating a new walker by performing random walks and swapping atoms. Currently, it is intended to use this routine for
171
multi-component systems. The actual number of random walks and swaps to perform is determined by the weights of the fields `walks_freq` and `swaps_freq`.
172
For example, if `walks_freq=4` and `swaps_freq=1`, then the probability of performing a random walk is 4/5, and the probability of performing a swap is 1/5.
173

174
# Fields
175
- `walks_freq::Int`: The frequency of random walks to perform.
176
- `swaps_freq::Int`: The frequency of atom swaps to perform.
177
"""
178
mutable struct MCMixedMoves <: MCRoutine
179
    walks_freq::Int
12✔
180
    swaps_freq::Int
181
end
182

183
"""
184
    struct MCMixedMovesParallel <: MCRoutineParallel
185
A type for generating a new walker by performing random walks and swapping atoms in parallel. Currently, it is intended to use this routine for
186
multi-component systems. The actual number of random walks and swaps to perform is determined by the weights of the fields `walks_freq` and `swaps_freq`.
187
For example, if `walks_freq=4` and `swaps_freq=1`, then the probability of performing a random walk is 4/5, and the probability of performing a swap is 1/5.
188

189
# Fields
190
- `walks_freq::Int`: The frequency of random walks to perform.
191
- `swaps_freq::Int`: The frequency of atom swaps to perform.
192
"""
193
mutable struct MCMixedMovesParallel <: MCRoutineParallel
194
    walks_freq::Int
195
    swaps_freq::Int
196
end
197

198
"""
199
    struct MCRejectionSampling <: MCRoutine
200
A type for generating a new walker by performing rejection sampling. Currently, it is intended to use this routine for lattice gas systems.
201
"""
202
struct MCRejectionSampling <: MCRoutine end
4✔
203

204
"""
205
    sort_by_energy!(liveset::LJAtomWalkers)
206

207
Sorts the walkers in the liveset by their energy in descending order.
208

209
# Arguments
210
- `liveset::LJAtomWalkers`: The liveset of walkers to be sorted.
211

212
# Returns
213
- `liveset::LJAtomWalkers`: The sorted liveset.
214
"""
215
function sort_by_energy!(liveset::AbstractLiveSet)
280✔
216
    sort!(liveset.walkers, by = x -> x.energy, rev=true)
1,151✔
217
    # println("after sort ats[1].system_data.energy: ", ats[1].system_data.energy)
218
    return liveset
280✔
219
end
220

221
"""
222
    update_iter!(liveset::AtomWalkers)
223

224
Update the iteration count for each walker in the liveset.
225

226
# Arguments
227
- `liveset::AtomWalkers`: The set of walkers to update.
228

229
"""
230
function update_iter!(liveset::AbstractLiveSet)
258✔
231
    for at in liveset.walkers
259✔
232
        at.iter += 1
762✔
233
    end
762✔
234
end
235

236
"""
237
    estimate_temperature(n_walker::Int, n_cull::Int, ediff::Float64)
238
Estimate the temperature for the nested sampling algorithm from dlog(ω)/dE.
239
"""
240
function estimate_temperature(n_walkers::Int, n_cull::Int, ediff::Float64, iter::Int=1)
×
241
    ω = (n_cull / (n_walkers + n_cull)) * (n_walkers / (n_walkers + n_cull))^iter
×
242
    β = log(ω) / ediff
×
243
    kb = 8.617333262145e-5 # eV/K
×
244
    T = 1 / (kb * β) # in Kelvin
×
245
    return T
×
246
end
247

248

249
"""
250
    nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCRoutine)
251

252
Perform a single step of the nested sampling algorithm using the Monte Carlo random walk routine.
253

254
# Arguments
255
- `liveset::AtomWalkers`: The set of atom walkers.
256
- `ns_params::NestedSamplingParameters`: The parameters for nested sampling.
257
- `mc_routine::MCRoutine`: The Monte Carlo routine for generating new samples. See [`MCRoutine`](@ref).
258

259
# Returns
260
- `iter`: The iteration number after the step.
261
- `emax`: The highest energy recorded during the step.
262
- `liveset`: The updated set of atom walkers.
263
- `ns_params`: The updated nested sampling parameters.
264
"""
265
function nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCRoutine)
112✔
266
    sort_by_energy!(liveset)
112✔
267
    ats = liveset.walkers
112✔
268
    lj = liveset.potential
112✔
269
    iter::Union{Missing,Int} = missing
112✔
270
    emax::Union{Missing,typeof(0.0u"eV")} = liveset.walkers[1].energy
112✔
271
    if mc_routine isa MCRandomWalkMaxE
112✔
272
        to_walk = deepcopy(ats[1])
192✔
273
    elseif mc_routine isa MCRandomWalkClone
16✔
274
        to_walk = deepcopy(rand(ats[2:end]))
24✔
275
    else
276
        error("Unsupported MCRoutine type: $mc_routine")
4✔
277
    end
278
    if length(mc_routine.dims) == 3
108✔
279
        accept, rate, at = MC_random_walk!(ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax)
100✔
280
    elseif length(mc_routine.dims) == 2
8✔
281
        accept, rate, at = MC_random_walk_2D!(ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax; dims=mc_routine.dims)
4✔
282
        # @info "Doing a 2D random walk"
283
    elseif length(mc_routine.dims) == 1
4✔
284
        error("Unsupported dimensions: $(mc_routine.dims)")
4✔
285
    end
286
    # accept, rate, at = MC_random_walk!(ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax)
287
    # @info "iter: $(liveset.walkers[1].iter), acceptance rate: $(round(rate; sigdigits=4)), emax: $(round(typeof(1.0u"eV"), emax; sigdigits=10)), is_accepted: $accept, step_size: $(round(ns_params.step_size; sigdigits=4))"
288
    if accept
104✔
289
        push!(ats, at)
93✔
290
        popfirst!(ats)
93✔
291
        update_iter!(liveset)
93✔
292
        ns_params.fail_count = 0
93✔
293
        iter = liveset.walkers[1].iter
93✔
294
    else
295
        # @warn "Failed to accept MC move"
296
        emax = missing
11✔
297
        ns_params.fail_count += 1
11✔
298
    end
299
    adjust_step_size(ns_params, rate)
175✔
300
    return iter, emax, liveset, ns_params
104✔
301
end
302

303
"""
304
    nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCRoutineParallel)
305
Perform a single step of the nested sampling algorithm using the parallel Monte Carlo random walk routine.
306
# Arguments
307
- `liveset::AtomWalkers`: The set of atom walkers.
308
- `ns_params::NestedSamplingParameters`: The parameters for nested sampling.
309
- `mc_routine::MCRoutineParallel`: The parallel Monte Carlo routine for generating new samples. See [`MCRoutineParallel`](@ref).
310
# Returns
311
- `iter`: The iteration number after the step.
312
- `emax`: The highest energy recorded during the step.
313
- `liveset`: The updated set of atom walkers.
314
- `ns_params`: The updated nested sampling parameters.
315
"""
316
function nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCDistributed)
×
317
    sort_by_energy!(liveset)
×
318
    ats = liveset.walkers
×
319
    lj = liveset.potential
×
320
    iter::Union{Missing,Int} = missing
×
321
    emax::Union{Vector{Missing},Vector{typeof(0.0u"eV")}} = [liveset.walkers[i].energy for i in 1:nworkers()]
×
322

323
    to_walk_inds = sample(2:length(ats), nworkers(); replace=false)
×
324
    # println("to_walk_inds: ", to_walk_inds) # DEBUG
325
    
326
    to_walks = deepcopy.(ats[to_walk_inds])
×
327

328
    if length(mc_routine.dims) == 3
×
329
        random_walk_function = MC_random_walk!
×
330
    elseif length(mc_routine.dims) == 2
×
331
        random_walk_function = MC_random_walk_2D!
×
332
    else
333
        error("Unsupported dimensions: $(mc_routine.dims)")
×
334
    end
335

336
    mc_steps_per_worker = ceil(Int, ns_params.mc_steps / nworkers()) # distribute the total MC steps among workers
×
337

338
    walking = [remotecall(random_walk_function, workers()[i], mc_steps_per_worker, to_walk, lj, ns_params.step_size, emax[mc_routine.n_cull]) for (i,to_walk) in enumerate(to_walks)]
×
339
    walked = fetch.(walking)
×
340
    finalize.(walking) # finalize the remote calls, clear the memory
×
341

342
    accepted_rates = [x[2] for x in walked]
×
343
    rate = mean(accepted_rates)
×
344

345
    # sort!(walked, by = x -> x[3].energy, rev=true)
346
    # filter!(x -> x[1], walked) # remove the failed ones
347
    accepted_inds = findall(x -> x[1]==1, walked)
×
348

349
    if length(accepted_inds) < mc_routine.n_cull # if not enough accepted walkers
×
350
        ns_params.fail_count += 1
×
351
        emax = [missing]
×
352
        return iter, emax[end], liveset, ns_params
×
353
    else
354
        # pick one from the accepted ones
355
        picked = sample(accepted_inds, mc_routine.n_cull; replace=false)
×
356
        for (i, ind) in enumerate(picked)
×
357
            ats[i] = walked[ind][3]
×
358
        end
×
359
        # println("picked: ", picked) # DEBUG
360
        # remove the picked one from accepted_inds
361
        filter!(x -> x ∉ picked, accepted_inds)
×
362
        # println("remaining accepted_inds: ", accepted_inds) # DEBUG
363

364
        if !isempty(accepted_inds)
×
365
            for i in accepted_inds
×
366
                ats[to_walk_inds[i]] = walked[i][3]
×
367
                # println("Updating ats at index $(to_walk_inds[i])") # DEBUG
368
            end
×
369
        end
370
    end
371

372
    update_iter!(liveset)
×
373
    ns_params.fail_count = 0
×
374
    iter = liveset.walkers[1].iter
×
375

376
    adjust_step_size(ns_params, rate)
×
377
    return iter, emax[mc_routine.n_cull], liveset, ns_params
×
378
end
379

380
function nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCRoutineParallel)
28✔
381
    sort_by_energy!(liveset)
28✔
382
    ats = liveset.walkers
28✔
383
    lj = liveset.potential
28✔
384
    iter::Union{Missing,Int} = missing
28✔
385
    emax::Union{Vector{Missing},Vector{typeof(0.0u"eV")}} = [liveset.walkers[i].energy for i in 1:nworkers()]
56✔
386

387
    if mc_routine isa MCRandomWalkMaxEParallel
28✔
388
        to_walk_inds = 1:nworkers()
8✔
389
    elseif mc_routine isa MCRandomWalkCloneParallel
24✔
390
        to_walk_inds = sample(2:length(ats), nworkers(); replace=false)
48✔
391
    end
392
    
393
    to_walks = deepcopy.(ats[to_walk_inds])
28✔
394

395
    if length(mc_routine.dims) == 3
28✔
396
        random_walk_function = MC_random_walk!
28✔
397
    elseif length(mc_routine.dims) == 2
×
398
        random_walk_function = MC_random_walk_2D!
×
399
    else
400
        error("Unsupported dimensions: $(mc_routine.dims)")
×
401
    end
402

403

404
    walking = [remotecall(random_walk_function, workers()[i], ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax[end]) for (i,to_walk) in enumerate(to_walks)]
28✔
405
    walked = fetch.(walking)
28✔
406
    finalize.(walking) # finalize the remote calls, clear the memory
28✔
407

408
    accepted_rates = [x[2] for x in walked]
28✔
409
    rate = mean(accepted_rates)
28✔
410

411
    if prod([x[1] for x in walked]) == 0 # if any of the walkers failed
28✔
412
        ns_params.fail_count += 1
×
413
        emax = [missing]
×
414
        return iter, emax[end], liveset, ns_params
×
415
    end
416

417
    # sort!(walked, by = x -> x[3].energy, rev=true)
418
    # filter!(x -> x[1], walked) # remove the failed ones
419

420
    for (i, at) in enumerate(walked)
28✔
421
        ats[i] = at[3]
56✔
422
    end
56✔
423

424
    update_iter!(liveset)
28✔
425
    ns_params.fail_count = 0
28✔
426
    iter = liveset.walkers[1].iter
28✔
427

428
    adjust_step_size(ns_params, rate)
37✔
429
    return iter, emax[end], liveset, ns_params
28✔
430
end
431

432
function nested_sampling_step!(liveset::LJSurfaceWalkers, ns_params::NestedSamplingParameters, mc_routine::MCRoutineParallel)
8✔
433
    sort_by_energy!(liveset)
8✔
434
    ats = liveset.walkers
8✔
435
    lj = liveset.potential
8✔
436
    iter::Union{Missing,Int} = missing
8✔
437
    emax::Union{Vector{Missing},Vector{typeof(0.0u"eV")}} = [liveset.walkers[i].energy for i in 1:nworkers()]
16✔
438

439
    if mc_routine isa MCRandomWalkMaxEParallel
8✔
440
        to_walk_inds = 1:nworkers()
8✔
441
    elseif mc_routine isa MCRandomWalkCloneParallel
4✔
442
        to_walk_inds = sample(2:length(ats), nworkers(); replace=false)
8✔
443
    end
444
    
445
    to_walks = deepcopy.(ats[to_walk_inds])
8✔
446

447
    if length(mc_routine.dims) == 3
8✔
448
        random_walk_function = MC_random_walk!
8✔
449
    elseif length(mc_routine.dims) == 2
×
450
        random_walk_function = MC_random_walk_2D!
×
451
    else
452
        error("Unsupported dimensions: $(mc_routine.dims)")
×
453
    end
454

455

456
    walking = [remotecall(random_walk_function, workers()[i], ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax[end], liveset.surface) for (i,to_walk) in enumerate(to_walks)]
8✔
457
    walked = fetch.(walking)
8✔
458
    finalize.(walking) # finalize the remote calls, clear the memory
8✔
459

460
    accepted_rates = [x[2] for x in walked]
8✔
461
    rate = mean(accepted_rates)
8✔
462

463
    if prod([x[1] for x in walked]) == 0 # if any of the walkers failed
8✔
464
        ns_params.fail_count += 1
2✔
465
        emax = [missing]
2✔
466
        return iter, emax[end], liveset, ns_params
2✔
467
    end
468

469
    # sort!(walked, by = x -> x[3].energy, rev=true)
470
    # filter!(x -> x[1], walked) # remove the failed ones
471

472
    for (i, at) in enumerate(walked)
6✔
473
        ats[i] = at[3]
12✔
474
    end
12✔
475

476
    update_iter!(liveset)
6✔
477
    ns_params.fail_count = 0
6✔
478
    iter = liveset.walkers[1].iter
6✔
479

480
    adjust_step_size(ns_params, rate)
6✔
481
    return iter, emax[end], liveset, ns_params
6✔
482
end
483

484
function nested_sampling_step!(liveset::LJSurfaceWalkers, ns_params::NestedSamplingParameters, mc_routine::MCRoutine)
16✔
485
    sort_by_energy!(liveset)
16✔
486
    ats = liveset.walkers
16✔
487
    lj = liveset.potential
16✔
488
    iter::Union{Missing,Int} = missing
16✔
489
    emax::Union{Missing,typeof(0.0u"eV")} = liveset.walkers[1].energy
16✔
490
    if mc_routine isa MCRandomWalkMaxE
16✔
491
        to_walk = deepcopy(ats[1])
8✔
492
    elseif mc_routine isa MCRandomWalkClone
12✔
493
        to_walk = deepcopy(rand(ats[2:end]))
16✔
494
    else
495
        error("Unsupported MCRoutine type: $mc_routine")
4✔
496
    end
497
    if length(mc_routine.dims) == 3
12✔
498
        accept, rate, at = MC_random_walk!(ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax, liveset.surface)
8✔
499
    else
500
        error("Unsupported dimensions: $(mc_routine.dims)")
4✔
501
    end
502
    # accept, rate, at = MC_random_walk!(ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax)
503
    # @info "iter: $(liveset.walkers[1].iter), acceptance rate: $(round(rate; sigdigits=4)), emax: $(round(typeof(1.0u"eV"), emax; sigdigits=10)), is_accepted: $accept, step_size: $(round(ns_params.step_size; sigdigits=4))"
504
    if accept
8✔
505
        push!(ats, at)
8✔
506
        popfirst!(ats)
8✔
507
        update_iter!(liveset)
8✔
508
        ns_params.fail_count = 0
8✔
509
        iter = liveset.walkers[1].iter
8✔
510
    else
511
        # @warn "Failed to accept MC move"
512
        emax = missing
×
513
        ns_params.fail_count += 1
×
514
    end
515
    adjust_step_size(ns_params, rate)
8✔
516
    return iter, emax, liveset, ns_params
8✔
517
end
518

519
"""
520
    nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCMixedMoves)
521

522
Perform a single step of the nested sampling algorithm using the Monte Carlo mixed moves routine.
523
By default, this routine performs parallel decorrelation of multiple walkers.
524

525
Arguments
526
- `liveset::AtomWalkers`: The set of atom walkers.
527
- `ns_params::NestedSamplingParameters`: The parameters for nested sampling.
528
- `mc_routine::MCMixedMoves`: The Monte Carlo mixed moves routine.
529

530
Returns
531
- `iter`: The iteration number after the step.
532
- `emax`: The highest energy recorded during the step.
533
- `liveset`: The updated set of atom walkers.
534
- `ns_params`: The updated nested sampling parameters.
535

536
Note
537
- To invoke the parallel version of this routine, use `MCMixedMovesParallel` as the `mc_routine` argument.
538
"""
539
function nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCMixedMoves)
4✔
540
    sort_by_energy!(liveset)
4✔
541
    ats = liveset.walkers
4✔
542
    lj = liveset.potential
4✔
543
    iter::Union{Missing,Int} = missing
4✔
544
    emax::Union{Missing,typeof(0.0u"eV")} = liveset.walkers[1].energy
4✔
545
    to_walk = deepcopy(rand(ats[2:end]))
8✔
546

547
    accept, rate, at = MC_mixed_moves!(ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax, [mc_routine.walks_freq, mc_routine.swaps_freq])
8✔
548

549
    # accept, rate, at = MC_random_walk!(ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax)
550
    # @info "iter: $(liveset.walkers[1].iter), acceptance rate: $(round(rate; sigdigits=4)), emax: $(round(typeof(1.0u"eV"), emax; sigdigits=10)), is_accepted: $accept, step_size: $(round(ns_params.step_size; sigdigits=4))"
551
    if accept
4✔
552
        push!(ats, at)
4✔
553
        popfirst!(ats)
4✔
554
        update_iter!(liveset)
4✔
555
        ns_params.fail_count = 0
4✔
556
        iter = liveset.walkers[1].iter
4✔
557
    else
558
        # @warn "Failed to accept MC move"
NEW
559
        emax = missing
×
NEW
560
        ns_params.fail_count += 1
×
561
    end
562
    adjust_step_size(ns_params, rate)
7✔
563

564
    return iter, emax, liveset, ns_params
4✔
565
end
566

NEW
567
function nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCMixedMovesParallel)
×
UNCOV
568
    sort_by_energy!(liveset)
×
UNCOV
569
    ats = liveset.walkers
×
UNCOV
570
    lj = liveset.potential
×
UNCOV
571
    iter::Union{Missing,Int} = missing
×
UNCOV
572
    emax::Union{Vector{Missing},Vector{typeof(0.0u"eV")}} = [liveset.walkers[i].energy for i in 1:nworkers()]
×
573

UNCOV
574
    to_walk_inds = sample(2:length(ats), nworkers(); replace=false)
×
575
    # println("to_walk_inds: ", to_walk_inds) # DEBUG
576
    
UNCOV
577
    to_walks = deepcopy.(ats[to_walk_inds])
×
578

UNCOV
579
    walking = [remotecall(MC_mixed_moves!, workers()[i], ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax[1], [mc_routine.walks_freq, mc_routine.swaps_freq]) for (i,to_walk) in enumerate(to_walks)]
×
UNCOV
580
    walked = fetch.(walking)
×
UNCOV
581
    finalize.(walking) # finalize the remote calls, clear the memory
×
582

UNCOV
583
    accepted_rates = [x[2] for x in walked]
×
UNCOV
584
    rate = mean(accepted_rates)
×
585

586
    # sort!(walked, by = x -> x[3].energy, rev=true)
587
    # filter!(x -> x[1], walked) # remove the failed ones
UNCOV
588
    accepted_inds = findall(x -> x[1]==1, walked)
×
589

UNCOV
590
    if length(accepted_inds) == 0 # if all of the walkers failed
×
591
        ns_params.fail_count += 1
×
592
        emax = [missing]
×
593
        return iter, emax[end], liveset, ns_params
×
594
    else
595
        # pick one from the accepted ones
UNCOV
596
        picked = rand(accepted_inds)
×
UNCOV
597
        ats[1] = walked[picked][3]
×
598
        # println("picked: ", picked) # DEBUG
599
        # remove the picked one from accepted_inds
UNCOV
600
        filter!(x -> x != picked, accepted_inds)
×
601
        # println("remaining accepted_inds: ", accepted_inds) # DEBUG
602

UNCOV
603
        if !isempty(accepted_inds)
×
UNCOV
604
            for i in accepted_inds
×
UNCOV
605
                ats[to_walk_inds[i]] = walked[i][3]
×
606
                # println("Updating ats at index $(to_walk_inds[i])") # DEBUG
UNCOV
607
            end
×
608
        end
609
    end
610

UNCOV
611
    update_iter!(liveset)
×
UNCOV
612
    ns_params.fail_count = 0
×
UNCOV
613
    iter = liveset.walkers[1].iter
×
614

UNCOV
615
    adjust_step_size(ns_params, rate)
×
UNCOV
616
    return iter, emax[1], liveset, ns_params
×
617
end
618

619
"""
620
    nested_sampling_step!(liveset::LatticeGasWalkers, ns_params::LatticeNestedSamplingParameters, mc_routine::MCRoutine)
621

622
Perform a single step of the nested sampling algorithm.
623

624
This function takes a `liveset` of lattice gas walkers, `ns_params` containing the parameters for nested sampling, and `mc_routine` representing the Monte Carlo 
625
routine for generating new samples. It performs a single step of the nested sampling algorithm by updating the liveset with a new walker.
626

627
## Arguments
628
- `liveset::LatticeGasWalkers`: The liveset of lattice gas walkers.
629
- `ns_params::LatticeNestedSamplingParameters`: The parameters for nested sampling.
630
- `mc_routine::MCRoutine`: The Monte Carlo routine for generating new samples.
631

632
## Returns
633
- `iter`: The iteration number of the liveset after the step.
634
- `emax`: The maximum energy of the liveset after the step.
635
"""
636
function nested_sampling_step!(liveset::LatticeGasWalkers, 
96✔
637
                               ns_params::NestedSamplingParameters, 
638
                               mc_routine::MCRoutine)
639
    sort_by_energy!(liveset)
96✔
640
    ats = liveset.walkers
96✔
641
    h = liveset.hamiltonian
96✔
642
    iter::Union{Missing,Int} = missing
96✔
643
    emax::Union{Missing,Float64} = liveset.walkers[1].energy.val
96✔
644
    if mc_routine isa MCRandomWalkMaxE
96✔
645
        to_walk = deepcopy(ats[1])
184✔
646
    elseif mc_routine isa MCRandomWalkClone
4✔
647
        to_walk = deepcopy(rand(ats[2:end]))
8✔
648
    else
649
        error("Unsupported MCRoutine type: $mc_routine")
×
650
    end
651
    accept, rate, at = MC_random_walk!(ns_params.mc_steps, to_walk, h, emax; energy_perturb=ns_params.energy_perturbation)
96✔
652

653
    # @info "iter: $(liveset.walkers[1].iter), acceptance rate: $rate, emax: $emax, is_accepted: $accept"
654
    if accept
96✔
655
        push!(ats, at)
89✔
656
        popfirst!(ats)
89✔
657
        update_iter!(liveset)
89✔
658
        ns_params.fail_count = 0
89✔
659
        iter = liveset.walkers[1].iter
89✔
660
    else
661
        # @warn "Failed to accept MC move"
662
        emax = missing
7✔
663
        ns_params.fail_count += 1
7✔
664
    end
665
    # adjust_step_size(ns_params, rate)
666
    return iter, emax * unit(liveset.walkers[1].energy), liveset, ns_params
96✔
667
end
668

669
"""
670
    nested_sampling_step!(liveset::LatticeGasWalkers, ns_params::LatticeNestedSamplingParameters, mc_routine::MCNewSample)
671

672
Perform a single step of the nested sampling algorithm.
673

674
This function takes a `liveset` of lattice gas walkers, `ns_params` containing the parameters for nested sampling, and `mc_routine` representing the Monte Carlo routine for generating new samples. It performs a single step of the nested sampling algorithm by updating the liveset with a new walker.
675

676
## Arguments
677
- `liveset::LatticeGasWalkers`: The liveset of lattice gas walkers.
678
- `ns_params::LatticeNestedSamplingParameters`: The parameters for nested sampling.
679
- `mc_routine::MCNewSample`: The Monte Carlo routine for generating new samples.
680

681
## Returns
682
- `iter`: The iteration number of the liveset after the step.
683
- `emax`: The maximum energy of the liveset after the step.
684
- `liveset::LatticeGasWalkers`: The updated liveset after the step.
685
- `ns_params::LatticeNestedSamplingParameters`: The updated nested sampling parameters after the step.
686
"""
687
function nested_sampling_step!(liveset::LatticeGasWalkers, 
4✔
688
                               ns_params::NestedSamplingParameters, 
689
                               mc_routine::MCNewSample)
690
    sort_by_energy!(liveset)
4✔
691
    ats = liveset.walkers
4✔
692
    h = liveset.hamiltonian
4✔
693
    iter::Union{Missing,Int} = missing
4✔
694
    emax::Union{Missing,Float64} = liveset.walkers[1].energy.val
4✔
695

696
    to_walk = deepcopy(ats[1])
8✔
697

698
    accept, at = MC_new_sample!(to_walk, h, emax; energy_perturb=ns_params.energy_perturbation)
4✔
699

700
    # @info "iter: $(liveset.walkers[1].iter), emax: $emax, is_accepted: $accept"
701
    if accept
4✔
702
        push!(ats, at)
2✔
703
        popfirst!(ats)
2✔
704
        update_iter!(liveset)
2✔
705
        ns_params.fail_count = 0
2✔
706
        iter = liveset.walkers[1].iter
2✔
707
    else
708
        # @warn "Failed to accept MC move"
709
        emax = missing
2✔
710
        ns_params.fail_count += 1
2✔
711
    end
712
    # adjust_step_size(ns_params, rate)
713
    return iter, emax * unit(liveset.walkers[1].energy), liveset, ns_params
4✔
714
end
715

716

717
function nested_sampling_step!(liveset::LatticeGasWalkers, 
4✔
718
                               ns_params::NestedSamplingParameters, 
719
                               mc_routine::MCRejectionSampling)
720
    sort_by_energy!(liveset)
4✔
721
    ats = liveset.walkers
4✔
722
    h = liveset.hamiltonian
4✔
723
    iter::Union{Missing,Int} = missing
4✔
724
    emax::Union{Missing,Float64} = liveset.walkers[1].energy.val
4✔
725

726
    to_walk = deepcopy(ats[1])
8✔
727

728
    accept, at = MC_rejection_sampling!(to_walk, h, emax; energy_perturb=ns_params.energy_perturbation)
4✔
729

730
    # @info "iter: $(liveset.walkers[1].iter), emax: $emax, is_accepted: $accept"
731
    if accept
4✔
732
        push!(ats, at)
4✔
733
        popfirst!(ats)
4✔
734
        update_iter!(liveset)
4✔
735
        ns_params.fail_count = 0
4✔
736
        iter = liveset.walkers[1].iter
4✔
737
    else
738
        # @warn "Failed to accept MC move"
739
        emax = missing
×
740
        ns_params.fail_count += 1
×
741
    end
742
    # adjust_step_size(ns_params, rate)
743
    return iter, emax * unit(liveset.walkers[1].energy), liveset, ns_params
4✔
744
end
745

746

747

748
"""
749
    nested_sampling(liveset::AbstractLiveSet, ns_params::NestedSamplingParameters, n_steps::Int64, mc_routine::MCRoutine; args...)
750

751
Perform a nested sampling loop for a given number of steps.
752

753
# Arguments
754
- `liveset::AbstractLiveSet`: The initial set of walkers.
755
- `ns_params::NestedSamplingParameters`: The parameters for nested sampling.
756
- `n_steps::Int64`: The number of steps to perform.
757
- `mc_routine::MCRoutine`: The Monte Carlo routine to use.
758

759
# Returns
760
- `df`: A DataFrame containing the iteration number and maximum energy for each step.
761
- `liveset`: The updated set of walkers.
762
- `ns_params`: The updated nested sampling parameters.
763
"""
764
function nested_sampling(liveset::AbstractLiveSet, 
36✔
765
                                ns_params::NestedSamplingParameters, 
766
                                n_steps::Int64, 
767
                                mc_routine::MCRoutine,
768
                                save_strategy::DataSavingStrategy)
769
    df = DataFrame(iter=Int[], emax=Float64[])
36✔
770
    for i in 1:n_steps
36✔
771
        print_info = i % save_strategy.n_info == 0
200✔
772
        write_walker_every_n(liveset.walkers[1], i, save_strategy)
200✔
773
        iter, emax, liveset, ns_params = nested_sampling_step!(liveset, ns_params, mc_routine)
215✔
774
        @debug "n_step $i, iter: $iter, emax: $emax"
200✔
775
        if ns_params.fail_count >= ns_params.allowed_fail_count
200✔
776
            @warn "Failed to accept MC move $(ns_params.allowed_fail_count) times in a row. Reset step size!"
×
777
            ns_params.fail_count = 0
×
778
            ns_params.step_size = ns_params.initial_step_size
×
779
        end
780
        if !(iter isa typeof(missing))
200✔
781
            push!(df, (iter, emax.val))
231✔
782
        end
783
        print_message(i, iter, emax, ns_params.step_size, print_info, liveset)
383✔
784
        write_df_every_n(df, i, save_strategy)
200✔
785
        write_ls_every_n(liveset, i, save_strategy)
200✔
786
    end
364✔
787
    return df, liveset, ns_params
36✔
788
end
789

790
function print_message(i, iter, emax, step_size, print_info, liveset::LatticeWalkers)
88✔
791
    if print_info && !(iter isa typeof(missing))
88✔
792
        @info "iter: $(liveset.walkers[1].iter), emax: $(emax)"
38✔
793
    elseif print_info && iter isa typeof(missing)
50✔
794
        @info "MC move failed, step: $(i), emax: $(liveset.walkers[1].energy)"
2✔
795
    end
796
end
797

798
function print_message(i, iter, emax, step_size, print_info, liveset::AtomWalkers)
112✔
799
    if print_info && !(iter isa typeof(missing))
112✔
800
        @info "iter: $(liveset.walkers[1].iter), emax: $(emax.val), step_size: $(round(step_size; sigdigits=4))"
48✔
801
    elseif print_info && iter isa typeof(missing)
64✔
802
        @info "MC move failed, step: $(i), emax: $(liveset.walkers[1].energy.val), step_size: $(round(step_size; sigdigits=4))"
4✔
803
    end
804
end
805
    
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc