• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

NLESC-JCER / QMCTorch / 14968442546

12 May 2025 09:14AM UTC coverage: 83.955%. First build
14968442546

Pull #187

github

web-flow
Merge a67f074c6 into 20fe7ebf9
Pull Request #187: Clean up Main

951 of 1326 branches covered (71.72%)

Branch coverage included in aggregate %.

287 of 362 new or added lines in 47 files covered. (79.28%)

4522 of 5193 relevant lines covered (87.08%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.66
/qmctorch/wavefunction/pooling/slater_pooling.py
1
import torch
1✔
2
from torch import nn
1✔
3
import operator as op
1✔
4
from typing import Tuple, Callable, Optional, List, Union
1✔
5
from ...scf import Molecule
1✔
6
from ...utils import bdet2, btrace
1✔
7
from .orbital_configurations import get_excitation, get_unique_excitation
1✔
8
from .orbital_projector import ExcitationMask, OrbitalProjector
1✔
9

10

11
class SlaterPooling(nn.Module):
1✔
12

13
    """Applies a slater determinant pooling in the active space."""
14

15
    def __init__(
1✔
16
        self,
17
        config_method: str,
18
        configs: Tuple[torch.LongTensor, torch.LongTensor],
19
        mol: Molecule,
20
        cuda: bool = False,
21
    ) -> None:
22
        """Computes the Slater determinants
23

24
        Args:
25
            config_method (str): method used to define the config
26
            configs (Tuple[torch.LongTensor, torch.LongTensor]):
27
                configuratin of the electrons
28
            mol (Molecule): Molecule instance
29
            cuda (bool, optional): Turns GPU ON/OFF. Defaults to False.
30

31
        """
32
        super(SlaterPooling, self).__init__()
1✔
33

34
        self.config_method = config_method
1✔
35

36
        self.configs = configs
1✔
37
        self.nconfs = len(configs[0])
1✔
38
        self.index_max_orb_up = self.configs[0].max().item() + 1
1✔
39
        self.index_max_orb_down = self.configs[1].max().item() + 1
1✔
40

41
        self.excitation_index = get_excitation(configs)
1✔
42
        self.unique_excitation, self.index_unique_excitation = get_unique_excitation(
1✔
43
            configs
44
        )
45

46
        self.nmo = mol.basis.nmo
1✔
47
        self.nup = mol.nup
1✔
48
        self.ndown = mol.ndown
1✔
49
        self.nelec = self.nup + self.ndown
1✔
50
        self.use_explicit_operator = False
1✔
51

52
        self.orb_proj = OrbitalProjector(configs, mol, cuda=cuda)
1✔
53
        self.exc_mask = ExcitationMask(
1✔
54
            self.unique_excitation,
55
            mol,
56
            (self.index_max_orb_up, self.index_max_orb_down),
57
            cuda=cuda,
58
        )
59

60
        self.device = torch.device("cpu")
1✔
61
        if cuda:
1!
62
            self.device = torch.device("cuda")
×
63

64
    def forward(self, input: torch.Tensor) -> torch.Tensor:
1✔
65
        """Computes the values of the determinats
66

67
        Args:
68
            input (torch.Tensor): MO matrices nbatch x nelec x nmo
69

70
        Returns:
71
            torch.Tensor: slater determinants
72
        """
73
        if self.config_method.startswith("cas("):
1✔
74
            return self.det_explicit(input)
1✔
75
        elif self.config_method == 'explicit':
1!
76
            return self.det_explicit(input)
×
77
        else:
78
            if self.use_explicit_operator:
1!
79
                return self.det_explicit(input)
×
80
            return self.det_single_double(input)
1✔
81

82
    def get_slater_matrices(
1✔
83
        self, input: torch.Tensor
84
    ) -> Tuple[torch.Tensor, torch.Tensor]:
85
        """Computes the slater matrices
86

87
        Args:
88
            input (torch.Tensor): MO matrices nbatch x nelec x nmo
89

90
        Returns:
91
            Tuple[torch.Tensor, torch.Tensor]:
92
                slater matrices of spin up/down
93
        """
94
        return self.orb_proj.split_orbitals(input, unique_configs=True)
1✔
95

96
    def det_explicit(self, input: torch.Tensor) -> torch.Tensor:
1✔
97
        """Computes the values of the determinants from the slater matrices
98

99
        Args:
100
            input (torch.tensor): MO matrices nbatch x nelec x nmo
101

102
        Returns:
103
            torch.tensor: slater determinants
104
        """
105
        mo_up, mo_down = self.get_slater_matrices(input)
1✔
106
        det_up = torch.det(mo_up)
1✔
107
        det_down = torch.det(mo_down)
1✔
108
        return (det_up[self.orb_proj.index_unique_configs[0], ...] * det_down[self.orb_proj.index_unique_configs[1], ...]).transpose(0, 1)
1✔
109

110
    def det_single_double(self, input: torch.Tensor) -> torch.Tensor:
1✔
111
        """Computes the determinant of ground state + single + double excitations.
112

113
        Args:
114
            input (torch.Tensor): MO matrices nbatch x nelec x nmo
115

116
        Returns:
117
            torch.Tensor: Slater determinants for the configurations
118
        """
119
        # Compute the determinant of the unique single and double excitations
120
        det_unique_up, det_unique_down = self.det_unique_single_double(input)
1✔
121

122
        # Returns the product of spin up/down determinants required by each excitation
123
        return (
1✔
124
            det_unique_up[:, self.index_unique_excitation[0]]
125
            * det_unique_down[:, self.index_unique_excitation[1]]
126
        )
127

128
    def det_ground_state(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1✔
129
        """Computes the Slater determinants of the ground state.
130

131
        Args:
132
            input (torch.Tensor): Molecular orbital matrices of shape (nbatch, nelec, nmo).
133

134
        Returns:
135
            Tuple[torch.Tensor, torch.Tensor]: Slater determinants for spin up and spin down configurations.
136
        """
137
        return (
×
138
            torch.det(input[:, : self.nup, : self.nup]),
139
            torch.det(input[:, self.nup :, : self.ndown]),
140
        )
141

142
    def det_unique_single_double(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1✔
143
        """Computes the SD of single/double excitations
144

145
        The determinants of the single excitations
146
        are calculated from the ground state determinant and
147
        the ground state Slater matrices whith one column modified.
148
        See : Monte Carlo Methods in ab initio quantum chemistry
149
        B.L. Hammond, appendix B1
150

151

152
        Note : if the state on configs are specified in order
153
        we end up with excitations that comes from a deep orbital, the resulting
154
        slater matrix has one column changed (with the new orbital) and several
155
        permutation. We therefore need to multiply the slater determinant
156
        by (-1)^nperm.
157

158

159
        .. math::
160

161
            MO = [ A | B ]
162
            det(Exc_{ij}) = (det(A) * A^{-1} * B)_{i,j}
163

164
        Args:
165
            input (torch.tensor): MO matrices nbatch x nelec x nmo
166

167
        """
168

169
        nbatch = input.shape[0]
1✔
170

171
        if not hasattr(self.exc_mask, "index_unique_single_up"):
1✔
172
            self.exc_mask.get_index_unique_single()
1✔
173

174
        if not hasattr(self.exc_mask, "index_unique_double_up"):
1✔
175
            self.exc_mask.get_index_unique_double()
1✔
176

177
        do_single = len(self.exc_mask.index_unique_single_up) != 0
1✔
178
        do_double = len(self.exc_mask.index_unique_double_up) != 0
1✔
179

180
        # occupied orbital matrix + det and inv on spin up
181
        Aup = input[:, : self.nup, : self.nup]
1✔
182
        detAup = torch.det(Aup)
1✔
183

184
        # occupied orbital matrix + det and inv on spin down
185
        Adown = input[:, self.nup :, : self.ndown]
1✔
186
        detAdown = torch.det(Adown)
1✔
187

188
        # store all the dets we need
189
        det_out_up = detAup.unsqueeze(-1).clone()
1✔
190
        det_out_down = detAdown.unsqueeze(-1).clone()
1✔
191

192
        # return the ground state
193
        if self.config_method == "ground_state":
1✔
194
            return det_out_up, det_out_down
1✔
195

196
        # inverse of the
197
        invAup = torch.inverse(Aup)
1✔
198
        invAdown = torch.inverse(Adown)
1✔
199

200
        # virtual orbital matrices spin up/down
201
        Bup = input[:, : self.nup, self.nup : self.index_max_orb_up]
1✔
202
        Bdown = input[:, self.nup :, self.ndown : self.index_max_orb_down]
1✔
203

204
        # compute the products of Ain and B
205
        mat_exc_up = invAup @ Bup
1✔
206
        mat_exc_down = invAdown @ Bdown
1✔
207

208
        if do_single:
1!
209
            # determinant of the unique excitation spin up
210
            det_single_up = mat_exc_up.view(nbatch, -1)[
1✔
211
                :, self.exc_mask.index_unique_single_up
212
            ]
213

214
            # determinant of the unique excitation spin down
215
            det_single_down = mat_exc_down.view(nbatch, -1)[
1✔
216
                :, self.exc_mask.index_unique_single_down
217
            ]
218

219
            # multiply with ground state determinant
220
            # and account for permutation for deep excitation
221
            det_single_up = detAup.unsqueeze(-1) * det_single_up.view(nbatch, -1)
1✔
222

223
            # multiply with ground state determinant
224
            # and account for permutation for deep excitation
225
            det_single_down = detAdown.unsqueeze(-1) * det_single_down.view(nbatch, -1)
1✔
226

227
            # accumulate the dets
228
            det_out_up = torch.cat((det_out_up, det_single_up), dim=1)
1✔
229
            det_out_down = torch.cat((det_out_down, det_single_down), dim=1)
1✔
230

231
        if do_double:
1✔
232
            # det of unique spin up double exc
233
            det_double_up = mat_exc_up.view(nbatch, -1)[
1✔
234
                :, self.exc_mask.index_unique_double_up
235
            ]
236

237
            det_double_up = bdet2(det_double_up.view(nbatch, -1, 2, 2))
1✔
238

239
            det_double_up = detAup.unsqueeze(-1) * det_double_up
1✔
240

241
            # det of unique spin down double exc
242
            det_double_down = mat_exc_down.view(nbatch, -1)[
1✔
243
                :, self.exc_mask.index_unique_double_down
244
            ]
245

246
            det_double_down = bdet2(det_double_down.view(nbatch, -1, 2, 2))
1✔
247

248
            det_double_down = detAdown.unsqueeze(-1) * det_double_down
1✔
249

250
            det_out_up = torch.cat((det_out_up, det_double_up), dim=1)
1✔
251
            det_out_down = torch.cat((det_out_down, det_double_down), dim=1)
1✔
252

253
        return det_out_up, det_out_down
1✔
254

255
    def operator(
1✔
256
        self,
257
        mo: torch.Tensor,
258
        bop: torch.Tensor,
259
        op: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = op.add,
260
        op_squared: bool = False,
261
        inv_mo: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
262
    ) -> torch.Tensor:
263
        """Computes the values of an opearator applied to the procuts of determinant
264

265
        Args:
266
            mo (torch.Tensor): matrix of MO vals(Nbatch, Nelec, Nmo)
267
            bkin (torch.Tensor): kinetic operator (Nbatch, Nelec, Nmo)
268
            op (operator): how to combine the up/down contribution
269
            op_squared (bool, optional): return the trace of the square of the product if True
270
            inv_mo (tupe, optional): precomputed inverse of the mo up & down matrices
271

272
        Returns:
273
            torch.Tensor: kinetic energy
274
        """
275

276
        # get the values of the operator
277
        if self.config_method == "ground_state":
1✔
278
            op_vals = self.operator_ground_state(mo, bop, op_squared)
1✔
279

280

281
        elif self.config_method.startswith("single"):
1✔
282
            if self.use_explicit_operator:
1!
283
                op_vals = self.operator_explicit(mo, bop, op_squared)
×
284
            else:
285
                op_vals = self.operator_single_double(mo, bop, op_squared, inv_mo)
1✔
286

287
        elif self.config_method.startswith("cas("):
1!
288
            op_vals = self.operator_explicit(mo, bop, op_squared)
1✔
289

290
        elif self.config_method == 'explicit':
×
291
            op_vals = self.operator_explicit(mo, bop, op_squared)
×
292

293
        else:
294
            raise ValueError("Configuration %s not recognized" % self.config_method)
×
295

296
        # combine the values is necessary
297
        if op is not None:
1✔
298
            return op(*op_vals)
1✔
299
        else:
300
            return op_vals
1✔
301

302
    def operator_ground_state(
1✔
303
            self,
304
            mo: torch.Tensor,
305
            bop: torch.Tensor,
306
            op_squared: bool = False,
307
            inv_mo: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
308
    ) -> Tuple[torch.Tensor, torch.Tensor]:
309
        """Computes the values of any operator on gs only
310

311
        Args:
312
            mo (torch.tensor): matrix of molecular orbitals
313
            bop (torch.tensor): matrix of kinetic operator
314
            op_squared (bool, optional) return the trace of the square of the product if True
315
            inv_mo (tuple, optional): precomputed inverse of the up/down MO matrices
316

317
        Returns:
318
            tuple: operator values
319
        """
320
        if inv_mo is None:
1!
321
            invAup, invAdown = self.compute_inverse_occupied_mo_matrix(mo)
1✔
322
        else:
323
            invAup, invAdown = inv_mo
×
324

325
        # precompute the product A^{-1} B
326
        op_ground_up = invAup @ bop[..., : self.nup, : self.nup]
1✔
327
        op_ground_down = invAdown @ bop[..., self.nup :, : self.ndown]
1✔
328

329
        if op_squared:
1!
330
            op_ground_up = op_ground_up @ op_ground_up
×
331
            op_ground_down = op_ground_down @ op_ground_down
×
332

333
        # ground state operator
334
        op_ground_up = btrace(op_ground_up)
1✔
335
        op_ground_down = btrace(op_ground_down)
1✔
336

337
        op_ground_up.unsqueeze_(-1)
1✔
338
        op_ground_down.unsqueeze_(-1)
1✔
339

340
        return op_ground_up, op_ground_down
1✔
341

342
    def operator_explicit(
1✔
343
        self,
344
        mo: torch.Tensor,
345
        bkin: torch.Tensor,
346
        op_squared: bool = False,
347
    ) -> Tuple[torch.Tensor, torch.Tensor]:
348
        r"""Computes the value of any operator using the trace trick for a product
349
            of spin up/down determinant.
350

351
        .. math::
352
            -\\frac{1}{2} \Delta \Psi = -\\frac{1}{2}  D_{up} D_{down}
353
            ( \Delta_{up} D_{up} / D_{up} + \Delta_{down} D_{down}  / D_{down} )
354

355
        Args:
356
            mo: matrix of MO vals(Nbatch, Nelec, Nmo)
357
            bkin: kinetic operator (Nbatch, Nelec, Nmo)
358
            op_squared: return the trace of the square of the product if True
359

360
        Returns:
361
            tuple: kinetic energy
362
        """
363

364
        # shortcut up/down matrices
365
        Aup, Adown = self.orb_proj.split_orbitals(mo, unique_configs=True)
1✔
366
        Bup, Bdown = self.orb_proj.split_orbitals(bkin, unique_configs=True)
1✔
367

368
        # check if we have 1 or multiple ops
369
        multiple_op = Bup.ndim == 5
1✔
370

371
        # inverse of MO matrices
372
        iAup = torch.inverse(Aup)
1✔
373
        iAdown = torch.inverse(Adown)
1✔
374

375
        # if we have multiple operators
376
        if multiple_op:
1✔
377
            iAup = iAup.unsqueeze(1)
1✔
378
            iAdown = iAdown.unsqueeze(1)
1✔
379

380
        # precompute product invA x B
381
        op_val_up = iAup @ Bup
1✔
382
        op_val_down = iAdown @ Bdown
1✔
383

384
        if op_squared:
1!
385
            op_val_up = op_val_up @ op_val_up
×
386
            op_val_down = op_val_down @ op_val_down
×
387

388
        # kinetic terms
389
        op_val_up = btrace(op_val_up)
1✔
390
        op_val_down = btrace(op_val_down)
1✔
391

392
        # reshape
393
        if multiple_op:
1✔
394
            op_val_up = op_val_up.permute(1, 2, 0)
1✔
395
            op_val_down = op_val_down.permute(1, 2, 0)
1✔
396
        else:
397
            op_val_up = op_val_up.transpose(0, 1)
1✔
398
            op_val_down = op_val_down.transpose(0, 1)
1✔
399

400
        return (
1✔
401
            op_val_up[..., self.orb_proj.index_unique_configs[0]],
402
            op_val_down[..., self.orb_proj.index_unique_configs[1]],
403
        )
404

405
    def operator_single_double(
1✔
406
        self,
407
        mo: torch.Tensor,
408
        bop: torch.Tensor,
409
        op_squared: bool = False,
410
        inv_mo: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
411
    ) -> Tuple[torch.Tensor, torch.Tensor]:
412
        """Computes the value of any operator on gs + single + double
413

414
        Args:
415
            mo: matrix of molecular orbitals (torch.tensor)
416
            bop: matrix of kinetic operator (torch.tensor)
417
            op_squared: return the trace of the square of the product if True (bool)
418
            inv_mo: precomputed inverse of the up/down MO matrices (tuple, optional)
419

420
        Returns:
421
            tuple: kinetic energy values (torch.tensor)
422
        """
423

424
        op_up, op_down = self.operator_unique_single_double(mo, bop, op_squared, inv_mo)
1✔
425

426
        return (
1✔
427
            op_up[..., self.index_unique_excitation[0]],
428
            op_down[..., self.index_unique_excitation[1]],
429
        )
430

431
    def operator_unique_single_double(
1✔
432
            self,
433
            mo: torch.Tensor,
434
            bop: torch.Tensor,
435
            op_squared: bool,
436
            inv_mo: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
437
    ) -> Tuple[torch.Tensor, torch.Tensor]:
438
        """Compute the operator value of the unique single/double conformation
439

440
        Args:
441
            mo (torch.Tensor): matrix of molecular orbitals
442
            bop (torch.Tensor): matrix of kinetic operator
443
            op_squared (bool): return the trace of the square of the product if True
444
            inv_mo (tuple, optional): precomputed inverse of the up/down MO matrices
445

446
        Returns:
447
            tuple: operator values
448
        """
449

450
        nbatch = mo.shape[0]
1✔
451

452
        if not hasattr(self.exc_mask, "index_unique_single_up"):
1✔
453
            self.exc_mask.get_index_unique_single()
1✔
454

455
        if not hasattr(self.exc_mask, "index_unique_double_up"):
1✔
456
            self.exc_mask.get_index_unique_double()
1✔
457

458
        do_single = len(self.exc_mask.index_unique_single_up) != 0
1✔
459
        do_double = len(self.exc_mask.index_unique_double_up) != 0
1✔
460

461
        # compute or retrieve the inverse of the up/down MO matrices
462
        if inv_mo is None:
1✔
463
            invAup, invAdown = self.compute_inverse_occupied_mo_matrix(mo)
1✔
464
        else:
465
            invAup, invAdown = inv_mo
1✔
466

467

468
        # precompute invA @ B
469
        invAB_up = invAup @ bop[..., : self.nup, : self.nup]
1✔
470
        invAB_down = invAdown @ bop[..., self.nup :, : self.ndown]
1✔
471

472
        # ground state operator
473
        if op_squared:
1✔
474
            op_ground_up = btrace(invAB_up @ invAB_up)
1✔
475
            op_ground_down = btrace(invAB_down @ invAB_down)
1✔
476
        else:
477
            op_ground_up = btrace(invAB_up)
1✔
478
            op_ground_down = btrace(invAB_down)
1✔
479

480
        op_ground_up.unsqueeze_(-1)
1✔
481
        op_ground_down.unsqueeze_(-1)
1✔
482

483
        # store the kin terms we need
484
        op_out_up = op_ground_up.clone()
1✔
485
        op_out_down = op_ground_down.clone()
1✔
486

487
        # virtual orbital matrices spin up/down
488
        Avirt_up = mo[:, : self.nup, self.nup : self.index_max_orb_up]
1✔
489
        Avirt_down = mo[:, self.nup :, self.ndown : self.index_max_orb_down]
1✔
490

491
        # compute the products of invA and Btilde
492
        mat_exc_up = invAup @ Avirt_up
1✔
493
        mat_exc_down = invAdown @ Avirt_down
1✔
494

495
        # bop_up = bop[..., :self.nup, :self.index_max_orb_up]
496
        bop_occ_up = bop[..., : self.nup, : self.nup]
1✔
497
        bop_virt_up = bop[..., : self.nup, self.nup : self.index_max_orb_up]
1✔
498

499
        # bop_down = bop[:, self.nup:, :self.index_max_orb_down]
500
        bop_occ_down = bop[..., self.nup :, : self.ndown]
1✔
501
        bop_virt_down = bop[..., self.nup :, self.ndown : self.index_max_orb_down]
1✔
502

503
        Mup = invAup @ bop_virt_up - invAup @ bop_occ_up @ invAup @ Avirt_up
1✔
504
        Mdown = (
1✔
505
            invAdown @ bop_virt_down - invAdown @ bop_occ_down @ invAdown @ Avirt_down
506
        )
507

508

509
        # if we only want the normal value of the op and not its squared
510
        if not op_squared:
1✔
511

512
            # reshape the M matrices
513
            Mup = Mup.view(*Mup.shape[:-2], -1)
1✔
514
            Mdown = Mdown.view(*Mdown.shape[:-2], -1)
1✔
515

516
            if do_single:
1!
517
                # spin up
518
                op_sin_up = self.op_single(
1✔
519
                    op_ground_up,
520
                    mat_exc_up,
521
                    Mup,
522
                    self.exc_mask.index_unique_single_up,
523
                    nbatch,
524
                )
525

526
                # spin down
527
                op_sin_down = self.op_single(
1✔
528
                    op_ground_down,
529
                    mat_exc_down,
530
                    Mdown,
531
                    self.exc_mask.index_unique_single_down,
532
                    nbatch,
533
                )
534

535
                # store the terms we need
536
                op_out_up = torch.cat((op_out_up, op_sin_up), dim=-1)
1✔
537
                op_out_down = torch.cat((op_out_down, op_sin_down), dim=-1)
1✔
538

539
            if do_double:
1✔
540
                # spin up
541
                op_dbl_up = self.op_multiexcitation(
1✔
542
                    op_ground_up,
543
                    mat_exc_up,
544
                    Mup,
545
                    self.exc_mask.index_unique_double_up,
546
                    2,
547
                    nbatch,
548
                )
549

550
                # spin down
551
                op_dbl_down = self.op_multiexcitation(
1✔
552
                    op_ground_down,
553
                    mat_exc_down,
554
                    Mdown,
555
                    self.exc_mask.index_unique_double_down,
556
                    2,
557
                    nbatch,
558
                )
559

560
                # store the terms we need
561
                op_out_up = torch.cat((op_out_up, op_dbl_up), dim=-1)
1✔
562
                op_out_down = torch.cat((op_out_down, op_dbl_down), dim=-1)
1✔
563

564
            return op_out_up, op_out_down
1✔
565

566
        # if we want the squre of the operator
567
        # typically trace(ABAB)
568
        else:
569

570
            # compute A^-1 B M
571
            Yup = invAB_up @ Mup
1✔
572
            Ydown = invAB_down @ Mdown
1✔
573

574
            # reshape the M matrices
575
            Mup = Mup.view(*Mup.shape[:-2], -1)
1✔
576
            Mdown = Mdown.view(*Mdown.shape[:-2], -1)
1✔
577

578
            # reshape the Y matrices
579
            Yup = Yup.view(*Yup.shape[:-2], -1)
1✔
580
            Ydown = Ydown.view(*Ydown.shape[:-2], -1)
1✔
581

582
            if do_single:
1!
583
                # spin up
584
                op_sin_up = self.op_squared_single(
1✔
585
                    op_ground_up,
586
                    mat_exc_up,
587
                    Mup,
588
                    Yup,
589
                    self.exc_mask.index_unique_single_up,
590
                    nbatch,
591
                )
592

593
                # spin down
594
                op_sin_down = self.op_squared_single(
1✔
595
                    op_ground_down,
596
                    mat_exc_down,
597
                    Mdown,
598
                    Ydown,
599
                    self.exc_mask.index_unique_single_down,
600
                    nbatch,
601
                )
602

603
                # store the terms we need
604
                op_out_up = torch.cat((op_out_up, op_sin_up), dim=-1)
1✔
605
                op_out_down = torch.cat((op_out_down, op_sin_down), dim=-1)
1✔
606

607
            if do_double:
1!
608

609
                # spin up values
610
                op_dbl_up = self.op_squared_multiexcitation(
×
611
                    op_ground_up,
612
                    mat_exc_up,
613
                    Mup,
614
                    Yup,
615
                    self.exc_mask.index_unique_double_down,
616
                    2,
617
                    nbatch,
618
                )
619

620
                # spin down values
621
                op_dbl_down = self.op_squared_multiexcitation(
×
622
                    op_ground_down,
623
                    mat_exc_down,
624
                    Mdown,
625
                    Ydown,
626
                    self.exc_mask.index_unique_double_down,
627
                    2,
628
                    nbatch,
629
                )
630

631
                # store the terms we need
632
                op_out_up = torch.cat((op_out_up, op_dbl_up), dim=-1)
×
633
                op_out_down = torch.cat((op_out_down, op_dbl_down), dim=-1)
×
634

635
            return op_out_up, op_out_down
1✔
636

637
    @staticmethod
1✔
638
    def op_single(
1✔
639
        baseterm: torch.Tensor,
640
        mat_exc: torch.Tensor,
641
        M: torch.Tensor,
642
        index: List[int],
643
        nbatch: int,
644
    ) -> torch.Tensor:
645
        r"""Computes the operator values for single excitation
646

647
        .. math::
648
            Tr( \bar{A}^{-1} \bar{B}) = Tr(A^{-1} B) + Tr( T M )
649
            T = P ( A^{-1} \bar{A})^{-1} P
650
            M = A^{-1}\bar{B} - A^{-1}BA^{-1}\bar{A}
651

652
        Args:
653
            baseterm (torch.Tensor): trace(A B)
654
            mat_exc (torch.Tensor): invA @ Abar
655
            M (torch.Tensor): invA Bbar - inv A B inv A Abar
656
            index (List[int]): list of index of the excitations
657
            nbatch (int): batch size
658

659
        Returns:
660
            torch.Tensor: trace(T M) + trace(A B)
661
        """
662

663
        # compute the values of T
664
        T = 1.0 / mat_exc.view(nbatch, -1)[:, index]
1✔
665

666
        # computes trace(T M)
667
        op_vals = T * M[..., index]
1✔
668

669
        # add the base terms
670
        op_vals += baseterm
1✔
671

672
        return op_vals
1✔
673

674
    @staticmethod
1✔
675
    def op_multiexcitation(
1✔
676
        baseterm: torch.Tensor,
677
        mat_exc: torch.Tensor,
678
        M: torch.Tensor,
679
        index: List[int],
680
        size: int,
681
        nbatch: int
682
    ) -> torch.Tensor:
683
        r"""Computes the operator values for single excitation
684

685
        .. math::
686
            Tr( \bar{A}^{-1} \bar{B}) = Tr(A^{-1} B) + Tr( T M )
687
            T = P ( A^{-1} \bar{A})^{-1} P
688
            M = A^{-1}\bar{B} - A^{-1}BA^{-1}\bar{A}
689

690
        Args:
691
            baseterm (torch.Tensor): trace(A B)
692
            mat_exc (torch.Tensor): invA @ Abar
693
            M (torch.Tensor): invA Bbar - inv A B inv A Abar
694
            index (List[int]): list of index of the excitations
695
            size (int): number of excitation
696
            nbatch (int): batch size
697
        Returns:
698
            torch.Tensor: trace(A B) + trace(T M)
699
        """
700

701
        # get the values of the excitation matrix invA Abar
702
        T = mat_exc.view(nbatch, -1)[:, index]
1✔
703

704
        # get the shapes of the size x size matrices
705
        _ext_shape = (*T.shape[:-1], -1, size, size)
1✔
706
        _m_shape = (*M.shape[:-1], -1, size, size)
1✔
707

708
        # computes the inverse of invA Abar
709
        T = torch.inverse(T.view(_ext_shape))
1✔
710

711
        # computes T @ M (after reshaping M as size x size matrices)
712
        # THIS IS SURPRSINGLY THE COMPUTATIONAL BOTTLENECK
713
        m_tmp = M[..., index].view(_m_shape)
1✔
714
        op_vals =  T @ m_tmp
1✔
715

716
        # compute the trace
717
        op_vals = btrace(op_vals)
1✔
718

719
        # add the base term
720
        op_vals += baseterm
1✔
721

722
        return op_vals
1✔
723

724
    @staticmethod
1✔
725
    def op_squared_single(
1✔
726
        baseterm: torch.Tensor,
727
        mat_exc: torch.Tensor,
728
        M: torch.Tensor,
729
        Y: torch.Tensor,
730
        index: List[int],
731
        nbatch: int
732
    ) -> torch.Tensor:
733
        r"""Computes the operator squared for single excitation
734

735
        .. math::
736
            Tr( (\bar{A}^{-1} \bar{B})^2) = Tr((A^{-1} B)^2) + Tr( (T M)^2 ) + 2 Tr(T Y)
737
            T = P ( A^{-1} \bar{A})^{-1} P -> mat_exc in the code
738
            M = A^{-1}\bar{B} - A^{-1}BA^{-1}\bar{A}
739
            Y = A^{-1} B M
740

741
        Args:
742
            baseterm (torch.Tensor): trace(A B A B)
743
            mat_exc (torch.Tensor): invA @ Abar
744
            M (torch.Tensor): invA Bbar - inv A B inv A Abar
745
            Y (torch.Tensor): invA B M
746
            index (List[int]): list of index of the excitations
747
            nbatch (int): batch size
748
        Returns:
749
            torch.Tensor: trace((A^{-1} B)^2) + trace((T M)^2) + 2 trace(T Y)
750
        """
751

752
        # get the values of the inverse excitation matrix
753
        T = 1.0 / (mat_exc.view(nbatch, -1)[:, index])
1✔
754

755
        # compute  trace(( T M )^2)
756
        tmp = T * M[..., index]
1✔
757
        op_vals = tmp * tmp
1✔
758

759
        # trace(T Y)
760
        tmp = T * Y[..., index]
1✔
761
        op_vals += 2 * tmp
1✔
762

763
        # add the base term
764
        op_vals += baseterm
1✔
765

766
        return op_vals
1✔
767

768
    @staticmethod
1✔
769
    def op_squared_multiexcitation(
1✔
770
        baseterm: torch.tensor,
771
        mat_exc: torch.tensor,
772
        M: torch.tensor,
773
        Y: torch.tensor,
774
        index: List[int],
775
        size: int,
776
        nbatch: int
777
    ) -> torch.tensor:
778
        r"""Computes the operator squared for multiple excitation
779

780
        .. math::
781
            Tr( (\bar{A}^{-1} \bar{B})^2) = Tr((A^{-1} B)^2) + Tr( (T M)^2 ) + 2 Tr(T Y)
782
            T = P ( A^{-1} \bar{A})^{-1} P -> mat_exc in the code
783
            M = A^{-1}\bar{B} - A^{-1}BA^{-1}\bar{A}
784
            Y = A^{-1} B M
785

786
        Args:
787
            baseterm (torch.tensor): trace(A B A B)
788
            mat_exc (torch.tensor): invA @ Abar
789
            M (torch.tensor): invA Bbar - inv A B inv A Abar
790
            Y (torch.tensor): invA B M
791
            index (List[int]): list of index of the excitations
792
            nbatch (int): batch size
793
            size (int): number of excitation
794
        Returns:
795
            torch.tensor: trace((A^{-1} B)^2) + trace((T M)^2) + 2 trace(T Y)
796
        """
797

798
        # get the values of the excitation matrix invA Abar
799
        T = mat_exc.view(nbatch, -1)[:, index]
×
800

801
        # get the shape as a series of size x size matrices
802
        _ext_shape = (*T.shape[:-1], -1, size, size)
×
803
        _m_shape = (*M.shape[:-1], -1, size, size)
×
804
        _y_shape = (*Y.shape[:-1], -1, size, size)
×
805

806
        # reshape T and take the inverse of the matrices
807
        T = torch.inverse(T.view(_ext_shape))
×
808

809
        # compute  trace(( T M )^2)
810
        tmp = T @ (M[..., index]).view(_m_shape)
×
811

812
        # take the trace of that and add to base value
813
        tmp = btrace(tmp @ tmp)
×
814
        op_vals = tmp
×
815

816
        # compute trace( T Y )
817
        tmp = T @ (Y[..., index]).view(_y_shape)
×
818
        tmp = btrace(tmp)
×
819
        op_vals += 2 * tmp
×
820

821
        # add the base term
822
        op_vals += baseterm
×
823

824
        return op_vals
×
825

826

827
    def compute_inverse_occupied_mo_matrix(
1✔
828
            self,
829
            mo: torch.Tensor
830
            ) -> Union[Tuple[torch.Tensor, torch.Tensor], None]:
831
        """precompute the inverse of the occupied mo matrix
832

833
        Args:
834
            mo (torch.tensor): matrix of the molecular orbitals
835

836
        Returns:
837
            tuple: inverse of the spin up/down mo matrices
838
        """
839
        # return None if we use the explicit calculation of all dets
840
        if self.config_method.startswith("cas("):
1!
NEW
841
            return None
×
842

843
        if self.use_explicit_operator:
1!
844
            return None
×
845

846
        # return inverse of the mo matrices
847
        return (torch.inverse(mo[:, : self.nup, : self.nup]),
1✔
848
                torch.inverse(mo[:, self.nup :, : self.ndown]))
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc