• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pypest / pyemu / 19207387408

09 Nov 2025 10:49AM UTC coverage: 57.256% (-20.4%) from 77.703%
19207387408

Pull #633

github

web-flow
Merge a1b5645d3 into 04c2058d1
Pull Request #633: parallel tests again -- trying to get something closer to 1 hour

0 of 1 new or added line in 1 file covered. (0.0%)

3455 existing lines in 22 files now uncovered.

11008 of 19226 relevant lines covered (57.26%)

5.22 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

15.38
/pyemu/plot/plot_utils.py
1
"""Plotting functions for various PEST(++) and pyemu operations"""
2
import os
12✔
3
import numpy as np
12✔
4
import pandas as pd
12✔
5
import warnings
12✔
6
from datetime import datetime
12✔
7
import string
12✔
8
from pyemu.logger import Logger
12✔
9
from pyemu.pst import pst_utils
12✔
10
from ..pyemu_warnings import PyemuWarning
12✔
11

12
font = {"font.size": 6}
12✔
13
try:
12✔
14
    import matplotlib
12✔
15
    import matplotlib.pyplot as plt
12✔
16
    from matplotlib.backends.backend_pdf import PdfPages
12✔
17
    from matplotlib.gridspec import GridSpec
12✔
18
except Exception as e:
×
19
    # raise Exception("error importing matplotlib: {0}".format(str(e)))
20
    warnings.warn("error importing matplotlib: {0}".format(str(e)), PyemuWarning)
×
21

22
import pyemu
12✔
23

24
figsize = (8, 10.5)
12✔
25
nr, nc = 4, 2
12✔
26
# page_gs = GridSpec(nr,nc)
27

28
abet = string.ascii_uppercase
12✔
29

30
def apply_custom_font(rc_params=None):
12✔
31
    if rc_params is None:
12✔
32
        rc_params = font
12✔
33
    def decorator(func):
12✔
34
        def wrapper(*args, **kwargs):
12✔
35
            with plt.rc_context(rc_params):
11✔
36
                return func(*args, **kwargs)
11✔
37
        return wrapper
12✔
38
    return decorator
12✔
39

40

41
@apply_custom_font({"font.size": 6})
12✔
42
def plot_summary_distributions(
12✔
43
    df,
44
    ax=None,
45
    label_post=False,
46
    label_prior=False,
47
    subplots=False,
48
    figsize=(11, 8.5),
49
    pt_color="b",
50
):
51
    """helper function to plot gaussian distributions from prior and posterior
52
    means and standard deviations
53

54
    Args:
55
        df (`pandas.DataFrame`): a dataframe and csv file.  Must have columns named:
56
            'prior_mean','prior_stdev','post_mean','post_stdev'.  If loaded
57
            from a csv file, column 0 is assumed to be the index
58
        ax (`matplotlib.pyplot.axis`): If None, and not subplots, then one is created
59
            and all distributions are plotted on a single plot
60
        label_post (`bool`): flag to add text labels to the peak of the posterior
61
        label_prior (`bool`): flag to add text labels to the peak of the prior
62
        subplots (`bool`): flag to use subplots.  If True, then 6 axes per page
63
            are used and a single prior and posterior is plotted on each
64
        figsize (`tuple`): matplotlib figure size
65

66
    Returns:
67
        tuple containing:
68

69
        - **[`matplotlib.figure`]**: list of figures
70
        - **[`matplotlib.axis`]**: list of axes
71

72
    Note:
73
        This is useful for demystifying FOSM results
74

75
        if subplots is False, a single axis is returned
76

77
    Example::
78

79
        import matplotlib.pyplot as plt
80
        import pyemu
81
        pyemu.plot_utils.plot_summary_distributions("pest.par.usum.csv")
82
        plt.show()
83

84
    """
85
    import matplotlib.pyplot as plt
11✔
86
    if isinstance(df, str):
11✔
87
        df = pd.read_csv(df, index_col=0)
11✔
88
    if ax is None and not subplots:
11✔
89
        fig = plt.figure(figsize=figsize)
11✔
90
        ax = plt.subplot(111)
11✔
91
        ax.grid()
11✔
92

93
    if "post_stdev" not in df.columns and "post_var" in df.columns:
11✔
94
        df.loc[:, "post_stdev"] = df.post_var.apply(np.sqrt)
×
95
    if "prior_stdev" not in df.columns and "prior_var" in df.columns:
11✔
96
        df.loc[:, "prior_stdev"] = df.prior_var.apply(np.sqrt)
×
97
    if "prior_expt" not in df.columns and "prior_mean" in df.columns:
11✔
98
        df.loc[:, "prior_expt"] = df.prior_mean
11✔
99
    if "post_expt" not in df.columns and "post_mean" in df.columns:
11✔
100
        df.loc[:, "post_expt"] = df.post_mean
11✔
101

102
    if subplots:
11✔
103
        fig = plt.figure(figsize=figsize)
11✔
104
        ax = plt.subplot(2, 3, 1)
11✔
105
        ax_per_page = 6
11✔
106
        ax_count = 0
11✔
107
        axes = []
11✔
108
        figs = []
11✔
109
    for name in df.index:
11✔
110
        x, y = gaussian_distribution(
11✔
111
            df.loc[name, "post_expt"], df.loc[name, "post_stdev"]
112
        )
113
        ax.fill_between(x, 0, y, facecolor=pt_color, edgecolor="none", alpha=0.25)
11✔
114
        if label_post:
11✔
115
            mx_idx = np.argmax(y)
11✔
116
            xtxt, ytxt = x[mx_idx], y[mx_idx] * 1.001
11✔
117
            ax.text(xtxt, ytxt, name, ha="center", alpha=0.5)
11✔
118

119
        x, y = gaussian_distribution(
11✔
120
            df.loc[name, "prior_expt"], df.loc[name, "prior_stdev"]
121
        )
122
        ax.plot(x, y, color="0.5", lw=3.0, dashes=(2, 1))
11✔
123
        if label_prior:
11✔
124
            mx_idx = np.argmax(y)
×
125
            xtxt, ytxt = x[mx_idx], y[mx_idx] * 1.001
×
126
            ax.text(xtxt, ytxt, name, ha="center", alpha=0.5)
×
127
        # ylim = list(ax.get_ylim())
128
        # ylim[1] *= 1.2
129
        # ylim[0] = 0.0
130
        # ax.set_ylim(ylim)
131
        if subplots:
11✔
132
            ax.set_title(name)
11✔
133
            ax_count += 1
11✔
134
            ax.set_yticklabels([])
11✔
135
            axes.append(ax)
11✔
136
            if name == df.index[-1]:
11✔
137
                break
11✔
138
            if ax_count >= ax_per_page:
11✔
139
                figs.append(fig)
11✔
140
                fig = plt.figure(figsize=figsize)
11✔
141
                ax_count = 0
11✔
142
            ax = plt.subplot(2, 3, ax_count + 1)
11✔
143
    if subplots:
11✔
144
        figs.append(fig)
11✔
145
        return figs, axes
11✔
146
    ylim = list(ax.get_ylim())
11✔
147
    ylim[1] *= 1.2
11✔
148
    ylim[0] = 0.0
11✔
149
    ax.set_ylim(ylim)
11✔
150
    ax.set_yticklabels([])
11✔
151
    return ax
11✔
152

153

154
@apply_custom_font()
12✔
155
def gaussian_distribution(mean, stdev, num_pts=50):
12✔
156
    """get an x and y numpy.ndarray that spans the +/- 4
157
    standard deviation range of a gaussian distribution with
158
    a given mean and standard deviation. useful for plotting
159

160
    Args:
161
        mean (`float`): the mean of the distribution
162
        stdev (`float`): the standard deviation of the distribution
163
        num_pts (`int`): the number of points in the returned ndarrays.
164
            Default is 50
165

166
    Returns:
167
        tuple containing:
168

169
        - **numpy.ndarray**: the x-values of the distribution
170
        - **numpy.ndarray**: the y-values of the distribution
171

172
    Example::
173

174
        mean,std = 1.0, 2.0
175
        x,y = pyemu.plot.gaussian_distribution(mean,std)
176
        plt.fill_between(x,0,y)
177
        plt.show()
178

179

180
    """
181
    xstart = mean - (4.0 * stdev)
11✔
182
    xend = mean + (4.0 * stdev)
11✔
183
    x = np.linspace(xstart, xend, num_pts)
11✔
184
    y = (1.0 / np.sqrt(2.0 * np.pi * stdev * stdev)) * np.exp(
11✔
185
        -1.0 * ((x - mean) ** 2) / (2.0 * stdev * stdev)
186
    )
187
    return x, y
11✔
188

189

190
@apply_custom_font()
12✔
191
def pst_helper(pst, kind=None, **kwargs):
12✔
192
    """`pyemu.Pst` plot helper - takes the
193
    handoff from `pyemu.Pst.plot()`
194

195
    Args:
196
        kind (`str`): the kind of plot to make
197
        **kargs (`dict`): keyword arguments to pass to the
198
            plotting function and ultimately to `matplotlib`
199

200
    Returns:
201
        varies: usually a combination of `matplotlib.figure` (s) and/or
202
        `matplotlib.axis` (s)
203

204
    Example::
205

206
        pst = pyemu.Pst("pest.pst") #assumes pest.res or pest.rei is found
207
        pst.plot(kind="1to1")
208
        plt.show()
209
        pst.plot(kind="phipie")
210
        plt.show()
211
        pst.plot(kind="prior")
212
        plt.show()
213

214
    """
215

UNCOV
216
    echo = kwargs.get("echo", False)
×
UNCOV
217
    logger = pyemu.Logger("plot_pst_helper.log", echo=echo)
×
UNCOV
218
    logger.statement("plot_utils.pst_helper()")
×
219

UNCOV
220
    kinds = {
×
221
        "prior": pst_prior,
222
        "1to1": res_1to1,
223
        "phi_pie": res_phi_pie,
224
        "phi_progress": phi_progress,
225
    }
226

UNCOV
227
    if kind is None:
×
UNCOV
228
        returns = []
×
UNCOV
229
        base_filename = pst.filename
×
UNCOV
230
        if pst.new_filename is not None:
×
231
            base_filename = pst.new_filename
×
UNCOV
232
        base_filename = base_filename.replace(".pst", "")
×
UNCOV
233
        for name, funcc in kinds.items():
×
UNCOV
234
            plt_name = base_filename + "." + name + ".pdf"
×
UNCOV
235
            returns.append(funcc(pst, logger=logger, filename=plt_name))
×
236

UNCOV
237
        return returns
×
UNCOV
238
    elif kind not in kinds:
×
239
        logger.lraise(
×
240
            "unrecognized kind:{0}, should one of {1}".format(
241
                kind, ",".join(list(kinds.keys()))
242
            )
243
        )
UNCOV
244
    return kinds[kind](pst, logger, **kwargs)
×
245

246

247
@apply_custom_font()
12✔
248
def phi_progress(pst, logger=None, filename=None, **kwargs):
12✔
249
    """make plot of phi vs number of model runs - requires
250
    available  ".iobj" file generated by a PESTPP-GLM run.
251

252
    Args:
253
        pst (`pyemu.Pst`): a control file instance
254
        logger (`pyemu.Logger`):  if None, a generic one is created.  Default is None
255
        filename (`str`): PDF filename to save figures to.  If None, figures
256
            are returned.  Default is None
257
        kwargs (`dict`): optional keyword args to pass to plotting function
258

259
    Returns:
260
        `matplotlib.axis`: the axis the plot was made on
261

262
    Example::
263

264
        import matplotlib.pyplot as plt
265
        import pyemu
266
        pst = pyemu.Pst("my.pst")
267
        pyemu.plot_utils.phi_progress(pst)
268
        plt.show()
269

270
    """
UNCOV
271
    if logger is None:
×
272
        logger = Logger("Default_Logger.log", echo=False)
×
UNCOV
273
    logger.log("plot phi_progress")
×
274

UNCOV
275
    iobj_file = pst.filename.replace(".pst", ".iobj")
×
UNCOV
276
    if not os.path.exists(iobj_file):
×
277
        logger.lraise("couldn't find iobj file {0}".format(iobj_file))
×
UNCOV
278
    df = pd.read_csv(iobj_file)
×
UNCOV
279
    if "ax" in kwargs:
×
280
        ax = kwargs["ax"]
×
281
    else:
UNCOV
282
        fig = plt.figure(figsize=figsize)
×
UNCOV
283
        ax = plt.subplot(1, 1, 1)
×
UNCOV
284
    ax.plot(df.model_runs_completed, df.total_phi, marker=".")
×
UNCOV
285
    ax.set_xlabel("model runs")
×
UNCOV
286
    ax.set_ylabel(r"$\phi$")
×
UNCOV
287
    ax.grid()
×
UNCOV
288
    if filename is not None:
×
UNCOV
289
        plt.savefig(filename)
×
UNCOV
290
    logger.log("plot phi_progress")
×
UNCOV
291
    return ax
×
292

293

294
def _get_page_axes(count=nr * nc):
12✔
UNCOV
295
    axes = [plt.subplot(nr, nc, i + 1) for i in range(min(count, nr * nc))]
×
296
    # [ax.set_yticks([]) for ax in axes]
UNCOV
297
    return axes
×
298

299

300
@apply_custom_font()
12✔
301
def res_1to1(
12✔
302
    pst, logger=None, filename=None, plot_hexbin=False, histogram=False, **kwargs
303
):
304
    """make 1-to-1 plots and also observed vs residual by observation group
305

306
    Args:
307
        pst (`pyemu.Pst`): a control file instance
308
        logger (`pyemu.Logger`):  if None, a generic one is created.  Default is None
309
        filename (`str`): PDF filename to save figures to.  If None, figures
310
            are returned.  Default is None
311
        hexbin (`bool`): flag to use the hexbinning for large numbers of residuals.
312
            Default is False
313
        histogram (`bool`): flag to plot residual histograms instead of obs vs residual.
314
            Default is False (use `matplotlib.pyplot.scatter` )
315
        kwargs (`dict`): optional keyword args to pass to plotting function
316

317
    Returns:
318
        `matplotlib.axis`: the axis the plot was made on
319

320
    Example::
321

322
        import matplotlib.pyplot as plt
323
        import pyemu
324
        pst = pyemu.Pst("my.pst")
325
        pyemu.plot_utils.phi_progress(pst)
326
        plt.show()
327

328
    """
UNCOV
329
    if logger is None:
×
UNCOV
330
        logger = Logger("Default_Logger.log", echo=False)
×
UNCOV
331
    logger.log("plot res_1to1")
×
332

UNCOV
333
    if "ensemble" in kwargs:
×
UNCOV
334
        res = pst_utils.res_from_en(pst, kwargs["ensemble"])
×
UNCOV
335
        try:
×
UNCOV
336
            res = pst_utils.res_from_en(pst, kwargs["ensemble"])
×
337
        except Exception as e:
×
338
            logger.lraise("res_1to1: error loading ensemble file: {0}".format(str(e)))
×
339
    else:
UNCOV
340
        try:
×
UNCOV
341
            res = pst.res
×
342
        except:
×
343
            logger.lraise("res_phi_pie: pst.res is None, couldn't find residuals file")
×
344

UNCOV
345
    obs = pst.observation_data
×
346

UNCOV
347
    if "grouper" in kwargs:
×
348
        raise NotImplementedError()
×
349
    else:
UNCOV
350
        grouper = obs.groupby(obs.obgnme).groups
×
351

UNCOV
352
    fig = plt.figure(figsize=figsize)
×
UNCOV
353
    if "fig_title" in kwargs:
×
UNCOV
354
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
355
    else:
UNCOV
356
        plt.figtext(
×
357
            0.5,
358
            0.5,
359
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
360
                pst.filename, str(datetime.now())
361
            ),
362
            ha="center",
363
        )
364
    # if plot_hexbin:
365
    #    pdfname = pst.filename.replace(".pst", ".1to1.hexbin.pdf")
366
    # else:
367
    #    pdfname = pst.filename.replace(".pst", ".1to1.pdf")
UNCOV
368
    figs = []
×
UNCOV
369
    ax_count = 0
×
UNCOV
370
    for g, names in grouper.items():
×
UNCOV
371
        logger.log("plotting 1to1 for {0}".format(g))
×
372

UNCOV
373
        obs_g = obs.loc[names, :]
×
UNCOV
374
        obs_g.loc[:, "sim"] = res.loc[names, "modelled"]
×
UNCOV
375
        logger.statement("using control file obsvals to calculate residuals")
×
UNCOV
376
        obs_g.loc[:, "res"] = obs_g.sim - obs_g.obsval
×
UNCOV
377
        if "include_zero" not in kwargs or kwargs["include_zero"] is False:
×
UNCOV
378
            obs_g = obs_g.loc[obs_g.weight > 0, :]
×
UNCOV
379
        if obs_g.shape[0] == 0:
×
UNCOV
380
            logger.statement("no non-zero obs for group '{0}'".format(g))
×
UNCOV
381
            logger.log("plotting 1to1 for {0}".format(g))
×
UNCOV
382
            continue
×
383

UNCOV
384
        if ax_count % (nr * nc) == 0:
×
UNCOV
385
            if ax_count > 0:
×
UNCOV
386
                plt.tight_layout()
×
387
            # pdf.savefig()
388
            # plt.close(fig)
UNCOV
389
            figs.append(fig)
×
UNCOV
390
            fig = plt.figure(figsize=figsize)
×
UNCOV
391
            axes = _get_page_axes()
×
UNCOV
392
            ax_count = 0
×
393

UNCOV
394
        ax = axes[ax_count]
×
395

396
        # if obs_g.shape[0] == 1:
397
        #    ax.scatter(list(obs_g.sim),list(obs_g.obsval),marker='.',s=30,color='b')
398
        # else:
UNCOV
399
        mx = max(obs_g.obsval.max(), obs_g.sim.max())
×
UNCOV
400
        mn = min(obs_g.obsval.min(), obs_g.sim.min())
×
401

402
        # if obs_g.shape[0] == 1:
UNCOV
403
        mx *= 1.1
×
UNCOV
404
        mn *= 0.9
×
UNCOV
405
        ax.axis("square")
×
UNCOV
406
        if plot_hexbin:
×
407
            ax.hexbin(
×
408
                obs_g.obsval.values,
409
                obs_g.sim.values,
410
                mincnt=1,
411
                gridsize=(75, 75),
412
                extent=(mn, mx, mn, mx),
413
                bins="log",
414
                edgecolors=None,
415
            )
416
        #               plt.colorbar(ax=ax)
417
        else:
UNCOV
418
            ax.scatter(obs_g.obsval.values, obs_g.sim.values,
×
419
                       marker=".", s=10, color="b")
420

UNCOV
421
        ax.plot([mn, mx], [mn, mx], "k--", lw=1.0)
×
UNCOV
422
        xlim = (mn, mx)
×
UNCOV
423
        ax.set_xlim(mn, mx)
×
UNCOV
424
        ax.set_ylim(mn, mx)
×
UNCOV
425
        ax.grid()
×
426

UNCOV
427
        ax.set_xlabel("observed", labelpad=0.1)
×
UNCOV
428
        ax.set_ylabel("simulated", labelpad=0.1)
×
UNCOV
429
        ax.set_title(
×
430
            "{0}) group:{1}, {2} observations".format(
431
                abet[ax_count], g, obs_g.shape[0]
432
            ),
433
            loc="left",
434
        )
435

UNCOV
436
        ax_count += 1
×
437

UNCOV
438
        if histogram == False:
×
UNCOV
439
            ax = axes[ax_count]
×
UNCOV
440
            ax.scatter(obs_g.obsval.values, obs_g.res.values,
×
441
                       marker=".", s=10, color="b")
UNCOV
442
            ylim = ax.get_ylim()
×
UNCOV
443
            mx = max(np.abs(ylim[0]), np.abs(ylim[1]))
×
UNCOV
444
            if obs_g.shape[0] == 1:
×
UNCOV
445
                mx *= 1.1
×
UNCOV
446
            ax.set_ylim(-mx, mx)
×
447
            # show a zero residuals line
UNCOV
448
            ax.plot(xlim, [0, 0], "k--", lw=1.0)
×
UNCOV
449
            meanres = obs_g.res.mean()
×
450
            # show mean residuals line
UNCOV
451
            ax.plot(xlim, [meanres, meanres], "r-", lw=1.0)
×
UNCOV
452
            ax.set_xlim(xlim)
×
UNCOV
453
            ax.set_ylabel("residual", labelpad=0.1)
×
UNCOV
454
            ax.set_xlabel("observed", labelpad=0.1)
×
UNCOV
455
            ax.set_title(
×
456
                "{0}) group:{1}, {2} observations".format(
457
                    abet[ax_count], g, obs_g.shape[0]
458
                ),
459
                loc="left",
460
            )
UNCOV
461
            ax.grid()
×
UNCOV
462
            ax_count += 1
×
463
        else:
464
            # need max and min res to set xlim, otherwise wonky figsize
UNCOV
465
            mxr = obs_g.res.max()
×
UNCOV
466
            mnr = obs_g.res.min()
×
467

468
            # if obs_g.shape[0] == 1:
UNCOV
469
            mxr *= 1.1
×
UNCOV
470
            mnr *= 0.9
×
UNCOV
471
            rlim = (mnr, mxr)
×
472

UNCOV
473
            ax = axes[ax_count]
×
UNCOV
474
            ax.hist(obs_g.res, bins=50, color="b")
×
UNCOV
475
            meanres = obs_g.res.mean()
×
UNCOV
476
            ax.axvline(meanres, color="r", lw=1)
×
UNCOV
477
            b, t = ax.get_ylim()
×
UNCOV
478
            ax.text(meanres + meanres / 10, t - t / 10, "Mean: {:.2f}".format(meanres))
×
UNCOV
479
            ax.set_xlim(rlim)
×
UNCOV
480
            ax.set_ylabel("count", labelpad=0.1)
×
UNCOV
481
            ax.set_xlabel("residual", labelpad=0.1)
×
UNCOV
482
            ax.set_title(
×
483
                "{0}) group:{1}, {2} observations".format(
484
                    abet[ax_count], g, obs_g.shape[0]
485
                ),
486
                loc="left",
487
            )
UNCOV
488
            ax.grid()
×
UNCOV
489
            ax_count += 1
×
UNCOV
490
        logger.log("plotting 1to1 for {0}".format(g))
×
491

UNCOV
492
    for a in range(ax_count, nr * nc):
×
UNCOV
493
        axes[a].set_axis_off()
×
UNCOV
494
        axes[a].set_yticks([])
×
UNCOV
495
        axes[a].set_xticks([])
×
496

UNCOV
497
    plt.tight_layout()
×
498
    # pdf.savefig()
499
    # plt.close(fig)
UNCOV
500
    figs.append(fig)
×
UNCOV
501
    if filename is not None:
×
UNCOV
502
        with PdfPages(filename) as pdf:
×
UNCOV
503
            for fig in figs:
×
UNCOV
504
                pdf.savefig(fig)
×
UNCOV
505
                plt.close(fig)
×
UNCOV
506
        logger.log("plot res_1to1")
×
507
    else:
UNCOV
508
        logger.log("plot res_1to1")
×
UNCOV
509
        return figs
×
510

511

512
@apply_custom_font()
12✔
513
def plot_id_bar(id_df, nsv=None, logger=None, **kwargs):
12✔
514
    """Plot a stacked bar chart of identifiability based on
515
    a the `pyemu.ErrVar.get_identifiability()` dataframe
516

517
    Args:
518
        id_df (`pandas.DataFrame`) : dataframe of identifiability
519
        nsv (`int`): number of singular values to consider
520
        logger (`pyemu.Logger`, optional): a logger.  If None, a generic
521
            one is created
522
        kwargs (`dict`): a dict of keyword arguments to pass to the
523
            plotting function
524

525
    Returns:
526
        `matplotlib.Axis`: the axis with the plot
527

528
    Example::
529

530
        import pyemu
531
        pest_obj = pyemu.Pst(pest_control_file)
532
        ev = pyemu.ErrVar(jco='freyberg_jac.jcb'))
533
        id_df = ev.get_identifiability_dataframe(singular_value=48)
534
        pyemu.plot_utils.plot_id_bar(id_df, nsv=12, figsize=(12,4)
535

536
    """
537
    if logger is None:
11✔
538
        logger = Logger("Default_Logger.log", echo=False)
11✔
539
    logger.log("plot id bar")
11✔
540

541
    df = id_df.copy()
11✔
542

543
    # drop the final `ident` column
544
    if "ident" in df.columns:
11✔
545
        df.drop("ident", inplace=True, axis=1)
11✔
546

547
    if nsv is None or nsv > len(df.columns):
11✔
548
        nsv = len(df.columns)
11✔
549
        logger.log("set number of SVs and number in the dataframe")
11✔
550

551
    df = df[df.columns[:nsv]]
11✔
552

553
    df["ident"] = df.sum(axis=1)
11✔
554
    df.sort_values(by="ident", inplace=True, ascending=False)
11✔
555
    df.drop("ident", inplace=True, axis=1)
11✔
556

557
    if "figsize" in kwargs:
11✔
558
        figsize = kwargs["figsize"]
×
559
    else:
560
        figsize = (8, 10.5)
11✔
561
    if "ax" in kwargs:
11✔
562
        ax = kwargs["ax"]
×
563
    else:
564
        fig = plt.figure(figsize=figsize)
11✔
565
        ax = plt.subplot(1, 1, 1)
11✔
566

567
    # plto the stacked bar chart (the easy part!)
568
    df.plot.bar(stacked=True, cmap="jet_r", legend=False, ax=ax)
11✔
569

570
    #
571
    # horrible shenanigans to make a colorbar rather than a legend
572
    #
573

574
    # special case colormap just dark red if one SV
575
    if nsv == 1:
11✔
576
        tcm = matplotlib.colors.LinearSegmentedColormap.from_list(
×
577
            "one_sv", [plt.get_cmap("jet_r")(0)] * 2, N=2
578
        )
579
        sm = plt.cm.ScalarMappable(
×
580
            cmap=tcm, norm=matplotlib.colors.Normalize(vmin=0, vmax=nsv + 1)
581
        )
582
    # or typically just rock the jet_r colormap over the range of SVs
583
    else:
584
        sm = plt.cm.ScalarMappable(
11✔
585
            cmap=plt.get_cmap("jet_r"),
586
            norm=matplotlib.colors.Normalize(vmin=1, vmax=nsv),
587
        )
588
    sm._A = []
11✔
589

590
    # now, if too many ticks for the colorbar, summarize them
591
    if nsv < 20:
11✔
592
        ticks = range(1, nsv + 1)
11✔
593
    else:
594
        ticks = np.arange(1, nsv + 1, int((nsv + 1) / 30))
×
595

596
    cb = plt.colorbar(sm, ax=ax)
11✔
597
    cb.set_ticks(ticks)
11✔
598

599
    logger.log("plot id bar")
11✔
600

601
    return ax
11✔
602

603

604
@apply_custom_font()
12✔
605
def res_phi_pie(pst, logger=None, **kwargs):
12✔
606
    """plot current phi components as a pie chart.
607

608
    Args:
609
        pst (`pyemu.Pst`): a control file instance with the residual datafrane
610
            instance available.
611
        logger (`pyemu.Logger`): a logger.  If None, a generic one is created
612
        kwargs (`dict`): a dict of plotting options. Accepts 'include_zero'
613
            as a flag to include phi groups with only zero-weight obs (not
614
            sure why anyone would do this, but whatevs).
615

616
            Also accepts 'label_comps': list of components for the labels. Options are
617
            ['name', 'phi_comp', 'phi_percent']. Labels will use those three components
618
            in the order of the 'label_comps' list.
619

620
            Any additional
621
            args are passed to `matplotlib`.
622

623
    Returns:
624
        `matplotlib.Axis`: the axis with the plot.
625

626
    Example::
627

628
        import pyemu
629
        pst = pyemu.Pst("my.pst")
630
        pyemu.plot_utils.res_phi_pie(pst,figsize=(12,4))
631
        pyemu.plot_utils.res_phi_pie(pst,label_comps = ['name','phi_percent'], figsize=(12,4))
632

633

634
    """
UNCOV
635
    if logger is None:
×
UNCOV
636
        logger = Logger("Default_Logger.log", echo=False)
×
UNCOV
637
    logger.log("plot res_phi_pie")
×
638

UNCOV
639
    if "ensemble" in kwargs:
×
UNCOV
640
        try:
×
UNCOV
641
            res = pst_utils.res_from_en(pst, kwargs["ensemble"])
×
642
        except:
×
643
            logger.statement(
×
644
                "res_1to1: could not find ensemble file {0}".format(kwargs["ensemble"])
645
            )
646
    else:
UNCOV
647
        try:
×
UNCOV
648
            res = pst.res
×
649
        except:
×
650
            logger.lraise("res_phi_pie: pst.res is None, couldn't find residuals file")
×
651

UNCOV
652
    obs = pst.observation_data
×
UNCOV
653
    phi = pst.phi
×
UNCOV
654
    phi_comps = pst.phi_components
×
UNCOV
655
    norm_phi_comps = pst.phi_components_normalized
×
UNCOV
656
    keys = list(phi_comps.keys())
×
UNCOV
657
    if "include_zero" not in kwargs or kwargs["include_zero"] is False:
×
UNCOV
658
        phi_comps = {k: phi_comps[k] for k in keys if phi_comps[k] > 0.0}
×
UNCOV
659
        keys = list(phi_comps.keys())
×
UNCOV
660
        norm_phi_comps = {k: norm_phi_comps[k] for k in keys}
×
UNCOV
661
    if "ax" in kwargs:
×
UNCOV
662
        ax = kwargs["ax"]
×
663
    else:
UNCOV
664
        fig = plt.figure(figsize=figsize)
×
UNCOV
665
        ax = plt.subplot(1, 1, 1, aspect="equal")
×
666

UNCOV
667
    if "label_comps" not in kwargs:
×
UNCOV
668
        labels = [
×
669
            "{0}\n{1:4G}\n({2:3.1f}%)".format(
670
                k, phi_comps[k], 100.0 * (phi_comps[k] / phi)
671
            )
672
            for k in keys
673
        ]
674
    else:
675
        # make sure the components for the labels are in a list
676
        if not isinstance(kwargs["label_comps"], list):
×
677
            fmtchoices = list([kwargs["label_comps"]])
×
678
        else:
679
            fmtchoices = kwargs["label_comps"]
×
680
        # assemble all possible label components
681
        labfmts = {
×
682
            "name": ["{}\n", keys],
683
            "phi_comp": ["{:4G}\n", [phi_comps[k] for k in keys]],
684
            "phi_percent": ["({:3.1f}%)", [100.0 * (phi_comps[k] / phi) for k in keys]],
685
        }
686
        if fmtchoices[0] == "phi_percent":
×
687
            labfmts["phi_percent"][0] = "{}\n".format(labfmts["phi_percent"][0])
×
688
        # make the string format
689
        labfmtstr = "".join([labfmts[k][0] for k in fmtchoices])
×
690
        # pull it together
691
        labels = [
×
692
            labfmtstr.format(*k) for k in zip(*[labfmts[j][1] for j in fmtchoices])
693
        ]
694

UNCOV
695
    ax.pie([float(norm_phi_comps[k]) for k in keys], labels=labels)
×
UNCOV
696
    logger.log("plot res_phi_pie")
×
UNCOV
697
    if "filename" in kwargs:
×
UNCOV
698
        plt.savefig(kwargs["filename"])
×
UNCOV
699
    return ax
×
700

701

702
@apply_custom_font()
12✔
703
def pst_prior(pst, logger=None, filename=None, **kwargs):
12✔
704
    """helper to plot prior parameter histograms implied by
705
    parameter bounds. Saves a multipage pdf named <case>.prior.pdf
706

707
    Args:
708
        pst (`pyemu.Pst`): control file
709
        logger (`pyemu.Logger`): a logger.  If None, a generic one is created.
710
        filename (`str`):  PDF filename to save plots to.
711
            If None, return figs without saving.  Default is None.
712
        kwargs (`dict`): additional plotting options. Accepts 'grouper' as
713
            dict to group parameters on to a single axis (use
714
            parameter groups if not passed),'unique_only' to only show unique
715
            mean-stdev combinations within a given group.  Any additional args
716
            are passed to `matplotlib`.
717

718
    Returns:
719
        [`matplotlib.Figure`]: a list of figures created.
720

721
    Example::
722

723
        pst = pyemu.Pst("pest.pst")
724
        pyemu.pst_utils.pst_prior(pst)
725
        plt.show()
726

727
    """
UNCOV
728
    if logger is None:
×
729
        logger = Logger("Default_Logger.log", echo=False)
×
UNCOV
730
    logger.log("plot pst_prior")
×
UNCOV
731
    par = pst.parameter_data
×
732

UNCOV
733
    if "parcov_filename" in pst.pestpp_options:
×
734
        logger.warn("ignoring parcov_filename, using parameter bounds for prior cov")
×
UNCOV
735
    logger.log("loading cov from parameter data")
×
UNCOV
736
    cov = pyemu.Cov.from_parameter_data(pst)
×
UNCOV
737
    logger.log("loading cov from parameter data")
×
738

UNCOV
739
    logger.log("building mean parameter values")
×
UNCOV
740
    li = par.partrans.loc[cov.names] == "log"
×
UNCOV
741
    mean = par.parval1.loc[cov.names]
×
UNCOV
742
    info = par.loc[cov.names, :].copy()
×
UNCOV
743
    info.loc[:, "mean"] = mean
×
UNCOV
744
    info.loc[li, "mean"] = mean[li].apply(np.log10)
×
UNCOV
745
    logger.log("building mean parameter values")
×
746

UNCOV
747
    logger.log("building stdev parameter values")
×
UNCOV
748
    if cov.isdiagonal:
×
UNCOV
749
        std = cov.x.flatten()
×
750
    else:
751
        std = np.diag(cov.x)
×
UNCOV
752
    std = np.sqrt(std)
×
UNCOV
753
    info.loc[:, "prior_std"] = std
×
754

UNCOV
755
    logger.log("building stdev parameter values")
×
756

UNCOV
757
    if std.shape != mean.shape:
×
758
        logger.lraise("mean.shape {0} != std.shape {1}".format(mean.shape, std.shape))
×
759

UNCOV
760
    if "grouper" in kwargs:
×
761
        raise NotImplementedError()
×
762
        # check for consistency here
763

764
    else:
UNCOV
765
        par_adj = par.loc[par.partrans.apply(lambda x: x in ["log", "none"]), :]
×
UNCOV
766
        grouper = par_adj.groupby(par_adj.pargp).groups
×
767
        # grouper = par.groupby(par.pargp).groups
768

UNCOV
769
    if len(grouper) == 0:
×
770
        raise Exception("no adustable parameters to plot")
×
771

UNCOV
772
    fig = plt.figure(figsize=figsize)
×
UNCOV
773
    if "fig_title" in kwargs:
×
UNCOV
774
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
775
    else:
UNCOV
776
        plt.figtext(
×
777
            0.5,
778
            0.5,
779
            "pyemu.Pst.plot(kind='prior')\nfrom pest control file '{0}'\n at {1}".format(
780
                pst.filename, str(datetime.now())
781
            ),
782
            ha="center",
783
        )
UNCOV
784
    figs = []
×
UNCOV
785
    ax_count = 0
×
UNCOV
786
    grps_names = list(grouper.keys())
×
UNCOV
787
    grps_names.sort()
×
UNCOV
788
    for g in grps_names:
×
UNCOV
789
        names = grouper[g]
×
UNCOV
790
        logger.log("plotting priors for {0}".format(",".join(list(names))))
×
UNCOV
791
        if ax_count % (nr * nc) == 0:
×
UNCOV
792
            plt.tight_layout()
×
793
            # pdf.savefig()
794
            # plt.close(fig)
UNCOV
795
            figs.append(fig)
×
UNCOV
796
            fig = plt.figure(figsize=figsize)
×
UNCOV
797
            axes = _get_page_axes()
×
UNCOV
798
            ax_count = 0
×
799

UNCOV
800
        islog = False
×
UNCOV
801
        vc = info.partrans.value_counts()
×
UNCOV
802
        if vc.shape[0] > 1:
×
803
            logger.warn("mixed partrans for group {0}".format(g))
×
UNCOV
804
        elif "log" in vc.index:
×
UNCOV
805
            islog = True
×
UNCOV
806
        ax = axes[ax_count]
×
UNCOV
807
        if "unique_only" in kwargs and kwargs["unique_only"]:
×
808

UNCOV
809
            ms = (
×
810
                info.loc[names, :]
811
                .apply(lambda x: (x["mean"], x["prior_std"]), axis=1)
812
                .unique()
813
            )
UNCOV
814
            for (m, s) in ms:
×
UNCOV
815
                x, y = gaussian_distribution(m, s)
×
UNCOV
816
                ax.fill_between(x, 0, y, facecolor="0.5", alpha=0.5, edgecolor="none")
×
817

818
        else:
UNCOV
819
            for m, s in zip(info.loc[names, "mean"], info.loc[names, "prior_std"]):
×
UNCOV
820
                x, y = gaussian_distribution(m, s)
×
UNCOV
821
                ax.fill_between(x, 0, y, facecolor="0.5", alpha=0.5, edgecolor="none")
×
UNCOV
822
        ax.set_title(
×
823
            "{0}) group:{1}, {2} parameters".format(abet[ax_count], g, names.shape[0]),
824
            loc="left",
825
        )
826

UNCOV
827
        ax.set_yticks([])
×
UNCOV
828
        if islog:
×
UNCOV
829
            ax.set_xlabel("$log_{10}$ parameter value", labelpad=0.1)
×
830
        else:
UNCOV
831
            ax.set_xlabel("parameter value", labelpad=0.1)
×
UNCOV
832
        logger.log("plotting priors for {0}".format(",".join(list(names))))
×
833

UNCOV
834
        ax_count += 1
×
835

UNCOV
836
    for a in range(ax_count, nr * nc):
×
UNCOV
837
        axes[a].set_axis_off()
×
UNCOV
838
        axes[a].set_yticks([])
×
UNCOV
839
        axes[a].set_xticks([])
×
840

UNCOV
841
    plt.tight_layout()
×
842
    # pdf.savefig()
843
    # plt.close(fig)
UNCOV
844
    figs.append(fig)
×
UNCOV
845
    if filename is not None:
×
UNCOV
846
        with PdfPages(filename) as pdf:
×
847
            # plt.tight_layout()
UNCOV
848
            pdf.savefig(fig)
×
UNCOV
849
            plt.close(fig)
×
UNCOV
850
        logger.log("plot pst_prior")
×
851
    else:
UNCOV
852
        logger.log("plot pst_prior")
×
UNCOV
853
        return figs
×
854

855

856
@apply_custom_font()
12✔
857
def ensemble_helper(
12✔
858
    ensemble,
859
    bins=10,
860
    facecolor="0.5",
861
    plot_cols=None,
862
    filename=None,
863
    func_dict=None,
864
    sync_bins=True,
865
    deter_vals=None,
866
    std_window=None,
867
    deter_range=False,
868
    **kwargs
869
):
870
    """helper function to plot ensemble histograms
871

872
    Args:
873
        ensemble : varies
874
            the ensemble argument can be a pandas.DataFrame or derived type or a str, which
875
            is treated as a filename.  Optionally, ensemble can be a list of these types or
876
            a dict, in which case, the keys are treated as facecolor str (e.g., 'b', 'y', etc).
877
        facecolor : str
878
            the histogram facecolor.  Only applies if ensemble is a single thing
879
        plot_cols : enumerable
880
            a collection of columns (in form of a list of parameters, or a dict with keys for
881
            parsing plot axes and values of parameters) from the ensemble(s) to plot.  If None,
882
            (the union of) all cols are plotted. Default is None
883
        filename : str
884
            the name of the pdf to create.  If None, return figs without saving.  Default is None.
885
        func_dict : dict
886
            a dictionary of unary functions (e.g., `np.log10` to apply to columns.  Key is
887
            column name.  Default is None
888
        sync_bins : bool
889
            flag to use the same bin edges for all ensembles. Only applies if more than
890
            one ensemble is being plotted.  Default is True
891
        deter_vals : dict
892
            dict of deterministic values to plot as a vertical line. key is ensemble column name
893
        std_window : float
894
            the number of standard deviations around the mean to mark as vertical lines.  If None,
895
            nothing happens.  Default is None
896
        deter_range : bool
897
            flag to set xlims to deterministic value +/- std window.  If True, std_window must not be None.
898
            Default is False
899

900
    Example::
901

902
        # plot prior and posterior par ensembles
903
        pst = pyemu.Pst("my.pst")
904
        prior = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="prior.jcb")
905
        post = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="my.3.par.jcb")
906
        pyemu.plot_utils.ensemble_helper(ensemble={"0.5":prior, "b":post},filename="ensemble.pdf")
907

908
        #plot prior and posterior simulated equivalents to observations with obs noise and obs vals
909
        pst = pyemu.Pst("my.pst")
910
        prior = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.0.obs.jcb")
911
        post = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.3.obs.jcb")
912
        noise = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.obs+noise.jcb")
913
        pyemu.plot_utils.ensemble_helper(ensemble={"0.5":prior, "b":post,"r":noise},
914
                                         filename="ensemble.pdf",
915
                                         deter_vals=pst.observation_data.obsval.to_dict())
916

917

918
    """
UNCOV
919
    logger = pyemu.Logger("ensemble_helper.log")
×
UNCOV
920
    logger.log("pyemu.plot_utils.ensemble_helper()")
×
UNCOV
921
    ensembles = _process_ensemble_arg(ensemble, facecolor, logger)
×
UNCOV
922
    if len(ensembles) == 0:
×
923
        raise Exception("plot_uitls.ensemble_helper() error processing `ensemble` arg")
×
924
    # apply any functions
UNCOV
925
    if func_dict is not None:
×
UNCOV
926
        logger.log("applying functions")
×
UNCOV
927
        for col, funcc in func_dict.items():
×
UNCOV
928
            for fc, en in ensembles.items():
×
UNCOV
929
                if col in en.columns:
×
UNCOV
930
                    en.loc[:, col] = en.loc[:, col].apply(funcc)
×
UNCOV
931
        logger.log("applying functions")
×
932

933
    # get a list of all cols (union)
UNCOV
934
    all_cols = set()
×
UNCOV
935
    for fc, en in ensembles.items():
×
UNCOV
936
        cols = set(en.columns)
×
UNCOV
937
        all_cols.update(cols)
×
UNCOV
938
    if plot_cols is None:
×
939
        plot_cols = {i: [v] for i, v in (zip(all_cols, all_cols))}
×
940
    else:
UNCOV
941
        if isinstance(plot_cols, list):
×
UNCOV
942
            splot_cols = set(plot_cols)
×
UNCOV
943
            plot_cols = {i: [v] for i, v in (zip(plot_cols, plot_cols))}
×
UNCOV
944
        elif isinstance(plot_cols, dict):
×
UNCOV
945
            splot_cols = []
×
UNCOV
946
            for label, pcols in plot_cols.items():
×
UNCOV
947
                splot_cols.extend(list(pcols))
×
UNCOV
948
            splot_cols = set(splot_cols)
×
949
        else:
950
            logger.lraise(
×
951
                "unrecognized plot_cols type: {0}, should be list or dict".format(
952
                    type(plot_cols)
953
                )
954
            )
955

UNCOV
956
        missing = splot_cols - all_cols
×
UNCOV
957
        if len(missing) > 0:
×
958
            logger.lraise(
×
959
                "the following plot_cols are missing: {0}".format(",".join(missing))
960
            )
961

UNCOV
962
    logger.statement("plotting {0} histograms".format(len(plot_cols)))
×
963

UNCOV
964
    fig = plt.figure(figsize=figsize)
×
UNCOV
965
    if "fig_title" in kwargs:
×
966
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
967
    else:
UNCOV
968
        plt.figtext(
×
969
            0.5,
970
            0.5,
971
            "pyemu.plot_utils.ensemble_helper()\n at {0}".format(str(datetime.now())),
972
            ha="center",
973
        )
974
    # plot_cols = list(plot_cols)
975
    # plot_cols.sort()
UNCOV
976
    labels = list(plot_cols.keys())
×
UNCOV
977
    labels.sort()
×
UNCOV
978
    logger.statement("saving pdf to {0}".format(filename))
×
UNCOV
979
    figs = []
×
980

UNCOV
981
    ax_count = 0
×
982

983
    # for label,plot_col in plot_cols.items():
UNCOV
984
    for label in labels:
×
UNCOV
985
        plot_col = plot_cols[label]
×
UNCOV
986
        logger.log("plotting reals for {0}".format(label))
×
UNCOV
987
        if ax_count % (nr * nc) == 0:
×
UNCOV
988
            plt.tight_layout()
×
989
            # pdf.savefig()
990
            # plt.close(fig)
UNCOV
991
            figs.append(fig)
×
UNCOV
992
            fig = plt.figure(figsize=figsize)
×
UNCOV
993
            axes = _get_page_axes()
×
UNCOV
994
            [ax.set_yticks([]) for ax in axes]
×
UNCOV
995
            ax_count = 0
×
996

UNCOV
997
        ax = axes[ax_count]
×
998

UNCOV
999
        if sync_bins:
×
UNCOV
1000
            mx, mn = -1.0e30, 1.0e30
×
UNCOV
1001
            for fc, en in ensembles.items():
×
1002
                # for pc in plot_col:
1003
                #     if pc in en.columns:
1004
                #         emx,emn = en.loc[:,pc].max(),en.loc[:,pc].min()
1005
                #         mx = max(mx,emx)
1006
                #         mn = min(mn,emn)
UNCOV
1007
                emn = np.nanmin(en.loc[:, plot_col].values)
×
UNCOV
1008
                emx = np.nanmax(en.loc[:, plot_col].values)
×
UNCOV
1009
                mx = max(mx, emx)
×
UNCOV
1010
                mn = min(mn, emn)
×
UNCOV
1011
            if mx == -1.0e30 and mn == 1.0e30:
×
1012
                logger.warn("all NaNs for label: {0}".format(label))
×
1013
                ax.set_title(
×
1014
                    "{0}) {1}, count:{2} - all NaN".format(
1015
                        abet[ax_count], label, len(plot_col)
1016
                    ),
1017
                    loc="left",
1018
                )
1019
                ax.set_yticks([])
×
1020
                ax.set_xticks([])
×
1021
                ax_count += 1
×
1022
                continue
×
UNCOV
1023
            plot_bins = np.linspace(mn, mx, num=bins)
×
UNCOV
1024
            logger.statement("{0} min:{1:5G}, max:{2:5G}".format(label, mn, mx))
×
1025
        else:
UNCOV
1026
            plot_bins = bins
×
UNCOV
1027
        for fc, en in ensembles.items():
×
1028
            # for pc in plot_col:
1029
            #    if pc in en.columns:
1030
            #        try:
1031
            #            en.loc[:,pc].hist(bins=plot_bins,facecolor=fc,
1032
            #                                    edgecolor="none",alpha=0.5,
1033
            #                                    density=True,ax=ax)
1034
            #        except Exception as e:
1035
            #            logger.warn("error plotting histogram for {0}:{1}".
1036
            #                        format(pc,str(e)))
UNCOV
1037
            vals = en.loc[:, plot_col].values.flatten()
×
1038
            # print(plot_bins)
1039
            # print(vals)
1040

UNCOV
1041
            ax.hist(
×
1042
                vals,
1043
                bins=plot_bins,
1044
                edgecolor="none",
1045
                alpha=0.5,
1046
                density=True,
1047
                facecolor=fc,
1048
            )
UNCOV
1049
            v = None
×
UNCOV
1050
            if deter_vals is not None:
×
UNCOV
1051
                for pc in plot_col:
×
UNCOV
1052
                    if pc in deter_vals:
×
UNCOV
1053
                        ylim = ax.get_ylim()
×
UNCOV
1054
                        v = deter_vals[pc]
×
UNCOV
1055
                        ax.plot([v, v], ylim, "k--", lw=1.5)
×
UNCOV
1056
                        ax.set_ylim(ylim)
×
1057

UNCOV
1058
            if std_window is not None:
×
1059
                try:
×
1060
                    ylim = ax.get_ylim()
×
1061
                    mn, st = (
×
1062
                        en.loc[:, pc].mean(),
1063
                        en.loc[:, pc].std() * (std_window / 2.0),
1064
                    )
1065

1066
                    ax.plot([mn - st, mn - st], ylim, color=fc, lw=1.5, ls="--")
×
1067
                    ax.plot([mn + st, mn + st], ylim, color=fc, lw=1.5, ls="--")
×
1068
                    ax.set_ylim(ylim)
×
1069
                    if deter_range and v is not None:
×
1070
                        xmn = v - st
×
1071
                        xmx = v + st
×
1072
                        ax.set_xlim(xmn, xmx)
×
1073
                except:
×
1074
                    logger.warn("error plotting std window for {0}".format(pc))
×
UNCOV
1075
        ax.grid()
×
UNCOV
1076
        if len(ensembles) > 1:
×
UNCOV
1077
            ax.set_title(
×
1078
                "{0}) {1}, count: {2}".format(abet[ax_count], label, len(plot_col)),
1079
                loc="left",
1080
            )
1081
        else:
UNCOV
1082
            ax.set_title(
×
1083
                "{0}) {1}, count:{2}\nmin:{3:3.1E}, max:{4:3.1E}".format(
1084
                    abet[ax_count],
1085
                    label,
1086
                    len(plot_col),
1087
                    np.nanmin(vals),
1088
                    np.nanmax(vals),
1089
                ),
1090
                loc="left",
1091
            )
UNCOV
1092
        ax_count += 1
×
1093

UNCOV
1094
    for a in range(ax_count, nr * nc):
×
UNCOV
1095
        axes[a].set_axis_off()
×
UNCOV
1096
        axes[a].set_yticks([])
×
UNCOV
1097
        axes[a].set_xticks([])
×
1098

UNCOV
1099
    plt.tight_layout()
×
1100
    # pdf.savefig()
1101
    # plt.close(fig)
UNCOV
1102
    figs.append(fig)
×
UNCOV
1103
    if filename is not None:
×
1104
        # plt.tight_layout()
UNCOV
1105
        with PdfPages(filename) as pdf:
×
UNCOV
1106
            for fig in figs:
×
UNCOV
1107
                pdf.savefig(fig)
×
UNCOV
1108
                plt.close(fig)
×
UNCOV
1109
    logger.log("pyemu.plot_utils.ensemble_helper()")
×
1110

1111

1112
@apply_custom_font()
12✔
1113
def ensemble_change_summary(
12✔
1114
    ensemble1,
1115
    ensemble2,
1116
    pst,
1117
    bins=10,
1118
    facecolor="0.5",
1119
    logger=None,
1120
    filename=None,
1121
    **kwargs
1122
):
1123
    """helper function to plot first and second moment change histograms between two
1124
    ensembles
1125

1126
    Args:
1127
        ensemble1 (varies): filename or `pandas.DataFrame` or `pyemu.Ensemble`
1128
        ensemble2 (varies): filename or `pandas.DataFrame` or `pyemu.Ensemble`
1129
        pst (`pyemu.Pst`): control file
1130
        facecolor (`str`): the histogram facecolor.
1131
        filename (`str`): the name of the multi-pdf to create. If None, return figs without saving.  Default is None.
1132

1133
    Returns:
1134
        [`matplotlib.Figure`]: a list of figures.  Returns None is
1135
        `filename` is not None
1136

1137
    Example::
1138

1139
        pst = pyemu.Pst("my.pst")
1140
        prior = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="prior.jcb")
1141
        post = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="my.3.par.jcb")
1142
        pyemu.plot_utils.ensemble_change_summary(prior,post)
1143
        plt.show()
1144

1145

1146
    """
UNCOV
1147
    if logger is None:
×
UNCOV
1148
        logger = Logger("Default_Logger.log", echo=False)
×
UNCOV
1149
    logger.log("plot ensemble change")
×
1150

UNCOV
1151
    if isinstance(ensemble1, str):
×
1152
        ensemble1 = pd.read_csv(ensemble1, index_col=0)
×
UNCOV
1153
    ensemble1.columns = ensemble1.columns.str.lower()
×
1154

UNCOV
1155
    if isinstance(ensemble2, str):
×
1156
        ensemble2 = pd.read_csv(ensemble2, index_col=0)
×
UNCOV
1157
    ensemble2.columns = ensemble2.columns.str.lower()
×
1158

1159
    # better to ensure this is caught by pestpp-ies ensemble csvs
UNCOV
1160
    unnamed1 = [col for col in ensemble1.columns if "unnamed:" in col]
×
UNCOV
1161
    if len(unnamed1) != 0:
×
1162
        ensemble1 = ensemble1.iloc[
×
1163
            :, :-1
1164
        ]  # ensure unnamed col result of poor csv read only (ie last col)
UNCOV
1165
    unnamed2 = [col for col in ensemble2.columns if "unnamed:" in col]
×
UNCOV
1166
    if len(unnamed2) != 0:
×
1167
        ensemble2 = ensemble2.iloc[
×
1168
            :, :-1
1169
        ]  # ensure unnamed col result of poor csv read only (ie last col)
1170

UNCOV
1171
    d = set(ensemble1.columns).symmetric_difference(set(ensemble2.columns))
×
1172

UNCOV
1173
    if len(d) != 0:
×
1174
        logger.lraise(
×
1175
            "ensemble1 does not have the same columns as ensemble2: {0}".format(
1176
                ",".join(d)
1177
            )
1178
        )
UNCOV
1179
    if "grouper" in kwargs:
×
1180
        raise NotImplementedError()
×
1181
    else:
UNCOV
1182
        en_cols = ensemble1.columns
×
UNCOV
1183
        if len(en_cols.difference(pst.par_names)) == 0:
×
UNCOV
1184
            par = pst.parameter_data.loc[en_cols, :]
×
UNCOV
1185
            grouper = par.groupby(par.pargp).groups
×
UNCOV
1186
            grouper["all"] = pst.adj_par_names
×
UNCOV
1187
            li = par.loc[par.partrans == "log", "parnme"]
×
UNCOV
1188
            ensemble1.loc[:, li] = ensemble1.loc[:, li].apply(np.log10)
×
UNCOV
1189
            ensemble2.loc[:, li] = ensemble2.loc[:, li].apply(np.log10)
×
UNCOV
1190
        elif len(en_cols.difference(pst.obs_names)) == 0:
×
UNCOV
1191
            obs = pst.observation_data.loc[en_cols, :]
×
UNCOV
1192
            grouper = obs.groupby(obs.obgnme).groups
×
UNCOV
1193
            grouper["all"] = pst.nnz_obs_names
×
1194
        else:
1195
            logger.lraise("could not match ensemble cols with par or obs...")
×
1196

UNCOV
1197
    en1_mn, en1_std = ensemble1.mean(axis=0), ensemble1.std(axis=0)
×
UNCOV
1198
    en2_mn, en2_std = ensemble2.mean(axis=0), ensemble2.std(axis=0)
×
1199

1200
    # mn_diff = 100.0 * ((en1_mn - en2_mn) / en1_mn)
1201
    # std_diff = 100 * ((en1_std - en2_std) / en1_std)
1202

UNCOV
1203
    mn_diff = -1 * (en2_mn - en1_mn)
×
UNCOV
1204
    std_diff = 100 * (((en1_std - en2_std) / en1_std))
×
1205
    # set en1_std==0 to nan
1206
    # std_diff[en1_std.index[en1_std==0]] = np.nan
1207

1208
    # diff = ensemble1 - ensemble2
1209
    # mn_diff = diff.mean(axis=0)
1210
    # std_diff = diff.std(axis=0)
1211

UNCOV
1212
    fig = plt.figure(figsize=figsize)
×
UNCOV
1213
    if "fig_title" in kwargs:
×
1214
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
1215
    else:
UNCOV
1216
        plt.figtext(
×
1217
            0.5,
1218
            0.5,
1219
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
1220
                pst.filename, str(datetime.now())
1221
            ),
1222
            ha="center",
1223
        )
1224
    # if plot_hexbin:
1225
    #    pdfname = pst.filename.replace(".pst", ".1to1.hexbin.pdf")
1226
    # else:
1227
    #    pdfname = pst.filename.replace(".pst", ".1to1.pdf")
UNCOV
1228
    figs = []
×
UNCOV
1229
    ax_count = 0
×
UNCOV
1230
    for g, names in grouper.items():
×
UNCOV
1231
        logger.log("plotting change for {0}".format(g))
×
1232

UNCOV
1233
        mn_g = mn_diff.loc[names]
×
UNCOV
1234
        std_g = std_diff.loc[names]
×
1235

UNCOV
1236
        if mn_g.shape[0] == 0:
×
1237
            logger.statement("no entries for group '{0}'".format(g))
×
1238
            logger.log("plotting change for {0}".format(g))
×
1239
            continue
×
1240

UNCOV
1241
        if ax_count % (nr * nc) == 0:
×
UNCOV
1242
            if ax_count > 0:
×
1243
                plt.tight_layout()
×
1244
            # pdf.savefig()
1245
            # plt.close(fig)
UNCOV
1246
            figs.append(fig)
×
UNCOV
1247
            fig = plt.figure(figsize=figsize)
×
UNCOV
1248
            axes = _get_page_axes()
×
UNCOV
1249
            ax_count = 0
×
1250

UNCOV
1251
        ax = axes[ax_count]
×
UNCOV
1252
        mn_g.hist(ax=ax, facecolor=facecolor, alpha=0.5, edgecolor=None, bins=bins)
×
1253
        # mx = max(mn_g.max(), mn_g.min(),np.abs(mn_g.max()),np.abs(mn_g.min())) * 1.2
1254
        # ax.set_xlim(-mx,mx)
1255

1256
        # std_g.hist(ax=ax,facecolor='b',alpha=0.5,edgecolor=None)
1257

1258
        # ax.set_xlim(xlim)
UNCOV
1259
        ax.set_yticklabels([])
×
UNCOV
1260
        ax.set_xlabel("mean change", labelpad=0.1)
×
UNCOV
1261
        ax.set_title(
×
1262
            "{0}) mean change group:{1}, {2} entries\nmax:{3:10G}, min:{4:10G}".format(
1263
                abet[ax_count], g, mn_g.shape[0], mn_g.max(), mn_g.min()
1264
            ),
1265
            loc="left",
1266
        )
UNCOV
1267
        ax.grid()
×
UNCOV
1268
        ax_count += 1
×
1269

UNCOV
1270
        ax = axes[ax_count]
×
UNCOV
1271
        std_g.hist(ax=ax, facecolor=facecolor, alpha=0.5, edgecolor=None, bins=bins)
×
1272
        # std_g.hist(ax=ax,facecolor='b',alpha=0.5,edgecolor=None)
1273

1274
        # ax.set_xlim(xlim)
UNCOV
1275
        ax.set_yticklabels([])
×
UNCOV
1276
        ax.set_xlabel("sigma percent reduction", labelpad=0.1)
×
UNCOV
1277
        ax.set_title(
×
1278
            "{0}) sigma change group:{1}, {2} entries\nmax:{3:10G}, min:{4:10G}".format(
1279
                abet[ax_count], g, mn_g.shape[0], std_g.max(), std_g.min()
1280
            ),
1281
            loc="left",
1282
        )
UNCOV
1283
        ax.grid()
×
UNCOV
1284
        ax_count += 1
×
1285

UNCOV
1286
        logger.log("plotting change for {0}".format(g))
×
1287

UNCOV
1288
    for a in range(ax_count, nr * nc):
×
UNCOV
1289
        axes[a].set_axis_off()
×
UNCOV
1290
        axes[a].set_yticks([])
×
UNCOV
1291
        axes[a].set_xticks([])
×
1292

UNCOV
1293
    plt.tight_layout()
×
1294
    # pdf.savefig()
1295
    # plt.close(fig)
UNCOV
1296
    figs.append(fig)
×
UNCOV
1297
    if filename is not None:
×
1298
        # plt.tight_layout()
UNCOV
1299
        with PdfPages(filename) as pdf:
×
UNCOV
1300
            for fig in figs:
×
UNCOV
1301
                pdf.savefig(fig)
×
UNCOV
1302
                plt.close(fig)
×
UNCOV
1303
        logger.log("plot ensemble change")
×
1304
    else:
UNCOV
1305
        logger.log("plot ensemble change")
×
UNCOV
1306
        return figs
×
1307

1308

1309
def _process_ensemble_arg(ensemble, facecolor, logger):
12✔
1310
    """private method to work out ensemble plot args"""
UNCOV
1311
    ensembles = {}
×
UNCOV
1312
    if isinstance(ensemble, pd.DataFrame) or isinstance(ensemble, pyemu.Ensemble):
×
UNCOV
1313
        if not isinstance(facecolor, str):
×
1314
            logger.lraise("facecolor must be str")
×
UNCOV
1315
        ensembles[facecolor] = ensemble
×
UNCOV
1316
    elif isinstance(ensemble, str):
×
UNCOV
1317
        if not isinstance(facecolor, str):
×
1318
            logger.lraise("facecolor must be str")
×
1319

UNCOV
1320
        logger.log("loading ensemble from csv file {0}".format(ensemble))
×
UNCOV
1321
        en = pd.read_csv(ensemble, index_col=0)
×
UNCOV
1322
        logger.statement("{0} shape: {1}".format(ensemble, en.shape))
×
UNCOV
1323
        ensembles[facecolor] = en
×
UNCOV
1324
        logger.log("loading ensemble from csv file {0}".format(ensemble))
×
1325

UNCOV
1326
    elif isinstance(ensemble, list):
×
UNCOV
1327
        if isinstance(facecolor, list):
×
1328
            if len(ensemble) != len(facecolor):
×
1329
                logger.lraise("facecolor list len != ensemble list len")
×
1330
        else:
UNCOV
1331
            colors = ["m", "c", "b", "r", "g", "y"]
×
1332

UNCOV
1333
            facecolor = [colors[i] for i in range(len(ensemble))]
×
UNCOV
1334
        ensembles = {}
×
UNCOV
1335
        for fc, en_arg in zip(facecolor, ensemble):
×
UNCOV
1336
            if isinstance(en_arg, str):
×
UNCOV
1337
                logger.log("loading ensemble from csv file {0}".format(en_arg))
×
UNCOV
1338
                en = pd.read_csv(en_arg, index_col=0)
×
UNCOV
1339
                logger.log("loading ensemble from csv file {0}".format(en_arg))
×
UNCOV
1340
                logger.statement("ensemble {0} gets facecolor {1}".format(en_arg, fc))
×
1341

UNCOV
1342
            elif isinstance(en_arg, pd.DataFrame) or isinstance(en_arg, pyemu.Ensemble):
×
UNCOV
1343
                en = en_arg
×
1344
            else:
1345
                logger.lraise("unrecognized ensemble list arg:{0}".format(en_arg))
×
UNCOV
1346
            ensembles[fc] = en
×
1347

UNCOV
1348
    elif isinstance(ensemble, dict):
×
UNCOV
1349
        for fc, en_arg in ensemble.items():
×
UNCOV
1350
            if isinstance(en_arg, pd.DataFrame) or isinstance(en_arg, pyemu.Ensemble):
×
UNCOV
1351
                ensembles[fc] = en_arg
×
UNCOV
1352
            elif isinstance(en_arg, str):
×
UNCOV
1353
                logger.log("loading ensemble from csv file {0}".format(en_arg))
×
UNCOV
1354
                en = pd.read_csv(en_arg, index_col=0)
×
UNCOV
1355
                logger.log("loading ensemble from csv file {0}".format(en_arg))
×
UNCOV
1356
                ensembles[fc] = en
×
1357
            else:
1358
                logger.lraise("unrecognized ensemble list arg:{0}".format(en_arg))
×
UNCOV
1359
    try:
×
UNCOV
1360
        for fc in ensembles:
×
UNCOV
1361
            ensembles[fc].columns = ensembles[fc].columns.str.lower()
×
1362
    except:
×
1363
        logger.lraise("error processing ensemble")
×
1364

UNCOV
1365
    return ensembles
×
1366

1367

1368
@apply_custom_font()
12✔
1369
def ensemble_res_1to1(
12✔
1370
    ensemble,
1371
    pst,
1372
    facecolor="0.5",
1373
    logger=None,
1374
    filename=None,
1375
    skip_groups=[],
1376
    base_ensemble=None,
1377
    **kwargs
1378
):
1379
    """helper function to plot ensemble 1-to-1 plots showing the simulated range
1380

1381
    Args:
1382
        ensemble (varies):  the ensemble argument can be a pandas.DataFrame or derived type or a str, which
1383
            is treated as a filename.  Optionally, ensemble can be a list of these types or
1384
            a dict, in which case, the keys are treated as facecolor str (e.g., 'b', 'y', etc).
1385
        pst (`pyemu.Pst`): a control file instance
1386
        facecolor (`str`): the histogram facecolor.  Only applies if `ensemble` is a single thing
1387
        filename (`str`): the name of the pdf to create. If None, return figs
1388
            without saving.  Default is None.
1389
        base_ensemble (`varies`): an optional ensemble argument for the observations + noise ensemble.
1390
            This will be plotted as a transparent red bar on the 1to1 plot.
1391

1392
    Note:
1393

1394
        the vertical bar on each plot the min-max range
1395

1396
    Example::
1397

1398

1399
        pst = pyemu.Pst("my.pst")
1400
        prior = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.0.obs.jcb")
1401
        post = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.3.obs.jcb")
1402
        pyemu.plot_utils.ensemble_res_1to1(ensemble={"0.5":prior, "b":post})
1403
        plt.show()
1404

1405
    """
UNCOV
1406
    def _get_plotlims(oen, ben, obsnames):
×
UNCOV
1407
        if not isinstance(oen, dict):
×
1408
            oen = {'g': oen.loc[:, obsnames]}
×
UNCOV
1409
        if not isinstance(ben, dict):
×
UNCOV
1410
            ben = {'g': ben.get(obsnames)}
×
UNCOV
1411
        outofrange = False
×
1412
        # work back from crazy values
UNCOV
1413
        oemin = 1e32
×
UNCOV
1414
        oemeanmin = 1e32
×
UNCOV
1415
        oemax = -1e32
×
UNCOV
1416
        oemeanmax = -1e32
×
UNCOV
1417
        bemin = 1e32
×
UNCOV
1418
        bemeanmin = 1e32
×
UNCOV
1419
        bemax = -1e32
×
UNCOV
1420
        bemeanmax = -1e32
×
UNCOV
1421
        for _, oeni in oen.items():  # loop over ensembles
×
UNCOV
1422
            oeni = oeni.loc[:, obsnames]  # slice group obs
×
UNCOV
1423
            oemin = np.nanmin([oemin, oeni.min().min()])
×
UNCOV
1424
            oemax = np.nanmax([oemax, oeni.max().max()])
×
1425
            # get min and max of mean sim vals
1426
            # (in case we want plot to ignore extremes)
UNCOV
1427
            oemeanmin = np.nanmin([oemeanmin, oeni.mean().min()])
×
UNCOV
1428
            oemeanmax = np.nanmax([oemeanmax, oeni.mean().max()])
×
UNCOV
1429
        for _, beni in ben.items():  # same with base ensemble/obsval
×
1430
            # work with either ensemble or obsval series
UNCOV
1431
            beni = beni.get(obsnames)
×
UNCOV
1432
            bemin = np.nanmin([bemin, beni.min().min()])
×
UNCOV
1433
            bemax = np.nanmax([bemax, beni.max().max()])
×
UNCOV
1434
            bemeanmin = np.nanmin([bemeanmin, beni.mean().min()])
×
UNCOV
1435
            bemeanmax = np.nanmax([bemeanmax, beni.mean().max()])
×
1436
        # get base ensemble range
UNCOV
1437
        berange = bemax-bemin
×
UNCOV
1438
        if berange == 0.:  # only one obs in group (probs)
×
UNCOV
1439
            berange = bemeanmax * 1.1  # expand a little
×
1440
        # add buffer to obs endpoints
UNCOV
1441
        bemin = bemin - (berange*0.05)
×
UNCOV
1442
        bemax = bemax + (berange*0.05)
×
UNCOV
1443
        if oemax < bemin:  # sim well below obs
×
UNCOV
1444
            oemin = oemeanmin  # set min to mean min
×
1445
            # (sim captured but not extremes)
UNCOV
1446
            outofrange = True
×
UNCOV
1447
        if oemin > bemax:  # sim well above obs
×
UNCOV
1448
            oemax = oemeanmax
×
UNCOV
1449
            outofrange = True
×
UNCOV
1450
        oerange = oemax - oemin
×
UNCOV
1451
        if bemax > oemax + (0.1*oerange):  # obs max well above sim
×
UNCOV
1452
            if not outofrange:  # but sim still in range
×
1453
                # zoom to sim
UNCOV
1454
                bemax = oemax + (0.1*oerange)
×
1455
            else:  # use obs mean max
UNCOV
1456
                bemax = bemeanmax
×
UNCOV
1457
        if bemin < oemin - (0.1 * oerange):  # obs min well below sim
×
UNCOV
1458
            if not outofrange:  # but sim still in range
×
1459
                # zoom to sim
UNCOV
1460
                bemin = oemin - (0.1 * oerange)
×
1461
            else:
UNCOV
1462
                bemin = bemeanmin
×
UNCOV
1463
        pmin = np.nanmin([oemin, bemin])
×
UNCOV
1464
        pmax = np.nanmax([oemax, bemax])
×
UNCOV
1465
        return pmin, pmax
×
1466

1467

UNCOV
1468
    if logger is None:
×
UNCOV
1469
        logger = Logger("Default_Logger.log", echo=False)
×
UNCOV
1470
    logger.log("plot res_1to1")
×
UNCOV
1471
    obs = pst.observation_data
×
UNCOV
1472
    ensembles = _process_ensemble_arg(ensemble, facecolor, logger)
×
1473

UNCOV
1474
    if base_ensemble is not None:
×
UNCOV
1475
        base_ensemble = _process_ensemble_arg(base_ensemble, "r", logger)
×
1476

UNCOV
1477
    if "grouper" in kwargs:
×
1478
        raise NotImplementedError()
×
1479
    else:
UNCOV
1480
        grouper = obs.groupby(obs.obgnme).groups
×
UNCOV
1481
        for skip_group in skip_groups:
×
1482
            grouper.pop(skip_group)
×
1483

UNCOV
1484
    fig = plt.figure(figsize=figsize)
×
UNCOV
1485
    if "fig_title" in kwargs:
×
1486
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
1487
    else:
UNCOV
1488
        plt.figtext(
×
1489
            0.5,
1490
            0.5,
1491
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
1492
                pst.filename, str(datetime.now())
1493
            ),
1494
            ha="center",
1495
        )
1496

UNCOV
1497
    figs = []
×
UNCOV
1498
    ax_count = 0
×
UNCOV
1499
    for g, names in grouper.items():
×
UNCOV
1500
        logger.log("plotting 1to1 for {0}".format(g))
×
1501
        # control file observation for group
UNCOV
1502
        obs_g = obs.loc[names, :]
×
1503
        # normally only look a non-zero weighted obs
UNCOV
1504
        if "include_zero" not in kwargs or kwargs["include_zero"] is False:
×
UNCOV
1505
            obs_g = obs_g.loc[obs_g.weight > 0, :]
×
UNCOV
1506
        if obs_g.shape[0] == 0:
×
UNCOV
1507
            logger.statement("no non-zero obs for group '{0}'".format(g))
×
UNCOV
1508
            logger.log("plotting 1to1 for {0}".format(g))
×
UNCOV
1509
            continue
×
1510
        # if the first axis in page
UNCOV
1511
        if ax_count % (nr * nc) == 0:
×
UNCOV
1512
            if ax_count > 0:
×
1513
                plt.tight_layout()
×
UNCOV
1514
            figs.append(fig)
×
UNCOV
1515
            fig = plt.figure(figsize=figsize)
×
UNCOV
1516
            axes = _get_page_axes()
×
UNCOV
1517
            ax_count = 0
×
UNCOV
1518
        ax = axes[ax_count]
×
1519

UNCOV
1520
        if base_ensemble is None:
×
1521
            # if obs not defined by obs+noise ensemble,
1522
            # use min and max for obsval from control file
UNCOV
1523
            pmin, pmax = _get_plotlims(ensembles, obs_g.obsval, obs_g.obsnme)
×
1524
        else:
1525
            # if obs defined by obs+noise use obs+noise min and max
UNCOV
1526
            pmin, pmax = _get_plotlims(ensembles, base_ensemble, obs_g.obsnme)
×
UNCOV
1527
            obs_gg = obs_g.sort_values(by="obsval")
×
UNCOV
1528
            for c, en in base_ensemble.items():
×
UNCOV
1529
                en_g = en.loc[:, obs_gg.obsnme]
×
UNCOV
1530
                emx = en_g.max()
×
UNCOV
1531
                emn = en_g.min()
×
1532
                
1533
                #exit()
1534
                # update y min and max for obs+noise ensembles
UNCOV
1535
                if len(obs_gg.obsval) > 1:
×
1536

UNCOV
1537
                    emx = np.zeros(obs_gg.shape[0]) + emx
×
UNCOV
1538
                    emn = np.zeros(obs_gg.shape[0]) + emn
×
UNCOV
1539
                    ax.fill_between(obs_gg.obsval.values, emn.values, emx.values,
×
1540
                                    facecolor=c, alpha=0.2, zorder=2)
1541
                else:
UNCOV
1542
                    ax.plot([obs_gg.obsval.values, obs_gg.obsval.values], [emn, emx], color=c, alpha=0.2, zorder=2)
×
UNCOV
1543
        for c, en in ensembles.items():
×
UNCOV
1544
            en_g = en.loc[:, obs_g.obsnme]
×
1545
            # output mins and maxs
UNCOV
1546
            emx = en_g.max()
×
UNCOV
1547
            emn = en_g.min()
×
UNCOV
1548
            [
×
1549
                ax.plot([ov, ov], [een, eex], color=c, zorder=1)
1550
                for ov, een, eex in zip(obs_g.obsval.values, emn.values, emx.values)
1551
            ]
UNCOV
1552
        ax.plot([pmin, pmax], [pmin, pmax], "k--", lw=1.0, zorder=3)
×
UNCOV
1553
        xlim = (pmin, pmax)
×
UNCOV
1554
        ax.set_xlim(pmin, pmax)
×
UNCOV
1555
        ax.set_ylim(pmin, pmax)
×
1556

UNCOV
1557
        if max(np.abs(xlim)) > 1.0e5:
×
1558
            ax.xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
1559
            ax.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
UNCOV
1560
        ax.grid()
×
1561

UNCOV
1562
        ax.set_xlabel("observed", labelpad=0.1)
×
UNCOV
1563
        ax.set_ylabel("simulated", labelpad=0.1)
×
UNCOV
1564
        ax.set_title(
×
1565
            "{0}) group:{1}, {2} observations".format(
1566
                abet[ax_count], g, obs_g.shape[0]
1567
            ),
1568
            loc="left",
1569
        )
1570

1571
        # Residual (RHS plot)
UNCOV
1572
        ax_count += 1
×
UNCOV
1573
        ax = axes[ax_count]
×
1574
        # ax.scatter(obs_g.obsval, obs_g.res, marker='.', s=10, color='b')
1575

UNCOV
1576
        if base_ensemble is not None:
×
UNCOV
1577
            obs_gg = obs_g.sort_values(by="obsval")
×
UNCOV
1578
            for c, en in base_ensemble.items():
×
UNCOV
1579
                en_g = en.loc[:, obs_gg.obsnme].subtract(obs_gg.obsval)
×
UNCOV
1580
                emx = en_g.max()
×
UNCOV
1581
                emn = en_g.min()
×
UNCOV
1582
                if len(obs_gg.obsval) > 1:
×
UNCOV
1583
                    emx = np.zeros(obs_gg.shape[0]) + emx
×
UNCOV
1584
                    emn = np.zeros(obs_gg.shape[0]) + emn
×
UNCOV
1585
                    ax.fill_between(obs_gg.obsval.values, emn.values, emx.values,
×
1586
                                    facecolor=c, alpha=0.2, zorder=2)
1587
                else:
1588
                    # [ax.plot([ov, ov], [een, eex], color=c,alpha=0.3) for ov, een, eex in zip(obs_g.obsval.values, en.values, ex.values)]
UNCOV
1589
                    ax.plot([obs_gg.obsval.values, obs_gg.obsval.values], [emn, emx], color=c,
×
1590
                            alpha=0.2, zorder=2)
UNCOV
1591
        omn = []
×
UNCOV
1592
        omx = []
×
UNCOV
1593
        for c, en in ensembles.items():
×
UNCOV
1594
            en_g = en.loc[:, obs_g.obsnme].subtract(obs_g.obsval, axis=1)
×
UNCOV
1595
            emx = en_g.max()
×
UNCOV
1596
            emn = en_g.min()
×
UNCOV
1597
            omn.append(emn)
×
UNCOV
1598
            omx.append(emx)
×
UNCOV
1599
            [
×
1600
                ax.plot([ov, ov], [een, eex], color=c, zorder=1)
1601
                for ov, een, eex in zip(obs_g.obsval.values, emn.values, emx.values)
1602
            ]
1603

UNCOV
1604
        omn = pd.concat(omn).min()
×
UNCOV
1605
        omx = pd.concat(omx).max()
×
UNCOV
1606
        mx = np.nanmax([np.abs(omn), np.abs(omx)])  # ensure symmetric about y=0
×
UNCOV
1607
        if obs_g.shape[0] == 1:
×
UNCOV
1608
            mx *= 1.05
×
1609
        else:
UNCOV
1610
            mx *= 1.02
×
UNCOV
1611
        if np.sign(omn) == np.sign(omx):
×
1612
            # allow y axis asymm if all above or below
UNCOV
1613
            mn = np.nanmin([0, np.sign(omn) * mx])
×
UNCOV
1614
            mx = np.nanmax([0, np.sign(omn) * mx])
×
1615
        else:
UNCOV
1616
            mn = -mx
×
UNCOV
1617
        ax.set_ylim(mn, mx)
×
UNCOV
1618
        bmin = obs_g.obsval.values.min()
×
UNCOV
1619
        bmax = obs_g.obsval.values.max()
×
UNCOV
1620
        brange = (bmax - bmin)
×
UNCOV
1621
        if brange == 0.:
×
UNCOV
1622
            brange = obs_g.obsval.values.mean()
×
UNCOV
1623
        bmin = bmin - 0.1*brange
×
UNCOV
1624
        bmax = bmax + 0.1*brange
×
UNCOV
1625
        xlim = (bmin, bmax)
×
1626
        # show a zero residuals line
UNCOV
1627
        ax.plot(xlim, [0, 0], "k--", lw=1.0, zorder=3)
×
1628

UNCOV
1629
        ax.set_xlim(xlim)
×
UNCOV
1630
        ax.set_ylabel("residual", labelpad=0.1)
×
UNCOV
1631
        ax.set_xlabel("observed", labelpad=0.1)
×
UNCOV
1632
        ax.set_title(
×
1633
            "{0}) group:{1}, {2} observations".format(
1634
                abet[ax_count], g, obs_g.shape[0]
1635
            ),
1636
            loc="left",
1637
        )
UNCOV
1638
        ax.grid()
×
UNCOV
1639
        if ax.get_xlim()[1] > 1.0e5:
×
1640
            ax.xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
1641

UNCOV
1642
        ax_count += 1
×
1643

UNCOV
1644
        logger.log("plotting 1to1 for {0}".format(g))
×
1645

UNCOV
1646
    for a in range(ax_count, nr * nc):
×
UNCOV
1647
        axes[a].set_axis_off()
×
UNCOV
1648
        axes[a].set_yticks([])
×
UNCOV
1649
        axes[a].set_xticks([])
×
1650

UNCOV
1651
    plt.tight_layout()
×
UNCOV
1652
    figs.append(fig)
×
UNCOV
1653
    if filename is not None:
×
1654
        # plt.tight_layout()
UNCOV
1655
        with PdfPages(filename) as pdf:
×
UNCOV
1656
            for fig in figs:
×
UNCOV
1657
                pdf.savefig(fig)
×
UNCOV
1658
                plt.close(fig)
×
UNCOV
1659
        logger.log("plot res_1to1")
×
1660
    else:
1661
        logger.log("plot res_1to1")
×
1662
        return figs
×
1663

1664

1665
@apply_custom_font()
12✔
1666
def plot_jac_test(
12✔
1667
    csvin, csvout, targetobs=None, filetype=None, maxoutputpages=1, outputdirectory=None
1668
):
1669
    """helper function to plot results of the Jacobian test performed using the pest++
1670
    program pestpp-swp.
1671

1672
    Args:
1673
        csvin (`str`): name of csv file used as input to sweep, typically developed with
1674
            static method pyemu.helpers.build_jac_test_csv()
1675
        csvout (`str`): name of csv file with output generated by sweep, both input
1676
            and output files can be specified in the pest++ control file
1677
            with pyemu using: pest_object.pestpp_options["sweep_parameter_csv_file"] = jactest_in_file.csv
1678
            pest_object.pestpp_options["sweep_output_csv_file"] = jactest_out_file.csv
1679
        targetobs ([`str`]): list of observation file names to plot, each parameter used for jactest can
1680
            have up to 32 observations plotted per page, throws a warning if more than
1681
            10 pages of output are requested per parameter. If none, all observations in
1682
            the output csv file are used.
1683
        filetype (`str`): file type to store output, if None, plt.show() is called.
1684
        maxoutputpages (`int`): maximum number of pages of output per parameter.  Each page can
1685
            hold up to 32 observation derivatives.  If value > 10, set it to
1686
            10 and throw a warning.  If observations in targetobs > 32*maxoutputpages,
1687
            then a random set is selected from the targetobs list (or all observations
1688
            in the csv file if targetobs=None).
1689
        outputdirectory (`str`):  directory to store results, if None, current working directory is used.
1690
            If string is passed, it is joined to the current working directory and
1691
            created if needed. If os.path is passed, it is used directly.
1692

1693
    Note:
1694
        Used in conjunction with pyemu.helpers.build_jac_test_csv() and sweep to perform
1695
        a Jacobian Test and then view the results. Can generate a lot of plots so easiest
1696
        to put into a separate directory and view the files.
1697

1698
    """
1699

1700
    localhome = os.getcwd()
×
1701
    # check if the output directory exists, if not make it
1702
    if outputdirectory is not None and not os.path.exists(
×
1703
        os.path.join(localhome, outputdirectory)
1704
    ):
1705
        os.mkdir(os.path.join(localhome, outputdirectory))
×
1706
    if outputdirectory is None:
×
1707
        figures_dir = localhome
×
1708
    else:
1709
        figures_dir = os.path.join(localhome, outputdirectory)
×
1710

1711
    # read the input and output files into pandas dataframes
1712
    jactest_in_df = pd.read_csv(csvin, engine="python", index_col=0)
×
1713
    jactest_in_df.index.name = "input_run_id"
×
1714
    jactest_out_df = pd.read_csv(csvout, engine="python", index_col=1)
×
1715

1716
    # subtract the base run from every row, leaves the one parameter that
1717
    # was perturbed in any row as only non-zero value. Set zeros to nan
1718
    # so round-off doesn't get us and sum across rows to get a column of
1719
    # the perturbation for each row, finally extract to a series. First
1720
    # the input csv and then the output.
1721
    base_par = jactest_in_df.loc["base"]
×
1722
    delta_par_df = jactest_in_df.subtract(base_par, axis="columns")
×
1723
    delta_par_df.replace(0, np.nan, inplace=True)
×
1724
    delta_par_df.drop("base", axis="index", inplace=True)
×
1725
    delta_par_df["change"] = delta_par_df.sum(axis="columns")
×
1726
    delta_par = pd.Series(delta_par_df["change"])
×
1727

1728
    base_obs = jactest_out_df.loc["base"]
×
1729
    delta_obs = jactest_out_df.subtract(base_obs)
×
1730
    delta_obs.drop("base", axis="index", inplace=True)
×
1731
    # if targetobs is None, then reset it to all the observations.
1732
    if targetobs is None:
×
1733
        targetobs = jactest_out_df.columns.tolist()[8:]
×
1734
    delta_obs = delta_obs[targetobs]
×
1735

1736
    # get the Jacobian by dividing the change in observation by the change in parameter
1737
    # for the perturbed parameters
1738
    jacobian = delta_obs.divide(delta_par, axis="index")
×
1739

1740
    # use the index created by build_jac_test_csv to get a column of parameter names
1741
    # and increments, then we can plot derivative vs. increment for each parameter
1742
    extr_df = pd.Series(jacobian.index.values).str.extract(r"(.+)(_\d+$)", expand=True)
×
1743
    extr_df[1] = pd.to_numeric(extr_df[1].str.replace("_", "")) + 1
×
1744
    extr_df.rename(columns={0: "parameter", 1: "increment"}, inplace=True)
×
1745
    extr_df.index = jacobian.index
×
1746

1747
    # make a dataframe for plotting the Jacobian by combining the parameter name
1748
    # and increments frame with the Jacobian frame
1749
    plotframe = pd.concat([extr_df, jacobian], axis=1, join="inner")
×
1750

1751
    # get a list of observations to keep based on maxoutputpages.
1752
    if maxoutputpages > 10:
×
1753
        print("WARNING, more than 10 pages of output requested per parameter")
×
1754
        print("maxoutputpage reset to 10.")
×
1755
        maxoutputpages = 10
×
1756
    num_obs_plotted = np.min(np.array([maxoutputpages * 32, len(targetobs)]))
×
1757
    if num_obs_plotted < len(targetobs):
×
1758
        # get random sample
1759
        index_plotted = np.random.choice(len(targetobs), num_obs_plotted, replace=False)
×
1760
        obs_plotted = [targetobs[x] for x in index_plotted]
×
1761
        real_pages = maxoutputpages
×
1762
    else:
1763
        obs_plotted = targetobs
×
1764
        real_pages = int(targetobs / 32) + 1
×
1765

1766
    # make a subplot of derivative vs. increment one plot for each of the
1767
    # observations in targetobs, and outputs grouped by parameter.
1768
    for param, group in plotframe.groupby("parameter"):
×
1769
        for page in range(0, real_pages):
×
1770
            fig, axes = plt.subplots(8, 4, sharex=True, figsize=(10, 15))
×
1771
            for row in range(0, 8):
×
1772
                for col in range(0, 4):
×
1773
                    count = 32 * page + 4 * row + col
×
1774
                    if count < num_obs_plotted:
×
1775
                        axes[row, col].scatter(
×
1776
                            group["increment"], group[obs_plotted[count]]
1777
                        )
1778
                        axes[row, col].plot(
×
1779
                            group["increment"], group[obs_plotted[count]], "r"
1780
                        )
1781
                        axes[row, col].set_title(obs_plotted[count])
×
1782
                        axes[row, col].set_xticks([1, 2, 3, 4, 5])
×
1783
                        axes[row, col].tick_params(direction="in")
×
1784
                        if row == 3:
×
1785
                            axes[row, col].set_xlabel("Increment")
×
1786
            plt.tight_layout()
×
1787

1788
            if filetype is None:
×
1789
                plt.show()
×
1790
            else:
1791
                plt.savefig(
×
1792
                    os.path.join(
1793
                        figures_dir, "{0}_jactest_{1}.{2}".format(param, page, filetype)
1794
                    )
1795
                )
1796
            plt.close()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc