• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

wexlergroup / FreeBird.jl / 19114191965

05 Nov 2025 07:41PM UTC coverage: 88.861% (-2.2%) from 91.098%
19114191965

Pull #120

github

yangmr04
tag a new version 0.2.2
Pull Request #120: Dev

62 of 107 new or added lines in 2 files covered. (57.94%)

9 existing lines in 1 file now uncovered.

1747 of 1966 relevant lines covered (88.86%)

65535.06 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

75.5
/src/SamplingSchemes/nested_sampling.jl
1
"""
2
    mutable struct NestedSamplingParameters <: SamplingParameters
3

4
The `NestedSamplingParameters` struct represents the parameters used in the nested sampling scheme.
5

6
# Fields
7
- `mc_steps::Int64`: The number of total Monte Carlo moves to perform. For a parallel MC routine, this number will be distributed among workers. 
8
If `mc_steps` is not divisible by the number of workers, the actual number of MC moves per worker will be `ceil(mc_steps / nworkers())`.
9
- `initial_step_size::Float64`: The initial step size, which is the fallback step size if MC routine fails to accept a move.
10
- `step_size::Float64`: The on-the-fly step size used in the sampling process.
11
- `step_size_lo::Float64`: The lower bound of the step size.
12
- `step_size_up::Float64`: The upper bound of the step size.
13
- `accept_range::Tuple{Float64, Float64}`: The range of acceptance rates for adjusting the step size.
14
e.g. (0.25, 0.75) means that the step size will decrease if the acceptance rate is below 0.25 and increase if it is above 0.75.
15
- `fail_count::Int64`: The number of failed MC moves in a row.
16
- `allowed_fail_count::Int64`: The maximum number of failed MC moves allowed before resetting the step size.
17
- `energy_perturbation::Float64`: The perturbation value used to adjust the energy of the walkers.
18
- `random_seed::Int64`: The seed for the random number generator.
19
"""
20
mutable struct NestedSamplingParameters <: SamplingParameters
21
    mc_steps::Int64
36✔
22
    initial_step_size::Float64
23
    step_size::Float64
24
    step_size_lo::Float64
25
    step_size_up::Float64
26
    accept_range::Tuple{Float64, Float64}
27
    fail_count::Int64
28
    allowed_fail_count::Int64
29
    energy_perturbation::Float64
30
    random_seed::Int64
31
end
32

33
function NestedSamplingParameters(;
25✔
34
            mc_steps::Int64=200,
35
            initial_step_size::Float64=0.01,
36
            step_size::Float64=0.1,
37
            step_size_lo::Float64=1e-6,
38
            step_size_up::Float64=1.0,
39
            accept_range::Tuple{Float64, Float64}=(0.25, 0.75),
40
            fail_count::Int64=0,
41
            allowed_fail_count::Int64=100,
42
            energy_perturbation::Float64=1e-12,
43
            random_seed::Int64=1234,
44
            )
45
    NestedSamplingParameters(mc_steps, initial_step_size, step_size, step_size_lo, step_size_up, accept_range, fail_count, allowed_fail_count, energy_perturbation, random_seed)  
20✔
46
end
47

48
"""
49
    LatticeNestedSamplingParameters(;
50
            mc_steps::Int64=100,
51
            energy_perturbation::Float64=1e-12,
52
            fail_count::Int64=0,
53
            allowed_fail_count::Int64=10,
54
            random_seed::Int64=1234,
55
            )
56
A convenience constructor for `NestedSamplingParameters` with default values suitable for lattice systems.
57
"""
58
function LatticeNestedSamplingParameters(;
20✔
59
            mc_steps::Int64=100,
60
            energy_perturbation::Float64=1e-12,
61
            fail_count::Int64=0,
62
            allowed_fail_count::Int64=10,
63
            random_seed::Int64=1234,
64
            )
65
    NestedSamplingParameters(mc_steps=mc_steps, fail_count=fail_count, allowed_fail_count=allowed_fail_count, energy_perturbation=energy_perturbation, random_seed=random_seed)
16✔
66
end
67

68

69
"""
70
    abstract type MCRoutine
71

72
An abstract type representing a Monte Carlo routine.
73

74
Currently, the following concrete types are supported:
75
- `MCRandomWalkMaxE`: A type for generating a new walker by performing a random walk for decorrelation on the
76
highest-energy walker.
77
- `MCRandomWalkClone`: A type for generating a new walker by cloning an existing walker and performing a random walk
78
for decorrelation.
79
- `MCNewSample`: A type for generating a new walker from a random configuration. Currently, it is intended to use 
80
this routine for lattice gas systems.
81
- `MCMixedMoves`: A type for generating a new walker by performing random walks and swapping atoms. Currently, it is
82
intended to use this routine for multi-component systems. The actual number of random walks and swaps to perform is
83
determined by the weights of the fields `walks_freq` and `swaps_freq`. See [`MCMixedMoves`](@ref).
84
- `MCRejectionSampling`: A type for generating a new walker by performing rejection sampling. Currently, it is intended
85
to use this routine for lattice gas systems.
86
- `MCDistributed`: A type for generating new walkers by performing random walks for decorrelation in parallel using Distributed.jl.
87
This routine supports multiple culling walkers and multiple decorrelation walkers. See [`MCDistributed`](@ref).
88
"""
89
abstract type MCRoutine end
90

91
"""
92
    abstract type MCRoutineParallel <: MCRoutine
93
(Internal) An abstract type representing a parallel Monte Carlo routine.
94
"""
95
abstract type MCRoutineParallel <: MCRoutine end
96

97
"""
98
    struct MCRandomWalkMaxE <: MCRoutine
99
A type for generating a new walker by performing a random walk for decorrelation on the highest-energy walker.
100
"""
101
struct MCRandomWalkMaxE <: MCRoutine 
102
    dims::Vector{Int64}
103
    function MCRandomWalkMaxE(dims::Vector{Int64}=[1, 2, 3])
60✔
104
        new(dims)
96✔
105
    end
106
end
107

108
"""
109
    struct MCRandomWalkClone <: MCRoutine
110
A type for generating a new walker by cloning an existing walker and performing a random walk for decorrelation.
111
"""
112
struct MCRandomWalkClone <: MCRoutine 
113
    dims::Vector{Int64}
114
    function MCRandomWalkClone(;dims::Vector{Int64}=[1, 2, 3])
45✔
115
        new(dims)
36✔
116
    end
117
end
118

119
"""
120
    struct MCRandomWalkCloneParallel <: MCRoutineParallel
121
A type for generating a new walker by cloning an existing walker and performing a random walk for decorrelation in parallel.
122
"""
123
struct MCRandomWalkCloneParallel <: MCRoutineParallel
124
    dims::Vector{Int64}
125
    function MCRandomWalkCloneParallel(;dims::Vector{Int64}=[1, 2, 3])
15✔
126
        new(dims)
12✔
127
    end
128
end
129

130
"""
131
    struct MCDistributed <: MCRoutineParallel
132
A type for generating new walkers by performing random walks for decorrelation in parallel using Distributed.jl.
133
# Fields
134
- `n_cull::Int64`: The number of lowest-energy walkers to cull (replace) in each iteration. The default is 1.
135
- `n_decorr::Int64`: The number of walkers to use for decorrelation (random walks). The default is `nworkers() - 1`.
136
- `dims::Vector{Int64}`: The dimensions along which to perform the random walks.
137
"""
138
struct MCDistributed <: MCRoutineParallel
139
    n_cull::Int64
140
    n_decorr::Int64
141
    dims::Vector{Int64}
NEW
142
    function MCDistributed(;n_cull::Int64=1, n_decorr::Int64=nworkers()-1, dims::Vector{Int64}=[1, 2, 3])
×
NEW
143
        if n_cull + n_decorr != nworkers()
×
NEW
144
            error("n_cull + n_decorr must be equal to the number of workers: $(nworkers())")
×
145
        end
NEW
146
        @info "Distributed nested sampling initiated: n_cull: $n_cull, n_decorr: $n_decorr, total workers: $(n_cull + n_decorr)"
×
NEW
147
        new(n_cull, n_decorr, dims)
×
148
    end
149
end
150

151
"""
152
    MCRandomWalkMaxEParallel <: MCRoutineParallel
153
A type for generating a new walker by performing a random walk for decorrelation on the highest-energy walker(s) in parallel.
154
"""
155
struct MCRandomWalkMaxEParallel <: MCRoutineParallel
156
    dims::Vector{Int64}
157
    function MCRandomWalkMaxEParallel(;dims::Vector{Int64}=[1, 2, 3])
15✔
158
        new(dims)
12✔
159
    end
160
end
161

162
"""
163
    struct MCNewSample <: MCRoutine
164
A type for generating a new walker from a random configuration. Currently, it is intended to use this routine for lattice gas systems.
165
"""
166
struct MCNewSample <: MCRoutine end
16✔
167

168
""" 
169
    struct MCMixedMoves <: MCRoutine
170
A type for generating a new walker by performing random walks and swapping atoms. Currently, it is intended to use this routine for
171
multi-component systems. The actual number of random walks and swaps to perform is determined by the weights of the fields `walks_freq` and `swaps_freq`.
172
For example, if `walks_freq=4` and `swaps_freq=1`, then the probability of performing a random walk is 4/5, and the probability of performing a swap is 1/5.
173

174
# Fields
175
- `walks_freq::Int`: The frequency of random walks to perform.
176
- `swaps_freq::Int`: The frequency of atom swaps to perform.
177
"""
178
mutable struct MCMixedMoves <: MCRoutine
179
    walks_freq::Int
12✔
180
    swaps_freq::Int
181
end
182

183
"""
184
    struct MCRejectionSampling <: MCRoutine
185
A type for generating a new walker by performing rejection sampling. Currently, it is intended to use this routine for lattice gas systems.
186
"""
187
struct MCRejectionSampling <: MCRoutine end
4✔
188

189
"""
190
    sort_by_energy!(liveset::LJAtomWalkers)
191

192
Sorts the walkers in the liveset by their energy in descending order.
193

194
# Arguments
195
- `liveset::LJAtomWalkers`: The liveset of walkers to be sorted.
196

197
# Returns
198
- `liveset::LJAtomWalkers`: The sorted liveset.
199
"""
200
function sort_by_energy!(liveset::AbstractLiveSet)
280✔
201
    sort!(liveset.walkers, by = x -> x.energy, rev=true)
1,145✔
202
    # println("after sort ats[1].system_data.energy: ", ats[1].system_data.energy)
203
    return liveset
280✔
204
end
205

206
"""
207
    update_iter!(liveset::AtomWalkers)
208

209
Update the iteration count for each walker in the liveset.
210

211
# Arguments
212
- `liveset::AtomWalkers`: The set of walkers to update.
213

214
"""
215
function update_iter!(liveset::AbstractLiveSet)
266✔
216
    for at in liveset.walkers
267✔
217
        at.iter += 1
786✔
218
    end
786✔
219
end
220

221
"""
222
    estimate_temperature(n_walker::Int, n_cull::Int, ediff::Float64)
223
Estimate the temperature for the nested sampling algorithm from dlog(ω)/dE.
224
"""
225
function estimate_temperature(n_walkers::Int, n_cull::Int, ediff::Float64, iter::Int=1)
×
226
    ω = (n_cull / (n_walkers + n_cull)) * (n_walkers / (n_walkers + n_cull))^iter
×
227
    β = log(ω) / ediff
×
228
    kb = 8.617333262145e-5 # eV/K
×
229
    T = 1 / (kb * β) # in Kelvin
×
230
    return T
×
231
end
232

233

234
"""
235
    nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCRoutine)
236

237
Perform a single step of the nested sampling algorithm using the Monte Carlo random walk routine.
238

239
# Arguments
240
- `liveset::AtomWalkers`: The set of atom walkers.
241
- `ns_params::NestedSamplingParameters`: The parameters for nested sampling.
242
- `mc_routine::MCRoutine`: The Monte Carlo routine for generating new samples. See [`MCRoutine`](@ref).
243

244
# Returns
245
- `iter`: The iteration number after the step.
246
- `emax`: The highest energy recorded during the step.
247
- `liveset`: The updated set of atom walkers.
248
- `ns_params`: The updated nested sampling parameters.
249
"""
250
function nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCRoutine)
112✔
251
    sort_by_energy!(liveset)
112✔
252
    ats = liveset.walkers
112✔
253
    lj = liveset.potential
112✔
254
    iter::Union{Missing,Int} = missing
112✔
255
    emax::Union{Missing,typeof(0.0u"eV")} = liveset.walkers[1].energy
112✔
256
    if mc_routine isa MCRandomWalkMaxE
112✔
257
        to_walk = deepcopy(ats[1])
192✔
258
    elseif mc_routine isa MCRandomWalkClone
16✔
259
        to_walk = deepcopy(rand(ats[2:end]))
24✔
260
    else
261
        error("Unsupported MCRoutine type: $mc_routine")
4✔
262
    end
263
    if length(mc_routine.dims) == 3
108✔
264
        accept, rate, at = MC_random_walk!(ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax)
100✔
265
    elseif length(mc_routine.dims) == 2
8✔
266
        accept, rate, at = MC_random_walk_2D!(ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax; dims=mc_routine.dims)
4✔
267
        # @info "Doing a 2D random walk"
268
    elseif length(mc_routine.dims) == 1
4✔
269
        error("Unsupported dimensions: $(mc_routine.dims)")
4✔
270
    end
271
    # accept, rate, at = MC_random_walk!(ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax)
272
    # @info "iter: $(liveset.walkers[1].iter), acceptance rate: $(round(rate; sigdigits=4)), emax: $(round(typeof(1.0u"eV"), emax; sigdigits=10)), is_accepted: $accept, step_size: $(round(ns_params.step_size; sigdigits=4))"
273
    if accept
104✔
274
        push!(ats, at)
98✔
275
        popfirst!(ats)
98✔
276
        update_iter!(liveset)
98✔
277
        ns_params.fail_count = 0
98✔
278
        iter = liveset.walkers[1].iter
98✔
279
    else
280
        # @warn "Failed to accept MC move"
281
        emax = missing
6✔
282
        ns_params.fail_count += 1
6✔
283
    end
284
    adjust_step_size(ns_params, rate)
178✔
285
    return iter, emax, liveset, ns_params
104✔
286
end
287

288
"""
289
    nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCRoutineParallel)
290
Perform a single step of the nested sampling algorithm using the parallel Monte Carlo random walk routine.
291
# Arguments
292
- `liveset::AtomWalkers`: The set of atom walkers.
293
- `ns_params::NestedSamplingParameters`: The parameters for nested sampling.
294
- `mc_routine::MCRoutineParallel`: The parallel Monte Carlo routine for generating new samples. See [`MCRoutineParallel`](@ref).
295
# Returns
296
- `iter`: The iteration number after the step.
297
- `emax`: The highest energy recorded during the step.
298
- `liveset`: The updated set of atom walkers.
299
- `ns_params`: The updated nested sampling parameters.
300
"""
NEW
301
function nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCDistributed)
×
NEW
302
    sort_by_energy!(liveset)
×
NEW
303
    ats = liveset.walkers
×
NEW
304
    lj = liveset.potential
×
NEW
305
    iter::Union{Missing,Int} = missing
×
NEW
306
    emax::Union{Vector{Missing},Vector{typeof(0.0u"eV")}} = [liveset.walkers[i].energy for i in 1:nworkers()]
×
307

NEW
308
    to_walk_inds = sample(2:length(ats), nworkers(); replace=false)
×
309
    # println("to_walk_inds: ", to_walk_inds) # DEBUG
310
    
NEW
311
    to_walks = deepcopy.(ats[to_walk_inds])
×
312

NEW
313
    if length(mc_routine.dims) == 3
×
NEW
314
        random_walk_function = MC_random_walk!
×
NEW
315
    elseif length(mc_routine.dims) == 2
×
NEW
316
        random_walk_function = MC_random_walk_2D!
×
317
    else
NEW
318
        error("Unsupported dimensions: $(mc_routine.dims)")
×
319
    end
320

NEW
321
    mc_steps_per_worker = ceil(Int, ns_params.mc_steps / nworkers()) # distribute the total MC steps among workers
×
322

NEW
323
    walking = [remotecall(random_walk_function, workers()[i], mc_steps_per_worker, to_walk, lj, ns_params.step_size, emax[mc_routine.n_cull]) for (i,to_walk) in enumerate(to_walks)]
×
NEW
324
    walked = fetch.(walking)
×
NEW
325
    finalize.(walking) # finalize the remote calls, clear the memory
×
326

NEW
327
    accepted_rates = [x[2] for x in walked]
×
NEW
328
    rate = mean(accepted_rates)
×
329

330
    # sort!(walked, by = x -> x[3].energy, rev=true)
331
    # filter!(x -> x[1], walked) # remove the failed ones
NEW
332
    accepted_inds = findall(x -> x[1]==1, walked)
×
333

NEW
334
    if length(accepted_inds) < mc_routine.n_cull # if not enough accepted walkers
×
NEW
335
        ns_params.fail_count += 1
×
NEW
336
        emax = [missing]
×
NEW
337
        return iter, emax[end], liveset, ns_params
×
338
    else
339
        # pick one from the accepted ones
NEW
340
        picked = sample(accepted_inds, mc_routine.n_cull; replace=false)
×
NEW
341
        for (i, ind) in enumerate(picked)
×
NEW
342
            ats[i] = walked[ind][3]
×
NEW
343
        end
×
344
        # println("picked: ", picked) # DEBUG
345
        # remove the picked one from accepted_inds
NEW
346
        filter!(x -> x ∉ picked, accepted_inds)
×
347
        # println("remaining accepted_inds: ", accepted_inds) # DEBUG
348

NEW
349
        if !isempty(accepted_inds)
×
NEW
350
            for i in accepted_inds
×
NEW
351
                ats[to_walk_inds[i]] = walked[i][3]
×
352
                # println("Updating ats at index $(to_walk_inds[i])") # DEBUG
NEW
353
            end
×
354
        end
355
    end
356

NEW
357
    update_iter!(liveset)
×
NEW
358
    ns_params.fail_count = 0
×
NEW
359
    iter = liveset.walkers[1].iter
×
360

NEW
361
    adjust_step_size(ns_params, rate)
×
NEW
362
    return iter, emax[mc_routine.n_cull], liveset, ns_params
×
363
end
364

365
function nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCRoutineParallel)
28✔
366
    sort_by_energy!(liveset)
28✔
367
    ats = liveset.walkers
28✔
368
    lj = liveset.potential
28✔
369
    iter::Union{Missing,Int} = missing
28✔
370
    emax::Union{Vector{Missing},Vector{typeof(0.0u"eV")}} = [liveset.walkers[i].energy for i in 1:nworkers()]
56✔
371

372
    if mc_routine isa MCRandomWalkMaxEParallel
28✔
373
        to_walk_inds = 1:nworkers()
8✔
374
    elseif mc_routine isa MCRandomWalkCloneParallel
24✔
375
        to_walk_inds = sample(2:length(ats), nworkers(); replace=false)
48✔
376
    end
377
    
378
    to_walks = deepcopy.(ats[to_walk_inds])
28✔
379

380
    if length(mc_routine.dims) == 3
28✔
381
        random_walk_function = MC_random_walk!
28✔
382
    elseif length(mc_routine.dims) == 2
×
383
        random_walk_function = MC_random_walk_2D!
×
384
    else
385
        error("Unsupported dimensions: $(mc_routine.dims)")
×
386
    end
387

388

389
    walking = [remotecall(random_walk_function, workers()[i], ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax[end]) for (i,to_walk) in enumerate(to_walks)]
28✔
390
    walked = fetch.(walking)
28✔
391
    finalize.(walking) # finalize the remote calls, clear the memory
28✔
392

393
    accepted_rates = [x[2] for x in walked]
28✔
394
    rate = mean(accepted_rates)
28✔
395

396
    if prod([x[1] for x in walked]) == 0 # if any of the walkers failed
28✔
397
        ns_params.fail_count += 1
×
398
        emax = [missing]
×
399
        return iter, emax[end], liveset, ns_params
×
400
    end
401

402
    # sort!(walked, by = x -> x[3].energy, rev=true)
403
    # filter!(x -> x[1], walked) # remove the failed ones
404

405
    for (i, at) in enumerate(walked)
28✔
406
        ats[i] = at[3]
56✔
407
    end
56✔
408

409
    update_iter!(liveset)
28✔
410
    ns_params.fail_count = 0
28✔
411
    iter = liveset.walkers[1].iter
28✔
412

413
    adjust_step_size(ns_params, rate)
37✔
414
    return iter, emax[end], liveset, ns_params
28✔
415
end
416

417
function nested_sampling_step!(liveset::LJSurfaceWalkers, ns_params::NestedSamplingParameters, mc_routine::MCRoutineParallel)
8✔
418
    sort_by_energy!(liveset)
8✔
419
    ats = liveset.walkers
8✔
420
    lj = liveset.potential
8✔
421
    iter::Union{Missing,Int} = missing
8✔
422
    emax::Union{Vector{Missing},Vector{typeof(0.0u"eV")}} = [liveset.walkers[i].energy for i in 1:nworkers()]
16✔
423

424
    if mc_routine isa MCRandomWalkMaxEParallel
8✔
425
        to_walk_inds = 1:nworkers()
8✔
426
    elseif mc_routine isa MCRandomWalkCloneParallel
4✔
427
        to_walk_inds = sample(2:length(ats), nworkers(); replace=false)
8✔
428
    end
429
    
430
    to_walks = deepcopy.(ats[to_walk_inds])
8✔
431

432
    if length(mc_routine.dims) == 3
8✔
433
        random_walk_function = MC_random_walk!
8✔
434
    elseif length(mc_routine.dims) == 2
×
435
        random_walk_function = MC_random_walk_2D!
×
436
    else
437
        error("Unsupported dimensions: $(mc_routine.dims)")
×
438
    end
439

440

441
    walking = [remotecall(random_walk_function, workers()[i], ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax[end], liveset.surface) for (i,to_walk) in enumerate(to_walks)]
8✔
442
    walked = fetch.(walking)
8✔
443
    finalize.(walking) # finalize the remote calls, clear the memory
8✔
444

445
    accepted_rates = [x[2] for x in walked]
8✔
446
    rate = mean(accepted_rates)
8✔
447

448
    if prod([x[1] for x in walked]) == 0 # if any of the walkers failed
8✔
449
        ns_params.fail_count += 1
2✔
450
        emax = [missing]
2✔
451
        return iter, emax[end], liveset, ns_params
2✔
452
    end
453

454
    # sort!(walked, by = x -> x[3].energy, rev=true)
455
    # filter!(x -> x[1], walked) # remove the failed ones
456

457
    for (i, at) in enumerate(walked)
6✔
458
        ats[i] = at[3]
12✔
459
    end
12✔
460

461
    update_iter!(liveset)
6✔
462
    ns_params.fail_count = 0
6✔
463
    iter = liveset.walkers[1].iter
6✔
464

465
    adjust_step_size(ns_params, rate)
6✔
466
    return iter, emax[end], liveset, ns_params
6✔
467
end
468

469
function nested_sampling_step!(liveset::LJSurfaceWalkers, ns_params::NestedSamplingParameters, mc_routine::MCRoutine)
16✔
470
    sort_by_energy!(liveset)
16✔
471
    ats = liveset.walkers
16✔
472
    lj = liveset.potential
16✔
473
    iter::Union{Missing,Int} = missing
16✔
474
    emax::Union{Missing,typeof(0.0u"eV")} = liveset.walkers[1].energy
16✔
475
    if mc_routine isa MCRandomWalkMaxE
16✔
476
        to_walk = deepcopy(ats[1])
8✔
477
    elseif mc_routine isa MCRandomWalkClone
12✔
478
        to_walk = deepcopy(rand(ats[2:end]))
16✔
479
    else
480
        error("Unsupported MCRoutine type: $mc_routine")
4✔
481
    end
482
    if length(mc_routine.dims) == 3
12✔
483
        accept, rate, at = MC_random_walk!(ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax, liveset.surface)
8✔
484
    else
485
        error("Unsupported dimensions: $(mc_routine.dims)")
4✔
486
    end
487
    # accept, rate, at = MC_random_walk!(ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax)
488
    # @info "iter: $(liveset.walkers[1].iter), acceptance rate: $(round(rate; sigdigits=4)), emax: $(round(typeof(1.0u"eV"), emax; sigdigits=10)), is_accepted: $accept, step_size: $(round(ns_params.step_size; sigdigits=4))"
489
    if accept
8✔
490
        push!(ats, at)
8✔
491
        popfirst!(ats)
8✔
492
        update_iter!(liveset)
8✔
493
        ns_params.fail_count = 0
8✔
494
        iter = liveset.walkers[1].iter
8✔
495
    else
496
        # @warn "Failed to accept MC move"
497
        emax = missing
×
498
        ns_params.fail_count += 1
×
499
    end
500
    adjust_step_size(ns_params, rate)
8✔
501
    return iter, emax, liveset, ns_params
8✔
502
end
503

504
"""
505
    nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCMixedMoves)
506

507
Perform a single step of the nested sampling algorithm using the Monte Carlo mixed moves routine.
508
By default, this routine performs parallel decorrelation of multiple walkers.
509

510
Arguments
511
- `liveset::AtomWalkers`: The set of atom walkers.
512
- `ns_params::NestedSamplingParameters`: The parameters for nested sampling.
513
- `mc_routine::MCMixedMoves`: The Monte Carlo mixed moves routine.
514

515
Returns
516
- `iter`: The iteration number after the step.
517
- `emax`: The highest energy recorded during the step.
518
- `liveset`: The updated set of atom walkers.
519
- `ns_params`: The updated nested sampling parameters.
520
"""
521
function nested_sampling_step!(liveset::AtomWalkers, ns_params::NestedSamplingParameters, mc_routine::MCMixedMoves)
4✔
522
    sort_by_energy!(liveset)
4✔
523
    ats = liveset.walkers
4✔
524
    lj = liveset.potential
4✔
525
    iter::Union{Missing,Int} = missing
4✔
526
    emax::Union{Vector{Missing},Vector{typeof(0.0u"eV")}} = [liveset.walkers[i].energy for i in 1:nworkers()]
8✔
527

528
    to_walk_inds = sample(2:length(ats), nworkers(); replace=false)
8✔
529
    # println("to_walk_inds: ", to_walk_inds) # DEBUG
530
    
531
    to_walks = deepcopy.(ats[to_walk_inds])
4✔
532

533
    walking = [remotecall(MC_mixed_moves!, workers()[i], ns_params.mc_steps, to_walk, lj, ns_params.step_size, emax[1], [mc_routine.walks_freq, mc_routine.swaps_freq]) for (i,to_walk) in enumerate(to_walks)]
4✔
534
    walked = fetch.(walking)
4✔
535
    finalize.(walking) # finalize the remote calls, clear the memory
4✔
536

537
    accepted_rates = [x[2] for x in walked]
4✔
538
    rate = mean(accepted_rates)
4✔
539

540
    # sort!(walked, by = x -> x[3].energy, rev=true)
541
    # filter!(x -> x[1], walked) # remove the failed ones
542
    accepted_inds = findall(x -> x[1]==1, walked)
12✔
543

544
    if length(accepted_inds) == 0 # if all of the walkers failed
4✔
UNCOV
545
        ns_params.fail_count += 1
×
NEW
546
        emax = [missing]
×
NEW
547
        return iter, emax[end], liveset, ns_params
×
548
    else
549
        # pick one from the accepted ones
550
        picked = rand(accepted_inds)
4✔
551
        ats[1] = walked[picked][3]
4✔
552
        # println("picked: ", picked) # DEBUG
553
        # remove the picked one from accepted_inds
554
        filter!(x -> x != picked, accepted_inds)
15✔
555
        # println("remaining accepted_inds: ", accepted_inds) # DEBUG
556

557
        if !isempty(accepted_inds)
4✔
558
            for i in accepted_inds
4✔
559
                ats[to_walk_inds[i]] = walked[i][3]
4✔
560
                # println("Updating ats at index $(to_walk_inds[i])") # DEBUG
561
            end
4✔
562
        end
563
    end
564

565
    update_iter!(liveset)
4✔
566
    ns_params.fail_count = 0
4✔
567
    iter = liveset.walkers[1].iter
4✔
568

569
    adjust_step_size(ns_params, rate)
7✔
570
    return iter, emax[1], liveset, ns_params
4✔
571
end
572

573
"""
574
    nested_sampling_step!(liveset::LatticeGasWalkers, ns_params::LatticeNestedSamplingParameters, mc_routine::MCRoutine)
575

576
Perform a single step of the nested sampling algorithm.
577

578
This function takes a `liveset` of lattice gas walkers, `ns_params` containing the parameters for nested sampling, and `mc_routine` representing the Monte Carlo 
579
routine for generating new samples. It performs a single step of the nested sampling algorithm by updating the liveset with a new walker.
580

581
## Arguments
582
- `liveset::LatticeGasWalkers`: The liveset of lattice gas walkers.
583
- `ns_params::LatticeNestedSamplingParameters`: The parameters for nested sampling.
584
- `mc_routine::MCRoutine`: The Monte Carlo routine for generating new samples.
585

586
## Returns
587
- `iter`: The iteration number of the liveset after the step.
588
- `emax`: The maximum energy of the liveset after the step.
589
"""
590
function nested_sampling_step!(liveset::LatticeGasWalkers, 
96✔
591
                               ns_params::NestedSamplingParameters, 
592
                               mc_routine::MCRoutine)
593
    sort_by_energy!(liveset)
96✔
594
    ats = liveset.walkers
96✔
595
    h = liveset.hamiltonian
96✔
596
    iter::Union{Missing,Int} = missing
96✔
597
    emax::Union{Missing,Float64} = liveset.walkers[1].energy.val
96✔
598
    if mc_routine isa MCRandomWalkMaxE
96✔
599
        to_walk = deepcopy(ats[1])
184✔
600
    elseif mc_routine isa MCRandomWalkClone
4✔
601
        to_walk = deepcopy(rand(ats[2:end]))
8✔
602
    else
603
        error("Unsupported MCRoutine type: $mc_routine")
×
604
    end
605
    accept, rate, at = MC_random_walk!(ns_params.mc_steps, to_walk, h, emax; energy_perturb=ns_params.energy_perturbation)
96✔
606

607
    # @info "iter: $(liveset.walkers[1].iter), acceptance rate: $rate, emax: $emax, is_accepted: $accept"
608
    if accept
96✔
609
        push!(ats, at)
94✔
610
        popfirst!(ats)
94✔
611
        update_iter!(liveset)
94✔
612
        ns_params.fail_count = 0
94✔
613
        iter = liveset.walkers[1].iter
94✔
614
    else
615
        # @warn "Failed to accept MC move"
616
        emax = missing
2✔
617
        ns_params.fail_count += 1
2✔
618
    end
619
    # adjust_step_size(ns_params, rate)
620
    return iter, emax * unit(liveset.walkers[1].energy), liveset, ns_params
96✔
621
end
622

623
"""
624
    nested_sampling_step!(liveset::LatticeGasWalkers, ns_params::LatticeNestedSamplingParameters, mc_routine::MCNewSample)
625

626
Perform a single step of the nested sampling algorithm.
627

628
This function takes a `liveset` of lattice gas walkers, `ns_params` containing the parameters for nested sampling, and `mc_routine` representing the Monte Carlo routine for generating new samples. It performs a single step of the nested sampling algorithm by updating the liveset with a new walker.
629

630
## Arguments
631
- `liveset::LatticeGasWalkers`: The liveset of lattice gas walkers.
632
- `ns_params::LatticeNestedSamplingParameters`: The parameters for nested sampling.
633
- `mc_routine::MCNewSample`: The Monte Carlo routine for generating new samples.
634

635
## Returns
636
- `iter`: The iteration number of the liveset after the step.
637
- `emax`: The maximum energy of the liveset after the step.
638
- `liveset::LatticeGasWalkers`: The updated liveset after the step.
639
- `ns_params::LatticeNestedSamplingParameters`: The updated nested sampling parameters after the step.
640
"""
641
function nested_sampling_step!(liveset::LatticeGasWalkers, 
4✔
642
                               ns_params::NestedSamplingParameters, 
643
                               mc_routine::MCNewSample)
644
    sort_by_energy!(liveset)
4✔
645
    ats = liveset.walkers
4✔
646
    h = liveset.hamiltonian
4✔
647
    iter::Union{Missing,Int} = missing
4✔
648
    emax::Union{Missing,Float64} = liveset.walkers[1].energy.val
4✔
649

650
    to_walk = deepcopy(ats[1])
8✔
651

652
    accept, at = MC_new_sample!(to_walk, h, emax; energy_perturb=ns_params.energy_perturbation)
4✔
653

654
    # @info "iter: $(liveset.walkers[1].iter), emax: $emax, is_accepted: $accept"
655
    if accept
4✔
UNCOV
656
        push!(ats, at)
×
UNCOV
657
        popfirst!(ats)
×
UNCOV
658
        update_iter!(liveset)
×
UNCOV
659
        ns_params.fail_count = 0
×
UNCOV
660
        iter = liveset.walkers[1].iter
×
661
    else
662
        # @warn "Failed to accept MC move"
663
        emax = missing
4✔
664
        ns_params.fail_count += 1
4✔
665
    end
666
    # adjust_step_size(ns_params, rate)
667
    return iter, emax * unit(liveset.walkers[1].energy), liveset, ns_params
4✔
668
end
669

670

671
function nested_sampling_step!(liveset::LatticeGasWalkers, 
4✔
672
                               ns_params::NestedSamplingParameters, 
673
                               mc_routine::MCRejectionSampling)
674
    sort_by_energy!(liveset)
4✔
675
    ats = liveset.walkers
4✔
676
    h = liveset.hamiltonian
4✔
677
    iter::Union{Missing,Int} = missing
4✔
678
    emax::Union{Missing,Float64} = liveset.walkers[1].energy.val
4✔
679

680
    to_walk = deepcopy(ats[1])
8✔
681

682
    accept, at = MC_rejection_sampling!(to_walk, h, emax; energy_perturb=ns_params.energy_perturbation)
4✔
683

684
    # @info "iter: $(liveset.walkers[1].iter), emax: $emax, is_accepted: $accept"
685
    if accept
4✔
686
        push!(ats, at)
4✔
687
        popfirst!(ats)
4✔
688
        update_iter!(liveset)
4✔
689
        ns_params.fail_count = 0
4✔
690
        iter = liveset.walkers[1].iter
4✔
691
    else
692
        # @warn "Failed to accept MC move"
693
        emax = missing
×
694
        ns_params.fail_count += 1
×
695
    end
696
    # adjust_step_size(ns_params, rate)
697
    return iter, emax * unit(liveset.walkers[1].energy), liveset, ns_params
4✔
698
end
699

700

701

702
"""
703
    nested_sampling(liveset::AbstractLiveSet, ns_params::NestedSamplingParameters, n_steps::Int64, mc_routine::MCRoutine; args...)
704

705
Perform a nested sampling loop for a given number of steps.
706

707
# Arguments
708
- `liveset::AbstractLiveSet`: The initial set of walkers.
709
- `ns_params::NestedSamplingParameters`: The parameters for nested sampling.
710
- `n_steps::Int64`: The number of steps to perform.
711
- `mc_routine::MCRoutine`: The Monte Carlo routine to use.
712

713
# Returns
714
- `df`: A DataFrame containing the iteration number and maximum energy for each step.
715
- `liveset`: The updated set of walkers.
716
- `ns_params`: The updated nested sampling parameters.
717
"""
718
function nested_sampling(liveset::AbstractLiveSet, 
36✔
719
                                ns_params::NestedSamplingParameters, 
720
                                n_steps::Int64, 
721
                                mc_routine::MCRoutine,
722
                                save_strategy::DataSavingStrategy)
723
    df = DataFrame(iter=Int[], emax=Float64[])
36✔
724
    for i in 1:n_steps
36✔
725
        print_info = i % save_strategy.n_info == 0
200✔
726
        write_walker_every_n(liveset.walkers[1], i, save_strategy)
200✔
727
        iter, emax, liveset, ns_params = nested_sampling_step!(liveset, ns_params, mc_routine)
215✔
728
        @debug "n_step $i, iter: $iter, emax: $emax"
200✔
729
        if ns_params.fail_count >= ns_params.allowed_fail_count
200✔
UNCOV
730
            @warn "Failed to accept MC move $(ns_params.allowed_fail_count) times in a row. Reset step size!"
×
UNCOV
731
            ns_params.fail_count = 0
×
UNCOV
732
            ns_params.step_size = ns_params.initial_step_size
×
733
        end
734
        if !(iter isa typeof(missing))
200✔
735
            push!(df, (iter, emax.val))
245✔
736
        end
737
        print_message(i, iter, emax, ns_params.step_size, print_info, liveset)
395✔
738
        write_df_every_n(df, i, save_strategy)
200✔
739
        write_ls_every_n(liveset, i, save_strategy)
200✔
740
    end
364✔
741
    return df, liveset, ns_params
36✔
742
end
743

744
function print_message(i, iter, emax, step_size, print_info, liveset::LatticeWalkers)
88✔
745
    if print_info && !(iter isa typeof(missing))
88✔
746
        @info "iter: $(liveset.walkers[1].iter), emax: $(emax)"
39✔
747
    elseif print_info && iter isa typeof(missing)
49✔
748
        @info "MC move failed, step: $(i), emax: $(liveset.walkers[1].energy)"
1✔
749
    end
750
end
751

752
function print_message(i, iter, emax, step_size, print_info, liveset::AtomWalkers)
112✔
753
    if print_info && !(iter isa typeof(missing))
112✔
754
        @info "iter: $(liveset.walkers[1].iter), emax: $(emax.val), step_size: $(round(step_size; sigdigits=4))"
50✔
755
    elseif print_info && iter isa typeof(missing)
62✔
756
        @info "MC move failed, step: $(i), emax: $(liveset.walkers[1].energy.val), step_size: $(round(step_size; sigdigits=4))"
2✔
757
    end
758
end
759
    
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc