• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

LSSTDESC / Spectractor / 4890731301

pending completion
4890731301

Pull #125

github

GitHub
Merge 79e1b14f6 into 3549ae5c3
Pull Request #125: Fitparams

892 of 892 new or added lines in 16 files covered. (100.0%)

7127 of 7963 relevant lines covered (89.5%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.45
/spectractor/fit/fitter.py
1
from iminuit import Minuit
1✔
2
from scipy import optimize
1✔
3
from schwimmbad import MPIPool
1✔
4
import emcee
1✔
5
import time
1✔
6
import matplotlib.pyplot as plt
1✔
7
from matplotlib.ticker import MaxNLocator
1✔
8
import numpy as np
1✔
9
import sys
1✔
10
import os
1✔
11
import json
1✔
12
import multiprocessing
1✔
13
from dataclasses import dataclass
1✔
14
from typing import Optional, Union
1✔
15

16
from spectractor import parameters
1✔
17
from spectractor.config import set_logger
1✔
18
from spectractor.tools import (formatting_numbers, compute_correlation_matrix, plot_correlation_matrix_simple,
1✔
19
                               NumpyArrayEncoder)
20
from spectractor.fit.statistics import Likelihood
1✔
21

22

23
@dataclass()
1✔
24
class FitParameters:
1✔
25
    """Container for the parameters to fit on data with FitWorkspace.
26

27
    Attributes
28
    ----------
29
    p: np.ndarray
30
        Array containing the parameter values.
31
    input_labels: list, optional
32
        List of the parameter labels for screen print.
33
        If None, make a default list with parameters labelled par (default: None).
34
    axis_names: list, optional
35
        List of the parameter labels for plot print.
36
        If None, make a default list with parameters labelled par (default: None).
37
    bounds: list, optional
38
        List of 2-element list giving the (lower, upper) bounds for every parameter.
39
        If None, make a default list with np.infinity boundaries (default: None).
40
    fixed: list, optional
41
        List of boolean: True to fix a parameter, False to let it free.
42
        If None, make a default list with False values: all parameters are free (default: None).
43
    truth: np.ndarray, optional
44
        Array of truth parameters (default: None).
45
    filename: str, optional
46
        File name associated to the fitted parameters (usually _spectrum.fits file name) (default: '').
47

48
    Examples
49
    --------
50
    >>> from spectractor.fit.fitter import FitParameters
51
    >>> params = FitParameters(p=[1, 1, 1, 1, 1])
52
    >>> params.ndim
53
    5
54
    >>> params.p
55
    array([1., 1., 1., 1., 1.])
56
    >>> params.input_labels
57
    ['par0', 'par1', 'par2', 'par3', 'par4']
58
    >>> params.bounds
59
    [[-inf, inf], [-inf, inf], [-inf, inf], [-inf, inf], [-inf, inf]]
60
    """
61
    p: Union[np.ndarray, list]
1✔
62
    input_labels: Optional[list] = None
1✔
63
    axis_names: Optional[list] = None
1✔
64
    bounds: Optional[list] = None
1✔
65
    fixed: Optional[list] = None
1✔
66
    truth: Optional[list] = None
1✔
67
    filename: Optional[str] = ""
1✔
68
    extra: Optional[dict] = None
1✔
69

70
    def __post_init__(self):
1✔
71
        if type(self.p) is list:
1✔
72
            self.p = np.array(self.p, dtype=float)
1✔
73
        self.p = np.asarray(self.p, dtype=float)
1✔
74
        if not self.input_labels:
1✔
75
            self.input_labels = [f"par{k}" for k in range(self.ndim)]
1✔
76
        else:
77
            if len(self.input_labels) != self.ndim:
1✔
78
                raise ValueError("input_labels argument must have same size as values argument.")
×
79
        if not self.axis_names:
1✔
80
            self.axis_names = [f"$p_{k}$" for k in range(self.ndim)]
1✔
81
        else:
82
            if len(self.axis_names) != self.ndim:
1✔
83
                raise ValueError("input_labels argument must have same size as values argument.")
×
84
        if self.bounds is None:
1✔
85
            self.bounds = [[-np.inf, np.inf]] * self.ndim
1✔
86
        else:
87
            if np.array(self.bounds).shape != (self.ndim, 2):
1✔
88
                raise ValueError(f"bounds argument size {np.array(self.bounds).shape} must be same as values argument {(self.ndim, 2)}.")
×
89
        if not self.fixed:
1✔
90
            self.fixed = [False] * self.ndim
1✔
91
        else:
92
            if len(list(self.fixed)) != self.ndim:
1✔
93
                raise ValueError("fixed argument must have same size as values argument.")
×
94
        self.cov = np.zeros((self.nfree, self.nfree))
1✔
95

96
    @property
1✔
97
    def rho(self):
1✔
98
        """Correlation matrix computed from the covariance matrix
99

100
        Returns
101
        -------
102
        rho: np.ndarray
103
            The correlation matrix array.
104

105
        Examples
106
        --------
107
        >>> from spectractor.fit.fitter import FitParameters
108
        >>> params = FitParameters(p=[1, 1, 1], axis_names=["x", "y", "z"])
109
        >>> params.cov = np.array([[2,-0.5,0],[-0.5,2,-1],[0,-1,2]])
110
        >>> params.rho
111
        array([[ 1.  , -0.25,  0.  ],
112
               [-0.25,  1.  , -0.5 ],
113
               [ 0.  , -0.5 ,  1.  ]])
114

115
        """
116
        return compute_correlation_matrix(self.cov)
1✔
117

118
    @property
1✔
119
    def err(self):
1✔
120
        """Uncertainties on fitted parameters, as the square root of the covariance matrix diagonal.
121

122
        Returns
123
        -------
124
        err: np.ndarray
125
            The uncertainty array.
126

127
        Examples
128
        -------
129
        >>> from spectractor.fit.fitter import FitParameters
130
        >>> import numpy as np
131
        >>> params = FitParameters(p=[1, 1, 1, 1, 1], fixed=[False, True, False, True, False])
132
        >>> params.cov = np.array([[1,-0.5,0],[-0.5,4,-1],[0,-1,9]])
133
        >>> params.err
134
        array([1., 0., 2., 0., 3.])
135
        """
136
        err = np.zeros_like(self.p, dtype=float)
1✔
137
        if np.sum(self.fixed) != len(self.fixed):
1✔
138
            err[~np.asarray(self.fixed)] = np.sqrt(np.diag(self.cov))
1✔
139
        return err
1✔
140

141
    def __eq__(self, other):
1✔
142
        if not isinstance(other, FitParameters):
1✔
143
            return NotImplemented
×
144
        out = True
1✔
145
        for key in self.__dict__.keys():
1✔
146
            if isinstance(getattr(self, key), np.ndarray):
1✔
147
                out *= np.all(np.equal(getattr(self, key).flatten(), getattr(other, key).flatten()))
1✔
148
            else:
149
                out *= getattr(self, key) == getattr(other, key)
1✔
150
        return out
1✔
151

152
    @property
1✔
153
    def ndim(self):
1✔
154
        """Number of parameters.
155

156
        Returns
157
        -------
158
        ndim: int
159

160
        Examples
161
        --------
162
        >>> from spectractor.fit.fitter import FitParameters
163
        >>> params = FitParameters(p=[1, 1, 1, 1, 1])
164
        >>> params.ndim
165
        5
166
        """
167
        return len(self.p)
1✔
168

169
    @property
1✔
170
    def nfree(self):
1✔
171
        """Number of free parameters.
172

173
        Returns
174
        -------
175
        nfree: int
176

177
        Examples
178
        --------
179
        >>> from spectractor.fit.fitter import FitParameters
180
        >>> params = FitParameters(p=[1, 1, 1, 1, 1], fixed=[True, False, True, False, True])
181
        >>> params.nfree
182
        2
183
        """
184
        return len(self.get_free_parameters())
1✔
185

186
    @property
1✔
187
    def nfixed(self):
1✔
188
        """Number of fixed parameters.
189

190
        Returns
191
        -------
192
        nfixed: int
193

194
        Examples
195
        --------
196
        >>> from spectractor.fit.fitter import FitParameters
197
        >>> params = FitParameters(p=[1, 1, 1, 1, 1], fixed=[True, False, True, False, True])
198
        >>> params.nfixed
199
        3
200
        """
201
        return len(self.get_fixed_parameters())
1✔
202

203
    def get_free_parameters(self):
1✔
204
        """Return indices array of free parameters.
205

206
        Examples
207
        --------
208
        >>> params = FitParameters(p=[1, 1, 1, 1, 1], fixed=None)
209
        >>> params.fixed
210
        [False, False, False, False, False]
211
        >>> params.get_free_parameters()
212
        array([0, 1, 2, 3, 4])
213
        >>> params = FitParameters(p=[1, 1, 1, 1, 1], fixed=[True, False, True, False, True])
214
        >>> params.fixed
215
        [True, False, True, False, True]
216
        >>> params.get_free_parameters()
217
        array([1, 3])
218

219
        """
220
        return np.array(np.where(np.array(self.fixed).astype(int) == 0)[0])
1✔
221

222
    def get_fixed_parameters(self):
1✔
223
        """Return indices array of fixed parameters.
224

225
        Examples
226
        --------
227
        >>> params = FitParameters(p=[1, 1, 1, 1, 1], fixed=None)
228
        >>> params.fixed
229
        [False, False, False, False, False]
230
        >>> params.get_fixed_parameters()
231
        array([], dtype=int64)
232
        >>> params = FitParameters(p=[1, 1, 1, 1, 1], fixed=[True, False, True, False, True])
233
        >>> params.fixed
234
        [True, False, True, False, True]
235
        >>> params.get_fixed_parameters()
236
        array([0, 2, 4])
237

238
        """
239
        return np.array(np.where(np.array(self.fixed).astype(int) == 1)[0])
1✔
240

241
    def print_parameters_summary(self):
1✔
242
        """Print the best fitting parameters on screen.
243
        Labels are from self.input_labels.
244

245
        Returns
246
        -------
247
        txt: str
248
            The printed text.
249

250
        Examples
251
        --------
252
        >>> parameters.VERBOSE = True
253
        >>> params = FitParameters(p=[1, 2, 3, 4], input_labels=["x", "y", "z", "t"], fixed=[True, False, True, False])
254
        >>> params.cov = np.array([[1, -0.5], [-0.5, 4]])
255
        >>> _ = params.print_parameters_summary()
256
        """
257
        txt = ""
1✔
258
        ifree = self.get_free_parameters()
1✔
259
        icov = 0
1✔
260
        for ip in range(self.ndim):
1✔
261
            if ip in ifree:
1✔
262
                txt += "%s: %s +%s -%s\n\t" % formatting_numbers(self.p[ip], np.sqrt(self.cov[icov, icov]),
1✔
263
                                                                 np.sqrt(self.cov[icov, icov]),
264
                                                                 label=self.input_labels[ip])
265
                icov += 1
1✔
266
            else:
267
                txt += f"{self.input_labels[ip]}: {self.p[ip]} (fixed)\n\t"
1✔
268
        return txt
1✔
269

270
    def plot_correlation_matrix(self, live_fit=False):
1✔
271
        """Compute and plot a correlation matrix.
272

273
        Save the plot if parameters.SAVE is True. The output file name is build from self.file_name,
274
        adding the suffix _correlation.pdf.
275

276
        Parameters
277
        ----------
278
        live_fit: bool, optional, optional
279
            If True, model, data and residuals plots are made along the fitting procedure (default: False).
280

281
        Examples
282
        --------
283
        >>> from spectractor.fit.fitter import FitParameters
284
        >>> params = FitParameters(p=[1, 1, 1], axis_names=["x", "y", "z"])
285
        >>> params.cov = np.array([[1,-0.5,0],[-0.5,1,-1],[0,-1,1]])
286
        >>> params.plot_correlation_matrix()
287
        """
288
        ipar = self.get_free_parameters()
1✔
289
        fig = plt.figure()
1✔
290
        plot_correlation_matrix_simple(plt.gca(), self.rho, axis_names=[self.axis_names[i] for i in ipar])
1✔
291
        fig.tight_layout()
1✔
292
        if (parameters.SAVE or parameters.LSST_SAVEFIGPATH) and self.filename != "":  # pragma: no cover
293
            figname = os.path.splitext(self.filename)[0] + "_correlation.pdf"
294
            fig.savefig(figname, dpi=100, bbox_inches='tight')
295
        if parameters.LSST_SAVEFIGPATH:  # pragma: no cover
296
            figname = os.path.join(parameters.LSST_SAVEFIGPATH, "parameters_correlation.pdf")
297
            fig.savefig(figname, dpi=100, bbox_inches='tight')
298
        if parameters.DISPLAY:  # pragma: no cover
299
            if live_fit:
300
                plt.draw()
301
                plt.pause(1e-8)
302
            else:
303
                plt.show()
304

305
    @property
1✔
306
    def txt_filename(self):
1✔
307
        return os.path.splitext(self.filename)[0] + "_bestfit.txt"
1✔
308

309
    @property
1✔
310
    def json_filename(self):
1✔
311
        return os.path.splitext(self.filename)[0] + "_bestfit.json"
1✔
312

313
    def write_text(self, header=""):
1✔
314
        """Save the best fitting parameter summary in a text file.
315

316
        The file name is build from self.file_name, adding the suffix _bestfit.txt.
317

318
        Parameters
319
        ----------
320
        header: str, optional
321
            A header to add to the file (default: "").
322

323
        Examples
324
        --------
325
        >>> params = FitParameters(p=[1, 2, 3, 4], input_labels=["x", "y", "z", "t"],  fixed=[True, False, True, False], filename="test_spectrum.fits")
326
        >>> params.cov = np.array([[1,-0.5,0],[-0.5,1,-1],[0,-1,1]])
327
        >>> params.write_text(header="chi2: 1")
328

329
        .. doctest::
330
            :hide:
331

332
            >>> assert os.path.isfile(params.txt_filename)
333
            >>> os.remove(params.txt_filename)
334
        """
335
        txt = self.filename + "\n"
1✔
336
        if header != "":
1✔
337
            txt += header + "\n"
1✔
338
        txt += self.print_parameters_summary()
1✔
339
        for row in self.cov:
1✔
340
            txt += np.array_str(row, max_line_width=20 * self.cov.shape[0]) + '\n'
1✔
341
        output_filename = os.path.splitext(self.filename)[0] + "_bestfit.txt"
1✔
342
        f = open(output_filename, 'w')
1✔
343
        f.write(txt)
1✔
344
        f.close()
1✔
345

346
    def write_json(self):
1✔
347
        pass
×
348

349

350
def write_fitparameter_json(json_filename, params, extra=None):
1✔
351
    """Save FitParameters attributes as a JSON file.
352

353
    Parameters
354
    ----------
355
    json_filename: str
356
        JSON file name.
357
    params: FitParameters
358
        A FitParameters instance to save in JSON json_filename.
359
    extra: dict, optional
360
        Extra information to write in the JSON file.
361

362
    Returns
363
    -------
364
    jsontxt: str
365
        The JSON dictionnary as string.
366

367
    Examples
368
    --------
369
    >>> params = FitParameters(p=[1, 2, 3, 4], input_labels=["x", "y", "z", "t"],  fixed=[True, False, True, False], filename="test_spectrum.fits")
370
    >>> params.cov = np.array([[1,-0.5,0],[-0.5,1,-1],[0,-1,1]])
371
    >>> jsonstr = write_fitparameter_json(params.json_filename, params, extra={"chi2": 1})
372
    >>> jsonstr  # doctest: +ELLIPSIS
373
    '{"p": [1, 2, 3, 4], "input_labels": ["x", "y", "z", "t"],..."extra": {"chi2": 1}...
374

375
    .. doctest::
376
        :hide:
377

378
        >>> assert os.path.isfile(params.json_filename)
379
        >>> os.remove(params.json_filename)
380
    """
381
    if json_filename == "":
1✔
382
        raise ValueError("Must provide attribute a JSON filename.")
×
383
    if extra:
1✔
384
        params.extra = extra
1✔
385
    jsontxt = json.dumps(params.__dict__, cls=NumpyArrayEncoder)
1✔
386
    with open(json_filename, 'w') as output_json:
1✔
387
        output_json.write(jsontxt)
1✔
388
    return jsontxt
1✔
389

390

391
def read_fitparameter_json(json_filename):
1✔
392
    """Read JSON file and store data in FitParameters instance.
393

394
    Parameters
395
    ----------
396
    json_filename: str
397
        The JSON file name.
398

399
    Returns
400
    -------
401
    params: FitParameters
402
        A FitParameters instance to loaded from JSON json_filename.
403

404
    Examples
405
    --------
406
    >>> params = FitParameters(p=[1, 2, 3, 4], input_labels=["x", "y", "z", "t"],  fixed=[True, False, True, False], filename="test_spectrum.fits")
407
    >>> params.cov = np.array([[1,-0.5,0],[-0.5,1,-1],[0,-1,1]])
408
    >>> _ = write_fitparameter_json(params.json_filename, params, extra={"chi2": 1})
409
    >>> new_params = read_fitparameter_json(params.json_filename)
410
    >>> new_params.p
411
    array([1, 2, 3, 4])
412

413
    .. doctest::
414
        :hide:
415

416
        >>> assert os.path.isfile(params.json_filename)
417
        >>> assert params == new_params
418
        >>> os.remove(params.json_filename)
419

420
    """
421
    params = FitParameters(p=[0])
1✔
422
    with open(json_filename, 'r') as f:
1✔
423
        data = json.load(f)
1✔
424
    for key in ["p", "cov"]:
1✔
425
        data[key] = np.asarray(data[key])
1✔
426
    for key in data:
1✔
427
        setattr(params, key, data[key])
1✔
428
    return params
1✔
429

430

431
class FitWorkspace:
1✔
432

433
    def __init__(self, params=None, file_name="", verbose=False, plot=False, live_fit=False, truth=None):
1✔
434
        """Generic class to create a fit workspace with parameters, bounds and general fitting methods.
435

436
        Parameters
437
        ----------
438
        params: FitParameters, optional
439
            The parameters to fit to data (default: None).
440
        file_name: str, optional
441
            The generic file name to save results. If file_name=="", nothing is saved ond disk (default: "").
442
        verbose: bool, optional
443
            Level of verbosity (default: False).
444
        plot: bool, optional
445
            Level of plotting (default: False).
446
        live_fit: bool, optional
447
            If True, model, data and residuals plots are made along the fitting procedure (default: False).
448
        truth: array_like, optional
449
            Array of true parameters (default: None).
450

451
        Examples
452
        --------
453
        >>> params = FitParameters(p=[1, 1, 1, 1, 1])
454
        >>> w = FitWorkspace(params)
455
        >>> w.params.ndim
456
        5
457
        """
458
        self.my_logger = set_logger(self.__class__.__name__)
1✔
459
        self.params = params
1✔
460
        self.filename = file_name
1✔
461
        self.truth = truth
1✔
462
        self.verbose = verbose
1✔
463
        self.plot = plot
1✔
464
        self.live_fit = live_fit
1✔
465
        self.data = None
1✔
466
        self.err = None
1✔
467
        self.data_cov = None
1✔
468
        self.W = None
1✔
469
        self.x = None
1✔
470
        self.outliers = []
1✔
471
        self.mask = []
1✔
472
        self.sigma_clip = 5
1✔
473
        self.model = None
1✔
474
        self.model_err = None
1✔
475
        self.model_noconv = None
1✔
476
        self.params_table = None
1✔
477
        self.costs = np.array([[]])
1✔
478

479
    def get_bad_indices(self):
1✔
480
        """List of indices that are outliers rejected by a sigma-clipping method or other masking method.
481

482
        Returns
483
        -------
484
        outliers: list
485

486
        Examples
487
        --------
488
        >>> w = FitWorkspace()
489
        >>> w.data = np.array([np.array([1,2,3]), np.array([1,2,3,4])], dtype=object)
490
        >>> w.outliers = [2, 6]
491
        >>> w.get_bad_indices()
492
        [array([2]), array([3])]
493
        """
494
        bad_indices = np.asarray(self.outliers, dtype=int)
1✔
495
        if self.data.dtype == object:
1✔
496
            if len(self.outliers) > 0:
1✔
497
                bad_indices = []
1✔
498
                start_index = 0
1✔
499
                for k in range(self.data.shape[0]):
1✔
500
                    mask = np.zeros(self.data[k].size, dtype=bool)
1✔
501
                    outliers = np.asarray(self.outliers)[np.logical_and(np.asarray(self.outliers) > start_index,
1✔
502
                                                                        np.asarray(self.outliers) < start_index +
503
                                                                        self.data[k].size)]
504
                    mask[outliers - start_index] = True
1✔
505
                    bad_indices.append(np.arange(self.data[k].size)[mask])
1✔
506
                    start_index += self.data[k].size
1✔
507
            else:
508
                bad_indices = [[] for _ in range(self.data.shape[0])]
×
509
        return bad_indices
1✔
510

511
    def simulate(self, *p):
1✔
512
        """Compute the model prediction given a set of parameters.
513

514
        Parameters
515
        ----------
516
        p: array_like
517
            Array of parameters for the computation of the model.
518

519
        Returns
520
        -------
521
        x: array_like
522
            The abscisse of the model prediction.
523
        model: array_like
524
            The model prediction.
525
        model_err: array_like
526
            The uncertainty on the model prediction.
527

528
        Examples
529
        --------
530
        >>> w = FitWorkspace()
531
        >>> p = np.zeros(3)
532
        >>> x, model, model_err = w.simulate(*p)
533

534
        .. doctest::
535
            :hide:
536
            >>> assert x is not None
537

538
        """
539
        self.x = np.array([])
1✔
540
        self.model = np.array([])
1✔
541
        self.model_err = np.array([])
1✔
542
        return self.x, self.model, self.model_err
1✔
543

544
    def plot_fit(self):
1✔
545
        """Generic function to plot the result of the fit for 1D curves.
546

547
        Returns
548
        -------
549
        fig: plt.FigureClass
550
            The figure.
551

552
        """
553
        fig = plt.figure()
1✔
554
        plt.errorbar(self.x, self.data, yerr=self.err, fmt='ko', label='Data')
1✔
555
        if self.truth is not None:
1✔
556
            x, truth, truth_err = self.simulate(*self.truth)
×
557
            plt.plot(self.x, truth, label="Truth")
×
558
        plt.plot(self.x, self.model, label='Best fitting model')
1✔
559
        plt.xlabel('$x$')
1✔
560
        plt.ylabel('$y$')
1✔
561
        title = ""
1✔
562
        for i, label in enumerate(self.params.input_labels):
1✔
563
            if self.params.cov.size > 0:
1✔
564
                err = np.sqrt(self.params.cov[i, i])
1✔
565
                formatting_numbers(self.params.p[i], err, err)
1✔
566
                _, par, err, _ = formatting_numbers(self.params.p[i], err, err, label=label)
1✔
567
                title += rf"{label} = {par} $\pm$ {err}"
1✔
568
            else:
569
                title += f"{label} = {self.params.p[i]:.3g}"
×
570
            if i < self.params.ndim - 1:
1✔
571
                title += ", "
1✔
572
        plt.title(title)
1✔
573
        plt.legend()
1✔
574
        plt.grid()
1✔
575
        if parameters.DISPLAY:  # pragma: no cover
576
            plt.show()
577
        return fig
1✔
578

579
    def weighted_residuals(self, p):  # pragma: nocover
580
        """Compute the weighted residuals array for a set of model parameters p.
581

582
        Parameters
583
        ----------
584
        p: array_like
585
            The array of model parameters.
586

587
        Returns
588
        -------
589
        residuals: np.array
590
            The array of weighted residuals.
591

592
        """
593
        x, model, model_err = self.simulate(*p)
594
        if self.data_cov is None:
595
            if len(self.outliers) > 0:
596
                model_err = model_err.flatten()
597
                err = self.err.flatten()
598
                res = (model.flatten() - self.data.flatten()) / np.sqrt(model_err * model_err + err * err)
599
            else:
600
                res = ((model - self.data) / np.sqrt(model_err * model_err + self.err * self.err)).flatten()
601
        else:
602
            if self.data_cov.ndim > 2:
603
                K = self.data_cov.shape[0]
604
                if np.any(model_err > 0):
605
                    cov = [self.data_cov[k] + np.diag(model_err[k] ** 2) for k in range(K)]
606
                    L = [np.linalg.inv(np.linalg.cholesky(cov[k])) for k in range(K)]
607
                else:
608
                    L = [np.linalg.cholesky(self.W[k]) for k in range(K)]
609
                res = [L[k] @ (model[k] - self.data[k]) for k in range(K)]
610
                res = np.concatenate(res).ravel()
611
            else:
612
                if np.any(model_err > 0):
613
                    cov = self.data_cov + np.diag(model_err * model_err)
614
                    L = np.linalg.inv(np.linalg.cholesky(cov))
615
                else:
616
                    if self.W.ndim == 1 and self.W.dtype != object:
617
                        L = np.sqrt(self.W)
618
                    elif self.W.ndim == 2 and self.W.dtype != object:
619
                        L = np.linalg.cholesky(self.W)
620
                    else:
621
                        raise ValueError(f"Case not implemented with self.W.ndim={self.W.ndim} "
622
                                         f"and self.W.dtype={self.W.dtype}")
623
                res = L @ (model - self.data)
624
        return res
625

626
    def compute_W_with_model_error(self, model_err):
1✔
627
        W = self.W
1✔
628
        zeros = W == 0
1✔
629
        if self.W.ndim == 1 and self.W.dtype != object:
1✔
630
            if np.any(model_err > 0):
1✔
631
                W = 1 / (self.data_cov + model_err * model_err)
1✔
632
        elif self.W.dtype == object:
1✔
633
            K = len(self.W)
1✔
634
            if self.W[0].ndim == 1:
1✔
635
                if np.any(model_err > 0):
1✔
636
                    W = [1 / (self.data_cov[k] + model_err[k] * model_err[k]) for k in range(K)]
×
637
            elif self.W[0].ndim == 2:
1✔
638
                K = len(self.W)
1✔
639
                if np.any(model_err > 0):
1✔
640
                    cov = [self.data_cov[k] + np.diag(model_err[k] ** 2) for k in range(K)]
×
641
                    L = [np.linalg.inv(np.linalg.cholesky(cov[k])) for k in range(K)]
×
642
                    W = [L[k].T @ L[k] for k in range(K)]
×
643
            else:
644
                raise ValueError(f"First element of fitworkspace.W has no ndim attribute or has a dimension above 2. "
×
645
                                 f"I get W[0]={self.W[0]}")
646
        elif self.W.ndim == 2 and self.W.dtype != object:
1✔
647
            if np.any(model_err > 0):
1✔
648
                cov = self.data_cov + np.diag(model_err * model_err)
1✔
649
                L = np.linalg.inv(np.linalg.cholesky(cov))
1✔
650
                W = L.T @ L
1✔
651
        W[zeros] = 0
1✔
652
        return W
1✔
653

654
    def chisq(self, p, model_output=False):
1✔
655
        """Compute the chi square for a set of model parameters p.
656

657
        Four cases are implemented: diagonal W, 2D W, array of diagonal Ws, array of 2D Ws. The two latter cases
658
        are for multiple independent data vectors with W being block diagonal.
659

660
        Parameters
661
        ----------
662
        p: array_like
663
            The array of model parameters.
664
        model_output: bool, optional
665
            If true, the simulated model is output.
666

667
        Returns
668
        -------
669
        chisq: float
670
            The chi square value.
671

672
        """
673
        # check data format
674
        if (self.data.dtype != object and self.data.ndim > 1) or (self.err.dtype != object and self.err.ndim > 1):
1✔
675
            raise ValueError("Fitworkspace.data and Fitworkspace.err must be a flat 1D array,"
×
676
                             " or an array of flat arrays of unequal lengths.")
677
        # prepare weight matrices in case they have not been built before
678
        self.prepare_weight_matrices()
1✔
679
        x, model, model_err = self.simulate(*p)
1✔
680
        W = self.compute_W_with_model_error(model_err)
1✔
681
        if W.ndim == 1 and W.dtype != object:
1✔
682
            res = (model - self.data)
1✔
683
            chisq = res @ (W * res)
1✔
684
        elif W.dtype == object:
1✔
685
            K = len(W)
1✔
686
            res = [model[k] - self.data[k] for k in range(K)]
1✔
687
            if W[0].ndim == 1:
1✔
688
                chisq = np.sum([res[k] @ (W[k] * res[k]) for k in range(K)])
1✔
689
            elif W[0].ndim == 2:
1✔
690
                chisq = np.sum([res[k] @ W[k] @ res[k] for k in range(K)])
1✔
691
            else:
692
                raise ValueError(f"First element of fitworkspace.W has no ndim attribute or has a dimension above 2. "
×
693
                                 f"I get W[0]={W[0]}")
694
        elif W.ndim == 2 and W.dtype != object:
1✔
695
            res = (model - self.data)
1✔
696
            chisq = res @ W @ res
1✔
697
        else:
698
            raise ValueError(
×
699
                f"Data inverse covariance matrix must be a np.ndarray of dimension 1 or 2,"
700
                f"either made of 1D or 2D arrays of equal lengths or not for block diagonal matrices."
701
                f"\nHere W type is {type(W)}, shape is {W.shape} and W is {W}.")
702
        if model_output:
1✔
703
            return chisq, x, model, model_err
1✔
704
        else:
705
            return chisq
1✔
706

707
    def prepare_weight_matrices(self):
1✔
708
        # Prepare covariance matrix for data
709
        if self.data_cov is None:
1✔
710
            self.data_cov = np.asarray(self.err.flatten() ** 2)
1✔
711
        # Prepare inverse covariance matrix for data
712
        if self.W is None:
1✔
713
            if self.data_cov.ndim == 1 and self.data_cov.dtype != object:
1✔
714
                self.W = 1 / self.data_cov
1✔
715
            elif self.data_cov.ndim == 2 and self.data_cov.dtype != object:
1✔
716
                L = np.linalg.inv(np.linalg.cholesky(self.data_cov))
1✔
717
                self.W = L.T @ L
1✔
718
            elif self.data_cov.dtype is object:
×
719
                if self.data_cov[0].ndim == 1:
×
720
                    self.W = np.array([1 / self.data_cov[k] for k in range(self.data_cov.shape[0])])
×
721
                else:
722
                    self.W = []
×
723
                    for k in range(len(self.data_cov)):
×
724
                        L = np.linalg.inv(np.linalg.cholesky(self.data_cov[k]))
×
725
                        self.W[k] = L.T @ L
×
726
                    self.W = np.asarray(self.W)
×
727
        if len(self.outliers) > 0:
1✔
728
            bad_indices = self.get_bad_indices()
1✔
729
            if self.W.ndim == 1 and self.W.dtype != object:
1✔
730
                self.W[bad_indices] = 0
1✔
731
            elif self.W.ndim == 2 and self.W.dtype != object:
1✔
732
                self.W[:, bad_indices] = 0
1✔
733
                self.W[bad_indices, :] = 0
1✔
734
            elif self.W.dtype == object:
×
735
                if self.data_cov[0].ndim == 1:
×
736
                    for k in range(len(self.W)):
×
737
                        self.W[k][bad_indices[k]] = 0
×
738
                else:
739
                    for k in range(len(self.W)):
×
740
                        self.W[k][:, bad_indices[k]] = 0
×
741
                        self.W[k][bad_indices[k], :] = 0
×
742
            else:
743
                raise ValueError(
×
744
                    f"Data inverse covariance matrix must be a np.ndarray of dimension 1 or 2,"
745
                    f"either made of 1D or 2D arrays of equal lengths or not for block diagonal matrices."
746
                    f"\nHere W type is {type(self.W)}, shape is {self.W.shape} and W is {self.W}.")
747

748
    def lnlike(self, p):
1✔
749
        """Compute the logarithmic likelihood for a set of model parameters p as -0.5*chisq.
750

751
        Parameters
752
        ----------
753
        p: array_like
754
            The array of model parameters.
755

756
        Returns
757
        -------
758
        lnlike: float
759
            The logarithmic likelihood value.
760

761
        """
762
        return -0.5 * self.chisq(p)
1✔
763

764
    def lnprior(self, p):
1✔
765
        """Compute the logarithmic prior for a set of model parameters p.
766

767
        The function returns 0 for good parameters, and -np.inf for parameters out of their boundaries.
768

769
        Parameters
770
        ----------
771
        p: array_like
772
            The array of model parameters.
773

774
        Returns
775
        -------
776
        lnprior: float
777
            The logarithmic value fo the prior.
778

779
        """
780
        in_bounds = True
1✔
781
        for npar, par in enumerate(p):
1✔
782
            if par < self.params.bounds[npar][0] or par > self.params.bounds[npar][1]:
1✔
783
                in_bounds = False
×
784
                break
×
785
        if in_bounds:
1✔
786
            return 0.0
1✔
787
        else:
788
            return -np.inf
×
789

790
    def jacobian(self, params, epsilon, model_input=None):
1✔
791
        """Generic function to compute the Jacobian matrix of a model, with numerical derivatives.
792

793
        Parameters
794
        ----------
795
        params: array_like
796
            The array of model parameters.
797
        epsilon: array_like
798
            The array of small steps to compute the partial derivatives of the model.
799
        model_input: array_like, optional
800
            A model input as a list with (x, model, model_err) to avoid an additional call to simulate().
801

802
        Returns
803
        -------
804
        J: np.array
805
            The Jacobian matrix.
806

807
        """
808
        if model_input:
1✔
809
            x, model, model_err = model_input
1✔
810
        else:
811
            x, model, model_err = self.simulate(*params)
×
812
        if self.W.dtype == object and self.W[0].ndim == 2:
1✔
813
            J = [[] for _ in range(params.size)]
×
814
        else:
815
            model = model.flatten()
1✔
816
            J = np.zeros((params.size, model.size))
1✔
817
        for ip, p in enumerate(params):
1✔
818
            if self.params.fixed[ip]:
1✔
819
                continue
1✔
820
            tmp_p = np.copy(params)
1✔
821
            if tmp_p[ip] + epsilon[ip] < self.params.bounds[ip][0] or tmp_p[ip] + epsilon[ip] > self.params.bounds[ip][1]:
1✔
822
                epsilon[ip] = - epsilon[ip]
1✔
823
            tmp_p[ip] += epsilon[ip]
1✔
824
            tmp_x, tmp_model, tmp_model_err = self.simulate(*tmp_p)
1✔
825
            if self.W.dtype == object and self.W[0].ndim == 2:
1✔
826
                for k in range(model.shape[0]):
×
827
                    J[ip].append((tmp_model[k] - model[k]) / epsilon[ip])
×
828
            else:
829
                J[ip] = (tmp_model.flatten() - model) / epsilon[ip]
1✔
830
        return np.asarray(J)
1✔
831

832
    def hessian(self, params, epsilon):  # pragma: nocover
833
        """Experimental function to compute the hessian of a model.
834

835
        Parameters
836
        ----------
837
        params: array_like
838
            The array of model parameters.
839
        epsilon: array_like
840
            The array of small steps to compute the partial derivatives of the model.
841

842
        Returns
843
        -------
844

845
        """
846
        x, model, model_err = self.simulate(*params)
847
        model = model.flatten()
848
        J = self.jacobian(params, epsilon)
849
        H = np.zeros((params.size, params.size, model.size))
850
        tmp_p = np.copy(params)
851
        for ip, p1 in enumerate(params):
852
            print(ip, p1, params[ip], tmp_p[ip], self.params.bounds[ip], epsilon[ip], tmp_p[ip] + epsilon[ip])
853
            if self.params.fixed[ip]:
854
                continue
855
            if tmp_p[ip] + epsilon[ip] < self.params.bounds[ip][0] or tmp_p[ip] + epsilon[ip] > self.params.bounds[ip][1]:
856
                epsilon[ip] = - epsilon[ip]
857
            tmp_p[ip] += epsilon[ip]
858
            print(tmp_p)
859
            # tmp_x, tmp_model, tmp_model_err = self.simulate(*tmp_p)
860
            # J[ip] = (tmp_model.flatten() - model) / epsilon[ip]
861
        tmp_J = self.jacobian(tmp_p, epsilon)
862
        for ip, p1 in enumerate(params):
863
            if self.params.fixed[ip]:
864
                continue
865
            for jp, p2 in enumerate(params):
866
                if self.params.fixed[jp]:
867
                    continue
868
                x, modelplus, model_err = self.simulate(params + epsilon)
869
                x, modelmoins, model_err = self.simulate(params - epsilon)
870
                model = model.flatten()
871

872
                print("hh", ip, jp, tmp_J[ip], J[jp], tmp_p[ip], params, (tmp_J[ip] - J[jp]) / epsilon)
873
                print((modelplus + modelmoins - 2 * model) / (np.asarray(epsilon) ** 2))
874
                H[ip, jp] = (tmp_J[ip] - J[jp]) / epsilon
875
                H[ip, jp] = (modelplus + modelmoins - 2 * model) / (np.asarray(epsilon) ** 2)
876
        return H
877

878
    def plot_gradient_descent(self):
1✔
879
        fig, ax = plt.subplots(2, 1, figsize=(10, 6), sharex="all")
1✔
880
        iterations = np.arange(self.params_table.shape[0])
1✔
881
        ax[0].plot(iterations, self.costs, lw=2)
1✔
882
        for ip in range(self.params_table.shape[1]):
1✔
883
            ax[1].plot(iterations, self.params_table[:, ip], label=f"{self.params.axis_names[ip]}")
1✔
884
        ax[1].set_yscale("symlog")
1✔
885
        ax[1].legend(ncol=6, loc=9)
1✔
886
        ax[1].grid()
1✔
887
        ax[0].set_yscale("log")
1✔
888
        ax[0].set_ylabel(r"$\chi^2$")
1✔
889
        ax[1].set_ylabel("Parameters")
1✔
890
        ax[0].grid()
1✔
891
        ax[1].set_xlabel("Iterations")
1✔
892
        ax[0].xaxis.set_major_locator(MaxNLocator(integer=True))
1✔
893
        fig.tight_layout()
1✔
894
        plt.subplots_adjust(wspace=0, hspace=0)
1✔
895
        if parameters.SAVE and self.filename != "":  # pragma: no cover
896
            figname = os.path.splitext(self.filename)[0] + "_fitting.pdf"
897
            self.my_logger.info(f"\n\tSave figure {figname}.")
898
            fig.savefig(figname, dpi=100, bbox_inches='tight')
899
        if parameters.DISPLAY:  # pragma: no cover
900
            plt.show()
901
        if parameters.PdfPages:  # args from the above? MFL
1✔
902
            parameters.PdfPages.savefig()
×
903

904
        self.simulate(*self.params.p)
1✔
905
        self.live_fit = False
1✔
906
        self.plot_fit()
1✔
907

908
    def save_gradient_descent(self):
1✔
909
        iterations = np.arange(self.params_table.shape[0]).astype(int)
1✔
910
        t = np.zeros((self.params_table.shape[1] + 2, self.params_table.shape[0]))
1✔
911
        t[0] = iterations
1✔
912
        t[1] = self.costs
1✔
913
        t[2:] = self.params_table.T
1✔
914
        h = 'iter,costs,' + ','.join(self.params.input_labels)
1✔
915
        output_filename = os.path.splitext(self.filename)[0] + "_fitting.txt"
1✔
916
        np.savetxt(output_filename, t.T, header=h, delimiter=",")
1✔
917
        self.my_logger.info(f"\n\tSave gradient descent log {output_filename}.")
1✔
918

919

920
class MCMCFitWorkspace(FitWorkspace):
1✔
921

922
    def __init__(self, params, file_name="", nwalkers=18, nsteps=1000, burnin=100, nbins=10,
1✔
923
                 verbose=False, plot=False, live_fit=False, truth=None):
924
        """Generic class to create a fit workspace with parameters, bounds and general fitting methods.
925

926
        Parameters
927
        ----------
928
        params: FitParameters
929
            The parameters to fit to data.
930
        file_name: str, optional
931
            The generic file name to save results. If file_name=="", nothing is saved ond disk (default: "").
932
        nwalkers: int, optional
933
            Number of walkers for MCMC exploration (default: 18).
934
        nsteps: int, optional
935
            Number of steps for MCMC exploration (default: 1000).
936
        burnin: int, optional
937
            Number of burn-in steps for MCMC exploration (default: 100).
938
        nbins: int, optional
939
            Number of bins to make histograms after MCMC exploration (default: 10).
940
        verbose: bool, optional
941
            Level of verbosity (default: False).
942
        plot: bool, optional
943
            Level of plotting (default: False).
944
        live_fit: bool, optional
945
            If True, model, data and residuals plots are made along the fitting procedure (default: False).
946
        truth: array_like, optional
947
            Array of true parameters (default: None).
948

949
        Examples
950
        --------
951
        >>> params = FitParameters(p=[1, 1, 1, 1, 1])
952
        >>> w = MCMCFitWorkspace(params)
953
        >>> w.nwalkers
954
        18
955
        """
956
        FitWorkspace.__init__(self, params, file_name=file_name, verbose=verbose, plot=plot, live_fit=live_fit, truth=truth)
1✔
957
        self.my_logger = set_logger(self.__class__.__name__)
1✔
958
        self.nwalkers = max(2 * self.params.ndim, nwalkers)
1✔
959
        self.nsteps = nsteps
1✔
960
        self.nbins = nbins
1✔
961
        self.burnin = burnin
1✔
962
        self.start = []
1✔
963
        self.likelihood = np.array([[]])
1✔
964
        self.gelmans = np.array([])
1✔
965
        self.chains = np.array([[]])
1✔
966
        self.lnprobs = np.array([[]])
1✔
967
        self.flat_chains = np.array([[]])
1✔
968
        self.valid_chains = [False] * self.nwalkers
1✔
969
        self.global_average = None
1✔
970
        self.global_std = None
1✔
971
        self.use_grid = False
1✔
972
        if self.filename != "":
1✔
973
            if "." in self.filename:
1✔
974
                self.emcee_filename = os.path.splitext(self.filename)[0] + "_emcee.h5"
1✔
975
            else:
976
                self.my_logger.warning("\n\tFile name must have an extension.")
×
977
        else:
978
            self.emcee_filename = "emcee.h5"
1✔
979

980
    def set_start(self, percent=0.02, a_random=1e-5):
1✔
981
        """Set the random starting points for MCMC exploration.
982

983
        A set of parameters are drawn with a uniform distribution between +/- percent times the starting guess.
984
        For null guess parameters, starting points are drawn from a uniform distribution between +/- a_random.
985

986
        Parameters
987
        ----------
988
        percent: float, optional
989
            Percent of the guess parameters to set the uniform interval to draw random points (default: 0.02).
990
        a_random: float, optional
991
            Absolute value to set the +/- uniform interval to draw random points
992
            for null guess parameters (default: 1e-5).
993

994
        Returns
995
        -------
996
        start: np.array
997
            Array of starting points of shape (ndim, nwalkers).
998

999
        """
1000
        self.start = np.array([np.random.uniform(self.params.p[i] - percent * self.params.p[i],
1✔
1001
                                                 self.params.p[i] + percent * self.params.p[i],
1002
                                                 self.nwalkers) for i in range(self.params.ndim)]).T
1003
        self.start[self.start == 0] = a_random * np.random.uniform(-1, 1)
1✔
1004
        return self.start
1✔
1005

1006
    def load_chains(self):
1✔
1007
        """Load the MCMC chains from a hdf5 file. The burn-in points are not rejected at this stage.
1008

1009
        Returns
1010
        -------
1011
        chains: np.array
1012
            Array of the chains.
1013
        lnprobs: np.array
1014
            Array of the logarithmic posterior probability.
1015

1016
        """
1017
        self.chains = [[]]
1✔
1018
        self.lnprobs = [[]]
1✔
1019
        self.nsteps = 0
1✔
1020
        # tau = -1
1021
        reader = emcee.backends.HDFBackend(self.emcee_filename)
1✔
1022
        try:
1✔
1023
            tau = reader.get_autocorr_time()
1✔
1024
        except emcee.autocorr.AutocorrError:
×
1025
            tau = -1
×
1026
        self.chains = reader.get_chain(discard=0, flat=False, thin=1)
1✔
1027
        self.lnprobs = reader.get_log_prob(discard=0, flat=False, thin=1)
1✔
1028
        self.nsteps = self.chains.shape[0]
1✔
1029
        self.nwalkers = self.chains.shape[1]
1✔
1030
        print(f"Auto-correlation time: {tau}")
1✔
1031
        print(f"Burn-in: {self.burnin}")
1✔
1032
        print(f"Chains shape: {self.chains.shape}")
1✔
1033
        print(f"Log prob shape: {self.lnprobs.shape}")
1✔
1034
        return self.chains, self.lnprobs
1✔
1035

1036
    def build_flat_chains(self):
1✔
1037
        """Flatten the chains array and apply burn-in.
1038

1039
        Returns
1040
        -------
1041
        flat_chains: np.array
1042
            Flat chains.
1043

1044
        """
1045
        self.flat_chains = self.chains[self.burnin:, self.valid_chains, :].reshape((-1, self.params.ndim))
1✔
1046
        return self.flat_chains
1✔
1047

1048
    def analyze_chains(self):
1✔
1049
        """Load the chains, build the probability densities for the parameters, compute the best fitting values
1050
        and the uncertainties and covariance matrices, and plot.
1051

1052
        """
1053
        self.load_chains()
1✔
1054
        self.set_chain_validity()
1✔
1055
        self.convergence_tests()
1✔
1056
        self.build_flat_chains()
1✔
1057
        self.likelihood = self.chain2likelihood()
1✔
1058
        self.params.cov = self.likelihood.cov_matrix
1✔
1059
        self.params.p = self.likelihood.mean_vec
1✔
1060
        self.simulate(*self.params.p)
1✔
1061
        self.plot_fit()
1✔
1062
        figure_name = os.path.splitext(self.emcee_filename)[0] + '_triangle.pdf'
1✔
1063
        self.likelihood.triangle_plots(output_filename=figure_name)
1✔
1064

1065
    def chain2likelihood(self, pdfonly=False, walker_index=-1):
1✔
1066
        """Convert the chains to a psoterior probability density function via histograms.
1067

1068
        Parameters
1069
        ----------
1070
        pdfonly: bool, optional
1071
            If True, do not compute the covariances and the 2D correlation plots (default: False).
1072
        walker_index: int, optional
1073
            The walker index to plot. If -1, all walkers are selected (default: -1).
1074

1075
        Returns
1076
        -------
1077
        likelihood: np.array
1078
            Posterior density function.
1079

1080
        """
1081
        if walker_index >= 0:
1✔
1082
            chains = self.chains[self.burnin:, walker_index, :]
1✔
1083
        else:
1084
            chains = self.flat_chains
1✔
1085
        rangedim = range(chains.shape[1])
1✔
1086
        centers = []
1✔
1087
        for i in rangedim:
1✔
1088
            centers.append(np.linspace(np.min(chains[:, i]), np.max(chains[:, i]), self.nbins - 1))
1✔
1089
        likelihood = Likelihood(centers, labels=self.params.input_labels, axis_names=self.params.axis_names, truth=self.params.truth)
1✔
1090
        if walker_index < 0:
1✔
1091
            for i in rangedim:
1✔
1092
                likelihood.pdfs[i].fill_histogram(chains[:, i], weights=None)
1✔
1093
                if not pdfonly:
1✔
1094
                    for j in rangedim:
1✔
1095
                        if i != j:
1✔
1096
                            likelihood.contours[i][j].fill_histogram(chains[:, i], chains[:, j], weights=None)
1✔
1097
            output_file = ""
1✔
1098
            if self.filename != "":
1✔
1099
                output_file = os.path.splitext(self.filename)[0] + "_bestfit.txt"
1✔
1100
            likelihood.stats(output=output_file)
1✔
1101
        else:
1102
            for i in rangedim:
1✔
1103
                likelihood.pdfs[i].fill_histogram(chains[:, i], weights=None)
1✔
1104
        return likelihood
1✔
1105

1106
    def compute_local_acceptance_rate(self, start_index, last_index, walker_index):
1✔
1107
        """Compute the local acceptance rate in a chain.
1108

1109
        Parameters
1110
        ----------
1111
        start_index: int
1112
            Beginning index.
1113
        last_index: int
1114
            End index.
1115
        walker_index: int
1116
            Index of the walker.
1117

1118
        Returns
1119
        -------
1120
        freq: float
1121
            The acceptance rate.
1122

1123
        """
1124
        frequences = []
1✔
1125
        test = -2 * self.lnprobs[start_index, walker_index]
1✔
1126
        counts = 1
1✔
1127
        for index in range(start_index + 1, last_index):
1✔
1128
            chi2 = -2 * self.lnprobs[index, walker_index]
1✔
1129
            if np.isclose(chi2, test):
1✔
1130
                counts += 1
1✔
1131
            else:
1132
                frequences.append(float(counts))
1✔
1133
                counts = 1
1✔
1134
                test = chi2
1✔
1135
        frequences.append(counts)
1✔
1136
        return 1.0 / np.mean(frequences)
1✔
1137

1138
    def set_chain_validity(self):
1✔
1139
        """Test the validity of a chain: reject chains whose chi2 is far from the mean of the others.
1140

1141
        Returns
1142
        -------
1143
        valid_chains: list
1144
            List of boolean values, True if the chain is valid, or False if invalid.
1145

1146
        """
1147
        nchains = [k for k in range(self.nwalkers)]
1✔
1148
        chisq_averages = []
1✔
1149
        chisq_std = []
1✔
1150
        for k in nchains:
1✔
1151
            chisqs = -2 * self.lnprobs[self.burnin:, k]
1✔
1152
            # if np.mean(chisqs) < 1e5:
1153
            chisq_averages.append(np.mean(chisqs))
1✔
1154
            chisq_std.append(np.std(chisqs))
1✔
1155
        self.global_average = np.mean(chisq_averages)
1✔
1156
        self.global_std = np.mean(chisq_std)
1✔
1157
        self.valid_chains = [False] * self.nwalkers
1✔
1158
        for k in nchains:
1✔
1159
            chisqs = -2 * self.lnprobs[self.burnin:, k]
1✔
1160
            chisq_average = np.mean(chisqs)
1✔
1161
            chisq_std = np.std(chisqs)
1✔
1162
            if 3 * self.global_std + self.global_average < chisq_average < 1e5:
1✔
1163
                self.valid_chains[k] = False
×
1164
            elif chisq_std < 0.1 * self.global_std:
1✔
1165
                self.valid_chains[k] = False
×
1166
            else:
1167
                self.valid_chains[k] = True
1✔
1168
        return self.valid_chains
1✔
1169

1170
    def convergence_tests(self):
1✔
1171
        """Compute the convergence tests (Gelman-Rubin, acceptance rate).
1172

1173
        """
1174
        chains = self.chains[self.burnin:, :, :]  # .reshape((-1, self.ndim))
1✔
1175
        nchains = [k for k in range(self.nwalkers)]
1✔
1176
        fig, ax = plt.subplots(self.params.ndim + 1, 2, figsize=(16, 7), sharex='all')
1✔
1177
        fontsize = 8
1✔
1178
        steps = np.arange(self.burnin, self.nsteps)
1✔
1179
        # Chi2 vs Index
1180
        print("Chisq statistics:")
1✔
1181
        for k in nchains:
1✔
1182
            chisqs = -2 * self.lnprobs[self.burnin:, k]
1✔
1183
            text = f"\tWalker {k:d}: {float(np.mean(chisqs)):.3f} +/- {float(np.std(chisqs)):.3f}"
1✔
1184
            if not self.valid_chains[k]:
1✔
1185
                text += " -> excluded"
×
1186
                ax[self.params.ndim, 0].plot(steps, chisqs, c='0.5', linestyle='--')
×
1187
            else:
1188
                ax[self.params.ndim, 0].plot(steps, chisqs)
1✔
1189
            print(text)
1✔
1190
        # global_average = np.mean(-2*self.lnprobs[self.valid_chains, self.burnin:])
1191
        # global_std = np.std(-2*self.lnprobs[self.valid_chains, self.burnin:])
1192
        ax[self.params.ndim, 0].set_ylim(
1✔
1193
            [self.global_average - 5 * self.global_std, self.global_average + 5 * self.global_std])
1194
        # Parameter vs Index
1195
        print("Computing Parameter vs Index plots...")
1✔
1196
        for i in range(self.params.ndim):
1✔
1197
            ax[i, 0].set_ylabel(self.params.axis_names[i], fontsize=fontsize)
1✔
1198
            for k in nchains:
1✔
1199
                if self.valid_chains[k]:
1✔
1200
                    ax[i, 0].plot(steps, chains[:, k, i])
1✔
1201
                else:
1202
                    ax[i, 0].plot(steps, chains[:, k, i], c='0.5', linestyle='--')
×
1203
                ax[i, 0].get_yaxis().set_label_coords(-0.05, 0.5)
1✔
1204
        ax[self.params.ndim, 0].set_ylabel(r'$\chi^2$', fontsize=fontsize)
1✔
1205
        ax[self.params.ndim, 0].set_xlabel('Steps', fontsize=fontsize)
1✔
1206
        ax[self.params.ndim, 0].get_yaxis().set_label_coords(-0.05, 0.5)
1✔
1207
        # Acceptance rate vs Index
1208
        print("Computing acceptance rate...")
1✔
1209
        min_len = self.nsteps
1✔
1210
        window = 100
1✔
1211
        if min_len > window:
1✔
1212
            for k in nchains:
1✔
1213
                ARs = []
1✔
1214
                indices = []
1✔
1215
                for pos in range(self.burnin + window, self.nsteps, window):
1✔
1216
                    ARs.append(self.compute_local_acceptance_rate(pos - window, pos, k))
1✔
1217
                    indices.append(pos)
1✔
1218
                if self.valid_chains[k]:
1✔
1219
                    ax[self.params.ndim, 1].plot(indices, ARs, label=f'Walker {k:d}')
1✔
1220
                else:
1221
                    ax[self.params.ndim, 1].plot(indices, ARs, label=f'Walker {k:d}', c='gray', linestyle='--')
×
1222
                ax[self.params.ndim, 1].set_xlabel('Steps', fontsize=fontsize)
1✔
1223
                ax[self.params.ndim, 1].set_ylabel('Aceptance rate', fontsize=fontsize)
1✔
1224
                # ax[self.dim + 1, 2].legend(loc='upper left', ncol=2, fontsize=10)
1225
        # Parameter PDFs by chain
1226
        print("Computing chain by chain PDFs...")
1✔
1227
        for k in nchains:
1✔
1228
            likelihood = self.chain2likelihood(pdfonly=True, walker_index=k)
1✔
1229
            likelihood.stats(pdfonly=True, verbose=False)
1✔
1230
            # for i in range(self.dim):
1231
            # ax[i, 1].plot(likelihood.pdfs[i].axe.axis, likelihood.pdfs[i].grid, lw=var.LINEWIDTH,
1232
            #               label=f'Walker {k:d}')
1233
            # ax[i, 1].set_xlabel(self.axis_names[i])
1234
            # ax[i, 1].set_ylabel('PDF')
1235
            # ax[i, 1].legend(loc='upper right', ncol=2, fontsize=10)
1236
        # Gelman-Rubin test.py
1237
        if len(nchains) > 1:
1✔
1238
            step = max(1, (self.nsteps - self.burnin) // 20)
1✔
1239
            self.gelmans = []
1✔
1240
            print(f'Gelman-Rubin tests (burnin={self.burnin:d}, step={step:d}, nsteps={self.nsteps:d}):')
1✔
1241
            for i in range(self.params.ndim):
1✔
1242
                Rs = []
1✔
1243
                lens = []
1✔
1244
                for pos in range(self.burnin + step, self.nsteps, step):
1✔
1245
                    chain_averages = []
1✔
1246
                    chain_variances = []
1✔
1247
                    global_average = np.mean(self.chains[self.burnin:pos, self.valid_chains, i])
1✔
1248
                    for k in nchains:
1✔
1249
                        if not self.valid_chains[k]:
1✔
1250
                            continue
×
1251
                        chain_averages.append(np.mean(self.chains[self.burnin:pos, k, i]))
1✔
1252
                        chain_variances.append(np.var(self.chains[self.burnin:pos, k, i], ddof=1))
1✔
1253
                    W = np.mean(chain_variances)
1✔
1254
                    B = 0
1✔
1255
                    for n in range(len(chain_averages)):
1✔
1256
                        B += (chain_averages[n] - global_average) ** 2
1✔
1257
                    B *= ((pos + 1) / (len(chain_averages) - 1))
1✔
1258
                    R = (W * pos / (pos + 1) + B / (pos + 1) * (len(chain_averages) + 1) / len(chain_averages)) / W
1✔
1259
                    Rs.append(R - 1)
1✔
1260
                    lens.append(pos)
1✔
1261
                print(f'\t{self.params.input_labels[i]}: R-1 = {Rs[-1]:.3f} (l = {lens[-1] - 1:d})')
1✔
1262
                self.gelmans.append(Rs[-1])
1✔
1263
                ax[i, 1].plot(lens, Rs, lw=1, label=self.params.axis_names[i])
1✔
1264
                ax[i, 1].axhline(0.03, c='k', linestyle='--')
1✔
1265
                ax[i, 1].set_xlabel('Walker length', fontsize=fontsize)
1✔
1266
                ax[i, 1].set_ylabel('$R-1$', fontsize=fontsize)
1✔
1267
                ax[i, 1].set_ylim(0, 0.6)
1✔
1268
                # ax[self.dim, 3].legend(loc='best', ncol=2, fontsize=10)
1269
        self.gelmans = np.array(self.gelmans)
1✔
1270
        fig.tight_layout()
1✔
1271
        plt.subplots_adjust(hspace=0)
1✔
1272
        if parameters.DISPLAY:  # pragma: no cover
1273
            plt.show()
1274
        if parameters.PdfPages:
1✔
1275
            parameters.PdfPages.savefig()
×
1276
        figure_name = self.emcee_filename.replace('.h5', '_convergence.pdf')
1✔
1277
        print(f'Save figure: {figure_name}')
1✔
1278
        fig.savefig(figure_name, dpi=100)
1✔
1279

1280
    def print_settings(self):
1✔
1281
        """Print the main settings of the FitWorkspace.
1282

1283
        """
1284
        print('************************************')
1✔
1285
        print(f"Input file: {self.filename}\nWalkers: {self.nwalkers}\t Steps: {self.nsteps}")
1✔
1286
        print(f"Output file: {self.emcee_filename}")
1✔
1287
        print('************************************')
1✔
1288

1289

1290
def lnprob(p):  # pragma: no cover
1291
    global fit_workspace
1292
    lp = fit_workspace.lnprior(p)
1293
    if not np.isfinite(lp):
1294
        return -1e20
1295
    return lp + fit_workspace.lnlike(p)
1296

1297

1298
def gradient_descent(fit_workspace, epsilon, niter=10, xtol=1e-3, ftol=1e-3, with_line_search=True):
1✔
1299
    """
1300

1301
    Four cases are implemented: diagonal W, 2D W, array of diagonal Ws, array of 2D Ws. The two latter cases
1302
    are for multiple independent data vectors with W being block diagonal.
1303

1304
    Parameters
1305
    ----------
1306
    fit_workspace: FitWorkspace
1307
    epsilon
1308
    niter
1309
    xtol
1310
    ftol
1311
    with_line_search
1312

1313
    Returns
1314
    -------
1315

1316
    """
1317
    my_logger = set_logger(__name__)
1✔
1318
    tmp_params = np.copy(fit_workspace.params.p)
1✔
1319
    fit_workspace.prepare_weight_matrices()
1✔
1320
    n_data_masked = len(fit_workspace.mask) + len(fit_workspace.outliers)
1✔
1321
    ipar = fit_workspace.params.get_free_parameters()
1✔
1322
    costs = []
1✔
1323
    params_table = []
1✔
1324
    inv_JT_W_J = np.zeros((len(ipar), len(ipar)))
1✔
1325
    for i in range(niter):
1✔
1326
        start = time.time()
1✔
1327
        cost, tmp_lambdas, tmp_model, tmp_model_err = fit_workspace.chisq(tmp_params, model_output=True)
1✔
1328
        # W matrix
1329
        W = fit_workspace.compute_W_with_model_error(tmp_model_err)
1✔
1330
        # residuals
1331
        if isinstance(W, np.ndarray) and W.dtype != object:
1✔
1332
            residuals = (tmp_model - fit_workspace.data).flatten()
1✔
1333
        elif isinstance(W, np.ndarray) and W.dtype == object:
1✔
1334
            residuals = [(tmp_model[k] - fit_workspace.data[k]) for k in range(len(W))]
1✔
1335
        else:
1336
            raise TypeError(f"Type of fit_workspace.W is {type(W)}. It must be a np.ndarray.")
×
1337
        # Jacobian
1338
        J = fit_workspace.jacobian(tmp_params, epsilon, model_input=[tmp_lambdas, tmp_model, tmp_model_err])
1✔
1339
        # remove parameters with unexpected null Jacobian vectors
1340
        for ip in range(J.shape[0]):
1✔
1341
            if ip not in ipar:
1✔
1342
                continue
1✔
1343
            if np.all(np.array(J[ip]).flatten() == np.zeros(np.array(J[ip]).size)):
1✔
1344
                ipar = np.delete(ipar, list(ipar).index(ip))
1✔
1345
                fit_workspace.params.fixed[ip] = True
1✔
1346
                my_logger.warning(
1✔
1347
                    f"\n\tStep {i}: {fit_workspace.params.input_labels[ip]} has a null Jacobian; parameter is fixed "
1348
                    f"at its last known current value ({tmp_params[ip]}).")
1349
        # remove fixed parameters
1350
        J = J[ipar].T
1✔
1351
        if W.ndim == 1 and W.dtype != object:
1✔
1352
            JT_W = J.T * W
1✔
1353
            JT_W_J = JT_W @ J
1✔
1354
        elif W.ndim == 2 and W.dtype != object:
1✔
1355
            JT_W = J.T @ W
1✔
1356
            JT_W_J = JT_W @ J
1✔
1357
        else:
1358
            if W[0].ndim == 1:
1✔
1359
                JT_W = np.array([j for j in J]).T * np.concatenate(W).ravel()
1✔
1360
                JT_W_J = JT_W @ np.array([j for j in J])
1✔
1361
            else:
1362
                # warning ! here the data arrays indexed by k can have different lengths because outliers
1363
                # because W inverse covariance is block diagonal and blocks can have different sizes
1364
                # the philosophy is to temporarily flatten the data arrays
1365
                JT_W = [np.concatenate([J[ip][k].T @ W[k]
1✔
1366
                                        for k in range(W.shape[0])]).ravel()
1367
                        for ip in range(len(J))]
1368
                JT_W_J = np.array([[JT_W[ip2] @ np.concatenate(J[ip1][:]).ravel() for ip1 in range(len(J))]
1✔
1369
                                   for ip2 in range(len(J))])
1370
        try:
1✔
1371
            L = np.linalg.inv(np.linalg.cholesky(JT_W_J))  # cholesky is too sensible to the numerical precision
1✔
1372
            inv_JT_W_J = L.T @ L
1✔
1373
        except np.linalg.LinAlgError:
×
1374
            inv_JT_W_J = np.linalg.inv(JT_W_J)
×
1375
        if fit_workspace.W.dtype != object:
1✔
1376
            JT_W_R0 = JT_W @ residuals
1✔
1377
        else:
1378
            JT_W_R0 = JT_W @ np.concatenate(residuals).ravel()
1✔
1379
        dparams = - inv_JT_W_J @ JT_W_R0
1✔
1380

1381
        if with_line_search:
1✔
1382
            def line_search(alpha):
1✔
1383
                tmp_params_2 = np.copy(tmp_params)
1✔
1384
                tmp_params_2[ipar] = tmp_params[ipar] + alpha * dparams
1✔
1385
                for ipp, pp in enumerate(tmp_params_2):
1✔
1386
                    if pp < fit_workspace.params.bounds[ipp][0]:
1✔
1387
                        tmp_params_2[ipp] = fit_workspace.params.bounds[ipp][0]
1✔
1388
                    if pp > fit_workspace.params.bounds[ipp][1]:
1✔
1389
                        tmp_params_2[ipp] = fit_workspace.params.bounds[ipp][1]
1✔
1390
                return fit_workspace.chisq(tmp_params_2)
1✔
1391

1392
            # tol parameter acts on alpha (not func)
1393
            alpha_min, fval, iter, funcalls = optimize.brent(line_search, full_output=True, tol=5e-1, brack=(0, 1))
1✔
1394
        else:
1395
            alpha_min = 1
×
1396
            fval = np.copy(cost)
×
1397
            funcalls = 0
×
1398
            iter = 0
×
1399

1400
        tmp_params[ipar] += alpha_min * dparams
1✔
1401
        # check bounds
1402
        for ip, p in enumerate(tmp_params):
1✔
1403
            if p < fit_workspace.params.bounds[ip][0]:
1✔
1404
                tmp_params[ip] = fit_workspace.params.bounds[ip][0]
1✔
1405
            if p > fit_workspace.params.bounds[ip][1]:
1✔
1406
                tmp_params[ip] = fit_workspace.params.bounds[ip][1]
1✔
1407

1408
        # prepare outputs
1409
        costs.append(fval)
1✔
1410
        params_table.append(np.copy(tmp_params))
1✔
1411
        fit_workspace.p = tmp_params
1✔
1412
        if fit_workspace.verbose:
1✔
1413
            my_logger.info(f"\n\tIteration={i}: initial cost={cost:.5g} initial chisq_red={cost / (tmp_model.size - n_data_masked):.5g}"
1✔
1414
                           f"\n\t\t Line search: alpha_min={alpha_min:.3g} iter={iter} funcalls={funcalls}"
1415
                           f"\n\tParameter shifts: {alpha_min * dparams}"
1416
                           f"\n\tNew parameters: {tmp_params[ipar]}"
1417
                           f"\n\tFinal cost={fval:.5g} final chisq_red={fval / (tmp_model.size - n_data_masked):.5g} "
1418
                           f"computed in {time.time() - start:.2f}s")
1419
        if fit_workspace.live_fit:  # pragma: no cover
1420
            fit_workspace.simulate(*tmp_params)
1421
            fit_workspace.plot_fit()
1422
            fit_workspace.cov = inv_JT_W_J
1423
            # fit_workspace.params.plot_correlation_matrix(ipar)
1424
        if len(ipar) == 0:
1✔
1425
            my_logger.warning(f"\n\tGradient descent terminated in {i} iterations because all parameters "
×
1426
                              f"have null Jacobian.")
1427
            break
×
1428
        if np.sum(np.abs(alpha_min * dparams)) / np.sum(np.abs(tmp_params[ipar])) < xtol:
1✔
1429
            my_logger.info(f"\n\tGradient descent terminated in {i} iterations because the sum of parameter shift "
1✔
1430
                           f"relative to the sum of the parameters is below xtol={xtol}.")
1431
            break
1✔
1432
        if len(costs) > 1 and np.abs(costs[-2] - fval) / np.max([np.abs(fval), np.abs(costs[-2])]) < ftol:
1✔
1433
            my_logger.info(f"\n\tGradient descent terminated in {i} iterations because the "
1✔
1434
                           f"relative change of cost is below ftol={ftol}.")
1435
            break
1✔
1436
    plt.close()
1✔
1437
    return tmp_params, inv_JT_W_J, np.array(costs), np.array(params_table)
1✔
1438

1439

1440
def simple_newton_minimisation(fit_workspace, epsilon, niter=10, xtol=1e-3, ftol=1e-3):  # pragma: no cover
1441
    """Experimental function to minimize a function.
1442

1443
    Parameters
1444
    ----------
1445
    fit_workspace: FitWorkspace
1446
    epsilon
1447
    niter
1448
    xtol
1449
    ftol
1450

1451
    """
1452
    my_logger = set_logger(__name__)
1453
    tmp_params = np.copy(fit_workspace.params.p)
1454
    ipar = fit_workspace.params.get_free_parameters()
1455
    funcs = []
1456
    params_table = []
1457
    inv_H = np.zeros((len(ipar), len(ipar)))
1458
    for i in range(niter):
1459
        start = time.time()
1460
        tmp_lambdas, tmp_model, tmp_model_err = fit_workspace.simulate(*tmp_params)
1461
        # if fit_workspace.live_fit:
1462
        #    fit_workspace.plot_fit()
1463
        J = fit_workspace.jacobian(tmp_params, epsilon)
1464
        # remove parameters with unexpected null Jacobian vectors
1465
        for ip in range(J.shape[0]):
1466
            if ip not in ipar:
1467
                continue
1468
            if np.all(J[ip] == np.zeros(J.shape[1])):
1469
                ipar = np.delete(ipar, list(ipar).index(ip))
1470
                # tmp_params[ip] = 0
1471
                my_logger.warning(
1472
                    f"\n\tStep {i}: {fit_workspace.params.input_labels[ip]} has a null Jacobian; parameter is fixed "
1473
                    f"at its last known current value ({tmp_params[ip]}).")
1474
        # remove fixed parameters
1475
        J = J[ipar].T
1476
        # hessian
1477
        H = fit_workspace.hessian(tmp_params, epsilon)
1478
        try:
1479
            L = np.linalg.inv(np.linalg.cholesky(H))  # cholesky is too sensible to the numerical precision
1480
            inv_H = L.T @ L
1481
        except np.linalg.LinAlgError:
1482
            inv_H = np.linalg.inv(H)
1483
        dparams = - inv_H[:, :, 0] @ J[:, 0]
1484
        print("dparams", dparams, inv_H, J, H)
1485
        tmp_params[ipar] += dparams
1486

1487
        # check bounds
1488
        print("tmp_params", tmp_params, dparams, inv_H, J)
1489
        for ip, p in enumerate(tmp_params):
1490
            if p < fit_workspace.params.bounds[ip][0]:
1491
                tmp_params[ip] = fit_workspace.params.bounds[ip][0]
1492
            if p > fit_workspace.params.bounds[ip][1]:
1493
                tmp_params[ip] = fit_workspace.params.bounds[ip][1]
1494

1495
        tmp_lambdas, new_model, tmp_model_err = fit_workspace.simulate(*tmp_params)
1496
        new_func = new_model[0]
1497
        funcs.append(new_func)
1498

1499
        r = np.log10(fit_workspace.regs)
1500
        js = [fit_workspace.jacobian(np.asarray([rr]), epsilon)[0] for rr in np.array(r)]
1501
        plt.plot(r, js, label="J")
1502
        plt.grid()
1503
        plt.legend()
1504
        plt.show()
1505

1506
        if parameters.DISPLAY:
1507
            fig = plt.figure()
1508
            plt.plot(r, js, label="prior")
1509
            mod = tmp_model + J[0] * (r - (tmp_params - dparams)[0])
1510
            plt.plot(r, mod)
1511
            plt.axvline(tmp_params)
1512
            plt.axhline(tmp_model)
1513
            plt.grid()
1514
            plt.legend()
1515
            plt.draw()
1516
            plt.pause(1e-8)
1517
            plt.close(fig)
1518

1519
        # prepare outputs
1520
        params_table.append(np.copy(tmp_params))
1521
        if fit_workspace.verbose:
1522
            my_logger.info(f"\n\tIteration={i}: initial func={tmp_model[0]:.5g}"
1523
                           f"\n\tParameter shifts: {dparams}"
1524
                           f"\n\tNew parameters: {tmp_params[ipar]}"
1525
                           f"\n\tFinal func={new_func:.5g}"
1526
                           f" computed in {time.time() - start:.2f}s")
1527
        if fit_workspace.live_fit:
1528
            fit_workspace.simulate(*tmp_params)
1529
            fit_workspace.plot_fit()
1530
            fit_workspace.cov = inv_H[:, :, 0]
1531
            print("shape", fit_workspace.cov.shape)
1532
            # fit_workspace.params.plot_correlation_matrix(ipar)
1533
        if len(ipar) == 0:
1534
            my_logger.warning(f"\n\tGradient descent terminated in {i} iterations because all parameters "
1535
                              f"have null Jacobian.")
1536
            break
1537
        if np.sum(np.abs(dparams)) / np.sum(np.abs(tmp_params[ipar])) < xtol:
1538
            my_logger.info(f"\n\tGradient descent terminated in {i} iterations because the sum of parameter shift "
1539
                           f"relative to the sum of the parameters is below xtol={xtol}.")
1540
            break
1541
        if len(funcs) > 1 and np.abs(funcs[-2] - new_func) / np.max([np.abs(new_func), np.abs(funcs[-2])]) < ftol:
1542
            my_logger.info(f"\n\tGradient descent terminated in {i} iterations because the "
1543
                           f"relative change of cost is below ftol={ftol}.")
1544
            break
1545
    plt.close()
1546
    return tmp_params, inv_H[:, :, 0], np.array(funcs), np.array(params_table)
1547

1548

1549
def run_gradient_descent(fit_workspace, epsilon, xtol, ftol, niter, verbose=False, with_line_search=True):
1✔
1550
    if fit_workspace.costs.size == 0:
1✔
1551
        fit_workspace.costs = np.array([fit_workspace.chisq(fit_workspace.params.p)])
1✔
1552
        fit_workspace.params_table = np.array([fit_workspace.params.p])
1✔
1553
    p, cov, tmp_costs, tmp_params_table = gradient_descent(fit_workspace, epsilon, niter=niter, xtol=xtol, ftol=ftol,
1✔
1554
                                                           with_line_search=with_line_search)
1555
    fit_workspace.params.p, fit_workspace.params.cov = p, cov
1✔
1556
    fit_workspace.params_table = np.concatenate([fit_workspace.params_table, tmp_params_table])
1✔
1557
    fit_workspace.costs = np.concatenate([fit_workspace.costs, tmp_costs])
1✔
1558
    if verbose or fit_workspace.verbose:
1✔
1559
        fit_workspace.my_logger.info(f"\n\t{fit_workspace.params.print_parameters_summary()}")
1✔
1560
    if parameters.DEBUG and (verbose or fit_workspace.verbose):
1✔
1561
        fit_workspace.plot_gradient_descent()
1✔
1562
        if len(fit_workspace.params.get_free_parameters()) > 1:
1✔
1563
            fit_workspace.params.plot_correlation_matrix()
1✔
1564

1565

1566
def run_simple_newton_minimisation(fit_workspace, epsilon, xtol=1e-8, ftol=1e-8, niter=50, verbose=False):  # pragma: no cover
1567
    fit_workspace.p, fit_workspace.cov, funcs, params_table = simple_newton_minimisation(fit_workspace,
1568
                                                                                         epsilon, niter=niter,
1569
                                                                                         xtol=xtol, ftol=ftol)
1570
    if verbose or fit_workspace.verbose:
1571
        fit_workspace.my_logger.info(f"\n\t{fit_workspace.params.print_parameters_summary()}")
1572
    if parameters.DEBUG and (verbose or fit_workspace.verbose):
1573
        fit_workspace.plot_gradient_descent()
1574
        if len(fit_workspace.params.get_free_parameters()) > 1:
1575
            fit_workspace.params.plot_correlation_matrix()
1576
    return params_table, funcs
1577

1578

1579
def run_minimisation(fit_workspace, method="newton", epsilon=None, xtol=1e-4, ftol=1e-4, niter=50,
1✔
1580
                     verbose=False, with_line_search=True, minimizer_method="L-BFGS-B"):
1581
    my_logger = set_logger(__name__)
1✔
1582

1583
    bounds = fit_workspace.params.bounds
1✔
1584

1585
    nll = lambda params: -fit_workspace.lnlike(params)
1✔
1586

1587
    guess = fit_workspace.params.p.astype('float64')
1✔
1588
    if verbose:
1✔
1589
        my_logger.debug(f"\n\tStart guess: {guess}")
1✔
1590

1591
    if method == "minimize":
1✔
1592
        start = time.time()
1✔
1593
        result = optimize.minimize(nll, fit_workspace.params.p, method=minimizer_method,
1✔
1594
                                   options={'ftol': ftol, 'maxiter': 100000}, bounds=bounds)
1595
        fit_workspace.params.p = result['x']
1✔
1596
        if verbose:
1✔
1597
            my_logger.debug(f"\n\t{result}")
1✔
1598
            my_logger.debug(f"\n\tMinimize: total computation time: {time.time() - start}s")
1✔
1599
            if parameters.DEBUG:
1✔
1600
                fit_workspace.plot_fit()
1✔
1601
    elif method == 'basinhopping':
1✔
1602
        start = time.time()
1✔
1603
        minimizer_kwargs = dict(method=minimizer_method, bounds=bounds)
1✔
1604
        result = optimize.basinhopping(nll, guess, minimizer_kwargs=minimizer_kwargs)
1✔
1605
        fit_workspace.params.p = result['x']
1✔
1606
        if verbose:
1✔
1607
            my_logger.debug(f"\n\t{result}")
×
1608
            my_logger.debug(f"\n\tBasin-hopping: total computation time: {time.time() - start}s")
×
1609
            if parameters.DEBUG:
×
1610
                fit_workspace.plot_fit()
×
1611
    elif method == "least_squares":  # pragma: no cover
1612
        fit_workspace.my_logger.warning("least_squares might not work, use with caution... "
1613
                                        "or repair carefully the function weighted_residuals()")
1614
        start = time.time()
1615
        x_scale = np.abs(guess)
1616
        x_scale[x_scale == 0] = 0.1
1617
        p = optimize.least_squares(fit_workspace.weighted_residuals, guess, verbose=2, ftol=1e-6, x_scale=x_scale,
1618
                                   diff_step=0.001, bounds=bounds.T)
1619
        fit_workspace.params.p = p.x  # m.np_values()
1620
        if verbose:
1621
            my_logger.debug(f"\n\t{p}")
1622
            my_logger.debug(f"\n\tLeast_squares: total computation time: {time.time() - start}s")
1623
            if parameters.DEBUG:
1624
                fit_workspace.plot_fit()
1625
    elif method == "minuit":
1✔
1626
        start = time.time()
1✔
1627
        error = 0.1 * np.abs(guess) * np.ones_like(guess)
1✔
1628
        error[2:5] = 0.3 * np.abs(guess[2:5]) * np.ones_like(guess[2:5])
1✔
1629
        z = np.where(np.isclose(error, 0.0, 1e-6))
1✔
1630
        error[z] = 1.
1✔
1631
        # noinspection PyArgumentList
1632
        # m = Minuit(fcn=nll, values=guess, error=error, errordef=1, fix=fix, print_level=verbose, limit=bounds)
1633
        m = Minuit(nll, np.copy(guess))
1✔
1634
        m.errors = error
1✔
1635
        m.errordef = 1
1✔
1636
        m.fixed = fit_workspace.params.fixed
1✔
1637
        m.print_level = verbose
1✔
1638
        m.limits = bounds
1✔
1639
        m.tol = 10
1✔
1640
        m.migrad()
1✔
1641
        fit_workspace.p = np.array(m.values[:])
1✔
1642
        if verbose:
1✔
1643
            my_logger.debug(f"\n\t{m}")
×
1644
            my_logger.debug(f"\n\tMinuit: total computation time: {time.time() - start}s")
×
1645
            if parameters.DEBUG:
×
1646
                fit_workspace.plot_fit()
×
1647
    elif method == "newton":
1✔
1648
        if epsilon is None:
1✔
1649
            epsilon = 1e-4 * guess
1✔
1650
            epsilon[epsilon == 0] = 1e-4
1✔
1651

1652
        start = time.time()
1✔
1653
        run_gradient_descent(fit_workspace, epsilon, xtol=xtol, ftol=ftol, niter=niter, verbose=verbose,
1✔
1654
                             with_line_search=with_line_search)
1655
        if verbose:
1✔
1656
            my_logger.debug(f"\n\tNewton: total computation time: {time.time() - start}s")
1✔
1657
        if fit_workspace.filename != "":
1✔
1658
            write_fitparameter_json(fit_workspace.params.json_filename, fit_workspace.params)
1✔
1659
            fit_workspace.save_gradient_descent()
1✔
1660

1661

1662
def run_minimisation_sigma_clipping(fit_workspace, method="newton", epsilon=None, xtol=1e-4, ftol=1e-4,
1✔
1663
                                    niter=50, sigma_clip=5.0, niter_clip=3, verbose=False):
1664
    my_logger = set_logger(__name__)
1✔
1665
    fit_workspace.sigma_clip = sigma_clip
1✔
1666
    for step in range(niter_clip):
1✔
1667
        if verbose:
1✔
1668
            my_logger.info(f"\n\tSigma-clipping step {step}/{niter_clip} (sigma={sigma_clip})")
1✔
1669
        run_minimisation(fit_workspace, method=method, epsilon=epsilon, xtol=xtol, ftol=ftol, niter=niter)
1✔
1670
        # remove outliers
1671
        if fit_workspace.data.dtype == object:
1✔
1672
            # indices_no_nan = ~np.isnan(np.concatenate(fit_workspace.data).ravel())
1673
            data = np.concatenate(fit_workspace.data).ravel()  # [indices_no_nan]
1✔
1674
            model = np.concatenate(fit_workspace.model).ravel()  # [indices_no_nan]
1✔
1675
            err = np.concatenate(fit_workspace.err).ravel()  # [indices_no_nan]
1✔
1676
        else:
1677
            # indices_no_nan = ~np.isnan(fit_workspace.data.flatten())
1678
            data = fit_workspace.data.flatten()  # [indices_no_nan]
1✔
1679
            model = fit_workspace.model.flatten()  # [indices_no_nan]
1✔
1680
            err = fit_workspace.err.flatten()  # [indices_no_nan]
1✔
1681
        residuals = np.abs(data - model) / err
1✔
1682
        outliers = residuals > sigma_clip
1✔
1683
        outliers = [i for i in range(data.size) if outliers[i]]
1✔
1684
        outliers.sort()
1✔
1685
        if len(outliers) > 0:
1✔
1686
            my_logger.debug(f'\n\tOutliers flat index list: {outliers}')
1✔
1687
            my_logger.info(f'\n\tOutliers: {len(outliers)} / {data.size - len(fit_workspace.mask)} data points '
1✔
1688
                           f'({100 * len(outliers) / (data.size - len(fit_workspace.mask)):.2f}%) '
1689
                           f'at more than {sigma_clip}-sigma from best-fit model.')
1690
            if np.all(fit_workspace.outliers == outliers):
1✔
1691
                my_logger.info(f'\n\tOutliers flat index list unchanged since last iteration: '
1✔
1692
                               f'break the sigma clipping iterations.')
1693
                break
1✔
1694
            else:
1695
                fit_workspace.outliers = outliers
1✔
1696
        else:
1697
            my_logger.info(f'\n\tNo outliers detected at first iteration: break the sigma clipping iterations.')
1✔
1698
            break
1✔
1699

1700

1701
def run_emcee(mcmc_fit_workspace, ln=lnprob):
1✔
1702
    my_logger = set_logger(__name__)
1✔
1703
    mcmc_fit_workspace.print_settings()
1✔
1704
    nsamples = mcmc_fit_workspace.nsteps
1✔
1705
    p0 = mcmc_fit_workspace.set_start()
1✔
1706
    filename = mcmc_fit_workspace.emcee_filename
1✔
1707
    backend = emcee.backends.HDFBackend(filename)
1✔
1708
    try:  # pragma: no cover
1709
        pool = MPIPool()
1710
        if not pool.is_master():
1711
            pool.wait()
1712
            sys.exit(0)
1713
        sampler = emcee.EnsembleSampler(mcmc_fit_workspace.nwalkers, mcmc_fit_workspace.ndim, ln, args=(),
1714
                                        pool=pool, backend=backend)
1715
        my_logger.info(f"\n\tInitial size: {backend.iteration}")
1716
        if backend.iteration > 0:
1717
            p0 = backend.get_last_sample()
1718
        if nsamples - backend.iteration > 0:
1719
            sampler.run_mcmc(p0, nsteps=max(0, nsamples - backend.iteration), progress=True)
1720
        pool.close()
1721
    except ValueError:
1✔
1722
        sampler = emcee.EnsembleSampler(mcmc_fit_workspace.nwalkers, mcmc_fit_workspace.params.ndim, ln, args=(),
1✔
1723
                                        threads=multiprocessing.cpu_count(), backend=backend)
1724
        my_logger.info(f"\n\tInitial size: {backend.iteration}")
1✔
1725
        if backend.iteration > 0:
1✔
1726
            p0 = sampler.get_last_sample()
×
1727
        for _ in sampler.sample(p0, iterations=max(0, nsamples - backend.iteration), progress=True, store=True):
1✔
1728
            continue
1✔
1729
    mcmc_fit_workspace.chains = sampler.chain
1✔
1730
    mcmc_fit_workspace.lnprobs = sampler.lnprobability
1✔
1731

1732

1733
class RegFitWorkspace(FitWorkspace):
1✔
1734

1735
    def __init__(self, w, opt_reg=parameters.PSF_FIT_REG_PARAM, verbose=False, live_fit=False):
1✔
1736
        """
1737

1738
        Parameters
1739
        ----------
1740
        w: ChromaticPSFFitWorkspace, FullFowardModelFitWorkspace
1741
            FitWorkspace instance where to apply regularisation.
1742
        opt_reg: float
1743
            Input value for optimal regularisation parameter (default: parameters.PSF_FIT_REG_PARAM).
1744
        verbose: bool, optional
1745
            Level of verbosity (default: False).
1746
        live_fit: bool, optional
1747
            If True, model, data and residuals plots are made along the fitting procedure (default: False).
1748

1749
        """
1750
        params = FitParameters(np.asarray([np.log10(opt_reg)]), input_labels=["log10_reg"],
1✔
1751
                               axis_names=[r"$\log_{10} r$"], fixed=None,
1752
                               bounds=[(-20, np.log10(w.amplitude_priors.size) + 2)])
1753
        FitWorkspace.__init__(self, params, verbose=verbose, live_fit=live_fit)
1✔
1754
        self.x = np.array([0])
1✔
1755
        self.data = np.array([0])
1✔
1756
        self.err = np.array([1])
1✔
1757
        self.w = w
1✔
1758
        self.opt_reg = opt_reg
1✔
1759
        self.resolution = np.zeros_like((self.w.amplitude_params.size, self.w.amplitude_params.size))
1✔
1760
        self.G = 0
1✔
1761
        self.chisquare = -1
1✔
1762

1763
    def print_regularisation_summary(self):
1✔
1764
        self.my_logger.info(f"\n\tOptimal regularisation parameter: {self.opt_reg}"
1✔
1765
                            f"\n\tTr(R) = {np.trace(self.resolution)}"
1766
                            f"\n\tN_params = {len(self.w.amplitude_params)}"
1767
                            f"\n\tN_data = {self.w.data.size - len(self.w.mask) - len(self.w.outliers)}"
1768
                            f" (without mask and outliers)")
1769

1770
    def simulate(self, log10_r):
1✔
1771
        reg = 10 ** log10_r
1✔
1772
        M_dot_W_dot_M_plus_Q = self.w.M_dot_W_dot_M + reg * self.w.Q
1✔
1773
        try:
1✔
1774
            L = np.linalg.inv(np.linalg.cholesky(M_dot_W_dot_M_plus_Q))
1✔
1775
            cov = L.T @ L
1✔
1776
        except np.linalg.LinAlgError:
×
1777
            cov = np.linalg.inv(M_dot_W_dot_M_plus_Q)
×
1778
        if self.w.W.ndim == 1:
1✔
1779
            A = cov @ (self.w.M.T @ (self.w.W * self.w.data) + reg * self.w.Q_dot_A0)
1✔
1780
        else:
1781
            A = cov @ (self.w.M.T @ (self.w.W @ self.w.data) + reg * self.w.Q_dot_A0)
×
1782
        if A.ndim == 2:  # ndim == 2 when A comes from a sparse matrix computation
1✔
1783
            A = np.asarray(A).reshape(-1)
1✔
1784
        self.resolution = np.eye(A.size) - reg * cov @ self.w.Q
1✔
1785
        diff = self.w.data - self.w.M @ A
1✔
1786
        if self.w.W.ndim == 1:
1✔
1787
            self.chisquare = diff @ (self.w.W * diff)
1✔
1788
        else:
1789
            self.chisquare = diff @ self.w.W @ diff
×
1790
        self.w.amplitude_params = A
1✔
1791
        self.w.amplitude_cov_matrix = cov
1✔
1792
        self.w.amplitude_params_err = np.array([np.sqrt(cov[x, x]) for x in range(cov.shape[0])])
1✔
1793
        self.G = self.chisquare / ((self.w.data.size - len(self.w.mask) - len(self.w.outliers)) - np.trace(self.resolution)) ** 2
1✔
1794
        return np.asarray([log10_r]), np.asarray([self.G]), np.zeros_like(self.data)
1✔
1795

1796
    def plot_fit(self):
1✔
1797
        log10_opt_reg = self.params.p[0]
1✔
1798
        opt_reg = 10 ** log10_opt_reg
1✔
1799
        regs = 10 ** np.linspace(min(-10, 0.9 * log10_opt_reg), max(3, 1.2 * log10_opt_reg), 50)
1✔
1800
        Gs = []
1✔
1801
        chisqs = []
1✔
1802
        resolutions = []
1✔
1803
        x = np.arange(len(self.w.amplitude_priors))
1✔
1804
        for r in regs:
1✔
1805
            self.simulate(np.log10(r))
1✔
1806
            if parameters.DISPLAY and False:  # pragma: no cover
1807
                fig = plt.figure()
1808
                plt.errorbar(x, self.w.amplitude_params, yerr=[np.sqrt(self.w.amplitude_cov_matrix[i, i]) for i in x],
1809
                             label=f"fit r={r:.2g}")
1810
                plt.plot(x, self.w.amplitude_priors, label="prior")
1811
                plt.grid()
1812
                plt.legend()
1813
                plt.draw()
1814
                plt.pause(1e-8)
1815
                plt.close(fig)
1816
            Gs.append(self.G)
1✔
1817
            chisqs.append(self.chisquare)
1✔
1818
            resolutions.append(np.trace(self.resolution))
1✔
1819
        fig, ax = plt.subplots(3, 1, figsize=(7, 5), sharex="all")
1✔
1820
        ax[0].plot(regs, Gs)
1✔
1821
        ax[0].axvline(opt_reg, color="k")
1✔
1822
        ax[1].axvline(opt_reg, color="k")
1✔
1823
        ax[2].axvline(opt_reg, color="k")
1✔
1824
        ax[0].set_ylabel(r"$G(r)$")
1✔
1825
        ax[0].set_xlabel("Regularisation hyper-parameter $r$")
1✔
1826
        ax[0].grid()
1✔
1827
        ax[0].set_title(f"Optimal regularisation parameter: {opt_reg:.3g}")
1✔
1828
        ax[1].plot(regs, chisqs)
1✔
1829
        ax[1].set_ylabel(r"$\chi^2(\mathbf{A}(r) \vert \mathbf{\theta})$")
1✔
1830
        ax[1].set_xlabel("Regularisation hyper-parameter $r$")
1✔
1831
        ax[1].grid()
1✔
1832
        ax[1].set_xscale("log")
1✔
1833
        ax[2].set_xscale("log")
1✔
1834
        ax[2].plot(regs, resolutions)
1✔
1835
        ax[2].set_ylabel(r"$\mathrm{Tr}\,\mathbf{R}$")
1✔
1836
        ax[2].set_xlabel("Regularisation hyper-parameter $r$")
1✔
1837
        ax[2].grid()
1✔
1838
        fig.tight_layout()
1✔
1839
        plt.subplots_adjust(hspace=0)
1✔
1840
        if parameters.DISPLAY:
1✔
1841
            plt.show()
×
1842
        if parameters.LSST_SAVEFIGPATH:
1✔
1843
            fig.savefig(os.path.join(parameters.LSST_SAVEFIGPATH, 'regularisation.pdf'))
×
1844

1845
        fig = plt.figure(figsize=(7, 5))
1✔
1846
        rho = compute_correlation_matrix(self.w.amplitude_cov_matrix)
1✔
1847
        plot_correlation_matrix_simple(plt.gca(), rho, axis_names=[''] * len(self.w.amplitude_params))
1✔
1848
        # ipar=np.arange(10, 20))
1849
        plt.gca().set_title(r"Correlation matrix $\mathbf{\rho}$")
1✔
1850
        fig.tight_layout()
1✔
1851
        if parameters.LSST_SAVEFIGPATH:
1✔
1852
            fig.savefig(os.path.join(parameters.LSST_SAVEFIGPATH, 'amplitude_correlation_matrix.pdf'))
×
1853
        if parameters.DISPLAY:
1✔
1854
            plt.show()
×
1855

1856
    def run_regularisation(self):
1✔
1857
        run_minimisation(self, method="minimize", ftol=1e-4, xtol=1e-2, verbose=self.verbose, epsilon=[1e-1],
1✔
1858
                         minimizer_method="Nelder-Mead")
1859
        self.opt_reg = 10 ** self.params.p[0]
1✔
1860
        self.simulate(np.log10(self.opt_reg))
1✔
1861
        self.print_regularisation_summary()
1✔
1862

1863

1864
if __name__ == "__main__":
1✔
1865
    import doctest
1✔
1866

1867
    doctest.testmod()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc