• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

NLESC-JCER / QMCTorch / 14968442546

12 May 2025 09:14AM UTC coverage: 83.955%. First build
14968442546

Pull #187

github

web-flow
Merge a67f074c6 into 20fe7ebf9
Pull Request #187: Clean up Main

951 of 1326 branches covered (71.72%)

Branch coverage included in aggregate %.

287 of 362 new or added lines in 47 files covered. (79.28%)

4522 of 5193 relevant lines covered (87.08%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

75.87
/qmctorch/solver/solver.py
1
from time import time
1✔
2
from tqdm import tqdm
1✔
3
from types import SimpleNamespace
1✔
4
from typing import Optional, Dict, List, Tuple, Any
1✔
5
import torch
1✔
6
from ..wavefunction import WaveFunction
1✔
7
from ..sampler import SamplerBase
1✔
8
from ..utils import  add_group_attr, dump_to_hdf5, DataLoader
1✔
9
from qmctorch.utils import add_group_attr, dump_to_hdf5, DataLoader
1✔
10

11
from .. import log
1✔
12
from .solver_base import SolverBase
1✔
13
from .loss import Loss
1✔
14

15
class Solver(SolverBase):
1✔
16
    def __init__(  # pylint: disable=too-many-arguments
1✔
17
        self,
18
        wf: Optional[WaveFunction] = None,
19
        sampler: Optional[SamplerBase] = None,
20
        optimizer: Optional[torch.optim.Optimizer] = None,
21
        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
22
        output: Optional[str] = None,
23
        rank: int = 0,
24
    ) -> None:
25
        """Basic QMC solver
26

27
        Args:
28
            wf (qmctorch.WaveFunction, optional): wave function. Defaults to None.
29
            sampler (qmctorch.sampler, optional): Sampler. Defaults to None.
30
            optimizer (torch.optim, optional): optimizer. Defaults to None.
31
            scheduler (torch.optim, optional): scheduler. Defaults to None.
32
            output (str, optional): hdf5 filename. Defaults to None.
33
            rank (int, optional): rank of he process. Defaults to 0.
34
        """
35
        SolverBase.__init__(self, wf, sampler, optimizer, scheduler, output, rank)
1✔
36

37
        self.set_params_requires_grad()
1✔
38

39
        self.configure(
1✔
40
            track=["local_energy"],
41
            freeze=None,
42
            loss="energy",
43
            grad="manual",
44
            ortho_mo=False,
45
            clip_loss=False,
46
            resampling={"mode": "update", "resample_every": 1, "nstep_update": 25},
47
        )
48

49
    def configure(
1✔
50
        self,
51
        track: Optional[List[str]] = None,
52
        freeze: Optional[List[torch.nn.Parameter]] = None,
53
        loss: Optional[str] = None,
54
        grad: Optional[str] = None,
55
        ortho_mo: Optional[bool] = None,
56
        clip_loss: bool = False,
57
        clip_threshold: int = 5,
58
        resampling: Optional[Dict[str, Any]] = None,
59
    ) -> None:
60
        """Configure the solver
61

62
        Args:
63
            track (list, optional): list of observable to track. Defaults to ['local_energy'].
64
            freeze (list, optional): list of parameters to freeze. Defaults to None.
65
            loss (str, optional): method to compute the loss: variance or energy.
66
                                  Defaults to 'energy'.
67
            grad (str, optional): method to compute the gradients: 'auto' or 'manual'.
68
                                  Defaults to 'auto'.
69
            ortho_mo (bool, optional): apply regularization to orthogonalize the MOs.
70
                                       Defaults to False.
71
            clip_loss (bool, optional): Clip the loss values at +/- X std. X defined in Loss
72
                                        as clip_num_std (default 5)
73
                                        Defaults to False.
74
            resampling (dict, optional): resampling options.
75
        """
76

77
        # set the parameters we want to optimize/freeze
78
        self.set_params_requires_grad()
1✔
79
        self.freeze_params_list = freeze
1✔
80
        self.freeze_parameters(freeze)
1✔
81

82
        # track the observable we want
83
        if track is not None:
1!
84
            self.track_observable(track)
1✔
85

86
        # define the grad calulation
87
        if grad is not None:
1!
88
            self.grad_method = grad
1✔
89
            self.evaluate_gradient = {
1✔
90
                "auto": self.evaluate_grad_auto,
91
                "manual": self.evaluate_grad_manual,
92
            }[grad]
93

94
        # resampling of the wave function
95
        if resampling is not None:
1✔
96
            self.configure_resampling(**resampling)
1✔
97

98
        # get the loss
99
        if loss is not None:
1!
100
            self.loss = Loss(self.wf, method=loss, clip=clip_loss, clip_threshold=clip_threshold)
1✔
101
            self.loss.use_weight = self.resampling_options.resample_every > 1
1✔
102

103
        # orthogonalization penalty for the MO coeffs
104
        self.ortho_mo = ortho_mo
1✔
105
        if self.ortho_mo is True:
1!
106
            log.warning("Orthogonalization of the MO coeffs via loss penalty is deprecated")
×
107

108
    def set_params_requires_grad(self,
1✔
109
                                 wf_params: Optional[bool] = True,
110
                                 geo_params: Optional[bool] = False):
111
        """Configure parameters for wf opt."""
112

113
        # opt all wf parameters
114
        self.wf.ao.bas_exp.requires_grad = wf_params
1✔
115
        self.wf.ao.bas_coeffs.requires_grad = wf_params
1✔
116

117
        for param in self.wf.mo.parameters():
1✔
118
            param.requires_grad = wf_params
1✔
119

120
        self.wf.fc.weight.requires_grad = wf_params
1✔
121

122
        if hasattr(self.wf, "jastrow"):
1!
123
            if self.wf.jastrow is not None:
1!
124
                for param in self.wf.jastrow.parameters():
1✔
125
                    param.requires_grad = wf_params
1✔
126

127
        # no opt the atom positions
128
        self.wf.ao.atom_coords.requires_grad = geo_params
1✔
129

130
    def freeze_parameters(self, freeze: List[str]) -> None:
1✔
131
        """Freeze the optimization of specified params.
132

133
        Args:
134
            freeze (list): list of param to freeze
135
        """
136
        if freeze is not None:
1✔
137
            if not isinstance(freeze, list):
1!
138
                freeze = [freeze]
×
139

140
            for name in freeze:
1!
141
                if name.lower() == "ci":
×
142
                    self.wf.fc.weight.requires_grad = False
×
143

144
                elif name.lower() == "mo":
×
145
                    for param in self.wf.mo.parameters():
×
146
                        param.requires_grad = False
×
147

148
                elif name.lower() == "ao":
×
149
                    self.wf.ao.bas_exp.requires_grad = False
×
150
                    self.wf.ao.bas_coeffs.requires_grad = False
×
151

152
                elif name.lower() == "jastrow":
×
153
                    for param in self.wf.jastrow.parameters():
×
154
                        param.requires_grad = False
×
155

156
                elif name.lower() == "backflow":
×
157
                    for param in self.wf.ao.backflow_trans.parameters():
×
158
                        param.requires_grad = False
×
159

160
                else:
161
                    opt_freeze = ["ci", "mo", "ao", "jastrow", "backflow"]
×
162
                    raise ValueError("Valid arguments for freeze are :", opt_freeze)
×
163

164
    def save_sampling_parameters(self) -> None:
1✔
165
        """save the sampling params."""
166
        self.sampler._nstep_save = self.sampler.nstep
1✔
167
        self.sampler._ntherm_save = self.sampler.ntherm
1✔
168
        # self.sampler._nwalker_save = self.sampler.walkers.nwalkers
169

170
        if self.resampling_options.mode == "update":
1✔
171
            self.sampler.ntherm = self.resampling_options.ntherm_update
1✔
172
            self.sampler.nstep = self.resampling_options.nstep_update
1✔
173
            # self.sampler.walkers.nwalkers = pos.shape[0]
174

175
    def restore_sampling_parameters(self) -> None:
1✔
176
        """restore sampling params to their original values."""
177
        self.sampler.nstep = self.sampler._nstep_save
1✔
178
        self.sampler.ntherm = self.sampler._ntherm_save
1✔
179
        # self.sampler.walkers.nwalkers = self.sampler._nwalker_save
180

181

182
    def run(
1✔
183
        self,
184
        nepoch: int,
185
        batchsize : Optional[int] = None,
186
        hdf5_group: Optional[str] = "wf_opt",
187
        chkpt_every: Optional[int] = None,
188
        tqdm: Optional[bool] = False
189
    ) -> SimpleNamespace:
190
        """Run a wave function optimization
191

192
        Args:
193
            nepoch (int): Number of optimziation step
194
            batchsize (int, optional): Number of sample in a mini batch.
195
                                       If None, all samples are used.
196
                                       Defaults to Never.
197
            hdf5_group (str, optional): name of the hdf5 group where to store the data.
198
                                        Defaults to 'wf_opt'.
199
            chkpt_every (int, optional): save a checkpoint every every iteration.
200
                                         Defaults to half the number of epoch
201
        """
202
        # prepare the optimization
203
        self.prepare_optimization(batchsize, chkpt_every, tqdm)
1✔
204
        self.log_data_opt(nepoch, "wave function optimization")
1✔
205

206
        # run the epochs
207
        self.run_epochs(nepoch)
1✔
208

209
        # restore the sampler number of step
210
        self.restore_sampling_parameters()
1✔
211

212
        # dump
213
        self.save_data(hdf5_group)
1✔
214

215
        return self.observable
1✔
216

217
    def prepare_optimization(self, batchsize: int, chkpt_every: int , tqdm: Optional[bool] = False):
1✔
218
        """Prepare the optimization process
219

220
        Args:
221
            batchsize (int or None): batchsize
222
            chkpt_every (int or none): save a chkpt file every
223
        """
224
        log.info("  Initial Sampling    :")
1✔
225
        tstart = time()
1✔
226

227
        # sample the wave function
228
        pos = self.sampler(self.wf.pdf, with_tqdm=tqdm)
1✔
229

230
        # handle the batch size
231
        if batchsize is None:
1✔
232
            batchsize = len(pos)
1✔
233

234
        # change the number of steps/walker size
235
        self.save_sampling_parameters()
1✔
236

237
        # create the data loader
238
        self.dataloader = DataLoader(pos, batch_size=batchsize, pin_memory=self.cuda)
1✔
239

240
        for ibatch, data in enumerate(self.dataloader):
1✔
241
            self.store_observable(data, ibatch=ibatch)
1✔
242

243
        # chkpt
244
        self.chkpt_every = chkpt_every
1✔
245

246
        log.info("  done in %1.2f sec." % (time() - tstart))
1✔
247

248
    def save_data(self, hdf5_group: str):
1✔
249
        """Save the data to hdf5.
250

251
        Args:
252
            hdf5_group (str): name of group in the hdf5 file
253
        """
254
        self.observable.models.last = dict(self.wf.state_dict())
1✔
255

256
        hdf5_group = dump_to_hdf5(self.observable, self.hdf5file, hdf5_group)
1✔
257

258
        add_group_attr(self.hdf5file, hdf5_group, {"type": "opt"})
1✔
259

260
    def run_epochs(self, nepoch: int,
1✔
261
                   with_tqdm: Optional[bool] = False,
262
                   verbose: Optional[bool] = True) -> float :
263
        """Run a certain number of epochs
264

265
        Args:
266
            nepoch (int): number of epoch to run
267
        """
268

269
        if with_tqdm and verbose:
1!
270
            raise ValueError("tqdm and verbose are mutually exclusive")
×
271

272
        # init the loss in case we have nepoch=0
273
        cumulative_loss = 0
1✔
274
        min_loss = 0  # this is set at n=0
1✔
275

276
        # the range
277
        rng = tqdm(
1✔
278
            range(nepoch),
279
            desc="INFO:QMCTorch|  Optimization",
280
            disable=not with_tqdm,
281
        )
282

283
        # loop over the epoch
284
        for n in rng:
1✔
285

286
            if verbose:
1!
287
                tstart = time()
1✔
288
                log.info("")
1✔
289
                log.info(
1✔
290
                    "  epoch %d | %d sampling points" % (n, len(self.dataloader.dataset))
291
                )
292

293
            # reset the gradients and loss
294
            cumulative_loss = 0
1✔
295
            self.opt.zero_grad()
1✔
296
            self.wf.zero_grad()
1✔
297

298
            # loop over the batches
299
            for ibatch, data in enumerate(self.dataloader):
1✔
300
                # port data to device
301
                lpos = data.to(self.device)
1✔
302

303
                # get the gradient
304
                loss, eloc = self.evaluate_gradient(lpos)
1✔
305
                cumulative_loss += loss.item()
1✔
306

307
                # check for nan
308
                if torch.isnan(eloc).any():
1!
309
                    log.info("Error : Nan detected in local energy")
×
310
                    return cumulative_loss
×
311

312
                # observable
313
                self.store_observable(lpos, local_energy=eloc, ibatch=ibatch)
1✔
314

315
            # optimize the parameters
316
            self.optimization_step(lpos)
1✔
317

318
            # save the model if necessary
319
            if n == 0 or cumulative_loss < min_loss:
1✔
320
                min_loss = cumulative_loss
1✔
321
                self.observable.models.best = dict(self.wf.state_dict())
1✔
322

323
            # save checkpoint file
324
            if self.chkpt_every is not None:
1!
325
                if (n > 0) and (n % self.chkpt_every == 0):
×
326
                    self.save_checkpoint(n, cumulative_loss)
×
327

328
            if verbose:
1!
329
                self.print_observable(cumulative_loss, verbose=False)
1✔
330

331
            # resample the data
332
            self.dataloader.dataset = self.resample(n, self.dataloader.dataset)
1✔
333

334
            # scheduler step
335
            if self.scheduler is not None:
1!
336
                self.scheduler.step()
×
337

338
            if verbose:
1!
339
                log.info("  epoch done in %1.2f sec." % (time() - tstart))
1✔
340

341
        return cumulative_loss
1✔
342

343
    def evaluate_grad_auto(self, lpos: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1✔
344
        """Evaluate the gradient using automatic differentiation
345

346
        Args:
347
            lpos (torch.tensor): sampling points
348

349
        Returns:
350
            tuple: loss values and local energies
351
        """
352

353
        # compute the loss
354
        loss, eloc = self.loss(lpos)
1✔
355

356
        # compute local gradients
357
        loss.backward()
1✔
358

359
        return loss, eloc
1✔
360

361
    def evaluate_grad_manual(self, lpos: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1✔
362
        """Evaluate the gradient using low variance expression
363
        WARNING : This method is not valid to compute forces
364
        as it does not include derivative of the hamiltonian
365
        wrt atomic positions
366

367
        The gradients are here evaluated following:
368

369
        .. math:
370
            dE/dk = < (dpsi/dk)/psi (E_L - <E_L >)>
371

372
        Other estimators are possible:
373

374
        .. math:
375
            dE/dk = 2 [ < (dpsi/dk) E_L/psi >  - < (dpsi/dk) / psi > <E_L > ]
376

377
        given in https://www.cond-mat.de/events/correl19/manuscripts/luechow.pdf eq. 17.
378
        Or
379

380
        .. math:
381
            dE/dk = <  (E_L - <E_L >) d[ln(abs(psi))] / dk) >
382

383
        used in PauliNet
384

385

386
        Args:
387
            lpos (torch.tensor): sampling points
388

389
        Returns:
390
            tuple: loss values and local energies
391
        """
392

393
        if self.loss.method not in ["energy", "weighted-energy"]:
1!
NEW
394
            raise ValueError("Manual gradient only for energy minimization")
×
395

396
        # compute local energy
397
        with torch.no_grad():
1✔
398
            eloc = self.wf.local_energy(lpos)
1✔
399

400
        # compute the wf values
401
        psi = self.wf(lpos)
1✔
402
        norm = 1.0 / len(psi)
1✔
403

404
        # evaluate the prefactor of the grads
405
        weight = eloc.clone()
1✔
406
        weight -= torch.mean(eloc)
1✔
407
        weight /= psi.clone()
1✔
408
        weight *= 2.0 * norm
1✔
409

410
        # clip the values
411
        clip_mask = self.loss.get_clipping_mask(eloc)
1✔
412
        psi = psi[clip_mask]
1✔
413
        weight = weight[clip_mask]
1✔
414

415
        # compute the gradients
416
        psi.backward(weight)
1✔
417

418
        return torch.mean(eloc), eloc
1✔
419

420
    def compute_forces(self, lpos: torch.tensor, batch_size: int = None, clip: int = None) -> torch.tensor:
1✔
421
        r"""
422
        Compute the forces using automatic differentation and stable estimator
423

424
        ..math::
425
            F = -\\langle \\nabla_\\alpha E_L(R) + (E_L(R) - E) \\nabla)\\alpha |\Psi(R)|^2 \\rangle
426

427
        see e.g. https://arxiv.org/abs/2404.09755
428

429
        Args:
430
            lpos (torch.tensor): sampling points
431
            batch_size (int): the size of the batch to use for the automatic differentiation
432
            clip (int): the number of decimal places to clip the sampling points
433

434
        Returns:
435
            torch.tensor: the numerical forces
436

437
        """
438

439
        def get_clipping_mask(values: torch.tensor, clip: int) -> torch.tensor:
1✔
440
            """
441
            Compute a mask to clip the values based on their zscore
442

443
            Parameters
444
            ----------
445
            values : torch.tensor
446
                the values to clip
447
            clip : int
448
                the number of decimal places to clip the values
449

450
            Returns
451
            -------
452
            mask : torch.tensor
453
                the mask to clip the values
454
            """
455
            if clip is not None:
1!
456
                median = torch.median(values)
×
457
                std = torch.std(values)
×
NEW
458
                zscore = torch.abs((values - median) / std)
×
459
                mask = zscore < clip
×
460
            else:
461
                mask = torch.ones_like(values).type(torch.bool)
1✔
462

463
            return mask
1✔
464

465
        # save the grad status of the ao
466
        original_requires_grad = self.wf.ao.atom_coords.requires_grad
1✔
467
        if not original_requires_grad:
1!
468
            self.wf.ao.atom_coords.requires_grad = True
1✔
469

470
        if batch_size is None:
1!
471
            batch_size = lpos.shape[0]
1✔
472
        nbatch = lpos.shape[0]//batch_size
1✔
473

474
        forces = torch.zeros_like(self.wf.ao.atom_coords).requires_grad_(False)
1✔
475
        for ibatch in range(nbatch):
1✔
476

477
            # get the batch
478
            idx_start = ibatch*batch_size
1✔
479
            idx_end = (ibatch+1)*batch_size
1✔
480
            if idx_end > lpos.shape[0]:
1!
481
                idx_end = lpos.shape[0]
×
482
            lpos_batch = lpos[idx_start:idx_end]
1✔
483

484
            # compute the local energy and its gradient
485
            local_energy = self.wf.local_energy(lpos_batch)
1✔
486
            clip_mask = get_clipping_mask(local_energy, clip)
1✔
487
            grad_eloc =  torch.autograd.grad(local_energy, self.wf.ao.atom_coords, grad_outputs=clip_mask)[0]
1✔
488

489
            # compute the log density and its gradient
490
            wf_val = self.wf.pdf(lpos_batch)
1✔
491
            proba = torch.log(wf_val)
1✔
492
            grad_outputs = ((local_energy-local_energy.mean()) * clip_mask).squeeze()
1✔
493
            grad_proba = torch.autograd.grad(proba, self.wf.ao.atom_coords, grad_outputs=grad_outputs)[0]
1✔
494

495
            # accumulate in the force
496
            forces += 1./batch_size * (grad_eloc + grad_proba)
1✔
497

498
        if not original_requires_grad:
1!
499
            self.wf.ao.atom_coords.requires_grad = False
1✔
500

501
        return forces
1✔
502

503

504
    def log_data_opt(self, nepoch, task):
1✔
505
        """Log data for the optimization."""
506
        log.info("")
1✔
507
        log.info("  Optimization")
1✔
508
        log.info("  Task                :", task)
1✔
509
        log.info("  Number Parameters   : {0}", self.wf.get_number_parameters())
1✔
510
        log.info("  Number of epoch     : {0}", nepoch)
1✔
511
        log.info("  Batch size          : {0}", self.sampler.get_sampling_size())
1✔
512
        log.info("  Loss function       : {0}", self.loss.method)
1✔
513
        log.info("  Clip Loss           : {0}", self.loss.clip)
1✔
514
        log.info("  Gradients           : {0}", self.grad_method)
1✔
515
        log.info("  Resampling mode     : {0}", self.resampling_options.mode)
1✔
516
        log.info("  Resampling every    : {0}", self.resampling_options.resample_every)
1✔
517
        log.info("  Resampling steps    : {0}", self.resampling_options.nstep_update)
1✔
518
        log.info("  Output file         : {0}", self.hdf5file)
1✔
519
        log.info("  Checkpoint every    : {0}", self.chkpt_every)
1✔
520
        log.info("")
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc