• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pypest / pyemu / 20395943807

20 Dec 2025 02:48PM UTC coverage: 76.3% (-1.5%) from 77.802%
20395943807

Pull #634

github

web-flow
Merge a7465fa40 into c55c668f9
Pull Request #634: feat_DSI-AE

373 of 836 new or added lines in 8 files covered. (44.62%)

35 existing lines in 4 files now uncovered.

14204 of 18616 relevant lines covered (76.3%)

8.06 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.06
/pyemu/plot/plot_utils.py
1
"""Plotting functions for various PEST(++) and pyemu operations"""
2
import os
11✔
3
import numpy as np
11✔
4
import pandas as pd
11✔
5
import warnings
11✔
6
from datetime import datetime
11✔
7
import string
11✔
8
from pyemu.logger import Logger
11✔
9
from pyemu.pst import pst_utils
11✔
10
from ..pyemu_warnings import PyemuWarning
11✔
11
import importlib.util
11✔
12
import pyemu
11✔
13

14
HAS_MATPLOTLIB = importlib.util.find_spec("matplotlib") is not None
11✔
15

16
font = {"font.size": 6}
11✔
17

18
figsize = (8, 10.5)
11✔
19
nr, nc = 4, 2
11✔
20
# page_gs = GridSpec(nr,nc)
21

22
abet = string.ascii_uppercase
11✔
23

24
def apply_custom_font(rc_params=None):
11✔
25
    if rc_params is None:
11✔
26
        rc_params = font
11✔
27
    def decorator(func):
11✔
28
        def wrapper(*args, **kwargs):
11✔
29
            if not HAS_MATPLOTLIB:
11✔
NEW
30
                return func(*args, **kwargs)
×
31

32
            import matplotlib.pyplot as plt
11✔
33
            with plt.rc_context(rc_params):
11✔
34
                return func(*args, **kwargs)
11✔
35
        return wrapper
11✔
36
    return decorator
11✔
37

38

39
@apply_custom_font({"font.size": 6})
11✔
40
def plot_summary_distributions(
11✔
41
    df,
42
    ax=None,
43
    label_post=False,
44
    label_prior=False,
45
    subplots=False,
46
    figsize=(11, 8.5),
47
    pt_color="b",
48
):
49
    """helper function to plot gaussian distributions from prior and posterior
50
    means and standard deviations
51

52
    Args:
53
        df (`pandas.DataFrame`): a dataframe and csv file.  Must have columns named:
54
            'prior_mean','prior_stdev','post_mean','post_stdev'.  If loaded
55
            from a csv file, column 0 is assumed to be the index
56
        ax (`matplotlib.pyplot.axis`): If None, and not subplots, then one is created
57
            and all distributions are plotted on a single plot
58
        label_post (`bool`): flag to add text labels to the peak of the posterior
59
        label_prior (`bool`): flag to add text labels to the peak of the prior
60
        subplots (`bool`): flag to use subplots.  If True, then 6 axes per page
61
            are used and a single prior and posterior is plotted on each
62
        figsize (`tuple`): matplotlib figure size
63

64
    Returns:
65
        tuple containing:
66

67
        - **[`matplotlib.figure`]**: list of figures
68
        - **[`matplotlib.axis`]**: list of axes
69

70
    Note:
71
        This is useful for demystifying FOSM results
72

73
        if subplots is False, a single axis is returned
74

75
    Example::
76

77
        import matplotlib.pyplot as plt
78
        import pyemu
79
        pyemu.plot_utils.plot_summary_distributions("pest.par.usum.csv")
80
        plt.show()
81

82
    """
83
    import matplotlib.pyplot as plt
11✔
84
    if isinstance(df, str):
11✔
85
        df = pd.read_csv(df, index_col=0)
11✔
86
    if ax is None and not subplots:
11✔
87
        fig = plt.figure(figsize=figsize)
11✔
88
        ax = plt.subplot(111)
11✔
89
        ax.grid()
11✔
90

91
    if "post_stdev" not in df.columns and "post_var" in df.columns:
11✔
92
        df.loc[:, "post_stdev"] = df.post_var.apply(np.sqrt)
×
93
    if "prior_stdev" not in df.columns and "prior_var" in df.columns:
11✔
94
        df.loc[:, "prior_stdev"] = df.prior_var.apply(np.sqrt)
×
95
    if "prior_expt" not in df.columns and "prior_mean" in df.columns:
11✔
96
        df.loc[:, "prior_expt"] = df.prior_mean
11✔
97
    if "post_expt" not in df.columns and "post_mean" in df.columns:
11✔
98
        df.loc[:, "post_expt"] = df.post_mean
11✔
99

100
    if subplots:
11✔
101
        fig = plt.figure(figsize=figsize)
11✔
102
        ax = plt.subplot(2, 3, 1)
11✔
103
        ax_per_page = 6
11✔
104
        ax_count = 0
11✔
105
        axes = []
11✔
106
        figs = []
11✔
107
    for name in df.index:
11✔
108
        x, y = gaussian_distribution(
11✔
109
            df.loc[name, "post_expt"], df.loc[name, "post_stdev"]
110
        )
111
        ax.fill_between(x, 0, y, facecolor=pt_color, edgecolor="none", alpha=0.25)
11✔
112
        if label_post:
11✔
113
            mx_idx = np.argmax(y)
11✔
114
            xtxt, ytxt = x[mx_idx], y[mx_idx] * 1.001
11✔
115
            ax.text(xtxt, ytxt, name, ha="center", alpha=0.5)
11✔
116

117
        x, y = gaussian_distribution(
11✔
118
            df.loc[name, "prior_expt"], df.loc[name, "prior_stdev"]
119
        )
120
        ax.plot(x, y, color="0.5", lw=3.0, dashes=(2, 1))
11✔
121
        if label_prior:
11✔
122
            mx_idx = np.argmax(y)
×
123
            xtxt, ytxt = x[mx_idx], y[mx_idx] * 1.001
×
124
            ax.text(xtxt, ytxt, name, ha="center", alpha=0.5)
×
125
        # ylim = list(ax.get_ylim())
126
        # ylim[1] *= 1.2
127
        # ylim[0] = 0.0
128
        # ax.set_ylim(ylim)
129
        if subplots:
11✔
130
            ax.set_title(name)
11✔
131
            ax_count += 1
11✔
132
            ax.set_yticklabels([])
11✔
133
            axes.append(ax)
11✔
134
            if name == df.index[-1]:
11✔
135
                break
11✔
136
            if ax_count >= ax_per_page:
11✔
137
                figs.append(fig)
11✔
138
                fig = plt.figure(figsize=figsize)
11✔
139
                ax_count = 0
11✔
140
            ax = plt.subplot(2, 3, ax_count + 1)
11✔
141
    if subplots:
11✔
142
        figs.append(fig)
11✔
143
        return figs, axes
11✔
144
    ylim = list(ax.get_ylim())
11✔
145
    ylim[1] *= 1.2
11✔
146
    ylim[0] = 0.0
11✔
147
    ax.set_ylim(ylim)
11✔
148
    ax.set_yticklabels([])
11✔
149
    return ax
11✔
150

151

152
@apply_custom_font()
11✔
153
def gaussian_distribution(mean, stdev, num_pts=50):
11✔
154
    """get an x and y numpy.ndarray that spans the +/- 4
155
    standard deviation range of a gaussian distribution with
156
    a given mean and standard deviation. useful for plotting
157

158
    Args:
159
        mean (`float`): the mean of the distribution
160
        stdev (`float`): the standard deviation of the distribution
161
        num_pts (`int`): the number of points in the returned ndarrays.
162
            Default is 50
163

164
    Returns:
165
        tuple containing:
166

167
        - **numpy.ndarray**: the x-values of the distribution
168
        - **numpy.ndarray**: the y-values of the distribution
169

170
    Example::
171

172
        mean,std = 1.0, 2.0
173
        x,y = pyemu.plot.gaussian_distribution(mean,std)
174
        plt.fill_between(x,0,y)
175
        plt.show()
176

177

178
    """
179
    xstart = mean - (4.0 * stdev)
11✔
180
    xend = mean + (4.0 * stdev)
11✔
181
    x = np.linspace(xstart, xend, num_pts)
11✔
182
    y = (1.0 / np.sqrt(2.0 * np.pi * stdev * stdev)) * np.exp(
11✔
183
        -1.0 * ((x - mean) ** 2) / (2.0 * stdev * stdev)
184
    )
185
    return x, y
11✔
186

187

188
@apply_custom_font()
11✔
189
def pst_helper(pst, kind=None, **kwargs):
11✔
190
    """`pyemu.Pst` plot helper - takes the
191
    handoff from `pyemu.Pst.plot()`
192

193
    Args:
194
        kind (`str`): the kind of plot to make
195
        **kargs (`dict`): keyword arguments to pass to the
196
            plotting function and ultimately to `matplotlib`
197

198
    Returns:
199
        varies: usually a combination of `matplotlib.figure` (s) and/or
200
        `matplotlib.axis` (s)
201

202
    Example::
203

204
        pst = pyemu.Pst("pest.pst") #assumes pest.res or pest.rei is found
205
        pst.plot(kind="1to1")
206
        plt.show()
207
        pst.plot(kind="phipie")
208
        plt.show()
209
        pst.plot(kind="prior")
210
        plt.show()
211

212
    """
213

214
    echo = kwargs.get("echo", False)
11✔
215
    logger = pyemu.Logger("plot_pst_helper.log", echo=echo)
11✔
216
    logger.statement("plot_utils.pst_helper()")
11✔
217

218
    kinds = {
11✔
219
        "prior": pst_prior,
220
        "1to1": res_1to1,
221
        "phi_pie": res_phi_pie,
222
        "phi_progress": phi_progress,
223
    }
224

225
    if kind is None:
11✔
226
        returns = []
11✔
227
        base_filename = pst.filename
11✔
228
        if pst.new_filename is not None:
11✔
229
            base_filename = pst.new_filename
×
230
        base_filename = base_filename.replace(".pst", "")
11✔
231
        for name, funcc in kinds.items():
11✔
232
            plt_name = base_filename + "." + name + ".pdf"
11✔
233
            returns.append(funcc(pst, logger=logger, filename=plt_name))
11✔
234

235
        return returns
11✔
236
    elif kind not in kinds:
11✔
237
        logger.lraise(
×
238
            "unrecognized kind:{0}, should one of {1}".format(
239
                kind, ",".join(list(kinds.keys()))
240
            )
241
        )
242
    return kinds[kind](pst, logger, **kwargs)
11✔
243

244

245
@apply_custom_font()
11✔
246
def phi_progress(pst, logger=None, filename=None, **kwargs):
11✔
247
    """make plot of phi vs number of model runs - requires
248
    available  ".iobj" file generated by a PESTPP-GLM run.
249

250
    Args:
251
        pst (`pyemu.Pst`): a control file instance
252
        logger (`pyemu.Logger`):  if None, a generic one is created.  Default is None
253
        filename (`str`): PDF filename to save figures to.  If None, figures
254
            are returned.  Default is None
255
        kwargs (`dict`): optional keyword args to pass to plotting function
256

257
    Returns:
258
        `matplotlib.axis`: the axis the plot was made on
259

260
    Example::
261

262
        import matplotlib.pyplot as plt
263
        import pyemu
264
        pst = pyemu.Pst("my.pst")
265
        pyemu.plot_utils.phi_progress(pst)
266
        plt.show()
267

268
    """
269
    _ensure_matplotlib()
11✔
270
    import matplotlib.pyplot as plt
11✔
271

272
    if logger is None:
11✔
273
        logger = Logger("Default_Logger.log", echo=False)
×
274
    logger.log("plot phi_progress")
11✔
275

276
    iobj_file = pst.filename.replace(".pst", ".iobj")
11✔
277
    if not os.path.exists(iobj_file):
11✔
278
        logger.lraise("couldn't find iobj file {0}".format(iobj_file))
×
279
    df = pd.read_csv(iobj_file)
11✔
280
    if "ax" in kwargs:
11✔
281
        ax = kwargs["ax"]
×
282
    else:
283
        fig = plt.figure(figsize=figsize)
11✔
284
        ax = plt.subplot(1, 1, 1)
11✔
285
    ax.plot(df.model_runs_completed, df.total_phi, marker=".")
11✔
286
    ax.set_xlabel("model runs")
11✔
287
    ax.set_ylabel(r"$\phi$")
11✔
288
    ax.grid()
11✔
289
    if filename is not None:
11✔
290
        plt.savefig(filename)
11✔
291
    logger.log("plot phi_progress")
11✔
292
    return ax
11✔
293

294

295
def _get_page_axes(count=nr * nc):
11✔
296
    import matplotlib.pyplot as plt
11✔
297

298
    axes = [plt.subplot(nr, nc, i + 1) for i in range(min(count, nr * nc))]
11✔
299
    # [ax.set_yticks([]) for ax in axes]
300
    return axes
11✔
301

302

303
@apply_custom_font()
11✔
304
def res_1to1(
11✔
305
    pst, logger=None, filename=None, plot_hexbin=False, histogram=False, **kwargs
306
):
307
    """make 1-to-1 plots and also observed vs residual by observation group
308

309
    Args:
310
        pst (`pyemu.Pst`): a control file instance
311
        logger (`pyemu.Logger`):  if None, a generic one is created.  Default is None
312
        filename (`str`): PDF filename to save figures to.  If None, figures
313
            are returned.  Default is None
314
        hexbin (`bool`): flag to use the hexbinning for large numbers of residuals.
315
            Default is False
316
        histogram (`bool`): flag to plot residual histograms instead of obs vs residual.
317
            Default is False (use `matplotlib.pyplot.scatter` )
318
        kwargs (`dict`): optional keyword args to pass to plotting function
319

320
    Returns:
321
        `matplotlib.axis`: the axis the plot was made on
322

323
    Example::
324

325
        import matplotlib.pyplot as plt
326
        import pyemu
327
        pst = pyemu.Pst("my.pst")
328
        pyemu.plot_utils.phi_progress(pst)
329
        plt.show()
330

331
    """
332
    if not HAS_MATPLOTLIB:
11✔
NEW
333
        msg = (
×
334
            "'res_1to1' requires the 'matplotlib' package. Install it "
335
            "with 'pip install pyemu[optional]'."
336
        )
NEW
337
        raise ImportError(msg)
×
338

339
    import matplotlib.pyplot as plt
11✔
340
    from matplotlib.backends.backend_pdf import PdfPages
11✔
341

342
    if logger is None:
11✔
343
        logger = Logger("Default_Logger.log", echo=False)
11✔
344
    logger.log("plot res_1to1")
11✔
345

346
    if "ensemble" in kwargs:
11✔
347
        res = pst_utils.res_from_en(pst, kwargs["ensemble"])
11✔
348
        try:
11✔
349
            res = pst_utils.res_from_en(pst, kwargs["ensemble"])
11✔
350
        except Exception as e:
×
351
            logger.lraise("res_1to1: error loading ensemble file: {0}".format(str(e)))
×
352
    else:
353
        try:
11✔
354
            res = pst.res
11✔
355
        except:
×
356
            logger.lraise("res_phi_pie: pst.res is None, couldn't find residuals file")
×
357

358
    obs = pst.observation_data
11✔
359

360
    if "grouper" in kwargs:
11✔
361
        raise NotImplementedError()
×
362
    else:
363
        grouper = obs.groupby(obs.obgnme).groups
11✔
364

365
    fig = plt.figure(figsize=figsize)
11✔
366
    if "fig_title" in kwargs:
11✔
367
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
11✔
368
    else:
369
        plt.figtext(
11✔
370
            0.5,
371
            0.5,
372
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
373
                pst.filename, str(datetime.now())
374
            ),
375
            ha="center",
376
        )
377
    # if plot_hexbin:
378
    #    pdfname = pst.filename.replace(".pst", ".1to1.hexbin.pdf")
379
    # else:
380
    #    pdfname = pst.filename.replace(".pst", ".1to1.pdf")
381
    figs = []
11✔
382
    ax_count = 0
11✔
383
    for g, names in grouper.items():
11✔
384
        logger.log("plotting 1to1 for {0}".format(g))
11✔
385

386
        obs_g = obs.loc[names, :]
11✔
387
        obs_g.loc[:, "sim"] = res.loc[names, "modelled"]
11✔
388
        logger.statement("using control file obsvals to calculate residuals")
11✔
389
        obs_g.loc[:, "res"] = obs_g.sim - obs_g.obsval
11✔
390
        if "include_zero" not in kwargs or kwargs["include_zero"] is False:
11✔
391
            obs_g = obs_g.loc[obs_g.weight > 0, :]
11✔
392
        if obs_g.shape[0] == 0:
11✔
393
            logger.statement("no non-zero obs for group '{0}'".format(g))
11✔
394
            logger.log("plotting 1to1 for {0}".format(g))
11✔
395
            continue
11✔
396

397
        if ax_count % (nr * nc) == 0:
11✔
398
            if ax_count > 0:
11✔
399
                plt.tight_layout()
11✔
400
            # pdf.savefig()
401
            # plt.close(fig)
402
            figs.append(fig)
11✔
403
            fig = plt.figure(figsize=figsize)
11✔
404
            axes = _get_page_axes()
11✔
405
            ax_count = 0
11✔
406

407
        ax = axes[ax_count]
11✔
408

409
        # if obs_g.shape[0] == 1:
410
        #    ax.scatter(list(obs_g.sim),list(obs_g.obsval),marker='.',s=30,color='b')
411
        # else:
412
        mx = max(obs_g.obsval.max(), obs_g.sim.max())
11✔
413
        mn = min(obs_g.obsval.min(), obs_g.sim.min())
11✔
414

415
        # if obs_g.shape[0] == 1:
416
        mx *= 1.1
11✔
417
        mn *= 0.9
11✔
418
        ax.axis("square")
11✔
419
        if plot_hexbin:
11✔
420
            ax.hexbin(
×
421
                obs_g.obsval.values,
422
                obs_g.sim.values,
423
                mincnt=1,
424
                gridsize=(75, 75),
425
                extent=(mn, mx, mn, mx),
426
                bins="log",
427
                edgecolors=None,
428
            )
429
        #               plt.colorbar(ax=ax)
430
        else:
431
            ax.scatter(obs_g.obsval.values, obs_g.sim.values,
11✔
432
                       marker=".", s=10, color="b")
433

434
        ax.plot([mn, mx], [mn, mx], "k--", lw=1.0)
11✔
435
        xlim = (mn, mx)
11✔
436
        ax.set_xlim(mn, mx)
11✔
437
        ax.set_ylim(mn, mx)
11✔
438
        ax.grid()
11✔
439

440
        ax.set_xlabel("observed", labelpad=0.1)
11✔
441
        ax.set_ylabel("simulated", labelpad=0.1)
11✔
442
        ax.set_title(
11✔
443
            "{0}) group:{1}, {2} observations".format(
444
                abet[ax_count], g, obs_g.shape[0]
445
            ),
446
            loc="left",
447
        )
448

449
        ax_count += 1
11✔
450

451
        if histogram == False:
11✔
452
            ax = axes[ax_count]
11✔
453
            ax.scatter(obs_g.obsval.values, obs_g.res.values,
11✔
454
                       marker=".", s=10, color="b")
455
            ylim = ax.get_ylim()
11✔
456
            mx = max(np.abs(ylim[0]), np.abs(ylim[1]))
11✔
457
            if obs_g.shape[0] == 1:
11✔
458
                mx *= 1.1
11✔
459
            ax.set_ylim(-mx, mx)
11✔
460
            # show a zero residuals line
461
            ax.plot(xlim, [0, 0], "k--", lw=1.0)
11✔
462
            meanres = obs_g.res.mean()
11✔
463
            # show mean residuals line
464
            ax.plot(xlim, [meanres, meanres], "r-", lw=1.0)
11✔
465
            ax.set_xlim(xlim)
11✔
466
            ax.set_ylabel("residual", labelpad=0.1)
11✔
467
            ax.set_xlabel("observed", labelpad=0.1)
11✔
468
            ax.set_title(
11✔
469
                "{0}) group:{1}, {2} observations".format(
470
                    abet[ax_count], g, obs_g.shape[0]
471
                ),
472
                loc="left",
473
            )
474
            ax.grid()
11✔
475
            ax_count += 1
11✔
476
        else:
477
            # need max and min res to set xlim, otherwise wonky figsize
478
            mxr = obs_g.res.max()
11✔
479
            mnr = obs_g.res.min()
11✔
480

481
            # if obs_g.shape[0] == 1:
482
            mxr *= 1.1
11✔
483
            mnr *= 0.9
11✔
484
            rlim = (mnr, mxr)
11✔
485

486
            ax = axes[ax_count]
11✔
487
            ax.hist(obs_g.res, bins=50, color="b")
11✔
488
            meanres = obs_g.res.mean()
11✔
489
            ax.axvline(meanres, color="r", lw=1)
11✔
490
            b, t = ax.get_ylim()
11✔
491
            ax.text(meanres + meanres / 10, t - t / 10, "Mean: {:.2f}".format(meanres))
11✔
492
            ax.set_xlim(rlim)
11✔
493
            ax.set_ylabel("count", labelpad=0.1)
11✔
494
            ax.set_xlabel("residual", labelpad=0.1)
11✔
495
            ax.set_title(
11✔
496
                "{0}) group:{1}, {2} observations".format(
497
                    abet[ax_count], g, obs_g.shape[0]
498
                ),
499
                loc="left",
500
            )
501
            ax.grid()
11✔
502
            ax_count += 1
11✔
503
        logger.log("plotting 1to1 for {0}".format(g))
11✔
504

505
    for a in range(ax_count, nr * nc):
11✔
506
        axes[a].set_axis_off()
11✔
507
        axes[a].set_yticks([])
11✔
508
        axes[a].set_xticks([])
11✔
509

510
    plt.tight_layout()
11✔
511
    # pdf.savefig()
512
    # plt.close(fig)
513
    figs.append(fig)
11✔
514
    if filename is not None:
11✔
515
        with PdfPages(filename) as pdf:
11✔
516
            for fig in figs:
11✔
517
                pdf.savefig(fig)
11✔
518
                plt.close(fig)
11✔
519
        logger.log("plot res_1to1")
11✔
520
    else:
521
        logger.log("plot res_1to1")
11✔
522
        return figs
11✔
523

524

525
@apply_custom_font()
11✔
526
def plot_id_bar(id_df, nsv=None, logger=None, **kwargs):
11✔
527
    """Plot a stacked bar chart of identifiability based on
528
    a the `pyemu.ErrVar.get_identifiability()` dataframe
529

530
    Args:
531
        id_df (`pandas.DataFrame`) : dataframe of identifiability
532
        nsv (`int`): number of singular values to consider
533
        logger (`pyemu.Logger`, optional): a logger.  If None, a generic
534
            one is created
535
        kwargs (`dict`): a dict of keyword arguments to pass to the
536
            plotting function
537

538
    Returns:
539
        `matplotlib.Axis`: the axis with the plot
540

541
    Example::
542

543
        import pyemu
544
        pest_obj = pyemu.Pst(pest_control_file)
545
        ev = pyemu.ErrVar(jco='freyberg_jac.jcb'))
546
        id_df = ev.get_identifiability_dataframe(singular_value=48)
547
        pyemu.plot_utils.plot_id_bar(id_df, nsv=12, figsize=(12,4)
548

549
    """
550
    _ensure_matplotlib()
11✔
551
    import matplotlib.colors
11✔
552
    import matplotlib.pyplot as plt
11✔
553

554
    if logger is None:
11✔
555
        logger = Logger("Default_Logger.log", echo=False)
11✔
556
    logger.log("plot id bar")
11✔
557

558
    df = id_df.copy()
11✔
559

560
    # drop the final `ident` column
561
    if "ident" in df.columns:
11✔
562
        df.drop("ident", inplace=True, axis=1)
11✔
563

564
    if nsv is None or nsv > len(df.columns):
11✔
565
        nsv = len(df.columns)
11✔
566
        logger.log("set number of SVs and number in the dataframe")
11✔
567

568
    df = df[df.columns[:nsv]]
11✔
569

570
    df["ident"] = df.sum(axis=1)
11✔
571
    df.sort_values(by="ident", inplace=True, ascending=False)
11✔
572
    df.drop("ident", inplace=True, axis=1)
11✔
573

574
    if "figsize" in kwargs:
11✔
575
        figsize = kwargs["figsize"]
×
576
    else:
577
        figsize = (8, 10.5)
11✔
578
    if "ax" in kwargs:
11✔
579
        ax = kwargs["ax"]
×
580
    else:
581
        fig = plt.figure(figsize=figsize)
11✔
582
        ax = plt.subplot(1, 1, 1)
11✔
583

584
    # plto the stacked bar chart (the easy part!)
585
    df.plot.bar(stacked=True, cmap="jet_r", legend=False, ax=ax)
11✔
586

587
    #
588
    # horrible shenanigans to make a colorbar rather than a legend
589
    #
590

591
    # special case colormap just dark red if one SV
592
    if nsv == 1:
11✔
593
        tcm = matplotlib.colors.LinearSegmentedColormap.from_list(
×
594
            "one_sv", [plt.get_cmap("jet_r")(0)] * 2, N=2
595
        )
596
        sm = plt.cm.ScalarMappable(
×
597
            cmap=tcm, norm=matplotlib.colors.Normalize(vmin=0, vmax=nsv + 1)
598
        )
599
    # or typically just rock the jet_r colormap over the range of SVs
600
    else:
601
        sm = plt.cm.ScalarMappable(
11✔
602
            cmap=plt.get_cmap("jet_r"),
603
            norm=matplotlib.colors.Normalize(vmin=1, vmax=nsv),
604
        )
605
    sm._A = []
11✔
606

607
    # now, if too many ticks for the colorbar, summarize them
608
    if nsv < 20:
11✔
609
        ticks = range(1, nsv + 1)
11✔
610
    else:
611
        ticks = np.arange(1, nsv + 1, int((nsv + 1) / 30))
×
612

613
    cb = plt.colorbar(sm, ax=ax)
11✔
614
    cb.set_ticks(ticks)
11✔
615

616
    logger.log("plot id bar")
11✔
617

618
    return ax
11✔
619

620

621
@apply_custom_font()
11✔
622
def res_phi_pie(pst, logger=None, **kwargs):
11✔
623
    """plot current phi components as a pie chart.
624

625
    Args:
626
        pst (`pyemu.Pst`): a control file instance with the residual datafrane
627
            instance available.
628
        logger (`pyemu.Logger`): a logger.  If None, a generic one is created
629
        kwargs (`dict`): a dict of plotting options. Accepts 'include_zero'
630
            as a flag to include phi groups with only zero-weight obs (not
631
            sure why anyone would do this, but whatevs).
632

633
            Also accepts 'label_comps': list of components for the labels. Options are
634
            ['name', 'phi_comp', 'phi_percent']. Labels will use those three components
635
            in the order of the 'label_comps' list.
636

637
            Any additional
638
            args are passed to `matplotlib`.
639

640
    Returns:
641
        `matplotlib.Axis`: the axis with the plot.
642

643
    Example::
644

645
        import pyemu
646
        pst = pyemu.Pst("my.pst")
647
        pyemu.plot_utils.res_phi_pie(pst,figsize=(12,4))
648
        pyemu.plot_utils.res_phi_pie(pst,label_comps = ['name','phi_percent'], figsize=(12,4))
649

650

651
    """
652
    _ensure_matplotlib()
11✔
653
    import matplotlib.pyplot as plt
11✔
654

655
    if logger is None:
11✔
656
        logger = Logger("Default_Logger.log", echo=False)
11✔
657
    logger.log("plot res_phi_pie")
11✔
658

659
    if "ensemble" in kwargs:
11✔
660
        try:
11✔
661
            res = pst_utils.res_from_en(pst, kwargs["ensemble"])
11✔
662
        except:
×
663
            logger.statement(
×
664
                "res_1to1: could not find ensemble file {0}".format(kwargs["ensemble"])
665
            )
666
    else:
667
        try:
11✔
668
            res = pst.res
11✔
669
        except:
×
670
            logger.lraise("res_phi_pie: pst.res is None, couldn't find residuals file")
×
671

672
    obs = pst.observation_data
11✔
673
    phi = pst.phi
11✔
674
    phi_comps = pst.phi_components
11✔
675
    norm_phi_comps = pst.phi_components_normalized
11✔
676
    keys = list(phi_comps.keys())
11✔
677
    if "include_zero" not in kwargs or kwargs["include_zero"] is False:
11✔
678
        phi_comps = {k: phi_comps[k] for k in keys if phi_comps[k] > 0.0}
11✔
679
        keys = list(phi_comps.keys())
11✔
680
        norm_phi_comps = {k: norm_phi_comps[k] for k in keys}
11✔
681
    if "ax" in kwargs:
11✔
682
        ax = kwargs["ax"]
11✔
683
    else:
684
        fig = plt.figure(figsize=figsize)
11✔
685
        ax = plt.subplot(1, 1, 1, aspect="equal")
11✔
686

687
    if "label_comps" not in kwargs:
11✔
688
        labels = [
11✔
689
            "{0}\n{1:4G}\n({2:3.1f}%)".format(
690
                k, phi_comps[k], 100.0 * (phi_comps[k] / phi)
691
            )
692
            for k in keys
693
        ]
694
    else:
695
        # make sure the components for the labels are in a list
696
        if not isinstance(kwargs["label_comps"], list):
×
697
            fmtchoices = list([kwargs["label_comps"]])
×
698
        else:
699
            fmtchoices = kwargs["label_comps"]
×
700
        # assemble all possible label components
701
        labfmts = {
×
702
            "name": ["{}\n", keys],
703
            "phi_comp": ["{:4G}\n", [phi_comps[k] for k in keys]],
704
            "phi_percent": ["({:3.1f}%)", [100.0 * (phi_comps[k] / phi) for k in keys]],
705
        }
706
        if fmtchoices[0] == "phi_percent":
×
707
            labfmts["phi_percent"][0] = "{}\n".format(labfmts["phi_percent"][0])
×
708
        # make the string format
709
        labfmtstr = "".join([labfmts[k][0] for k in fmtchoices])
×
710
        # pull it together
711
        labels = [
×
712
            labfmtstr.format(*k) for k in zip(*[labfmts[j][1] for j in fmtchoices])
713
        ]
714

715
    ax.pie([float(norm_phi_comps[k]) for k in keys], labels=labels)
11✔
716
    logger.log("plot res_phi_pie")
11✔
717
    if "filename" in kwargs:
11✔
718
        plt.savefig(kwargs["filename"])
11✔
719
    return ax
11✔
720

721

722
@apply_custom_font()
11✔
723
def pst_prior(pst, logger=None, filename=None, **kwargs):
11✔
724
    """helper to plot prior parameter histograms implied by
725
    parameter bounds. Saves a multipage pdf named <case>.prior.pdf
726

727
    Args:
728
        pst (`pyemu.Pst`): control file
729
        logger (`pyemu.Logger`): a logger.  If None, a generic one is created.
730
        filename (`str`):  PDF filename to save plots to.
731
            If None, return figs without saving.  Default is None.
732
        kwargs (`dict`): additional plotting options. Accepts 'grouper' as
733
            dict to group parameters on to a single axis (use
734
            parameter groups if not passed),'unique_only' to only show unique
735
            mean-stdev combinations within a given group.  Any additional args
736
            are passed to `matplotlib`.
737

738
    Returns:
739
        [`matplotlib.Figure`]: a list of figures created.
740

741
    Example::
742

743
        pst = pyemu.Pst("pest.pst")
744
        pyemu.pst_utils.pst_prior(pst)
745
        plt.show()
746

747
    """
748
    _ensure_matplotlib()
11✔
749
    import matplotlib.pyplot as plt
11✔
750
    from matplotlib.backends.backend_pdf import PdfPages
11✔
751

752
    if logger is None:
11✔
753
        logger = Logger("Default_Logger.log", echo=False)
×
754
    logger.log("plot pst_prior")
11✔
755
    par = pst.parameter_data
11✔
756

757
    if "parcov_filename" in pst.pestpp_options:
11✔
758
        logger.warn("ignoring parcov_filename, using parameter bounds for prior cov")
×
759
    logger.log("loading cov from parameter data")
11✔
760
    cov = pyemu.Cov.from_parameter_data(pst)
11✔
761
    logger.log("loading cov from parameter data")
11✔
762

763
    logger.log("building mean parameter values")
11✔
764
    li = par.partrans.loc[cov.names] == "log"
11✔
765
    mean = par.parval1.loc[cov.names]
11✔
766
    info = par.loc[cov.names, :].copy()
11✔
767
    info.loc[:, "mean"] = mean
11✔
768
    info.loc[li, "mean"] = mean[li].apply(np.log10)
11✔
769
    logger.log("building mean parameter values")
11✔
770

771
    logger.log("building stdev parameter values")
11✔
772
    if cov.isdiagonal:
11✔
773
        std = cov.x.flatten()
11✔
774
    else:
775
        std = np.diag(cov.x)
×
776
    std = np.sqrt(std)
11✔
777
    info.loc[:, "prior_std"] = std
11✔
778

779
    logger.log("building stdev parameter values")
11✔
780

781
    if std.shape != mean.shape:
11✔
782
        logger.lraise("mean.shape {0} != std.shape {1}".format(mean.shape, std.shape))
×
783

784
    if "grouper" in kwargs:
11✔
785
        raise NotImplementedError()
×
786
        # check for consistency here
787

788
    else:
789
        par_adj = par.loc[par.partrans.apply(lambda x: x in ["log", "none"]), :]
11✔
790
        grouper = par_adj.groupby(par_adj.pargp).groups
11✔
791
        # grouper = par.groupby(par.pargp).groups
792

793
    if len(grouper) == 0:
11✔
794
        raise Exception("no adustable parameters to plot")
×
795

796
    fig = plt.figure(figsize=figsize)
11✔
797
    if "fig_title" in kwargs:
11✔
798
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
11✔
799
    else:
800
        plt.figtext(
11✔
801
            0.5,
802
            0.5,
803
            "pyemu.Pst.plot(kind='prior')\nfrom pest control file '{0}'\n at {1}".format(
804
                pst.filename, str(datetime.now())
805
            ),
806
            ha="center",
807
        )
808
    figs = []
11✔
809
    ax_count = 0
11✔
810
    grps_names = list(grouper.keys())
11✔
811
    grps_names.sort()
11✔
812
    for g in grps_names:
11✔
813
        names = grouper[g]
11✔
814
        logger.log("plotting priors for {0}".format(",".join(list(names))))
11✔
815
        if ax_count % (nr * nc) == 0:
11✔
816
            plt.tight_layout()
11✔
817
            # pdf.savefig()
818
            # plt.close(fig)
819
            figs.append(fig)
11✔
820
            fig = plt.figure(figsize=figsize)
11✔
821
            axes = _get_page_axes()
11✔
822
            ax_count = 0
11✔
823

824
        islog = False
11✔
825
        vc = info.partrans.value_counts()
11✔
826
        if vc.shape[0] > 1:
11✔
827
            logger.warn("mixed partrans for group {0}".format(g))
×
828
        elif "log" in vc.index:
11✔
829
            islog = True
11✔
830
        ax = axes[ax_count]
11✔
831
        if "unique_only" in kwargs and kwargs["unique_only"]:
11✔
832

833
            ms = (
11✔
834
                info.loc[names, :]
835
                .apply(lambda x: (x["mean"], x["prior_std"]), axis=1)
836
                .unique()
837
            )
838
            for (m, s) in ms:
11✔
839
                x, y = gaussian_distribution(m, s)
11✔
840
                ax.fill_between(x, 0, y, facecolor="0.5", alpha=0.5, edgecolor="none")
11✔
841

842
        else:
843
            for m, s in zip(info.loc[names, "mean"], info.loc[names, "prior_std"]):
11✔
844
                x, y = gaussian_distribution(m, s)
11✔
845
                ax.fill_between(x, 0, y, facecolor="0.5", alpha=0.5, edgecolor="none")
11✔
846
        ax.set_title(
11✔
847
            "{0}) group:{1}, {2} parameters".format(abet[ax_count], g, names.shape[0]),
848
            loc="left",
849
        )
850

851
        ax.set_yticks([])
11✔
852
        if islog:
11✔
853
            ax.set_xlabel("$log_{10}$ parameter value", labelpad=0.1)
11✔
854
        else:
855
            ax.set_xlabel("parameter value", labelpad=0.1)
11✔
856
        logger.log("plotting priors for {0}".format(",".join(list(names))))
11✔
857

858
        ax_count += 1
11✔
859

860
    for a in range(ax_count, nr * nc):
11✔
861
        axes[a].set_axis_off()
11✔
862
        axes[a].set_yticks([])
11✔
863
        axes[a].set_xticks([])
11✔
864

865
    plt.tight_layout()
11✔
866
    # pdf.savefig()
867
    # plt.close(fig)
868
    figs.append(fig)
11✔
869
    if filename is not None:
11✔
870
        with PdfPages(filename) as pdf:
11✔
871
            # plt.tight_layout()
872
            pdf.savefig(fig)
11✔
873
            plt.close(fig)
11✔
874
        logger.log("plot pst_prior")
11✔
875
    else:
876
        logger.log("plot pst_prior")
11✔
877
        return figs
11✔
878

879

880
@apply_custom_font()
11✔
881
def ensemble_helper(
11✔
882
    ensemble,
883
    bins=10,
884
    facecolor="0.5",
885
    plot_cols=None,
886
    filename=None,
887
    func_dict=None,
888
    sync_bins=True,
889
    deter_vals=None,
890
    std_window=None,
891
    deter_range=False,
892
    **kwargs
893
):
894
    """helper function to plot ensemble histograms
895

896
    Args:
897
        ensemble : varies
898
            the ensemble argument can be a pandas.DataFrame or derived type or a str, which
899
            is treated as a filename.  Optionally, ensemble can be a list of these types or
900
            a dict, in which case, the keys are treated as facecolor str (e.g., 'b', 'y', etc).
901
        facecolor : str
902
            the histogram facecolor.  Only applies if ensemble is a single thing
903
        plot_cols : enumerable
904
            a collection of columns (in form of a list of parameters, or a dict with keys for
905
            parsing plot axes and values of parameters) from the ensemble(s) to plot.  If None,
906
            (the union of) all cols are plotted. Default is None
907
        filename : str
908
            the name of the pdf to create.  If None, return figs without saving.  Default is None.
909
        func_dict : dict
910
            a dictionary of unary functions (e.g., `np.log10` to apply to columns.  Key is
911
            column name.  Default is None
912
        sync_bins : bool
913
            flag to use the same bin edges for all ensembles. Only applies if more than
914
            one ensemble is being plotted.  Default is True
915
        deter_vals : dict
916
            dict of deterministic values to plot as a vertical line. key is ensemble column name
917
        std_window : float
918
            the number of standard deviations around the mean to mark as vertical lines.  If None,
919
            nothing happens.  Default is None
920
        deter_range : bool
921
            flag to set xlims to deterministic value +/- std window.  If True, std_window must not be None.
922
            Default is False
923

924
    Example::
925

926
        # plot prior and posterior par ensembles
927
        pst = pyemu.Pst("my.pst")
928
        prior = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="prior.jcb")
929
        post = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="my.3.par.jcb")
930
        pyemu.plot_utils.ensemble_helper(ensemble={"0.5":prior, "b":post},filename="ensemble.pdf")
931

932
        #plot prior and posterior simulated equivalents to observations with obs noise and obs vals
933
        pst = pyemu.Pst("my.pst")
934
        prior = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.0.obs.jcb")
935
        post = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.3.obs.jcb")
936
        noise = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.obs+noise.jcb")
937
        pyemu.plot_utils.ensemble_helper(ensemble={"0.5":prior, "b":post,"r":noise},
938
                                         filename="ensemble.pdf",
939
                                         deter_vals=pst.observation_data.obsval.to_dict())
940

941

942
    """
943
    _ensure_matplotlib()
11✔
944
    import matplotlib.pyplot as plt
11✔
945
    from matplotlib.backends.backend_pdf import PdfPages
11✔
946

947
    logger = pyemu.Logger("ensemble_helper.log")
11✔
948
    logger.log("pyemu.plot_utils.ensemble_helper()")
11✔
949
    ensembles = _process_ensemble_arg(ensemble, facecolor, logger)
11✔
950
    if len(ensembles) == 0:
11✔
951
        raise Exception("plot_uitls.ensemble_helper() error processing `ensemble` arg")
×
952
    # apply any functions
953
    if func_dict is not None:
11✔
954
        logger.log("applying functions")
11✔
955
        for col, funcc in func_dict.items():
11✔
956
            for fc, en in ensembles.items():
11✔
957
                if col in en.columns:
11✔
958
                    en.loc[:, col] = en.loc[:, col].apply(funcc)
11✔
959
        logger.log("applying functions")
11✔
960

961
    # get a list of all cols (union)
962
    all_cols = set()
11✔
963
    for fc, en in ensembles.items():
11✔
964
        cols = set(en.columns)
11✔
965
        all_cols.update(cols)
11✔
966
    if plot_cols is None:
11✔
967
        plot_cols = {i: [v] for i, v in (zip(all_cols, all_cols))}
×
968
    else:
969
        if isinstance(plot_cols, list):
11✔
970
            splot_cols = set(plot_cols)
11✔
971
            plot_cols = {i: [v] for i, v in (zip(plot_cols, plot_cols))}
11✔
972
        elif isinstance(plot_cols, dict):
11✔
973
            splot_cols = []
11✔
974
            for label, pcols in plot_cols.items():
11✔
975
                splot_cols.extend(list(pcols))
11✔
976
            splot_cols = set(splot_cols)
11✔
977
        else:
978
            logger.lraise(
×
979
                "unrecognized plot_cols type: {0}, should be list or dict".format(
980
                    type(plot_cols)
981
                )
982
            )
983

984
        missing = splot_cols - all_cols
11✔
985
        if len(missing) > 0:
11✔
986
            logger.lraise(
×
987
                "the following plot_cols are missing: {0}".format(",".join(missing))
988
            )
989

990
    logger.statement("plotting {0} histograms".format(len(plot_cols)))
11✔
991

992
    fig = plt.figure(figsize=figsize)
11✔
993
    if "fig_title" in kwargs:
11✔
994
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
995
    else:
996
        plt.figtext(
11✔
997
            0.5,
998
            0.5,
999
            "pyemu.plot_utils.ensemble_helper()\n at {0}".format(str(datetime.now())),
1000
            ha="center",
1001
        )
1002
    # plot_cols = list(plot_cols)
1003
    # plot_cols.sort()
1004
    labels = list(plot_cols.keys())
11✔
1005
    labels.sort()
11✔
1006
    logger.statement("saving pdf to {0}".format(filename))
11✔
1007
    figs = []
11✔
1008

1009
    ax_count = 0
11✔
1010

1011
    # for label,plot_col in plot_cols.items():
1012
    for label in labels:
11✔
1013
        plot_col = plot_cols[label]
11✔
1014
        logger.log("plotting reals for {0}".format(label))
11✔
1015
        if ax_count % (nr * nc) == 0:
11✔
1016
            plt.tight_layout()
11✔
1017
            # pdf.savefig()
1018
            # plt.close(fig)
1019
            figs.append(fig)
11✔
1020
            fig = plt.figure(figsize=figsize)
11✔
1021
            axes = _get_page_axes()
11✔
1022
            [ax.set_yticks([]) for ax in axes]
11✔
1023
            ax_count = 0
11✔
1024

1025
        ax = axes[ax_count]
11✔
1026

1027
        if sync_bins:
11✔
1028
            mx, mn = -1.0e30, 1.0e30
11✔
1029
            for fc, en in ensembles.items():
11✔
1030
                # for pc in plot_col:
1031
                #     if pc in en.columns:
1032
                #         emx,emn = en.loc[:,pc].max(),en.loc[:,pc].min()
1033
                #         mx = max(mx,emx)
1034
                #         mn = min(mn,emn)
1035
                emn = np.nanmin(en.loc[:, plot_col].values)
11✔
1036
                emx = np.nanmax(en.loc[:, plot_col].values)
11✔
1037
                mx = max(mx, emx)
11✔
1038
                mn = min(mn, emn)
11✔
1039
            if mx == -1.0e30 and mn == 1.0e30:
11✔
1040
                logger.warn("all NaNs for label: {0}".format(label))
×
1041
                ax.set_title(
×
1042
                    "{0}) {1}, count:{2} - all NaN".format(
1043
                        abet[ax_count], label, len(plot_col)
1044
                    ),
1045
                    loc="left",
1046
                )
1047
                ax.set_yticks([])
×
1048
                ax.set_xticks([])
×
1049
                ax_count += 1
×
1050
                continue
×
1051
            plot_bins = np.linspace(mn, mx, num=bins)
11✔
1052
            logger.statement("{0} min:{1:5G}, max:{2:5G}".format(label, mn, mx))
11✔
1053
        else:
1054
            plot_bins = bins
11✔
1055
        for fc, en in ensembles.items():
11✔
1056
            # for pc in plot_col:
1057
            #    if pc in en.columns:
1058
            #        try:
1059
            #            en.loc[:,pc].hist(bins=plot_bins,facecolor=fc,
1060
            #                                    edgecolor="none",alpha=0.5,
1061
            #                                    density=True,ax=ax)
1062
            #        except Exception as e:
1063
            #            logger.warn("error plotting histogram for {0}:{1}".
1064
            #                        format(pc,str(e)))
1065
            vals = en.loc[:, plot_col].values.flatten()
11✔
1066
            # print(plot_bins)
1067
            # print(vals)
1068

1069
            ax.hist(
11✔
1070
                vals,
1071
                bins=plot_bins,
1072
                edgecolor="none",
1073
                alpha=0.5,
1074
                density=True,
1075
                facecolor=fc,
1076
            )
1077
            v = None
11✔
1078
            if deter_vals is not None:
11✔
1079
                for pc in plot_col:
11✔
1080
                    if pc in deter_vals:
11✔
1081
                        ylim = ax.get_ylim()
11✔
1082
                        v = deter_vals[pc]
11✔
1083
                        ax.plot([v, v], ylim, "k--", lw=1.5)
11✔
1084
                        ax.set_ylim(ylim)
11✔
1085

1086
            if std_window is not None:
11✔
1087
                try:
×
1088
                    ylim = ax.get_ylim()
×
1089
                    mn, st = (
×
1090
                        en.loc[:, pc].mean(),
1091
                        en.loc[:, pc].std() * (std_window / 2.0),
1092
                    )
1093

1094
                    ax.plot([mn - st, mn - st], ylim, color=fc, lw=1.5, ls="--")
×
1095
                    ax.plot([mn + st, mn + st], ylim, color=fc, lw=1.5, ls="--")
×
1096
                    ax.set_ylim(ylim)
×
1097
                    if deter_range and v is not None:
×
1098
                        xmn = v - st
×
1099
                        xmx = v + st
×
1100
                        ax.set_xlim(xmn, xmx)
×
1101
                except:
×
1102
                    logger.warn("error plotting std window for {0}".format(pc))
×
1103
        ax.grid()
11✔
1104
        if len(ensembles) > 1:
11✔
1105
            ax.set_title(
11✔
1106
                "{0}) {1}, count: {2}".format(abet[ax_count], label, len(plot_col)),
1107
                loc="left",
1108
            )
1109
        else:
1110
            ax.set_title(
11✔
1111
                "{0}) {1}, count:{2}\nmin:{3:3.1E}, max:{4:3.1E}".format(
1112
                    abet[ax_count],
1113
                    label,
1114
                    len(plot_col),
1115
                    np.nanmin(vals),
1116
                    np.nanmax(vals),
1117
                ),
1118
                loc="left",
1119
            )
1120
        ax_count += 1
11✔
1121

1122
    for a in range(ax_count, nr * nc):
11✔
1123
        axes[a].set_axis_off()
11✔
1124
        axes[a].set_yticks([])
11✔
1125
        axes[a].set_xticks([])
11✔
1126

1127
    plt.tight_layout()
11✔
1128
    # pdf.savefig()
1129
    # plt.close(fig)
1130
    figs.append(fig)
11✔
1131
    if filename is not None:
11✔
1132
        # plt.tight_layout()
1133
        with PdfPages(filename) as pdf:
11✔
1134
            for fig in figs:
11✔
1135
                pdf.savefig(fig)
11✔
1136
                plt.close(fig)
11✔
1137
    logger.log("pyemu.plot_utils.ensemble_helper()")
11✔
1138

1139

1140
@apply_custom_font()
11✔
1141
def ensemble_change_summary(
11✔
1142
    ensemble1,
1143
    ensemble2,
1144
    pst,
1145
    bins=10,
1146
    facecolor="0.5",
1147
    logger=None,
1148
    filename=None,
1149
    **kwargs
1150
):
1151
    """helper function to plot first and second moment change histograms between two
1152
    ensembles
1153

1154
    Args:
1155
        ensemble1 (varies): filename or `pandas.DataFrame` or `pyemu.Ensemble`
1156
        ensemble2 (varies): filename or `pandas.DataFrame` or `pyemu.Ensemble`
1157
        pst (`pyemu.Pst`): control file
1158
        facecolor (`str`): the histogram facecolor.
1159
        filename (`str`): the name of the multi-pdf to create. If None, return figs without saving.  Default is None.
1160

1161
    Returns:
1162
        [`matplotlib.Figure`]: a list of figures.  Returns None is
1163
        `filename` is not None
1164

1165
    Example::
1166

1167
        pst = pyemu.Pst("my.pst")
1168
        prior = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="prior.jcb")
1169
        post = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="my.3.par.jcb")
1170
        pyemu.plot_utils.ensemble_change_summary(prior,post)
1171
        plt.show()
1172

1173

1174
    """
1175
    _ensure_matplotlib()
11✔
1176
    import matplotlib.pyplot as plt
11✔
1177
    from matplotlib.backends.backend_pdf import PdfPages
11✔
1178

1179
    if logger is None:
11✔
1180
        logger = Logger("Default_Logger.log", echo=False)
11✔
1181
    logger.log("plot ensemble change")
11✔
1182

1183
    if isinstance(ensemble1, str):
11✔
1184
        ensemble1 = pd.read_csv(ensemble1, index_col=0)
×
1185
    ensemble1.columns = ensemble1.columns.str.lower()
11✔
1186

1187
    if isinstance(ensemble2, str):
11✔
1188
        ensemble2 = pd.read_csv(ensemble2, index_col=0)
×
1189
    ensemble2.columns = ensemble2.columns.str.lower()
11✔
1190

1191
    # better to ensure this is caught by pestpp-ies ensemble csvs
1192
    unnamed1 = [col for col in ensemble1.columns if "unnamed:" in col]
11✔
1193
    if len(unnamed1) != 0:
11✔
1194
        ensemble1 = ensemble1.iloc[
×
1195
            :, :-1
1196
        ]  # ensure unnamed col result of poor csv read only (ie last col)
1197
    unnamed2 = [col for col in ensemble2.columns if "unnamed:" in col]
11✔
1198
    if len(unnamed2) != 0:
11✔
1199
        ensemble2 = ensemble2.iloc[
×
1200
            :, :-1
1201
        ]  # ensure unnamed col result of poor csv read only (ie last col)
1202

1203
    d = set(ensemble1.columns).symmetric_difference(set(ensemble2.columns))
11✔
1204

1205
    if len(d) != 0:
11✔
1206
        logger.lraise(
×
1207
            "ensemble1 does not have the same columns as ensemble2: {0}".format(
1208
                ",".join(d)
1209
            )
1210
        )
1211
    if "grouper" in kwargs:
11✔
1212
        raise NotImplementedError()
×
1213
    else:
1214
        en_cols = ensemble1.columns
11✔
1215
        if len(en_cols.difference(pst.par_names)) == 0:
11✔
1216
            par = pst.parameter_data.loc[en_cols, :]
11✔
1217
            grouper = par.groupby(par.pargp).groups
11✔
1218
            grouper["all"] = pst.adj_par_names
11✔
1219
            li = par.loc[par.partrans == "log", "parnme"]
11✔
1220
            ensemble1.loc[:, li] = ensemble1.loc[:, li].apply(np.log10)
11✔
1221
            ensemble2.loc[:, li] = ensemble2.loc[:, li].apply(np.log10)
11✔
1222
        elif len(en_cols.difference(pst.obs_names)) == 0:
11✔
1223
            obs = pst.observation_data.loc[en_cols, :]
11✔
1224
            grouper = obs.groupby(obs.obgnme).groups
11✔
1225
            grouper["all"] = pst.nnz_obs_names
11✔
1226
        else:
1227
            logger.lraise("could not match ensemble cols with par or obs...")
×
1228

1229
    en1_mn, en1_std = ensemble1.mean(axis=0), ensemble1.std(axis=0)
11✔
1230
    en2_mn, en2_std = ensemble2.mean(axis=0), ensemble2.std(axis=0)
11✔
1231

1232
    # mn_diff = 100.0 * ((en1_mn - en2_mn) / en1_mn)
1233
    # std_diff = 100 * ((en1_std - en2_std) / en1_std)
1234

1235
    mn_diff = -1 * (en2_mn - en1_mn)
11✔
1236
    std_diff = 100 * (((en1_std - en2_std) / en1_std))
11✔
1237
    # set en1_std==0 to nan
1238
    # std_diff[en1_std.index[en1_std==0]] = np.nan
1239

1240
    # diff = ensemble1 - ensemble2
1241
    # mn_diff = diff.mean(axis=0)
1242
    # std_diff = diff.std(axis=0)
1243

1244
    fig = plt.figure(figsize=figsize)
11✔
1245
    if "fig_title" in kwargs:
11✔
1246
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
1247
    else:
1248
        plt.figtext(
11✔
1249
            0.5,
1250
            0.5,
1251
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
1252
                pst.filename, str(datetime.now())
1253
            ),
1254
            ha="center",
1255
        )
1256
    # if plot_hexbin:
1257
    #    pdfname = pst.filename.replace(".pst", ".1to1.hexbin.pdf")
1258
    # else:
1259
    #    pdfname = pst.filename.replace(".pst", ".1to1.pdf")
1260
    figs = []
11✔
1261
    ax_count = 0
11✔
1262
    for g, names in grouper.items():
11✔
1263
        logger.log("plotting change for {0}".format(g))
11✔
1264

1265
        mn_g = mn_diff.loc[names]
11✔
1266
        std_g = std_diff.loc[names]
11✔
1267

1268
        if mn_g.shape[0] == 0:
11✔
1269
            logger.statement("no entries for group '{0}'".format(g))
×
1270
            logger.log("plotting change for {0}".format(g))
×
1271
            continue
×
1272

1273
        if ax_count % (nr * nc) == 0:
11✔
1274
            if ax_count > 0:
11✔
1275
                plt.tight_layout()
×
1276
            # pdf.savefig()
1277
            # plt.close(fig)
1278
            figs.append(fig)
11✔
1279
            fig = plt.figure(figsize=figsize)
11✔
1280
            axes = _get_page_axes()
11✔
1281
            ax_count = 0
11✔
1282

1283
        ax = axes[ax_count]
11✔
1284
        mn_g.hist(ax=ax, facecolor=facecolor, alpha=0.5, edgecolor=None, bins=bins)
11✔
1285
        # mx = max(mn_g.max(), mn_g.min(),np.abs(mn_g.max()),np.abs(mn_g.min())) * 1.2
1286
        # ax.set_xlim(-mx,mx)
1287

1288
        # std_g.hist(ax=ax,facecolor='b',alpha=0.5,edgecolor=None)
1289

1290
        # ax.set_xlim(xlim)
1291
        ax.set_yticklabels([])
11✔
1292
        ax.set_xlabel("mean change", labelpad=0.1)
11✔
1293
        ax.set_title(
11✔
1294
            "{0}) mean change group:{1}, {2} entries\nmax:{3:10G}, min:{4:10G}".format(
1295
                abet[ax_count], g, mn_g.shape[0], mn_g.max(), mn_g.min()
1296
            ),
1297
            loc="left",
1298
        )
1299
        ax.grid()
11✔
1300
        ax_count += 1
11✔
1301

1302
        ax = axes[ax_count]
11✔
1303
        std_g.hist(ax=ax, facecolor=facecolor, alpha=0.5, edgecolor=None, bins=bins)
11✔
1304
        # std_g.hist(ax=ax,facecolor='b',alpha=0.5,edgecolor=None)
1305

1306
        # ax.set_xlim(xlim)
1307
        ax.set_yticklabels([])
11✔
1308
        ax.set_xlabel("sigma percent reduction", labelpad=0.1)
11✔
1309
        ax.set_title(
11✔
1310
            "{0}) sigma change group:{1}, {2} entries\nmax:{3:10G}, min:{4:10G}".format(
1311
                abet[ax_count], g, mn_g.shape[0], std_g.max(), std_g.min()
1312
            ),
1313
            loc="left",
1314
        )
1315
        ax.grid()
11✔
1316
        ax_count += 1
11✔
1317

1318
        logger.log("plotting change for {0}".format(g))
11✔
1319

1320
    for a in range(ax_count, nr * nc):
11✔
1321
        axes[a].set_axis_off()
11✔
1322
        axes[a].set_yticks([])
11✔
1323
        axes[a].set_xticks([])
11✔
1324

1325
    plt.tight_layout()
11✔
1326
    # pdf.savefig()
1327
    # plt.close(fig)
1328
    figs.append(fig)
11✔
1329
    if filename is not None:
11✔
1330
        # plt.tight_layout()
1331
        with PdfPages(filename) as pdf:
11✔
1332
            for fig in figs:
11✔
1333
                pdf.savefig(fig)
11✔
1334
                plt.close(fig)
11✔
1335
        logger.log("plot ensemble change")
11✔
1336
    else:
1337
        logger.log("plot ensemble change")
11✔
1338
        return figs
11✔
1339

1340

1341
def _process_ensemble_arg(ensemble, facecolor, logger):
11✔
1342
    """private method to work out ensemble plot args"""
1343
    ensembles = {}
11✔
1344
    if isinstance(ensemble, pd.DataFrame) or isinstance(ensemble, pyemu.Ensemble):
11✔
1345
        if not isinstance(facecolor, str):
11✔
1346
            logger.lraise("facecolor must be str")
×
1347
        ensembles[facecolor] = ensemble
11✔
1348
    elif isinstance(ensemble, str):
11✔
1349
        if not isinstance(facecolor, str):
11✔
1350
            logger.lraise("facecolor must be str")
×
1351

1352
        logger.log("loading ensemble from csv file {0}".format(ensemble))
11✔
1353
        en = pd.read_csv(ensemble, index_col=0)
11✔
1354
        logger.statement("{0} shape: {1}".format(ensemble, en.shape))
11✔
1355
        ensembles[facecolor] = en
11✔
1356
        logger.log("loading ensemble from csv file {0}".format(ensemble))
11✔
1357

1358
    elif isinstance(ensemble, list):
11✔
1359
        if isinstance(facecolor, list):
11✔
1360
            if len(ensemble) != len(facecolor):
×
1361
                logger.lraise("facecolor list len != ensemble list len")
×
1362
        else:
1363
            colors = ["m", "c", "b", "r", "g", "y"]
11✔
1364

1365
            facecolor = [colors[i] for i in range(len(ensemble))]
11✔
1366
        ensembles = {}
11✔
1367
        for fc, en_arg in zip(facecolor, ensemble):
11✔
1368
            if isinstance(en_arg, str):
11✔
1369
                logger.log("loading ensemble from csv file {0}".format(en_arg))
11✔
1370
                en = pd.read_csv(en_arg, index_col=0)
11✔
1371
                logger.log("loading ensemble from csv file {0}".format(en_arg))
11✔
1372
                logger.statement("ensemble {0} gets facecolor {1}".format(en_arg, fc))
11✔
1373

1374
            elif isinstance(en_arg, pd.DataFrame) or isinstance(en_arg, pyemu.Ensemble):
11✔
1375
                en = en_arg
11✔
1376
            else:
1377
                logger.lraise("unrecognized ensemble list arg:{0}".format(en_arg))
×
1378
            ensembles[fc] = en
11✔
1379

1380
    elif isinstance(ensemble, dict):
11✔
1381
        for fc, en_arg in ensemble.items():
11✔
1382
            if isinstance(en_arg, pd.DataFrame) or isinstance(en_arg, pyemu.Ensemble):
11✔
1383
                ensembles[fc] = en_arg
11✔
1384
            elif isinstance(en_arg, str):
11✔
1385
                logger.log("loading ensemble from csv file {0}".format(en_arg))
11✔
1386
                en = pd.read_csv(en_arg, index_col=0)
11✔
1387
                logger.log("loading ensemble from csv file {0}".format(en_arg))
11✔
1388
                ensembles[fc] = en
11✔
1389
            else:
1390
                logger.lraise("unrecognized ensemble list arg:{0}".format(en_arg))
×
1391
    try:
11✔
1392
        for fc in ensembles:
11✔
1393
            ensembles[fc].columns = ensembles[fc].columns.str.lower()
11✔
1394
    except:
×
1395
        logger.lraise("error processing ensemble")
×
1396

1397
    return ensembles
11✔
1398

1399

1400
@apply_custom_font()
11✔
1401
def ensemble_res_1to1(
11✔
1402
    ensemble,
1403
    pst,
1404
    facecolor="0.5",
1405
    logger=None,
1406
    filename=None,
1407
    skip_groups=[],
1408
    base_ensemble=None,
1409
    **kwargs
1410
):
1411
    """helper function to plot ensemble 1-to-1 plots showing the simulated range
1412

1413
    Args:
1414
        ensemble (varies):  the ensemble argument can be a pandas.DataFrame or derived type or a str, which
1415
            is treated as a filename.  Optionally, ensemble can be a list of these types or
1416
            a dict, in which case, the keys are treated as facecolor str (e.g., 'b', 'y', etc).
1417
        pst (`pyemu.Pst`): a control file instance
1418
        facecolor (`str`): the histogram facecolor.  Only applies if `ensemble` is a single thing
1419
        filename (`str`): the name of the pdf to create. If None, return figs
1420
            without saving.  Default is None.
1421
        base_ensemble (`varies`): an optional ensemble argument for the observations + noise ensemble.
1422
            This will be plotted as a transparent red bar on the 1to1 plot.
1423

1424
    Note:
1425

1426
        the vertical bar on each plot the min-max range
1427

1428
    Example::
1429

1430

1431
        pst = pyemu.Pst("my.pst")
1432
        prior = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.0.obs.jcb")
1433
        post = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.3.obs.jcb")
1434
        pyemu.plot_utils.ensemble_res_1to1(ensemble={"0.5":prior, "b":post})
1435
        plt.show()
1436

1437
    """
1438
    _ensure_matplotlib()
11✔
1439
    import matplotlib.pyplot as plt
11✔
1440
    import matplotlib.ticker
11✔
1441
    from matplotlib.backends.backend_pdf import PdfPages
11✔
1442

1443
    def _get_plotlims(oen, ben, obsnames):
11✔
1444
        if not isinstance(oen, dict):
11✔
1445
            oen = {'g': oen.loc[:, obsnames]}
×
1446
        if not isinstance(ben, dict):
11✔
1447
            ben = {'g': ben.get(obsnames)}
11✔
1448
        outofrange = False
11✔
1449
        # work back from crazy values
1450
        oemin = 1e32
11✔
1451
        oemeanmin = 1e32
11✔
1452
        oemax = -1e32
11✔
1453
        oemeanmax = -1e32
11✔
1454
        bemin = 1e32
11✔
1455
        bemeanmin = 1e32
11✔
1456
        bemax = -1e32
11✔
1457
        bemeanmax = -1e32
11✔
1458
        for _, oeni in oen.items():  # loop over ensembles
11✔
1459
            oeni = oeni.loc[:, obsnames]  # slice group obs
11✔
1460
            oemin = np.nanmin([oemin, oeni.min().min()])
11✔
1461
            oemax = np.nanmax([oemax, oeni.max().max()])
11✔
1462
            # get min and max of mean sim vals
1463
            # (in case we want plot to ignore extremes)
1464
            oemeanmin = np.nanmin([oemeanmin, oeni.mean().min()])
11✔
1465
            oemeanmax = np.nanmax([oemeanmax, oeni.mean().max()])
11✔
1466
        for _, beni in ben.items():  # same with base ensemble/obsval
11✔
1467
            # work with either ensemble or obsval series
1468
            beni = beni.get(obsnames)
11✔
1469
            bemin = np.nanmin([bemin, beni.min().min()])
11✔
1470
            bemax = np.nanmax([bemax, beni.max().max()])
11✔
1471
            bemeanmin = np.nanmin([bemeanmin, beni.mean().min()])
11✔
1472
            bemeanmax = np.nanmax([bemeanmax, beni.mean().max()])
11✔
1473
        # get base ensemble range
1474
        berange = bemax-bemin
11✔
1475
        if berange == 0.:  # only one obs in group (probs)
11✔
1476
            berange = bemeanmax * 1.1  # expand a little
11✔
1477
        # add buffer to obs endpoints
1478
        bemin = bemin - (berange*0.05)
11✔
1479
        bemax = bemax + (berange*0.05)
11✔
1480
        if oemax < bemin:  # sim well below obs
11✔
1481
            oemin = oemeanmin  # set min to mean min
11✔
1482
            # (sim captured but not extremes)
1483
            outofrange = True
11✔
1484
        if oemin > bemax:  # sim well above obs
11✔
1485
            oemax = oemeanmax
11✔
1486
            outofrange = True
11✔
1487
        oerange = oemax - oemin
11✔
1488
        if bemax > oemax + (0.1*oerange):  # obs max well above sim
11✔
1489
            if not outofrange:  # but sim still in range
11✔
1490
                # zoom to sim
1491
                bemax = oemax + (0.1*oerange)
11✔
1492
            else:  # use obs mean max
1493
                bemax = bemeanmax
11✔
1494
        if bemin < oemin - (0.1 * oerange):  # obs min well below sim
11✔
1495
            if not outofrange:  # but sim still in range
11✔
1496
                # zoom to sim
1497
                bemin = oemin - (0.1 * oerange)
11✔
1498
            else:
1499
                bemin = bemeanmin
11✔
1500
        pmin = np.nanmin([oemin, bemin])
11✔
1501
        pmax = np.nanmax([oemax, bemax])
11✔
1502
        return pmin, pmax
11✔
1503

1504

1505
    if logger is None:
11✔
1506
        logger = Logger("Default_Logger.log", echo=False)
11✔
1507
    logger.log("plot res_1to1")
11✔
1508
    obs = pst.observation_data
11✔
1509
    ensembles = _process_ensemble_arg(ensemble, facecolor, logger)
11✔
1510

1511
    if base_ensemble is not None:
11✔
1512
        base_ensemble = _process_ensemble_arg(base_ensemble, "r", logger)
11✔
1513

1514
    if "grouper" in kwargs:
11✔
1515
        raise NotImplementedError()
×
1516
    else:
1517
        grouper = obs.groupby(obs.obgnme).groups
11✔
1518
        for skip_group in skip_groups:
11✔
1519
            grouper.pop(skip_group)
×
1520

1521
    fig = plt.figure(figsize=figsize)
11✔
1522
    if "fig_title" in kwargs:
11✔
1523
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
1524
    else:
1525
        plt.figtext(
11✔
1526
            0.5,
1527
            0.5,
1528
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
1529
                pst.filename, str(datetime.now())
1530
            ),
1531
            ha="center",
1532
        )
1533

1534
    figs = []
11✔
1535
    ax_count = 0
11✔
1536
    for g, names in grouper.items():
11✔
1537
        logger.log("plotting 1to1 for {0}".format(g))
11✔
1538
        # control file observation for group
1539
        obs_g = obs.loc[names, :]
11✔
1540
        # normally only look a non-zero weighted obs
1541
        if "include_zero" not in kwargs or kwargs["include_zero"] is False:
11✔
1542
            obs_g = obs_g.loc[obs_g.weight > 0, :]
11✔
1543
        if obs_g.shape[0] == 0:
11✔
1544
            logger.statement("no non-zero obs for group '{0}'".format(g))
11✔
1545
            logger.log("plotting 1to1 for {0}".format(g))
11✔
1546
            continue
11✔
1547
        # if the first axis in page
1548
        if ax_count % (nr * nc) == 0:
11✔
1549
            if ax_count > 0:
11✔
1550
                plt.tight_layout()
×
1551
            figs.append(fig)
11✔
1552
            fig = plt.figure(figsize=figsize)
11✔
1553
            axes = _get_page_axes()
11✔
1554
            ax_count = 0
11✔
1555
        ax = axes[ax_count]
11✔
1556

1557
        if base_ensemble is None:
11✔
1558
            # if obs not defined by obs+noise ensemble,
1559
            # use min and max for obsval from control file
1560
            pmin, pmax = _get_plotlims(ensembles, obs_g.obsval, obs_g.obsnme)
11✔
1561
        else:
1562
            # if obs defined by obs+noise use obs+noise min and max
1563
            pmin, pmax = _get_plotlims(ensembles, base_ensemble, obs_g.obsnme)
11✔
1564
            obs_gg = obs_g.sort_values(by="obsval")
11✔
1565
            for c, en in base_ensemble.items():
11✔
1566
                en_g = en.loc[:, obs_gg.obsnme]
11✔
1567
                emx = en_g.max()
11✔
1568
                emn = en_g.min()
11✔
1569
                
1570
                #exit()
1571
                # update y min and max for obs+noise ensembles
1572
                if len(obs_gg.obsval) > 1:
11✔
1573

1574
                    emx = np.zeros(obs_gg.shape[0]) + emx
11✔
1575
                    emn = np.zeros(obs_gg.shape[0]) + emn
11✔
1576
                    ax.fill_between(obs_gg.obsval.values, emn.values, emx.values,
11✔
1577
                                    facecolor=c, alpha=0.2, zorder=2)
1578
                else:
1579
                    ax.plot([obs_gg.obsval.values, obs_gg.obsval.values], [emn, emx], color=c, alpha=0.2, zorder=2)
11✔
1580
        for c, en in ensembles.items():
11✔
1581
            en_g = en.loc[:, obs_g.obsnme]
11✔
1582
            # output mins and maxs
1583
            emx = en_g.max()
11✔
1584
            emn = en_g.min()
11✔
1585
            [
11✔
1586
                ax.plot([ov, ov], [een, eex], color=c, zorder=1)
1587
                for ov, een, eex in zip(obs_g.obsval.values, emn.values, emx.values)
1588
            ]
1589
        ax.plot([pmin, pmax], [pmin, pmax], "k--", lw=1.0, zorder=3)
11✔
1590
        xlim = (pmin, pmax)
11✔
1591
        ax.set_xlim(pmin, pmax)
11✔
1592
        ax.set_ylim(pmin, pmax)
11✔
1593

1594
        if max(np.abs(xlim)) > 1.0e5:
11✔
1595
            ax.xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
1596
            ax.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
1597
        ax.grid()
11✔
1598

1599
        ax.set_xlabel("observed", labelpad=0.1)
11✔
1600
        ax.set_ylabel("simulated", labelpad=0.1)
11✔
1601
        ax.set_title(
11✔
1602
            "{0}) group:{1}, {2} observations".format(
1603
                abet[ax_count], g, obs_g.shape[0]
1604
            ),
1605
            loc="left",
1606
        )
1607

1608
        # Residual (RHS plot)
1609
        ax_count += 1
11✔
1610
        ax = axes[ax_count]
11✔
1611
        # ax.scatter(obs_g.obsval, obs_g.res, marker='.', s=10, color='b')
1612

1613
        if base_ensemble is not None:
11✔
1614
            obs_gg = obs_g.sort_values(by="obsval")
11✔
1615
            for c, en in base_ensemble.items():
11✔
1616
                en_g = en.loc[:, obs_gg.obsnme].subtract(obs_gg.obsval)
11✔
1617
                emx = en_g.max()
11✔
1618
                emn = en_g.min()
11✔
1619
                if len(obs_gg.obsval) > 1:
11✔
1620
                    emx = np.zeros(obs_gg.shape[0]) + emx
11✔
1621
                    emn = np.zeros(obs_gg.shape[0]) + emn
11✔
1622
                    ax.fill_between(obs_gg.obsval.values, emn.values, emx.values,
11✔
1623
                                    facecolor=c, alpha=0.2, zorder=2)
1624
                else:
1625
                    # [ax.plot([ov, ov], [een, eex], color=c,alpha=0.3) for ov, een, eex in zip(obs_g.obsval.values, en.values, ex.values)]
1626
                    ax.plot([obs_gg.obsval.values, obs_gg.obsval.values], [emn, emx], color=c,
11✔
1627
                            alpha=0.2, zorder=2)
1628
        omn = []
11✔
1629
        omx = []
11✔
1630
        for c, en in ensembles.items():
11✔
1631
            en_g = en.loc[:, obs_g.obsnme].subtract(obs_g.obsval, axis=1)
11✔
1632
            emx = en_g.max()
11✔
1633
            emn = en_g.min()
11✔
1634
            omn.append(emn)
11✔
1635
            omx.append(emx)
11✔
1636
            [
11✔
1637
                ax.plot([ov, ov], [een, eex], color=c, zorder=1)
1638
                for ov, een, eex in zip(obs_g.obsval.values, emn.values, emx.values)
1639
            ]
1640

1641
        omn = pd.concat(omn).min()
11✔
1642
        omx = pd.concat(omx).max()
11✔
1643
        mx = np.nanmax([np.abs(omn), np.abs(omx)])  # ensure symmetric about y=0
11✔
1644
        if obs_g.shape[0] == 1:
11✔
1645
            mx *= 1.05
11✔
1646
        else:
1647
            mx *= 1.02
11✔
1648
        if np.sign(omn) == np.sign(omx):
11✔
1649
            # allow y axis asymm if all above or below
1650
            mn = np.nanmin([0, np.sign(omn) * mx])
11✔
1651
            mx = np.nanmax([0, np.sign(omn) * mx])
11✔
1652
        else:
1653
            mn = -mx
11✔
1654
        ax.set_ylim(mn, mx)
11✔
1655
        bmin = obs_g.obsval.values.min()
11✔
1656
        bmax = obs_g.obsval.values.max()
11✔
1657
        brange = (bmax - bmin)
11✔
1658
        if brange == 0.:
11✔
1659
            brange = obs_g.obsval.values.mean()
11✔
1660
        bmin = bmin - 0.1*brange
11✔
1661
        bmax = bmax + 0.1*brange
11✔
1662
        xlim = (bmin, bmax)
11✔
1663
        # show a zero residuals line
1664
        ax.plot(xlim, [0, 0], "k--", lw=1.0, zorder=3)
11✔
1665

1666
        ax.set_xlim(xlim)
11✔
1667
        ax.set_ylabel("residual", labelpad=0.1)
11✔
1668
        ax.set_xlabel("observed", labelpad=0.1)
11✔
1669
        ax.set_title(
11✔
1670
            "{0}) group:{1}, {2} observations".format(
1671
                abet[ax_count], g, obs_g.shape[0]
1672
            ),
1673
            loc="left",
1674
        )
1675
        ax.grid()
11✔
1676
        if ax.get_xlim()[1] > 1.0e5:
11✔
1677
            ax.xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
1678

1679
        ax_count += 1
11✔
1680

1681
        logger.log("plotting 1to1 for {0}".format(g))
11✔
1682

1683
    for a in range(ax_count, nr * nc):
11✔
1684
        axes[a].set_axis_off()
11✔
1685
        axes[a].set_yticks([])
11✔
1686
        axes[a].set_xticks([])
11✔
1687

1688
    plt.tight_layout()
11✔
1689
    figs.append(fig)
11✔
1690
    if filename is not None:
11✔
1691
        # plt.tight_layout()
1692
        with PdfPages(filename) as pdf:
11✔
1693
            for fig in figs:
11✔
1694
                pdf.savefig(fig)
11✔
1695
                plt.close(fig)
11✔
1696
        logger.log("plot res_1to1")
11✔
1697
    else:
1698
        logger.log("plot res_1to1")
×
1699
        return figs
×
1700

1701

1702
@apply_custom_font()
11✔
1703
def plot_jac_test(
11✔
1704
    csvin, csvout, targetobs=None, filetype=None, maxoutputpages=1, outputdirectory=None
1705
):
1706
    """helper function to plot results of the Jacobian test performed using the pest++
1707
    program pestpp-swp.
1708

1709
    Args:
1710
        csvin (`str`): name of csv file used as input to sweep, typically developed with
1711
            static method pyemu.helpers.build_jac_test_csv()
1712
        csvout (`str`): name of csv file with output generated by sweep, both input
1713
            and output files can be specified in the pest++ control file
1714
            with pyemu using: pest_object.pestpp_options["sweep_parameter_csv_file"] = jactest_in_file.csv
1715
            pest_object.pestpp_options["sweep_output_csv_file"] = jactest_out_file.csv
1716
        targetobs ([`str`]): list of observation file names to plot, each parameter used for jactest can
1717
            have up to 32 observations plotted per page, throws a warning if more than
1718
            10 pages of output are requested per parameter. If none, all observations in
1719
            the output csv file are used.
1720
        filetype (`str`): file type to store output, if None, plt.show() is called.
1721
        maxoutputpages (`int`): maximum number of pages of output per parameter.  Each page can
1722
            hold up to 32 observation derivatives.  If value > 10, set it to
1723
            10 and throw a warning.  If observations in targetobs > 32*maxoutputpages,
1724
            then a random set is selected from the targetobs list (or all observations
1725
            in the csv file if targetobs=None).
1726
        outputdirectory (`str`):  directory to store results, if None, current working directory is used.
1727
            If string is passed, it is joined to the current working directory and
1728
            created if needed. If os.path is passed, it is used directly.
1729

1730
    Note:
1731
        Used in conjunction with pyemu.helpers.build_jac_test_csv() and sweep to perform
1732
        a Jacobian Test and then view the results. Can generate a lot of plots so easiest
1733
        to put into a separate directory and view the files.
1734

1735
    """
NEW
1736
    _ensure_matplotlib()
×
NEW
1737
    import matplotlib.pyplot as plt
×
1738

1739
    localhome = os.getcwd()
×
1740
    # check if the output directory exists, if not make it
1741
    if outputdirectory is not None and not os.path.exists(
×
1742
        os.path.join(localhome, outputdirectory)
1743
    ):
1744
        os.mkdir(os.path.join(localhome, outputdirectory))
×
1745
    if outputdirectory is None:
×
1746
        figures_dir = localhome
×
1747
    else:
1748
        figures_dir = os.path.join(localhome, outputdirectory)
×
1749

1750
    # read the input and output files into pandas dataframes
1751
    jactest_in_df = pd.read_csv(csvin, engine="python", index_col=0)
×
1752
    jactest_in_df.index.name = "input_run_id"
×
1753
    jactest_out_df = pd.read_csv(csvout, engine="python", index_col=1)
×
1754

1755
    # subtract the base run from every row, leaves the one parameter that
1756
    # was perturbed in any row as only non-zero value. Set zeros to nan
1757
    # so round-off doesn't get us and sum across rows to get a column of
1758
    # the perturbation for each row, finally extract to a series. First
1759
    # the input csv and then the output.
1760
    base_par = jactest_in_df.loc["base"]
×
1761
    delta_par_df = jactest_in_df.subtract(base_par, axis="columns")
×
1762
    delta_par_df.replace(0, np.nan, inplace=True)
×
1763
    delta_par_df.drop("base", axis="index", inplace=True)
×
1764
    delta_par_df["change"] = delta_par_df.sum(axis="columns")
×
1765
    delta_par = pd.Series(delta_par_df["change"])
×
1766

1767
    base_obs = jactest_out_df.loc["base"]
×
1768
    delta_obs = jactest_out_df.subtract(base_obs)
×
1769
    delta_obs.drop("base", axis="index", inplace=True)
×
1770
    # if targetobs is None, then reset it to all the observations.
1771
    if targetobs is None:
×
1772
        targetobs = jactest_out_df.columns.tolist()[8:]
×
1773
    delta_obs = delta_obs[targetobs]
×
1774

1775
    # get the Jacobian by dividing the change in observation by the change in parameter
1776
    # for the perturbed parameters
1777
    jacobian = delta_obs.divide(delta_par, axis="index")
×
1778

1779
    # use the index created by build_jac_test_csv to get a column of parameter names
1780
    # and increments, then we can plot derivative vs. increment for each parameter
1781
    extr_df = pd.Series(jacobian.index.values).str.extract(r"(.+)(_\d+$)", expand=True)
×
1782
    extr_df[1] = pd.to_numeric(extr_df[1].str.replace("_", "")) + 1
×
1783
    extr_df.rename(columns={0: "parameter", 1: "increment"}, inplace=True)
×
1784
    extr_df.index = jacobian.index
×
1785

1786
    # make a dataframe for plotting the Jacobian by combining the parameter name
1787
    # and increments frame with the Jacobian frame
1788
    plotframe = pd.concat([extr_df, jacobian], axis=1, join="inner")
×
1789

1790
    # get a list of observations to keep based on maxoutputpages.
1791
    if maxoutputpages > 10:
×
1792
        print("WARNING, more than 10 pages of output requested per parameter")
×
1793
        print("maxoutputpage reset to 10.")
×
1794
        maxoutputpages = 10
×
1795
    num_obs_plotted = np.min(np.array([maxoutputpages * 32, len(targetobs)]))
×
1796
    if num_obs_plotted < len(targetobs):
×
1797
        # get random sample
1798
        index_plotted = np.random.choice(len(targetobs), num_obs_plotted, replace=False)
×
1799
        obs_plotted = [targetobs[x] for x in index_plotted]
×
1800
        real_pages = maxoutputpages
×
1801
    else:
1802
        obs_plotted = targetobs
×
1803
        real_pages = int(targetobs / 32) + 1
×
1804

1805
    # make a subplot of derivative vs. increment one plot for each of the
1806
    # observations in targetobs, and outputs grouped by parameter.
1807
    for param, group in plotframe.groupby("parameter"):
×
1808
        for page in range(0, real_pages):
×
1809
            fig, axes = plt.subplots(8, 4, sharex=True, figsize=(10, 15))
×
1810
            for row in range(0, 8):
×
1811
                for col in range(0, 4):
×
1812
                    count = 32 * page + 4 * row + col
×
1813
                    if count < num_obs_plotted:
×
1814
                        axes[row, col].scatter(
×
1815
                            group["increment"], group[obs_plotted[count]]
1816
                        )
1817
                        axes[row, col].plot(
×
1818
                            group["increment"], group[obs_plotted[count]], "r"
1819
                        )
1820
                        axes[row, col].set_title(obs_plotted[count])
×
1821
                        axes[row, col].set_xticks([1, 2, 3, 4, 5])
×
1822
                        axes[row, col].tick_params(direction="in")
×
1823
                        if row == 3:
×
1824
                            axes[row, col].set_xlabel("Increment")
×
1825
            plt.tight_layout()
×
1826

1827
            if filetype is None:
×
1828
                plt.show()
×
1829
            else:
1830
                plt.savefig(
×
1831
                    os.path.join(
1832
                        figures_dir, "{0}_jactest_{1}.{2}".format(param, page, filetype)
1833
                    )
1834
                )
1835
            plt.close()
×
1836

1837
def _ensure_matplotlib():
11✔
1838
    if not HAS_MATPLOTLIB:
11✔
NEW
1839
        msg = (
×
1840
            "Plotting functions require the 'matplotlib' package. Install it "
1841
            "with 'pip install matplotlib'."
1842
        )
NEW
1843
        raise ImportError(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc