• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

NLESC-JCER / QMCTorch / 14968442546

12 May 2025 09:14AM UTC coverage: 83.955%. First build
14968442546

Pull #187

github

web-flow
Merge a67f074c6 into 20fe7ebf9
Pull Request #187: Clean up Main

951 of 1326 branches covered (71.72%)

Branch coverage included in aggregate %.

287 of 362 new or added lines in 47 files covered. (79.28%)

4522 of 5193 relevant lines covered (87.08%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.74
/qmctorch/wavefunction/orbitals/backflow/backflow_transformation.py
1
import torch
1✔
2
from torch import nn
1✔
3
from typing import Dict, Optional, Callable, Tuple
1✔
4
from ....scf import Molecule
1✔
5
from .kernels.backflow_kernel_base import BackFlowKernelBase
1✔
6
from ...jastrows.distance.electron_electron_distance import ElectronElectronDistance
1✔
7

8

9
class BackFlowTransformation(nn.Module):
1✔
10
    def __init__(
1✔
11
        self,
12
        mol: Molecule,
13
        backflow_kernel: BackFlowKernelBase,
14
        backflow_kernel_kwargs: Optional[Dict] = {},
15
        cuda: Optional[bool] = False,
16
    ):
17
        """Transform the electorn coordinates into backflow coordinates.
18
        see : Orbital-dependent backflow wave functions for real-space quantum Monte Carlo
19
        https://arxiv.org/abs/1910.07167
20

21
        .. math:
22
            \\bold{q}_i = \\bold{r}_i + \\sum_{j\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j)
23
        """
24
        super().__init__()
1✔
25
        self.nao = mol.basis.nao
1✔
26
        self.nelec = mol.nelec
1✔
27
        self.ndim = 3
1✔
28

29
        self.backflow_kernel = backflow_kernel(mol, cuda, **backflow_kernel_kwargs)
1✔
30

31
        self.edist = ElectronElectronDistance(mol.nelec)
1✔
32

33
        self.cuda = cuda
1✔
34
        self.device = torch.device("cpu")
1✔
35
        if self.cuda:
1!
36
            self.device = torch.device("cuda")
×
37

38
    def forward(self,
1✔
39
                pos: torch.Tensor,
40
                derivative: Optional[int] = 0
41
                ) -> torch.Tensor:
42
        if derivative == 0:
1✔
43
            return self._get_backflow(pos)
1✔
44

45
        elif derivative == 1:
1✔
46
            return self._get_backflow_derivative(pos)
1✔
47

48
        elif derivative == 2:
1!
49
            return self._get_backflow_second_derivative(pos)
1✔
50

51
        else:
52
            raise ValueError(
×
53
                "derivative of the backflow transformation must be 0, 1 or 2"
54
            )
55

56
    def _get_backflow(self,
1✔
57
                      pos: torch.Tensor
58
                      ) -> torch.Tensor:
59
        """Computes the backflow transformation
60

61
        .. math:
62
            \\bold{q}_i = \\bold{r}_i + \\sum_{j\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j)
63

64
        Args:
65
            pos(torch.tensor): original positions Nbatch x[Nelec*Ndim]
66

67
        Returns:
68
            torch.tensor: transformed positions Nbatch x[Nelec*Ndim]
69
        """
70
        # compute the difference
71
        # Nbatch x Nelec x Nelec x 3
72
        delta_ee = self.edist.get_difference(pos.reshape(-1, self.nelec, self.ndim))
1✔
73

74
        # compute the backflow function
75
        # Nbatch x Nelec x Nelec
76
        bf_kernel = self.backflow_kernel(self.edist(pos))
1✔
77

78
        # update pos
79
        pos = pos.reshape(-1, self.nelec, self.ndim) + (
1✔
80
            bf_kernel.unsqueeze(-1) * delta_ee
81
        ).sum(2)
82

83
        return pos.reshape(-1, self.nelec * self.ndim)
1✔
84

85
    def _get_backflow_derivative(self, pos: torch.Tensor) -> torch.Tensor:
1✔
86
        r"""Computes the derivative of the backflow transformation
87
           wrt the original positions of the electrons
88

89
        .. math::
90
            \\bold{q}_i = \\bold{r}_i + \\sum_{j\\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j)
91

92
        .. math::
93
            \\frac{d q_i}{d x_k} = \\delta_{ik}(1 + \\sum_{j\\neq i} \\frac{d \\eta(r_ij)}{d x_i}(x_i-x_j) + \\eta(r_ij)) +
94
                                   \\delta_{i\\neq k}(-\\frac{d \\eta(r_ik)}{d x_k}(x_i-x_k) - \\eta(r_ik))
95

96
        Args:
97
            pos(torch.tensor): orginal positions of the electrons Nbatch x[Nelec*Ndim]
98

99
        Returns:
100
            torch.tensor: d q_{i}/d x_k with:
101
                          q_{i} bf position of elec i
102
                          x_k original coordinate of the kth elec
103
                          Nelec x  Nbatch x Nelec x Norb x Ndim
104
        """
105

106
        # ee dist matrix : Nbatch x  Nelec x Nelec
107
        ree = self.edist(pos)
1✔
108
        nbatch, nelec, _ = ree.shape
1✔
109

110
        # derivative ee dist matrix : Nbatch x 3 x Nelec x Nelec
111
        # dr_ij / dx_i = - dr_ij / dx_j
112
        dree = self.edist(pos, derivative=1)
1✔
113

114
        # difference between elec pos
115
        # Nbatch, 3, Nelec, Nelec
116
        delta_ee = self.edist.get_difference(pos.reshape(nbatch, nelec, 3)).permute(
1✔
117
            0, 3, 1, 2
118
        )
119

120
        # backflow kernel : Nbatch x 1 x Nelec x Nelec
121
        bf = self.backflow_kernel(ree)
1✔
122

123
        # (d eta(r_ij) / d r_ij) (d r_ij/d beta_i)
124
        # derivative of the back flow kernel : Nbatch x 3 x Nelec x Nelec
125
        dbf = self.backflow_kernel(ree, derivative=1).unsqueeze(1)
1✔
126
        dbf = dbf * dree
1✔
127

128
        # (d eta(r_ij) / d beta_i) (alpha_i - alpha_j)
129
        # Nbatch x 3 x 3 x Nelec x Nelec
130
        dbf_delta_ee = dbf.unsqueeze(1) * delta_ee.unsqueeze(2)
1✔
131

132
        # compute the delta_ij * (1 + sum k \neq i eta(rik))
133
        # Nbatch x Nelec x Nelec (diagonal matrix)
134
        delta_ij_bf = torch.diag_embed(1 + bf.sum(-1), dim1=-1, dim2=-2)
1✔
135

136
        # eye 3x3 in 1x3x3x1x1
137
        eye_mat = torch.eye(3, 3).view(1, 3, 3, 1, 1).to(self.device)
1✔
138

139
        # compute the delta_ab * delta_ij * (1 + sum k \neq i eta(rik))
140
        # Nbatch x Ndim x Ndim x Nelec x Nelec (diagonal matrix)
141
        delta_ab_delta_ij_bf = eye_mat * delta_ij_bf.view(nbatch, 1, 1, nelec, nelec)
1✔
142

143
        # compute sum_k df(r_ik)/dbeta_i (alpha_i - alpha_k)
144
        # Nbatch x Ndim x Ndim x Nelec x Nelec
145
        delta_ij_sum = torch.diag_embed(dbf_delta_ee.sum(-1), dim1=-1, dim2=-2)
1✔
146

147
        # compute delta_ab * f(rij)
148
        delta_ab_bf = eye_mat * bf.view(nbatch, 1, 1, nelec, nelec)
1✔
149

150
        # return Nbatch x Ndim(alpha) x Ndim(beta) x Nelec(i) x Nelec(j)
151
        # nbatch d alpha_i / d beta_j
152
        out = delta_ab_delta_ij_bf + delta_ij_sum - dbf_delta_ee - delta_ab_bf
1✔
153

154
        return out.unsqueeze(-1)
1✔
155

156
    def _get_backflow_second_derivative(self, pos: torch.Tensor) -> torch.Tensor:
1✔
157
        r"""Computes the second derivative of the backflow transformation
158
           wrt the original positions of the electrons
159

160
        .. math::
161
            \\bold{q}_i = \\bold{r}_i + \\sum_{j\\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j)
162

163
        .. math::
164
            \\frac{d q_i}{d x_k} = \\delta_{ik}(1 + \\sum_{j\\neqi} \\frac{d \\eta(r_ij)}{d x_i} + \\eta(r_ij)) +
165
                                   \\delta_{i\\neq k}(-\\frac{d \\eta(r_ik)}{d x_k} - \\eta(r_ik))
166

167
        .. math::
168
            \\frac{d ^ 2 q_i}{d x_k ^ 2} = \\delta_{ik}(\\sum_{j\\neqi} \\frac{d ^ 2 \\eta(r_ij)}{d x_i ^ 2} + 2 \\frac{d \\eta(r_ij)}{d x_i}) +
169
                                       - \\delta_{i\\neq k}(\\frac{d ^ 2 \\eta(r_ik)}{d x_k ^ 2} + \\frac{d \\eta(r_ik)}{d x_k})
170

171
        Args:
172
            pos(torch.tensor): orginal positions of the electrons Nbatch x[Nelec*Ndim]
173

174
        Returns:
175
            torch.tensor: d q_{i}/d x_k with:
176
                          q_{i} bf position of elec i
177
                          x_k original coordinate of the kth elec
178
                          Nelec x  Nbatch x Nelec x Norb x Ndim
179
        """
180
        # ee dist matrix :
181
        # Nbatch x  Nelec x Nelec
182
        ree = self.edist(pos)
1✔
183
        nbatch, nelec, _ = ree.shape
1✔
184

185
        # difference between elec pos
186
        # Nbatch, 3, Nelec, Nelec
187
        delta_ee = self.edist.get_difference(pos.reshape(nbatch, nelec, 3)).permute(
1✔
188
            0, 3, 1, 2
189
        )
190

191
        # derivative ee dist matrix  d r_{ij} / d x_i
192
        # Nbatch x 3 x Nelec x Nelec
193
        dree = self.edist(pos, derivative=1)
1✔
194

195
        # derivative ee dist matrix :  d2 r_{ij} / d2 x_i
196
        # Nbatch x 3 x Nelec x Nelec
197
        d2ree = self.edist(pos, derivative=2)
1✔
198

199
        # derivative of the back flow kernel : d eta(r_ij)/d r_ij
200
        # Nbatch x 1 x Nelec x Nelec
201
        dbf = self.backflow_kernel(ree, derivative=1).unsqueeze(1)
1✔
202

203
        # second derivative of the back flow kernel : d2 eta(r_ij)/d2 r_ij
204
        # Nbatch x 1 x Nelec x Nelec
205
        d2bf = self.backflow_kernel(ree, derivative=2).unsqueeze(1)
1✔
206

207
        # (d^2 eta(r_ij) / d r_ij^2) (d r_ij/d x_i)^2
208
        # + (d eta(r_ij) / d r_ij) (d^2 r_ij/d x_i^2)
209
        # Nbatch x 3 x Nelec x Nelec
210
        d2bf = (d2bf * dree * dree) + (dbf * d2ree)
1✔
211

212
        # (d eta(r_ij) / d r_ij) (d r_ij/d x_i)
213
        # Nbatch x 3 x Nelec x Nelec
214
        dbf = dbf * dree
1✔
215

216
        # eye matrix in dim x dim
217
        eye_mat = torch.eye(3, 3).reshape(1, 3, 3, 1, 1).to(self.device)
1✔
218

219
        # compute delta_ij delta_ab 2 sum_k dbf(ik) / dbeta_i
220
        term1 = (
1✔
221
            2
222
            * eye_mat
223
            * torch.diag_embed(dbf.sum(-1), dim1=-1, dim2=-2).reshape(
224
                nbatch, 1, 3, nelec, nelec
225
            )
226
        )
227

228
        # (d2 eta(r_ij) / d2 beta_i) (alpha_i - alpha_j)
229
        # Nbatch x 3 x 3 x Nelec x Nelec
230
        d2bf_delta_ee = d2bf.unsqueeze(1) * delta_ee.unsqueeze(2)
1✔
231

232
        # compute sum_k d2f(r_ik)/d2beta_i (alpha_i - alpha_k)
233
        # Nbatch x Ndim x Ndim x Nelec x Nelec
234
        term2 = torch.diag_embed(d2bf_delta_ee.sum(-1), dim1=-1, dim2=-2)
1✔
235

236
        # compute delta_ab * df(rij)/dbeta_j
237
        term3 = 2 * eye_mat * dbf.reshape(nbatch, 1, 3, nelec, nelec)
1✔
238

239
        # return Nbatch x Ndim(alpha) x Ndim(beta) x Nelec(i) x Nelec(j)
240
        # nbatch d2 alpha_i / d2 beta_j
241
        out = term1 + term2 + d2bf_delta_ee + term3
1✔
242

243
        return out.unsqueeze(-1)
1✔
244

245
    def fit_kernel(self, lambda_func: Callable,
1✔
246
                   xmin: float = 0.01, xmax: float = 1.0, npts: int = 100,
247
                   lr: float = 0.001, num_epochs: int = 1000
248
        ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
249

250
        """
251
        Fit the backflow kernel to a given function.
252

253
        Args:
254
            lambda_func (Callable): function to be fit
255
            xmin (float): minimum x value
256
            xmax (float): maximum x value
257
            npts (int): number of points to sample in the interval [xmin, xmax]
258
            lr (float): learning rate
259
            num_epochs (int): number of epochs to run the optimization
260

261
        Returns:
262
            xpts (torch.tensor): x values used for fitting
263
            ground_truth (torch.tensor): y values of the given function
264
            fit_values (torch.tensor): y values of the fit function
265
        """
NEW
266
        xpts = torch.linspace(xmin, xmax, npts)
×
NEW
267
        ground_truth = lambda_func(xpts)
×
268

NEW
269
        criterion = torch.nn.MSELoss()
×
NEW
270
        optimizer = torch.optim.Adam(self.backflow_kernel.parameters(), lr=lr)
×
271

NEW
272
        for epoch in range(num_epochs):
×
NEW
273
            running_loss = 0.0
×
NEW
274
            optimizer.zero_grad()
×
NEW
275
            outputs = self.backflow_kernel(xpts.unsqueeze(1))
×
NEW
276
            loss = criterion(outputs.squeeze(), ground_truth)
×
NEW
277
            loss.backward()
×
NEW
278
            optimizer.step()
×
NEW
279
            running_loss += loss.item()
×
280

NEW
281
            if epoch % 100 == 0:
×
NEW
282
                print("Epoch {}: Loss = {}".format(epoch, loss.detach().numpy()))
×
283

NEW
284
        fit_values = self.backflow_kernel(xpts.unsqueeze(1)).squeeze()
×
NEW
285
        return xpts, ground_truth,  fit_values
×
286

287

288
    def __repr__(self):
1✔
289
        """representation of the backflow transformation"""
290
        return self.backflow_kernel.__class__.__name__
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc