• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

NLESC-JCER / QMCTorch / 14995306316

13 May 2025 11:21AM UTC coverage: 83.844%. Remained the same
14995306316

Pull #194

github

web-flow
Merge d8ac4723f into dd0c5094e
Pull Request #194: apply black to code base

955 of 1334 branches covered (71.59%)

Branch coverage included in aggregate %.

240 of 268 new or added lines in 50 files covered. (89.55%)

5 existing lines in 4 files now uncovered.

4515 of 5190 relevant lines covered (86.99%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.74
/qmctorch/wavefunction/orbitals/backflow/backflow_transformation.py
1
import torch
1✔
2
from torch import nn
1✔
3
from typing import Dict, Optional, Callable, Tuple
1✔
4
from ....scf import Molecule
1✔
5
from .kernels.backflow_kernel_base import BackFlowKernelBase
1✔
6
from ...jastrows.distance.electron_electron_distance import ElectronElectronDistance
1✔
7

8

9
class BackFlowTransformation(nn.Module):
1✔
10
    def __init__(
1✔
11
        self,
12
        mol: Molecule,
13
        backflow_kernel: BackFlowKernelBase,
14
        backflow_kernel_kwargs: Optional[Dict] = {},
15
        cuda: Optional[bool] = False,
16
    ):
17
        """Transform the electorn coordinates into backflow coordinates.
18
        see : Orbital-dependent backflow wave functions for real-space quantum Monte Carlo
19
        https://arxiv.org/abs/1910.07167
20

21
        .. math:
22
            \\bold{q}_i = \\bold{r}_i + \\sum_{j\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j)
23
        """
24
        super().__init__()
1✔
25
        self.nao = mol.basis.nao
1✔
26
        self.nelec = mol.nelec
1✔
27
        self.ndim = 3
1✔
28

29
        self.backflow_kernel = backflow_kernel(mol, cuda, **backflow_kernel_kwargs)
1✔
30

31
        self.edist = ElectronElectronDistance(mol.nelec)
1✔
32

33
        self.cuda = cuda
1✔
34
        self.device = torch.device("cpu")
1✔
35
        if self.cuda:
1!
36
            self.device = torch.device("cuda")
×
37

38
    def forward(self, pos: torch.Tensor, derivative: Optional[int] = 0) -> torch.Tensor:
1✔
39
        if derivative == 0:
1✔
40
            return self._get_backflow(pos)
1✔
41

42
        elif derivative == 1:
1✔
43
            return self._get_backflow_derivative(pos)
1✔
44

45
        elif derivative == 2:
1!
46
            return self._get_backflow_second_derivative(pos)
1✔
47

48
        else:
49
            raise ValueError(
×
50
                "derivative of the backflow transformation must be 0, 1 or 2"
51
            )
52

53
    def _get_backflow(self, pos: torch.Tensor) -> torch.Tensor:
1✔
54
        """Computes the backflow transformation
55

56
        .. math:
57
            \\bold{q}_i = \\bold{r}_i + \\sum_{j\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j)
58

59
        Args:
60
            pos(torch.tensor): original positions Nbatch x[Nelec*Ndim]
61

62
        Returns:
63
            torch.tensor: transformed positions Nbatch x[Nelec*Ndim]
64
        """
65
        # compute the difference
66
        # Nbatch x Nelec x Nelec x 3
67
        delta_ee = self.edist.get_difference(pos.reshape(-1, self.nelec, self.ndim))
1✔
68

69
        # compute the backflow function
70
        # Nbatch x Nelec x Nelec
71
        bf_kernel = self.backflow_kernel(self.edist(pos))
1✔
72

73
        # update pos
74
        pos = pos.reshape(-1, self.nelec, self.ndim) + (
1✔
75
            bf_kernel.unsqueeze(-1) * delta_ee
76
        ).sum(2)
77

78
        return pos.reshape(-1, self.nelec * self.ndim)
1✔
79

80
    def _get_backflow_derivative(self, pos: torch.Tensor) -> torch.Tensor:
1✔
81
        r"""Computes the derivative of the backflow transformation
82
           wrt the original positions of the electrons
83

84
        .. math::
85
            \\bold{q}_i = \\bold{r}_i + \\sum_{j\\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j)
86

87
        .. math::
88
            \\frac{d q_i}{d x_k} = \\delta_{ik}(1 + \\sum_{j\\neq i} \\frac{d \\eta(r_ij)}{d x_i}(x_i-x_j) + \\eta(r_ij)) +
89
                                   \\delta_{i\\neq k}(-\\frac{d \\eta(r_ik)}{d x_k}(x_i-x_k) - \\eta(r_ik))
90

91
        Args:
92
            pos(torch.tensor): orginal positions of the electrons Nbatch x[Nelec*Ndim]
93

94
        Returns:
95
            torch.tensor: d q_{i}/d x_k with:
96
                          q_{i} bf position of elec i
97
                          x_k original coordinate of the kth elec
98
                          Nelec x  Nbatch x Nelec x Norb x Ndim
99
        """
100

101
        # ee dist matrix : Nbatch x  Nelec x Nelec
102
        ree = self.edist(pos)
1✔
103
        nbatch, nelec, _ = ree.shape
1✔
104

105
        # derivative ee dist matrix : Nbatch x 3 x Nelec x Nelec
106
        # dr_ij / dx_i = - dr_ij / dx_j
107
        dree = self.edist(pos, derivative=1)
1✔
108

109
        # difference between elec pos
110
        # Nbatch, 3, Nelec, Nelec
111
        delta_ee = self.edist.get_difference(pos.reshape(nbatch, nelec, 3)).permute(
1✔
112
            0, 3, 1, 2
113
        )
114

115
        # backflow kernel : Nbatch x 1 x Nelec x Nelec
116
        bf = self.backflow_kernel(ree)
1✔
117

118
        # (d eta(r_ij) / d r_ij) (d r_ij/d beta_i)
119
        # derivative of the back flow kernel : Nbatch x 3 x Nelec x Nelec
120
        dbf = self.backflow_kernel(ree, derivative=1).unsqueeze(1)
1✔
121
        dbf = dbf * dree
1✔
122

123
        # (d eta(r_ij) / d beta_i) (alpha_i - alpha_j)
124
        # Nbatch x 3 x 3 x Nelec x Nelec
125
        dbf_delta_ee = dbf.unsqueeze(1) * delta_ee.unsqueeze(2)
1✔
126

127
        # compute the delta_ij * (1 + sum k \neq i eta(rik))
128
        # Nbatch x Nelec x Nelec (diagonal matrix)
129
        delta_ij_bf = torch.diag_embed(1 + bf.sum(-1), dim1=-1, dim2=-2)
1✔
130

131
        # eye 3x3 in 1x3x3x1x1
132
        eye_mat = torch.eye(3, 3).view(1, 3, 3, 1, 1).to(self.device)
1✔
133

134
        # compute the delta_ab * delta_ij * (1 + sum k \neq i eta(rik))
135
        # Nbatch x Ndim x Ndim x Nelec x Nelec (diagonal matrix)
136
        delta_ab_delta_ij_bf = eye_mat * delta_ij_bf.view(nbatch, 1, 1, nelec, nelec)
1✔
137

138
        # compute sum_k df(r_ik)/dbeta_i (alpha_i - alpha_k)
139
        # Nbatch x Ndim x Ndim x Nelec x Nelec
140
        delta_ij_sum = torch.diag_embed(dbf_delta_ee.sum(-1), dim1=-1, dim2=-2)
1✔
141

142
        # compute delta_ab * f(rij)
143
        delta_ab_bf = eye_mat * bf.view(nbatch, 1, 1, nelec, nelec)
1✔
144

145
        # return Nbatch x Ndim(alpha) x Ndim(beta) x Nelec(i) x Nelec(j)
146
        # nbatch d alpha_i / d beta_j
147
        out = delta_ab_delta_ij_bf + delta_ij_sum - dbf_delta_ee - delta_ab_bf
1✔
148

149
        return out.unsqueeze(-1)
1✔
150

151
    def _get_backflow_second_derivative(self, pos: torch.Tensor) -> torch.Tensor:
1✔
152
        r"""Computes the second derivative of the backflow transformation
153
           wrt the original positions of the electrons
154

155
        .. math::
156
            \\bold{q}_i = \\bold{r}_i + \\sum_{j\\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j)
157

158
        .. math::
159
            \\frac{d q_i}{d x_k} = \\delta_{ik}(1 + \\sum_{j\\neqi} \\frac{d \\eta(r_ij)}{d x_i} + \\eta(r_ij)) +
160
                                   \\delta_{i\\neq k}(-\\frac{d \\eta(r_ik)}{d x_k} - \\eta(r_ik))
161

162
        .. math::
163
            \\frac{d ^ 2 q_i}{d x_k ^ 2} = \\delta_{ik}(\\sum_{j\\neqi} \\frac{d ^ 2 \\eta(r_ij)}{d x_i ^ 2} + 2 \\frac{d \\eta(r_ij)}{d x_i}) +
164
                                       - \\delta_{i\\neq k}(\\frac{d ^ 2 \\eta(r_ik)}{d x_k ^ 2} + \\frac{d \\eta(r_ik)}{d x_k})
165

166
        Args:
167
            pos(torch.tensor): orginal positions of the electrons Nbatch x[Nelec*Ndim]
168

169
        Returns:
170
            torch.tensor: d q_{i}/d x_k with:
171
                          q_{i} bf position of elec i
172
                          x_k original coordinate of the kth elec
173
                          Nelec x  Nbatch x Nelec x Norb x Ndim
174
        """
175
        # ee dist matrix :
176
        # Nbatch x  Nelec x Nelec
177
        ree = self.edist(pos)
1✔
178
        nbatch, nelec, _ = ree.shape
1✔
179

180
        # difference between elec pos
181
        # Nbatch, 3, Nelec, Nelec
182
        delta_ee = self.edist.get_difference(pos.reshape(nbatch, nelec, 3)).permute(
1✔
183
            0, 3, 1, 2
184
        )
185

186
        # derivative ee dist matrix  d r_{ij} / d x_i
187
        # Nbatch x 3 x Nelec x Nelec
188
        dree = self.edist(pos, derivative=1)
1✔
189

190
        # derivative ee dist matrix :  d2 r_{ij} / d2 x_i
191
        # Nbatch x 3 x Nelec x Nelec
192
        d2ree = self.edist(pos, derivative=2)
1✔
193

194
        # derivative of the back flow kernel : d eta(r_ij)/d r_ij
195
        # Nbatch x 1 x Nelec x Nelec
196
        dbf = self.backflow_kernel(ree, derivative=1).unsqueeze(1)
1✔
197

198
        # second derivative of the back flow kernel : d2 eta(r_ij)/d2 r_ij
199
        # Nbatch x 1 x Nelec x Nelec
200
        d2bf = self.backflow_kernel(ree, derivative=2).unsqueeze(1)
1✔
201

202
        # (d^2 eta(r_ij) / d r_ij^2) (d r_ij/d x_i)^2
203
        # + (d eta(r_ij) / d r_ij) (d^2 r_ij/d x_i^2)
204
        # Nbatch x 3 x Nelec x Nelec
205
        d2bf = (d2bf * dree * dree) + (dbf * d2ree)
1✔
206

207
        # (d eta(r_ij) / d r_ij) (d r_ij/d x_i)
208
        # Nbatch x 3 x Nelec x Nelec
209
        dbf = dbf * dree
1✔
210

211
        # eye matrix in dim x dim
212
        eye_mat = torch.eye(3, 3).reshape(1, 3, 3, 1, 1).to(self.device)
1✔
213

214
        # compute delta_ij delta_ab 2 sum_k dbf(ik) / dbeta_i
215
        term1 = (
1✔
216
            2
217
            * eye_mat
218
            * torch.diag_embed(dbf.sum(-1), dim1=-1, dim2=-2).reshape(
219
                nbatch, 1, 3, nelec, nelec
220
            )
221
        )
222

223
        # (d2 eta(r_ij) / d2 beta_i) (alpha_i - alpha_j)
224
        # Nbatch x 3 x 3 x Nelec x Nelec
225
        d2bf_delta_ee = d2bf.unsqueeze(1) * delta_ee.unsqueeze(2)
1✔
226

227
        # compute sum_k d2f(r_ik)/d2beta_i (alpha_i - alpha_k)
228
        # Nbatch x Ndim x Ndim x Nelec x Nelec
229
        term2 = torch.diag_embed(d2bf_delta_ee.sum(-1), dim1=-1, dim2=-2)
1✔
230

231
        # compute delta_ab * df(rij)/dbeta_j
232
        term3 = 2 * eye_mat * dbf.reshape(nbatch, 1, 3, nelec, nelec)
1✔
233

234
        # return Nbatch x Ndim(alpha) x Ndim(beta) x Nelec(i) x Nelec(j)
235
        # nbatch d2 alpha_i / d2 beta_j
236
        out = term1 + term2 + d2bf_delta_ee + term3
1✔
237

238
        return out.unsqueeze(-1)
1✔
239

240
    def fit_kernel(
1✔
241
        self,
242
        lambda_func: Callable,
243
        xmin: float = 0.01,
244
        xmax: float = 1.0,
245
        npts: int = 100,
246
        lr: float = 0.001,
247
        num_epochs: int = 1000,
248
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
249
        """
250
        Fit the backflow kernel to a given function.
251

252
        Args:
253
            lambda_func (Callable): function to be fit
254
            xmin (float): minimum x value
255
            xmax (float): maximum x value
256
            npts (int): number of points to sample in the interval [xmin, xmax]
257
            lr (float): learning rate
258
            num_epochs (int): number of epochs to run the optimization
259

260
        Returns:
261
            xpts (torch.tensor): x values used for fitting
262
            ground_truth (torch.tensor): y values of the given function
263
            fit_values (torch.tensor): y values of the fit function
264
        """
265
        xpts = torch.linspace(xmin, xmax, npts)
×
266
        ground_truth = lambda_func(xpts)
×
267

268
        criterion = torch.nn.MSELoss()
×
269
        optimizer = torch.optim.Adam(self.backflow_kernel.parameters(), lr=lr)
×
270

271
        for epoch in range(num_epochs):
×
272
            running_loss = 0.0
×
273
            optimizer.zero_grad()
×
274
            outputs = self.backflow_kernel(xpts.unsqueeze(1))
×
275
            loss = criterion(outputs.squeeze(), ground_truth)
×
276
            loss.backward()
×
277
            optimizer.step()
×
278
            running_loss += loss.item()
×
279

280
            if epoch % 100 == 0:
×
281
                print("Epoch {}: Loss = {}".format(epoch, loss.detach().numpy()))
×
282

283
        fit_values = self.backflow_kernel(xpts.unsqueeze(1)).squeeze()
×
NEW
284
        return xpts, ground_truth, fit_values
×
285

286
    def __repr__(self):
1✔
287
        """representation of the backflow transformation"""
288
        return self.backflow_kernel.__class__.__name__
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc