• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pypest / pyemu / 12293359459

12 Dec 2024 09:20AM UTC coverage: 79.492% (+0.2%) from 79.318%
12293359459

Pull #562

github

web-flow
Merge 5ef4de2a3 into e799d6fdc
Pull Request #562: Fix typos

16 of 23 new or added lines in 8 files covered. (69.57%)

1 existing line in 1 file now uncovered.

12749 of 16038 relevant lines covered (79.49%)

8.2 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.05
/pyemu/plot/plot_utils.py
1
"""Plotting functions for various PEST(++) and pyemu operations"""
2
import os
11✔
3
import numpy as np
11✔
4
import pandas as pd
11✔
5
import warnings
11✔
6
from datetime import datetime
11✔
7
import string
11✔
8
from pyemu.logger import Logger
11✔
9
from pyemu.pst import pst_utils
11✔
10
from ..pyemu_warnings import PyemuWarning
11✔
11

12
font = {"font.size": 6}
11✔
13
try:
11✔
14
    import matplotlib
11✔
15
    import matplotlib.pyplot as plt
11✔
16
    from matplotlib.backends.backend_pdf import PdfPages
11✔
17
    from matplotlib.gridspec import GridSpec
11✔
18
except Exception as e:
×
19
    # raise Exception("error importing matplotlib: {0}".format(str(e)))
20
    warnings.warn("error importing matplotlib: {0}".format(str(e)), PyemuWarning)
×
21

22
import pyemu
11✔
23

24
figsize = (8, 10.5)
11✔
25
nr, nc = 4, 2
11✔
26
# page_gs = GridSpec(nr,nc)
27

28
abet = string.ascii_uppercase
11✔
29

30
def apply_custom_font(rc_params=None):
11✔
31
    if rc_params is None:
11✔
32
        rc_params = font
11✔
33
    def decorator(func):
11✔
34
        def wrapper(*args, **kwargs):
11✔
35
            with plt.rc_context(rc_params):
10✔
36
                return func(*args, **kwargs)
10✔
37
        return wrapper
11✔
38
    return decorator
11✔
39

40

41
@apply_custom_font({"font.size": 6})
11✔
42
def plot_summary_distributions(
11✔
43
    df,
44
    ax=None,
45
    label_post=False,
46
    label_prior=False,
47
    subplots=False,
48
    figsize=(11, 8.5),
49
    pt_color="b",
50
):
51
    """helper function to plot gaussian distributions from prior and posterior
52
    means and standard deviations
53

54
    Args:
55
        df (`pandas.DataFrame`): a dataframe and csv file.  Must have columns named:
56
            'prior_mean','prior_stdev','post_mean','post_stdev'.  If loaded
57
            from a csv file, column 0 is assumed to be the index
58
        ax (`matplotlib.pyplot.axis`): If None, and not subplots, then one is created
59
            and all distributions are plotted on a single plot
60
        label_post (`bool`): flag to add text labels to the peak of the posterior
61
        label_prior (`bool`): flag to add text labels to the peak of the prior
62
        subplots (`bool`): flag to use subplots.  If True, then 6 axes per page
63
            are used and a single prior and posterior is plotted on each
64
        figsize (`tuple`): matplotlib figure size
65

66
    Returns:
67
        tuple containing:
68

69
        - **[`matplotlib.figure`]**: list of figures
70
        - **[`matplotlib.axis`]**: list of axes
71

72
    Note:
73
        This is useful for demystifying FOSM results
74

75
        if subplots is False, a single axis is returned
76

77
    Example::
78

79
        import matplotlib.pyplot as plt
80
        import pyemu
81
        pyemu.plot_utils.plot_summary_distributions("pest.par.usum.csv")
82
        plt.show()
83

84
    """
85
    import matplotlib.pyplot as plt
10✔
86
    if isinstance(df, str):
10✔
87
        df = pd.read_csv(df, index_col=0)
10✔
88
    if ax is None and not subplots:
10✔
89
        fig = plt.figure(figsize=figsize)
10✔
90
        ax = plt.subplot(111)
10✔
91
        ax.grid()
10✔
92

93
    if "post_stdev" not in df.columns and "post_var" in df.columns:
10✔
94
        df.loc[:, "post_stdev"] = df.post_var.apply(np.sqrt)
×
95
    if "prior_stdev" not in df.columns and "prior_var" in df.columns:
10✔
96
        df.loc[:, "prior_stdev"] = df.prior_var.apply(np.sqrt)
×
97
    if "prior_expt" not in df.columns and "prior_mean" in df.columns:
10✔
98
        df.loc[:, "prior_expt"] = df.prior_mean
10✔
99
    if "post_expt" not in df.columns and "post_mean" in df.columns:
10✔
100
        df.loc[:, "post_expt"] = df.post_mean
10✔
101

102
    if subplots:
10✔
103
        fig = plt.figure(figsize=figsize)
10✔
104
        ax = plt.subplot(2, 3, 1)
10✔
105
        ax_per_page = 6
10✔
106
        ax_count = 0
10✔
107
        axes = []
10✔
108
        figs = []
10✔
109
    for name in df.index:
10✔
110
        x, y = gaussian_distribution(
10✔
111
            df.loc[name, "post_expt"], df.loc[name, "post_stdev"]
112
        )
113
        ax.fill_between(x, 0, y, facecolor=pt_color, edgecolor="none", alpha=0.25)
10✔
114
        if label_post:
10✔
115
            mx_idx = np.argmax(y)
10✔
116
            xtxt, ytxt = x[mx_idx], y[mx_idx] * 1.001
10✔
117
            ax.text(xtxt, ytxt, name, ha="center", alpha=0.5)
10✔
118

119
        x, y = gaussian_distribution(
10✔
120
            df.loc[name, "prior_expt"], df.loc[name, "prior_stdev"]
121
        )
122
        ax.plot(x, y, color="0.5", lw=3.0, dashes=(2, 1))
10✔
123
        if label_prior:
10✔
124
            mx_idx = np.argmax(y)
×
125
            xtxt, ytxt = x[mx_idx], y[mx_idx] * 1.001
×
126
            ax.text(xtxt, ytxt, name, ha="center", alpha=0.5)
×
127
        # ylim = list(ax.get_ylim())
128
        # ylim[1] *= 1.2
129
        # ylim[0] = 0.0
130
        # ax.set_ylim(ylim)
131
        if subplots:
10✔
132
            ax.set_title(name)
10✔
133
            ax_count += 1
10✔
134
            ax.set_yticklabels([])
10✔
135
            axes.append(ax)
10✔
136
            if name == df.index[-1]:
10✔
137
                break
10✔
138
            if ax_count >= ax_per_page:
10✔
139
                figs.append(fig)
10✔
140
                fig = plt.figure(figsize=figsize)
10✔
141
                ax_count = 0
10✔
142
            ax = plt.subplot(2, 3, ax_count + 1)
10✔
143
    if subplots:
10✔
144
        figs.append(fig)
10✔
145
        return figs, axes
10✔
146
    ylim = list(ax.get_ylim())
10✔
147
    ylim[1] *= 1.2
10✔
148
    ylim[0] = 0.0
10✔
149
    ax.set_ylim(ylim)
10✔
150
    ax.set_yticklabels([])
10✔
151
    return ax
10✔
152

153

154
@apply_custom_font()
11✔
155
def gaussian_distribution(mean, stdev, num_pts=50):
11✔
156
    """get an x and y numpy.ndarray that spans the +/- 4
157
    standard deviation range of a gaussian distribution with
158
    a given mean and standard deviation. useful for plotting
159

160
    Args:
161
        mean (`float`): the mean of the distribution
162
        stdev (`float`): the standard deviation of the distribution
163
        num_pts (`int`): the number of points in the returned ndarrays.
164
            Default is 50
165

166
    Returns:
167
        tuple containing:
168

169
        - **numpy.ndarray**: the x-values of the distribution
170
        - **numpy.ndarray**: the y-values of the distribution
171

172
    Example::
173

174
        mean,std = 1.0, 2.0
175
        x,y = pyemu.plot.gaussian_distribution(mean,std)
176
        plt.fill_between(x,0,y)
177
        plt.show()
178

179

180
    """
181
    xstart = mean - (4.0 * stdev)
10✔
182
    xend = mean + (4.0 * stdev)
10✔
183
    x = np.linspace(xstart, xend, num_pts)
10✔
184
    y = (1.0 / np.sqrt(2.0 * np.pi * stdev * stdev)) * np.exp(
10✔
185
        -1.0 * ((x - mean) ** 2) / (2.0 * stdev * stdev)
186
    )
187
    return x, y
10✔
188

189

190
@apply_custom_font()
11✔
191
def pst_helper(pst, kind=None, **kwargs):
11✔
192
    """`pyemu.Pst` plot helper - takes the
193
    handoff from `pyemu.Pst.plot()`
194

195
    Args:
196
        kind (`str`): the kind of plot to make
197
        **kargs (`dict`): keyword arguments to pass to the
198
            plotting function and ultimately to `matplotlib`
199

200
    Returns:
201
        varies: usually a combination of `matplotlib.figure` (s) and/or
202
        `matplotlib.axis` (s)
203

204
    Example::
205

206
        pst = pyemu.Pst("pest.pst") #assumes pest.res or pest.rei is found
207
        pst.plot(kind="1to1")
208
        plt.show()
209
        pst.plot(kind="phipie")
210
        plt.show()
211
        pst.plot(kind="prior")
212
        plt.show()
213

214
    """
215

216
    echo = kwargs.get("echo", False)
10✔
217
    logger = pyemu.Logger("plot_pst_helper.log", echo=echo)
10✔
218
    logger.statement("plot_utils.pst_helper()")
10✔
219

220
    kinds = {
10✔
221
        "prior": pst_prior,
222
        "1to1": res_1to1,
223
        "phi_pie": res_phi_pie,
224
        "phi_progress": phi_progress,
225
    }
226

227
    if kind is None:
10✔
228
        returns = []
10✔
229
        base_filename = pst.filename
10✔
230
        if pst.new_filename is not None:
10✔
231
            base_filename = pst.new_filename
×
232
        base_filename = base_filename.replace(".pst", "")
10✔
233
        for name, funcc in kinds.items():
10✔
234
            plt_name = base_filename + "." + name + ".pdf"
10✔
235
            returns.append(funcc(pst, logger=logger, filename=plt_name))
10✔
236

237
        return returns
10✔
238
    elif kind not in kinds:
10✔
239
        logger.lraise(
×
240
            "unrecognized kind:{0}, should one of {1}".format(
241
                kind, ",".join(list(kinds.keys()))
242
            )
243
        )
244
    return kinds[kind](pst, logger, **kwargs)
10✔
245

246

247
@apply_custom_font()
11✔
248
def phi_progress(pst, logger=None, filename=None, **kwargs):
11✔
249
    """make plot of phi vs number of model runs - requires
250
    available  ".iobj" file generated by a PESTPP-GLM run.
251

252
    Args:
253
        pst (`pyemu.Pst`): a control file instance
254
        logger (`pyemu.Logger`):  if None, a generic one is created.  Default is None
255
        filename (`str`): PDF filename to save figures to.  If None, figures
256
            are returned.  Default is None
257
        kwargs (`dict`): optional keyword args to pass to plotting function
258

259
    Returns:
260
        `matplotlib.axis`: the axis the plot was made on
261

262
    Example::
263

264
        import matplotlib.pyplot as plt
265
        import pyemu
266
        pst = pyemu.Pst("my.pst")
267
        pyemu.plot_utils.phi_progress(pst)
268
        plt.show()
269

270
    """
271
    if logger is None:
10✔
NEW
272
        logger = Logger("Default_Logger.log", echo=False)
×
273
    logger.log("plot phi_progress")
10✔
274

275
    iobj_file = pst.filename.replace(".pst", ".iobj")
10✔
276
    if not os.path.exists(iobj_file):
10✔
277
        logger.lraise("couldn't find iobj file {0}".format(iobj_file))
×
278
    df = pd.read_csv(iobj_file)
10✔
279
    if "ax" in kwargs:
10✔
280
        ax = kwargs["ax"]
×
281
    else:
282
        fig = plt.figure(figsize=figsize)
10✔
283
        ax = plt.subplot(1, 1, 1)
10✔
284
    ax.plot(df.model_runs_completed, df.total_phi, marker=".")
10✔
285
    ax.set_xlabel("model runs")
10✔
286
    ax.set_ylabel(r"$\phi$")
10✔
287
    ax.grid()
10✔
288
    if filename is not None:
10✔
289
        plt.savefig(filename)
10✔
290
    logger.log("plot phi_progress")
10✔
291
    return ax
10✔
292

293

294
def _get_page_axes(count=nr * nc):
11✔
295
    axes = [plt.subplot(nr, nc, i + 1) for i in range(min(count, nr * nc))]
10✔
296
    # [ax.set_yticks([]) for ax in axes]
297
    return axes
10✔
298

299

300
@apply_custom_font()
11✔
301
def res_1to1(
11✔
302
    pst, logger=None, filename=None, plot_hexbin=False, histogram=False, **kwargs
303
):
304
    """make 1-to-1 plots and also observed vs residual by observation group
305

306
    Args:
307
        pst (`pyemu.Pst`): a control file instance
308
        logger (`pyemu.Logger`):  if None, a generic one is created.  Default is None
309
        filename (`str`): PDF filename to save figures to.  If None, figures
310
            are returned.  Default is None
311
        hexbin (`bool`): flag to use the hexbinning for large numbers of residuals.
312
            Default is False
313
        histogram (`bool`): flag to plot residual histograms instead of obs vs residual.
314
            Default is False (use `matplotlib.pyplot.scatter` )
315
        kwargs (`dict`): optional keyword args to pass to plotting function
316

317
    Returns:
318
        `matplotlib.axis`: the axis the plot was made on
319

320
    Example::
321

322
        import matplotlib.pyplot as plt
323
        import pyemu
324
        pst = pyemu.Pst("my.pst")
325
        pyemu.plot_utils.phi_progress(pst)
326
        plt.show()
327

328
    """
329
    if logger is None:
10✔
330
        logger = Logger("Default_Logger.log", echo=False)
10✔
331
    logger.log("plot res_1to1")
10✔
332

333
    if "ensemble" in kwargs:
10✔
334
        res = pst_utils.res_from_en(pst, kwargs["ensemble"])
10✔
335
        try:
10✔
336
            res = pst_utils.res_from_en(pst, kwargs["ensemble"])
10✔
337
        except Exception as e:
×
338
            logger.lraise("res_1to1: error loading ensemble file: {0}".format(str(e)))
×
339
    else:
340
        try:
10✔
341
            res = pst.res
10✔
342
        except:
×
343
            logger.lraise("res_phi_pie: pst.res is None, couldn't find residuals file")
×
344

345
    obs = pst.observation_data
10✔
346

347
    if "grouper" in kwargs:
10✔
348
        raise NotImplementedError()
×
349
    else:
350
        grouper = obs.groupby(obs.obgnme).groups
10✔
351

352
    fig = plt.figure(figsize=figsize)
10✔
353
    if "fig_title" in kwargs:
10✔
354
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
10✔
355
    else:
356
        plt.figtext(
10✔
357
            0.5,
358
            0.5,
359
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
360
                pst.filename, str(datetime.now())
361
            ),
362
            ha="center",
363
        )
364
    # if plot_hexbin:
365
    #    pdfname = pst.filename.replace(".pst", ".1to1.hexbin.pdf")
366
    # else:
367
    #    pdfname = pst.filename.replace(".pst", ".1to1.pdf")
368
    figs = []
10✔
369
    ax_count = 0
10✔
370
    for g, names in grouper.items():
10✔
371
        logger.log("plotting 1to1 for {0}".format(g))
10✔
372

373
        obs_g = obs.loc[names, :]
10✔
374
        obs_g.loc[:, "sim"] = res.loc[names, "modelled"]
10✔
375
        logger.statement("using control file obsvals to calculate residuals")
10✔
376
        obs_g.loc[:, "res"] = obs_g.sim - obs_g.obsval
10✔
377
        if "include_zero" not in kwargs or kwargs["include_zero"] is False:
10✔
378
            obs_g = obs_g.loc[obs_g.weight > 0, :]
10✔
379
        if obs_g.shape[0] == 0:
10✔
380
            logger.statement("no non-zero obs for group '{0}'".format(g))
10✔
381
            logger.log("plotting 1to1 for {0}".format(g))
10✔
382
            continue
10✔
383

384
        if ax_count % (nr * nc) == 0:
10✔
385
            if ax_count > 0:
10✔
386
                plt.tight_layout()
10✔
387
            # pdf.savefig()
388
            # plt.close(fig)
389
            figs.append(fig)
10✔
390
            fig = plt.figure(figsize=figsize)
10✔
391
            axes = _get_page_axes()
10✔
392
            ax_count = 0
10✔
393

394
        ax = axes[ax_count]
10✔
395

396
        # if obs_g.shape[0] == 1:
397
        #    ax.scatter(list(obs_g.sim),list(obs_g.obsval),marker='.',s=30,color='b')
398
        # else:
399
        mx = max(obs_g.obsval.max(), obs_g.sim.max())
10✔
400
        mn = min(obs_g.obsval.min(), obs_g.sim.min())
10✔
401

402
        # if obs_g.shape[0] == 1:
403
        mx *= 1.1
10✔
404
        mn *= 0.9
10✔
405
        ax.axis("square")
10✔
406
        if plot_hexbin:
10✔
407
            ax.hexbin(
×
408
                obs_g.obsval.values,
409
                obs_g.sim.values,
410
                mincnt=1,
411
                gridsize=(75, 75),
412
                extent=(mn, mx, mn, mx),
413
                bins="log",
414
                edgecolors=None,
415
            )
416
        #               plt.colorbar(ax=ax)
417
        else:
418
            ax.scatter(obs_g.obsval.values, obs_g.sim.values,
10✔
419
                       marker=".", s=10, color="b")
420

421
        ax.plot([mn, mx], [mn, mx], "k--", lw=1.0)
10✔
422
        xlim = (mn, mx)
10✔
423
        ax.set_xlim(mn, mx)
10✔
424
        ax.set_ylim(mn, mx)
10✔
425
        ax.grid()
10✔
426

427
        ax.set_xlabel("observed", labelpad=0.1)
10✔
428
        ax.set_ylabel("simulated", labelpad=0.1)
10✔
429
        ax.set_title(
10✔
430
            "{0}) group:{1}, {2} observations".format(
431
                abet[ax_count], g, obs_g.shape[0]
432
            ),
433
            loc="left",
434
        )
435

436
        ax_count += 1
10✔
437

438
        if histogram == False:
10✔
439
            ax = axes[ax_count]
10✔
440
            ax.scatter(obs_g.obsval.values, obs_g.res.values,
10✔
441
                       marker=".", s=10, color="b")
442
            ylim = ax.get_ylim()
10✔
443
            mx = max(np.abs(ylim[0]), np.abs(ylim[1]))
10✔
444
            if obs_g.shape[0] == 1:
10✔
445
                mx *= 1.1
10✔
446
            ax.set_ylim(-mx, mx)
10✔
447
            # show a zero residuals line
448
            ax.plot(xlim, [0, 0], "k--", lw=1.0)
10✔
449
            meanres = obs_g.res.mean()
10✔
450
            # show mean residuals line
451
            ax.plot(xlim, [meanres, meanres], "r-", lw=1.0)
10✔
452
            ax.set_xlim(xlim)
10✔
453
            ax.set_ylabel("residual", labelpad=0.1)
10✔
454
            ax.set_xlabel("observed", labelpad=0.1)
10✔
455
            ax.set_title(
10✔
456
                "{0}) group:{1}, {2} observations".format(
457
                    abet[ax_count], g, obs_g.shape[0]
458
                ),
459
                loc="left",
460
            )
461
            ax.grid()
10✔
462
            ax_count += 1
10✔
463
        else:
464
            # need max and min res to set xlim, otherwise wonky figsize
465
            mxr = obs_g.res.max()
10✔
466
            mnr = obs_g.res.min()
10✔
467

468
            # if obs_g.shape[0] == 1:
469
            mxr *= 1.1
10✔
470
            mnr *= 0.9
10✔
471
            rlim = (mnr, mxr)
10✔
472

473
            ax = axes[ax_count]
10✔
474
            ax.hist(obs_g.res, bins=50, color="b")
10✔
475
            meanres = obs_g.res.mean()
10✔
476
            ax.axvline(meanres, color="r", lw=1)
10✔
477
            b, t = ax.get_ylim()
10✔
478
            ax.text(meanres + meanres / 10, t - t / 10, "Mean: {:.2f}".format(meanres))
10✔
479
            ax.set_xlim(rlim)
10✔
480
            ax.set_ylabel("count", labelpad=0.1)
10✔
481
            ax.set_xlabel("residual", labelpad=0.1)
10✔
482
            ax.set_title(
10✔
483
                "{0}) group:{1}, {2} observations".format(
484
                    abet[ax_count], g, obs_g.shape[0]
485
                ),
486
                loc="left",
487
            )
488
            ax.grid()
10✔
489
            ax_count += 1
10✔
490
        logger.log("plotting 1to1 for {0}".format(g))
10✔
491

492
    for a in range(ax_count, nr * nc):
10✔
493
        axes[a].set_axis_off()
10✔
494
        axes[a].set_yticks([])
10✔
495
        axes[a].set_xticks([])
10✔
496

497
    plt.tight_layout()
10✔
498
    # pdf.savefig()
499
    # plt.close(fig)
500
    figs.append(fig)
10✔
501
    if filename is not None:
10✔
502
        with PdfPages(filename) as pdf:
10✔
503
            for fig in figs:
10✔
504
                pdf.savefig(fig)
10✔
505
                plt.close(fig)
10✔
506
        logger.log("plot res_1to1")
10✔
507
    else:
508
        logger.log("plot res_1to1")
10✔
509
        return figs
10✔
510

511

512
@apply_custom_font()
11✔
513
def plot_id_bar(id_df, nsv=None, logger=None, **kwargs):
11✔
514
    """Plot a stacked bar chart of identifiability based on
515
    a the `pyemu.ErrVar.get_identifiability()` dataframe
516

517
    Args:
518
        id_df (`pandas.DataFrame`) : dataframe of identifiability
519
        nsv (`int`): number of singular values to consider
520
        logger (`pyemu.Logger`, optional): a logger.  If None, a generic
521
            one is created
522
        kwargs (`dict`): a dict of keyword arguments to pass to the
523
            plotting function
524

525
    Returns:
526
        `matplotlib.Axis`: the axis with the plot
527

528
    Example::
529

530
        import pyemu
531
        pest_obj = pyemu.Pst(pest_control_file)
532
        ev = pyemu.ErrVar(jco='freyberg_jac.jcb'))
533
        id_df = ev.get_identifiability_dataframe(singular_value=48)
534
        pyemu.plot_utils.plot_id_bar(id_df, nsv=12, figsize=(12,4)
535

536
    """
537
    if logger is None:
10✔
538
        logger = Logger("Default_Logger.log", echo=False)
10✔
539
    logger.log("plot id bar")
10✔
540

541
    df = id_df.copy()
10✔
542

543
    # drop the final `ident` column
544
    if "ident" in df.columns:
10✔
545
        df.drop("ident", inplace=True, axis=1)
10✔
546

547
    if nsv is None or nsv > len(df.columns):
10✔
548
        nsv = len(df.columns)
10✔
549
        logger.log("set number of SVs and number in the dataframe")
10✔
550

551
    df = df[df.columns[:nsv]]
10✔
552

553
    df["ident"] = df.sum(axis=1)
10✔
554
    df.sort_values(by="ident", inplace=True, ascending=False)
10✔
555
    df.drop("ident", inplace=True, axis=1)
10✔
556

557
    if "figsize" in kwargs:
10✔
558
        figsize = kwargs["figsize"]
×
559
    else:
560
        figsize = (8, 10.5)
10✔
561
    if "ax" in kwargs:
10✔
562
        ax = kwargs["ax"]
×
563
    else:
564
        fig = plt.figure(figsize=figsize)
10✔
565
        ax = plt.subplot(1, 1, 1)
10✔
566

567
    # plto the stacked bar chart (the easy part!)
568
    df.plot.bar(stacked=True, cmap="jet_r", legend=False, ax=ax)
10✔
569

570
    #
571
    # horrible shenanigans to make a colorbar rather than a legend
572
    #
573

574
    # special case colormap just dark red if one SV
575
    if nsv == 1:
10✔
576
        tcm = matplotlib.colors.LinearSegmentedColormap.from_list(
×
577
            "one_sv", [plt.get_cmap("jet_r")(0)] * 2, N=2
578
        )
579
        sm = plt.cm.ScalarMappable(
×
580
            cmap=tcm, norm=matplotlib.colors.Normalize(vmin=0, vmax=nsv + 1)
581
        )
582
    # or typically just rock the jet_r colormap over the range of SVs
583
    else:
584
        sm = plt.cm.ScalarMappable(
10✔
585
            cmap=plt.get_cmap("jet_r"),
586
            norm=matplotlib.colors.Normalize(vmin=1, vmax=nsv),
587
        )
588
    sm._A = []
10✔
589

590
    # now, if too many ticks for the colorbar, summarize them
591
    if nsv < 20:
10✔
592
        ticks = range(1, nsv + 1)
10✔
593
    else:
594
        ticks = np.arange(1, nsv + 1, int((nsv + 1) / 30))
×
595

596
    cb = plt.colorbar(sm, ax=ax)
10✔
597
    cb.set_ticks(ticks)
10✔
598

599
    logger.log("plot id bar")
10✔
600

601
    return ax
10✔
602

603

604
@apply_custom_font()
11✔
605
def res_phi_pie(pst, logger=None, **kwargs):
11✔
606
    """plot current phi components as a pie chart.
607

608
    Args:
609
        pst (`pyemu.Pst`): a control file instance with the residual datafrane
610
            instance available.
611
        logger (`pyemu.Logger`): a logger.  If None, a generic one is created
612
        kwargs (`dict`): a dict of plotting options. Accepts 'include_zero'
613
            as a flag to include phi groups with only zero-weight obs (not
614
            sure why anyone would do this, but whatevs).
615

616
            Also accepts 'label_comps': list of components for the labels. Options are
617
            ['name', 'phi_comp', 'phi_percent']. Labels will use those three components
618
            in the order of the 'label_comps' list.
619

620
            Any additional
621
            args are passed to `matplotlib`.
622

623
    Returns:
624
        `matplotlib.Axis`: the axis with the plot.
625

626
    Example::
627

628
        import pyemu
629
        pst = pyemu.Pst("my.pst")
630
        pyemu.plot_utils.res_phi_pie(pst,figsize=(12,4))
631
        pyemu.plot_utils.res_phi_pie(pst,label_comps = ['name','phi_percent'], figsize=(12,4))
632

633

634
    """
635
    if logger is None:
10✔
636
        logger = Logger("Default_Logger.log", echo=False)
10✔
637
    logger.log("plot res_phi_pie")
10✔
638

639
    if "ensemble" in kwargs:
10✔
640
        try:
10✔
641
            res = pst_utils.res_from_en(pst, kwargs["ensemble"])
10✔
642
        except:
×
643
            logger.statement(
×
644
                "res_1to1: could not find ensemble file {0}".format(kwargs["ensemble"])
645
            )
646
    else:
647
        try:
10✔
648
            res = pst.res
10✔
649
        except:
×
650
            logger.lraise("res_phi_pie: pst.res is None, couldn't find residuals file")
×
651

652
    obs = pst.observation_data
10✔
653
    phi = pst.phi
10✔
654
    phi_comps = pst.phi_components
10✔
655
    norm_phi_comps = pst.phi_components_normalized
10✔
656
    keys = list(phi_comps.keys())
10✔
657
    if "include_zero" not in kwargs or kwargs["include_zero"] is False:
10✔
658
        phi_comps = {k: phi_comps[k] for k in keys if phi_comps[k] > 0.0}
10✔
659
        keys = list(phi_comps.keys())
10✔
660
        norm_phi_comps = {k: norm_phi_comps[k] for k in keys}
10✔
661
    if "ax" in kwargs:
10✔
662
        ax = kwargs["ax"]
10✔
663
    else:
664
        fig = plt.figure(figsize=figsize)
10✔
665
        ax = plt.subplot(1, 1, 1, aspect="equal")
10✔
666

667
    if "label_comps" not in kwargs:
10✔
668
        labels = [
10✔
669
            "{0}\n{1:4G}\n({2:3.1f}%)".format(
670
                k, phi_comps[k], 100.0 * (phi_comps[k] / phi)
671
            )
672
            for k in keys
673
        ]
674
    else:
675
        # make sure the components for the labels are in a list
676
        if not isinstance(kwargs["label_comps"], list):
×
677
            fmtchoices = list([kwargs["label_comps"]])
×
678
        else:
679
            fmtchoices = kwargs["label_comps"]
×
680
        # assemble all possible label components
681
        labfmts = {
×
682
            "name": ["{}\n", keys],
683
            "phi_comp": ["{:4G}\n", [phi_comps[k] for k in keys]],
684
            "phi_percent": ["({:3.1f}%)", [100.0 * (phi_comps[k] / phi) for k in keys]],
685
        }
686
        if fmtchoices[0] == "phi_percent":
×
687
            labfmts["phi_percent"][0] = "{}\n".format(labfmts["phi_percent"][0])
×
688
        # make the string format
689
        labfmtstr = "".join([labfmts[k][0] for k in fmtchoices])
×
690
        # pull it together
691
        labels = [
×
692
            labfmtstr.format(*k) for k in zip(*[labfmts[j][1] for j in fmtchoices])
693
        ]
694

695
    ax.pie([float(norm_phi_comps[k]) for k in keys], labels=labels)
10✔
696
    logger.log("plot res_phi_pie")
10✔
697
    if "filename" in kwargs:
10✔
698
        plt.savefig(kwargs["filename"])
10✔
699
    return ax
10✔
700

701

702
@apply_custom_font()
11✔
703
def pst_prior(pst, logger=None, filename=None, **kwargs):
11✔
704
    """helper to plot prior parameter histograms implied by
705
    parameter bounds. Saves a multipage pdf named <case>.prior.pdf
706

707
    Args:
708
        pst (`pyemu.Pst`): control file
709
        logger (`pyemu.Logger`): a logger.  If None, a generic one is created.
710
        filename (`str`):  PDF filename to save plots to.
711
            If None, return figs without saving.  Default is None.
712
        kwargs (`dict`): additional plotting options. Accepts 'grouper' as
713
            dict to group parameters on to a single axis (use
714
            parameter groups if not passed),'unique_only' to only show unique
715
            mean-stdev combinations within a given group.  Any additional args
716
            are passed to `matplotlib`.
717

718
    Returns:
719
        [`matplotlib.Figure`]: a list of figures created.
720

721
    Example::
722

723
        pst = pyemu.Pst("pest.pst")
724
        pyemu.pst_utils.pst_prior(pst)
725
        plt.show()
726

727
    """
728
    if logger is None:
10✔
NEW
729
        logger = Logger("Default_Logger.log", echo=False)
×
730
    logger.log("plot pst_prior")
10✔
731
    par = pst.parameter_data
10✔
732

733
    if "parcov_filename" in pst.pestpp_options:
10✔
734
        logger.warn("ignoring parcov_filename, using parameter bounds for prior cov")
×
735
    logger.log("loading cov from parameter data")
10✔
736
    cov = pyemu.Cov.from_parameter_data(pst)
10✔
737
    logger.log("loading cov from parameter data")
10✔
738

739
    logger.log("building mean parameter values")
10✔
740
    li = par.partrans.loc[cov.names] == "log"
10✔
741
    mean = par.parval1.loc[cov.names]
10✔
742
    info = par.loc[cov.names, :].copy()
10✔
743
    info.loc[:, "mean"] = mean
10✔
744
    info.loc[li, "mean"] = mean[li].apply(np.log10)
10✔
745
    logger.log("building mean parameter values")
10✔
746

747
    logger.log("building stdev parameter values")
10✔
748
    if cov.isdiagonal:
10✔
749
        std = cov.x.flatten()
10✔
750
    else:
751
        std = np.diag(cov.x)
×
752
    std = np.sqrt(std)
10✔
753
    info.loc[:, "prior_std"] = std
10✔
754

755
    logger.log("building stdev parameter values")
10✔
756

757
    if std.shape != mean.shape:
10✔
758
        logger.lraise("mean.shape {0} != std.shape {1}".format(mean.shape, std.shape))
×
759

760
    if "grouper" in kwargs:
10✔
761
        raise NotImplementedError()
×
762
        # check for consistency here
763

764
    else:
765
        par_adj = par.loc[par.partrans.apply(lambda x: x in ["log", "none"]), :]
10✔
766
        grouper = par_adj.groupby(par_adj.pargp).groups
10✔
767
        # grouper = par.groupby(par.pargp).groups
768

769
    if len(grouper) == 0:
10✔
770
        raise Exception("no adustable parameters to plot")
×
771

772
    fig = plt.figure(figsize=figsize)
10✔
773
    if "fig_title" in kwargs:
10✔
774
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
10✔
775
    else:
776
        plt.figtext(
10✔
777
            0.5,
778
            0.5,
779
            "pyemu.Pst.plot(kind='prior')\nfrom pest control file '{0}'\n at {1}".format(
780
                pst.filename, str(datetime.now())
781
            ),
782
            ha="center",
783
        )
784
    figs = []
10✔
785
    ax_count = 0
10✔
786
    grps_names = list(grouper.keys())
10✔
787
    grps_names.sort()
10✔
788
    for g in grps_names:
10✔
789
        names = grouper[g]
10✔
790
        logger.log("plotting priors for {0}".format(",".join(list(names))))
10✔
791
        if ax_count % (nr * nc) == 0:
10✔
792
            plt.tight_layout()
10✔
793
            # pdf.savefig()
794
            # plt.close(fig)
795
            figs.append(fig)
10✔
796
            fig = plt.figure(figsize=figsize)
10✔
797
            axes = _get_page_axes()
10✔
798
            ax_count = 0
10✔
799

800
        islog = False
10✔
801
        vc = info.partrans.value_counts()
10✔
802
        if vc.shape[0] > 1:
10✔
803
            logger.warn("mixed partrans for group {0}".format(g))
×
804
        elif "log" in vc.index:
10✔
805
            islog = True
10✔
806
        ax = axes[ax_count]
10✔
807
        if "unique_only" in kwargs and kwargs["unique_only"]:
10✔
808

809
            ms = (
10✔
810
                info.loc[names, :]
811
                .apply(lambda x: (x["mean"], x["prior_std"]), axis=1)
812
                .unique()
813
            )
814
            for (m, s) in ms:
10✔
815
                x, y = gaussian_distribution(m, s)
10✔
816
                ax.fill_between(x, 0, y, facecolor="0.5", alpha=0.5, edgecolor="none")
10✔
817

818
        else:
819
            for m, s in zip(info.loc[names, "mean"], info.loc[names, "prior_std"]):
10✔
820
                x, y = gaussian_distribution(m, s)
10✔
821
                ax.fill_between(x, 0, y, facecolor="0.5", alpha=0.5, edgecolor="none")
10✔
822
        ax.set_title(
10✔
823
            "{0}) group:{1}, {2} parameters".format(abet[ax_count], g, names.shape[0]),
824
            loc="left",
825
        )
826

827
        ax.set_yticks([])
10✔
828
        if islog:
10✔
829
            ax.set_xlabel("$log_{10}$ parameter value", labelpad=0.1)
10✔
830
        else:
831
            ax.set_xlabel("parameter value", labelpad=0.1)
10✔
832
        logger.log("plotting priors for {0}".format(",".join(list(names))))
10✔
833

834
        ax_count += 1
10✔
835

836
    for a in range(ax_count, nr * nc):
10✔
837
        axes[a].set_axis_off()
10✔
838
        axes[a].set_yticks([])
10✔
839
        axes[a].set_xticks([])
10✔
840

841
    plt.tight_layout()
10✔
842
    # pdf.savefig()
843
    # plt.close(fig)
844
    figs.append(fig)
10✔
845
    if filename is not None:
10✔
846
        with PdfPages(filename) as pdf:
10✔
847
            # plt.tight_layout()
848
            pdf.savefig(fig)
10✔
849
            plt.close(fig)
10✔
850
        logger.log("plot pst_prior")
10✔
851
    else:
852
        logger.log("plot pst_prior")
10✔
853
        return figs
10✔
854

855

856
@apply_custom_font()
11✔
857
def ensemble_helper(
11✔
858
    ensemble,
859
    bins=10,
860
    facecolor="0.5",
861
    plot_cols=None,
862
    filename=None,
863
    func_dict=None,
864
    sync_bins=True,
865
    deter_vals=None,
866
    std_window=None,
867
    deter_range=False,
868
    **kwargs
869
):
870
    """helper function to plot ensemble histograms
871

872
    Args:
873
        ensemble : varies
874
            the ensemble argument can be a pandas.DataFrame or derived type or a str, which
875
            is treated as a filename.  Optionally, ensemble can be a list of these types or
876
            a dict, in which case, the keys are treated as facecolor str (e.g., 'b', 'y', etc).
877
        facecolor : str
878
            the histogram facecolor.  Only applies if ensemble is a single thing
879
        plot_cols : enumerable
880
            a collection of columns (in form of a list of parameters, or a dict with keys for
881
            parsing plot axes and values of parameters) from the ensemble(s) to plot.  If None,
882
            (the union of) all cols are plotted. Default is None
883
        filename : str
884
            the name of the pdf to create.  If None, return figs without saving.  Default is None.
885
        func_dict : dict
886
            a dictionary of unary functions (e.g., `np.log10` to apply to columns.  Key is
887
            column name.  Default is None
888
        sync_bins : bool
889
            flag to use the same bin edges for all ensembles. Only applies if more than
890
            one ensemble is being plotted.  Default is True
891
        deter_vals : dict
892
            dict of deterministic values to plot as a vertical line. key is ensemble column name
893
        std_window : float
894
            the number of standard deviations around the mean to mark as vertical lines.  If None,
895
            nothing happens.  Default is None
896
        deter_range : bool
897
            flag to set xlims to deterministic value +/- std window.  If True, std_window must not be None.
898
            Default is False
899

900
    Example::
901

902
        # plot prior and posterior par ensembles
903
        pst = pyemu.Pst("my.pst")
904
        prior = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="prior.jcb")
905
        post = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="my.3.par.jcb")
906
        pyemu.plot_utils.ensemble_helper(ensemble={"0.5":prior, "b":post},filename="ensemble.pdf")
907

908
        #plot prior and posterior simulated equivalents to observations with obs noise and obs vals
909
        pst = pyemu.Pst("my.pst")
910
        prior = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.0.obs.jcb")
911
        post = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.3.obs.jcb")
912
        noise = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.obs+noise.jcb")
913
        pyemu.plot_utils.ensemble_helper(ensemble={"0.5":prior, "b":post,"r":noise},
914
                                         filename="ensemble.pdf",
915
                                         deter_vals=pst.observation_data.obsval.to_dict())
916

917

918
    """
919
    logger = pyemu.Logger("ensemble_helper.log")
10✔
920
    logger.log("pyemu.plot_utils.ensemble_helper()")
10✔
921
    ensembles = _process_ensemble_arg(ensemble, facecolor, logger)
10✔
922
    if len(ensembles) == 0:
10✔
923
        raise Exception("plot_uitls.ensemble_helper() error processing `ensemble` arg")
×
924
    # apply any functions
925
    if func_dict is not None:
10✔
926
        logger.log("applying functions")
10✔
927
        for col, funcc in func_dict.items():
10✔
928
            for fc, en in ensembles.items():
10✔
929
                if col in en.columns:
10✔
930
                    en.loc[:, col] = en.loc[:, col].apply(funcc)
10✔
931
        logger.log("applying functions")
10✔
932

933
    # get a list of all cols (union)
934
    all_cols = set()
10✔
935
    for fc, en in ensembles.items():
10✔
936
        cols = set(en.columns)
10✔
937
        all_cols.update(cols)
10✔
938
    if plot_cols is None:
10✔
939
        plot_cols = {i: [v] for i, v in (zip(all_cols, all_cols))}
×
940
    else:
941
        if isinstance(plot_cols, list):
10✔
942
            splot_cols = set(plot_cols)
10✔
943
            plot_cols = {i: [v] for i, v in (zip(plot_cols, plot_cols))}
10✔
944
        elif isinstance(plot_cols, dict):
10✔
945
            splot_cols = []
10✔
946
            for label, pcols in plot_cols.items():
10✔
947
                splot_cols.extend(list(pcols))
10✔
948
            splot_cols = set(splot_cols)
10✔
949
        else:
950
            logger.lraise(
×
951
                "unrecognized plot_cols type: {0}, should be list or dict".format(
952
                    type(plot_cols)
953
                )
954
            )
955

956
        missing = splot_cols - all_cols
10✔
957
        if len(missing) > 0:
10✔
958
            logger.lraise(
×
959
                "the following plot_cols are missing: {0}".format(",".join(missing))
960
            )
961

962
    logger.statement("plotting {0} histograms".format(len(plot_cols)))
10✔
963

964
    fig = plt.figure(figsize=figsize)
10✔
965
    if "fig_title" in kwargs:
10✔
966
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
967
    else:
968
        plt.figtext(
10✔
969
            0.5,
970
            0.5,
971
            "pyemu.plot_utils.ensemble_helper()\n at {0}".format(str(datetime.now())),
972
            ha="center",
973
        )
974
    # plot_cols = list(plot_cols)
975
    # plot_cols.sort()
976
    labels = list(plot_cols.keys())
10✔
977
    labels.sort()
10✔
978
    logger.statement("saving pdf to {0}".format(filename))
10✔
979
    figs = []
10✔
980

981
    ax_count = 0
10✔
982

983
    # for label,plot_col in plot_cols.items():
984
    for label in labels:
10✔
985
        plot_col = plot_cols[label]
10✔
986
        logger.log("plotting reals for {0}".format(label))
10✔
987
        if ax_count % (nr * nc) == 0:
10✔
988
            plt.tight_layout()
10✔
989
            # pdf.savefig()
990
            # plt.close(fig)
991
            figs.append(fig)
10✔
992
            fig = plt.figure(figsize=figsize)
10✔
993
            axes = _get_page_axes()
10✔
994
            [ax.set_yticks([]) for ax in axes]
10✔
995
            ax_count = 0
10✔
996

997
        ax = axes[ax_count]
10✔
998

999
        if sync_bins:
10✔
1000
            mx, mn = -1.0e30, 1.0e30
10✔
1001
            for fc, en in ensembles.items():
10✔
1002
                # for pc in plot_col:
1003
                #     if pc in en.columns:
1004
                #         emx,emn = en.loc[:,pc].max(),en.loc[:,pc].min()
1005
                #         mx = max(mx,emx)
1006
                #         mn = min(mn,emn)
1007
                emn = np.nanmin(en.loc[:, plot_col].values)
10✔
1008
                emx = np.nanmax(en.loc[:, plot_col].values)
10✔
1009
                mx = max(mx, emx)
10✔
1010
                mn = min(mn, emn)
10✔
1011
            if mx == -1.0e30 and mn == 1.0e30:
10✔
1012
                logger.warn("all NaNs for label: {0}".format(label))
×
1013
                ax.set_title(
×
1014
                    "{0}) {1}, count:{2} - all NaN".format(
1015
                        abet[ax_count], label, len(plot_col)
1016
                    ),
1017
                    loc="left",
1018
                )
1019
                ax.set_yticks([])
×
1020
                ax.set_xticks([])
×
1021
                ax_count += 1
×
1022
                continue
×
1023
            plot_bins = np.linspace(mn, mx, num=bins)
10✔
1024
            logger.statement("{0} min:{1:5G}, max:{2:5G}".format(label, mn, mx))
10✔
1025
        else:
1026
            plot_bins = bins
10✔
1027
        for fc, en in ensembles.items():
10✔
1028
            # for pc in plot_col:
1029
            #    if pc in en.columns:
1030
            #        try:
1031
            #            en.loc[:,pc].hist(bins=plot_bins,facecolor=fc,
1032
            #                                    edgecolor="none",alpha=0.5,
1033
            #                                    density=True,ax=ax)
1034
            #        except Exception as e:
1035
            #            logger.warn("error plotting histogram for {0}:{1}".
1036
            #                        format(pc,str(e)))
1037
            vals = en.loc[:, plot_col].values.flatten()
10✔
1038
            # print(plot_bins)
1039
            # print(vals)
1040

1041
            ax.hist(
10✔
1042
                vals,
1043
                bins=plot_bins,
1044
                edgecolor="none",
1045
                alpha=0.5,
1046
                density=True,
1047
                facecolor=fc,
1048
            )
1049
            v = None
10✔
1050
            if deter_vals is not None:
10✔
1051
                for pc in plot_col:
10✔
1052
                    if pc in deter_vals:
10✔
1053
                        ylim = ax.get_ylim()
10✔
1054
                        v = deter_vals[pc]
10✔
1055
                        ax.plot([v, v], ylim, "k--", lw=1.5)
10✔
1056
                        ax.set_ylim(ylim)
10✔
1057

1058
            if std_window is not None:
10✔
1059
                try:
×
1060
                    ylim = ax.get_ylim()
×
1061
                    mn, st = (
×
1062
                        en.loc[:, pc].mean(),
1063
                        en.loc[:, pc].std() * (std_window / 2.0),
1064
                    )
1065

1066
                    ax.plot([mn - st, mn - st], ylim, color=fc, lw=1.5, ls="--")
×
1067
                    ax.plot([mn + st, mn + st], ylim, color=fc, lw=1.5, ls="--")
×
1068
                    ax.set_ylim(ylim)
×
1069
                    if deter_range and v is not None:
×
1070
                        xmn = v - st
×
1071
                        xmx = v + st
×
1072
                        ax.set_xlim(xmn, xmx)
×
1073
                except:
×
1074
                    logger.warn("error plotting std window for {0}".format(pc))
×
1075
        ax.grid()
10✔
1076
        if len(ensembles) > 1:
10✔
1077
            ax.set_title(
10✔
1078
                "{0}) {1}, count: {2}".format(abet[ax_count], label, len(plot_col)),
1079
                loc="left",
1080
            )
1081
        else:
1082
            ax.set_title(
10✔
1083
                "{0}) {1}, count:{2}\nmin:{3:3.1E}, max:{4:3.1E}".format(
1084
                    abet[ax_count],
1085
                    label,
1086
                    len(plot_col),
1087
                    np.nanmin(vals),
1088
                    np.nanmax(vals),
1089
                ),
1090
                loc="left",
1091
            )
1092
        ax_count += 1
10✔
1093

1094
    for a in range(ax_count, nr * nc):
10✔
1095
        axes[a].set_axis_off()
10✔
1096
        axes[a].set_yticks([])
10✔
1097
        axes[a].set_xticks([])
10✔
1098

1099
    plt.tight_layout()
10✔
1100
    # pdf.savefig()
1101
    # plt.close(fig)
1102
    figs.append(fig)
10✔
1103
    if filename is not None:
10✔
1104
        # plt.tight_layout()
1105
        with PdfPages(filename) as pdf:
10✔
1106
            for fig in figs:
10✔
1107
                pdf.savefig(fig)
10✔
1108
                plt.close(fig)
10✔
1109
    logger.log("pyemu.plot_utils.ensemble_helper()")
10✔
1110

1111

1112
@apply_custom_font()
11✔
1113
def ensemble_change_summary(
11✔
1114
    ensemble1,
1115
    ensemble2,
1116
    pst,
1117
    bins=10,
1118
    facecolor="0.5",
1119
    logger=None,
1120
    filename=None,
1121
    **kwargs
1122
):
1123
    """helper function to plot first and second moment change histograms between two
1124
    ensembles
1125

1126
    Args:
1127
        ensemble1 (varies): filename or `pandas.DataFrame` or `pyemu.Ensemble`
1128
        ensemble2 (varies): filename or `pandas.DataFrame` or `pyemu.Ensemble`
1129
        pst (`pyemu.Pst`): control file
1130
        facecolor (`str`): the histogram facecolor.
1131
        filename (`str`): the name of the multi-pdf to create. If None, return figs without saving.  Default is None.
1132

1133
    Returns:
1134
        [`matplotlib.Figure`]: a list of figures.  Returns None is
1135
        `filename` is not None
1136

1137
    Example::
1138

1139
        pst = pyemu.Pst("my.pst")
1140
        prior = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="prior.jcb")
1141
        post = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="my.3.par.jcb")
1142
        pyemu.plot_utils.ensemble_change_summary(prior,post)
1143
        plt.show()
1144

1145

1146
    """
1147
    if logger is None:
10✔
1148
        logger = Logger("Default_Logger.log", echo=False)
10✔
1149
    logger.log("plot ensemble change")
10✔
1150

1151
    if isinstance(ensemble1, str):
10✔
1152
        ensemble1 = pd.read_csv(ensemble1, index_col=0)
×
1153
    ensemble1.columns = ensemble1.columns.str.lower()
10✔
1154

1155
    if isinstance(ensemble2, str):
10✔
1156
        ensemble2 = pd.read_csv(ensemble2, index_col=0)
×
1157
    ensemble2.columns = ensemble2.columns.str.lower()
10✔
1158

1159
    # better to ensure this is caught by pestpp-ies ensemble csvs
1160
    unnamed1 = [col for col in ensemble1.columns if "unnamed:" in col]
10✔
1161
    if len(unnamed1) != 0:
10✔
1162
        ensemble1 = ensemble1.iloc[
×
1163
            :, :-1
1164
        ]  # ensure unnamed col result of poor csv read only (ie last col)
1165
    unnamed2 = [col for col in ensemble2.columns if "unnamed:" in col]
10✔
1166
    if len(unnamed2) != 0:
10✔
1167
        ensemble2 = ensemble2.iloc[
×
1168
            :, :-1
1169
        ]  # ensure unnamed col result of poor csv read only (ie last col)
1170

1171
    d = set(ensemble1.columns).symmetric_difference(set(ensemble2.columns))
10✔
1172

1173
    if len(d) != 0:
10✔
1174
        logger.lraise(
×
1175
            "ensemble1 does not have the same columns as ensemble2: {0}".format(
1176
                ",".join(d)
1177
            )
1178
        )
1179
    if "grouper" in kwargs:
10✔
1180
        raise NotImplementedError()
×
1181
    else:
1182
        en_cols = ensemble1.columns
10✔
1183
        if len(en_cols.difference(pst.par_names)) == 0:
10✔
1184
            par = pst.parameter_data.loc[en_cols, :]
10✔
1185
            grouper = par.groupby(par.pargp).groups
10✔
1186
            grouper["all"] = pst.adj_par_names
10✔
1187
            li = par.loc[par.partrans == "log", "parnme"]
10✔
1188
            ensemble1.loc[:, li] = ensemble1.loc[:, li].apply(np.log10)
10✔
1189
            ensemble2.loc[:, li] = ensemble2.loc[:, li].apply(np.log10)
10✔
1190
        elif len(en_cols.difference(pst.obs_names)) == 0:
10✔
1191
            obs = pst.observation_data.loc[en_cols, :]
10✔
1192
            grouper = obs.groupby(obs.obgnme).groups
10✔
1193
            grouper["all"] = pst.nnz_obs_names
10✔
1194
        else:
1195
            logger.lraise("could not match ensemble cols with par or obs...")
×
1196

1197
    en1_mn, en1_std = ensemble1.mean(axis=0), ensemble1.std(axis=0)
10✔
1198
    en2_mn, en2_std = ensemble2.mean(axis=0), ensemble2.std(axis=0)
10✔
1199

1200
    # mn_diff = 100.0 * ((en1_mn - en2_mn) / en1_mn)
1201
    # std_diff = 100 * ((en1_std - en2_std) / en1_std)
1202

1203
    mn_diff = -1 * (en2_mn - en1_mn)
10✔
1204
    std_diff = 100 * (((en1_std - en2_std) / en1_std))
10✔
1205
    # set en1_std==0 to nan
1206
    # std_diff[en1_std.index[en1_std==0]] = np.nan
1207

1208
    # diff = ensemble1 - ensemble2
1209
    # mn_diff = diff.mean(axis=0)
1210
    # std_diff = diff.std(axis=0)
1211

1212
    fig = plt.figure(figsize=figsize)
10✔
1213
    if "fig_title" in kwargs:
10✔
1214
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
1215
    else:
1216
        plt.figtext(
10✔
1217
            0.5,
1218
            0.5,
1219
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
1220
                pst.filename, str(datetime.now())
1221
            ),
1222
            ha="center",
1223
        )
1224
    # if plot_hexbin:
1225
    #    pdfname = pst.filename.replace(".pst", ".1to1.hexbin.pdf")
1226
    # else:
1227
    #    pdfname = pst.filename.replace(".pst", ".1to1.pdf")
1228
    figs = []
10✔
1229
    ax_count = 0
10✔
1230
    for g, names in grouper.items():
10✔
1231
        logger.log("plotting change for {0}".format(g))
10✔
1232

1233
        mn_g = mn_diff.loc[names]
10✔
1234
        std_g = std_diff.loc[names]
10✔
1235

1236
        if mn_g.shape[0] == 0:
10✔
1237
            logger.statement("no entries for group '{0}'".format(g))
×
1238
            logger.log("plotting change for {0}".format(g))
×
1239
            continue
×
1240

1241
        if ax_count % (nr * nc) == 0:
10✔
1242
            if ax_count > 0:
10✔
1243
                plt.tight_layout()
×
1244
            # pdf.savefig()
1245
            # plt.close(fig)
1246
            figs.append(fig)
10✔
1247
            fig = plt.figure(figsize=figsize)
10✔
1248
            axes = _get_page_axes()
10✔
1249
            ax_count = 0
10✔
1250

1251
        ax = axes[ax_count]
10✔
1252
        mn_g.hist(ax=ax, facecolor=facecolor, alpha=0.5, edgecolor=None, bins=bins)
10✔
1253
        # mx = max(mn_g.max(), mn_g.min(),np.abs(mn_g.max()),np.abs(mn_g.min())) * 1.2
1254
        # ax.set_xlim(-mx,mx)
1255

1256
        # std_g.hist(ax=ax,facecolor='b',alpha=0.5,edgecolor=None)
1257

1258
        # ax.set_xlim(xlim)
1259
        ax.set_yticklabels([])
10✔
1260
        ax.set_xlabel("mean change", labelpad=0.1)
10✔
1261
        ax.set_title(
10✔
1262
            "{0}) mean change group:{1}, {2} entries\nmax:{3:10G}, min:{4:10G}".format(
1263
                abet[ax_count], g, mn_g.shape[0], mn_g.max(), mn_g.min()
1264
            ),
1265
            loc="left",
1266
        )
1267
        ax.grid()
10✔
1268
        ax_count += 1
10✔
1269

1270
        ax = axes[ax_count]
10✔
1271
        std_g.hist(ax=ax, facecolor=facecolor, alpha=0.5, edgecolor=None, bins=bins)
10✔
1272
        # std_g.hist(ax=ax,facecolor='b',alpha=0.5,edgecolor=None)
1273

1274
        # ax.set_xlim(xlim)
1275
        ax.set_yticklabels([])
10✔
1276
        ax.set_xlabel("sigma percent reduction", labelpad=0.1)
10✔
1277
        ax.set_title(
10✔
1278
            "{0}) sigma change group:{1}, {2} entries\nmax:{3:10G}, min:{4:10G}".format(
1279
                abet[ax_count], g, mn_g.shape[0], std_g.max(), std_g.min()
1280
            ),
1281
            loc="left",
1282
        )
1283
        ax.grid()
10✔
1284
        ax_count += 1
10✔
1285

1286
        logger.log("plotting change for {0}".format(g))
10✔
1287

1288
    for a in range(ax_count, nr * nc):
10✔
1289
        axes[a].set_axis_off()
10✔
1290
        axes[a].set_yticks([])
10✔
1291
        axes[a].set_xticks([])
10✔
1292

1293
    plt.tight_layout()
10✔
1294
    # pdf.savefig()
1295
    # plt.close(fig)
1296
    figs.append(fig)
10✔
1297
    if filename is not None:
10✔
1298
        # plt.tight_layout()
1299
        with PdfPages(filename) as pdf:
10✔
1300
            for fig in figs:
10✔
1301
                pdf.savefig(fig)
10✔
1302
                plt.close(fig)
10✔
1303
        logger.log("plot ensemble change")
10✔
1304
    else:
1305
        logger.log("plot ensemble change")
10✔
1306
        return figs
10✔
1307

1308

1309
def _process_ensemble_arg(ensemble, facecolor, logger):
11✔
1310
    """private method to work out ensemble plot args"""
1311
    ensembles = {}
10✔
1312
    if isinstance(ensemble, pd.DataFrame) or isinstance(ensemble, pyemu.Ensemble):
10✔
1313
        if not isinstance(facecolor, str):
10✔
1314
            logger.lraise("facecolor must be str")
×
1315
        ensembles[facecolor] = ensemble
10✔
1316
    elif isinstance(ensemble, str):
10✔
1317
        if not isinstance(facecolor, str):
10✔
1318
            logger.lraise("facecolor must be str")
×
1319

1320
        logger.log("loading ensemble from csv file {0}".format(ensemble))
10✔
1321
        en = pd.read_csv(ensemble, index_col=0)
10✔
1322
        logger.statement("{0} shape: {1}".format(ensemble, en.shape))
10✔
1323
        ensembles[facecolor] = en
10✔
1324
        logger.log("loading ensemble from csv file {0}".format(ensemble))
10✔
1325

1326
    elif isinstance(ensemble, list):
10✔
1327
        if isinstance(facecolor, list):
10✔
1328
            if len(ensemble) != len(facecolor):
×
1329
                logger.lraise("facecolor list len != ensemble list len")
×
1330
        else:
1331
            colors = ["m", "c", "b", "r", "g", "y"]
10✔
1332

1333
            facecolor = [colors[i] for i in range(len(ensemble))]
10✔
1334
        ensembles = {}
10✔
1335
        for fc, en_arg in zip(facecolor, ensemble):
10✔
1336
            if isinstance(en_arg, str):
10✔
1337
                logger.log("loading ensemble from csv file {0}".format(en_arg))
10✔
1338
                en = pd.read_csv(en_arg, index_col=0)
10✔
1339
                logger.log("loading ensemble from csv file {0}".format(en_arg))
10✔
1340
                logger.statement("ensemble {0} gets facecolor {1}".format(en_arg, fc))
10✔
1341

1342
            elif isinstance(en_arg, pd.DataFrame) or isinstance(en_arg, pyemu.Ensemble):
10✔
1343
                en = en_arg
10✔
1344
            else:
1345
                logger.lraise("unrecognized ensemble list arg:{0}".format(en_arg))
×
1346
            ensembles[fc] = en
10✔
1347

1348
    elif isinstance(ensemble, dict):
10✔
1349
        for fc, en_arg in ensemble.items():
10✔
1350
            if isinstance(en_arg, pd.DataFrame) or isinstance(en_arg, pyemu.Ensemble):
10✔
1351
                ensembles[fc] = en_arg
10✔
1352
            elif isinstance(en_arg, str):
10✔
1353
                logger.log("loading ensemble from csv file {0}".format(en_arg))
10✔
1354
                en = pd.read_csv(en_arg, index_col=0)
10✔
1355
                logger.log("loading ensemble from csv file {0}".format(en_arg))
10✔
1356
                ensembles[fc] = en
10✔
1357
            else:
1358
                logger.lraise("unrecognized ensemble list arg:{0}".format(en_arg))
×
1359
    try:
10✔
1360
        for fc in ensembles:
10✔
1361
            ensembles[fc].columns = ensembles[fc].columns.str.lower()
10✔
1362
    except:
×
1363
        logger.lraise("error processing ensemble")
×
1364

1365
    return ensembles
10✔
1366

1367

1368
@apply_custom_font()
11✔
1369
def ensemble_res_1to1(
11✔
1370
    ensemble,
1371
    pst,
1372
    facecolor="0.5",
1373
    logger=None,
1374
    filename=None,
1375
    skip_groups=[],
1376
    base_ensemble=None,
1377
    **kwargs
1378
):
1379
    """helper function to plot ensemble 1-to-1 plots showing the simulated range
1380

1381
    Args:
1382
        ensemble (varies):  the ensemble argument can be a pandas.DataFrame or derived type or a str, which
1383
            is treated as a filename.  Optionally, ensemble can be a list of these types or
1384
            a dict, in which case, the keys are treated as facecolor str (e.g., 'b', 'y', etc).
1385
        pst (`pyemu.Pst`): a control file instance
1386
        facecolor (`str`): the histogram facecolor.  Only applies if `ensemble` is a single thing
1387
        filename (`str`): the name of the pdf to create. If None, return figs
1388
            without saving.  Default is None.
1389
        base_ensemble (`varies`): an optional ensemble argument for the observations + noise ensemble.
1390
            This will be plotted as a transparent red bar on the 1to1 plot.
1391

1392
    Note:
1393

1394
        the vertical bar on each plot the min-max range
1395

1396
    Example::
1397

1398

1399
        pst = pyemu.Pst("my.pst")
1400
        prior = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.0.obs.jcb")
1401
        post = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.3.obs.jcb")
1402
        pyemu.plot_utils.ensemble_res_1to1(ensemble={"0.5":prior, "b":post})
1403
        plt.show()
1404

1405
    """
1406
    def _get_plotlims(oen, ben, obsnames):
10✔
1407
        if not isinstance(oen, dict):
10✔
1408
            oen = {'g': oen.loc[:, obsnames]}
×
1409
        if not isinstance(ben, dict):
10✔
1410
            ben = {'g': ben.get(obsnames)}
10✔
1411
        outofrange = False
10✔
1412
        # work back from crazy values
1413
        oemin = 1e32
10✔
1414
        oemeanmin = 1e32
10✔
1415
        oemax = -1e32
10✔
1416
        oemeanmax = -1e32
10✔
1417
        bemin = 1e32
10✔
1418
        bemeanmin = 1e32
10✔
1419
        bemax = -1e32
10✔
1420
        bemeanmax = -1e32
10✔
1421
        for _, oeni in oen.items():  # loop over ensembles
10✔
1422
            oeni = oeni.loc[:, obsnames]  # slice group obs
10✔
1423
            oemin = np.nanmin([oemin, oeni.min().min()])
10✔
1424
            oemax = np.nanmax([oemax, oeni.max().max()])
10✔
1425
            # get min and max of mean sim vals
1426
            # (in case we want plot to ignore extremes)
1427
            oemeanmin = np.nanmin([oemeanmin, oeni.mean().min()])
10✔
1428
            oemeanmax = np.nanmax([oemeanmax, oeni.mean().max()])
10✔
1429
        for _, beni in ben.items():  # same with base ensemble/obsval
10✔
1430
            # work with either ensemble or obsval series
1431
            beni = beni.get(obsnames)
10✔
1432
            bemin = np.nanmin([bemin, beni.min().min()])
10✔
1433
            bemax = np.nanmax([bemax, beni.max().max()])
10✔
1434
            bemeanmin = np.nanmin([bemeanmin, beni.mean().min()])
10✔
1435
            bemeanmax = np.nanmax([bemeanmax, beni.mean().max()])
10✔
1436
        # get base ensemble range
1437
        berange = bemax-bemin
10✔
1438
        if berange == 0.:  # only one obs in group (probs)
10✔
1439
            berange = bemeanmax * 1.1  # expand a little
10✔
1440
        # add buffer to obs endpoints
1441
        bemin = bemin - (berange*0.05)
10✔
1442
        bemax = bemax + (berange*0.05)
10✔
1443
        if oemax < bemin:  # sim well below obs
10✔
1444
            oemin = oemeanmin  # set min to mean min
10✔
1445
            # (sim captured but not extremes)
1446
            outofrange = True
10✔
1447
        if oemin > bemax:  # sim well above obs
10✔
1448
            oemax = oemeanmax
10✔
1449
            outofrange = True
10✔
1450
        oerange = oemax - oemin
10✔
1451
        if bemax > oemax + (0.1*oerange):  # obs max well above sim
10✔
1452
            if not outofrange:  # but sim still in range
10✔
1453
                # zoom to sim
1454
                bemax = oemax + (0.1*oerange)
10✔
1455
            else:  # use obs mean max
1456
                bemax = bemeanmax
10✔
1457
        if bemin < oemin - (0.1 * oerange):  # obs min well below sim
10✔
1458
            if not outofrange:  # but sim still in range
10✔
1459
                # zoom to sim
1460
                bemin = oemin - (0.1 * oerange)
10✔
1461
            else:
1462
                bemin = bemeanmin
10✔
1463
        pmin = np.nanmin([oemin, bemin])
10✔
1464
        pmax = np.nanmax([oemax, bemax])
10✔
1465
        return pmin, pmax
10✔
1466

1467

1468
    if logger is None:
10✔
1469
        logger = Logger("Default_Logger.log", echo=False)
10✔
1470
    logger.log("plot res_1to1")
10✔
1471
    obs = pst.observation_data
10✔
1472
    ensembles = _process_ensemble_arg(ensemble, facecolor, logger)
10✔
1473

1474
    if base_ensemble is not None:
10✔
1475
        base_ensemble = _process_ensemble_arg(base_ensemble, "r", logger)
10✔
1476

1477
    if "grouper" in kwargs:
10✔
1478
        raise NotImplementedError()
×
1479
    else:
1480
        grouper = obs.groupby(obs.obgnme).groups
10✔
1481
        for skip_group in skip_groups:
10✔
1482
            grouper.pop(skip_group)
×
1483

1484
    fig = plt.figure(figsize=figsize)
10✔
1485
    if "fig_title" in kwargs:
10✔
1486
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
1487
    else:
1488
        plt.figtext(
10✔
1489
            0.5,
1490
            0.5,
1491
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
1492
                pst.filename, str(datetime.now())
1493
            ),
1494
            ha="center",
1495
        )
1496

1497
    figs = []
10✔
1498
    ax_count = 0
10✔
1499
    for g, names in grouper.items():
10✔
1500
        logger.log("plotting 1to1 for {0}".format(g))
10✔
1501
        # control file observation for group
1502
        obs_g = obs.loc[names, :]
10✔
1503
        # normally only look a non-zero weighted obs
1504
        if "include_zero" not in kwargs or kwargs["include_zero"] is False:
10✔
1505
            obs_g = obs_g.loc[obs_g.weight > 0, :]
10✔
1506
        if obs_g.shape[0] == 0:
10✔
1507
            logger.statement("no non-zero obs for group '{0}'".format(g))
10✔
1508
            logger.log("plotting 1to1 for {0}".format(g))
10✔
1509
            continue
10✔
1510
        # if the first axis in page
1511
        if ax_count % (nr * nc) == 0:
10✔
1512
            if ax_count > 0:
10✔
1513
                plt.tight_layout()
×
1514
            figs.append(fig)
10✔
1515
            fig = plt.figure(figsize=figsize)
10✔
1516
            axes = _get_page_axes()
10✔
1517
            ax_count = 0
10✔
1518
        ax = axes[ax_count]
10✔
1519

1520
        if base_ensemble is None:
10✔
1521
            # if obs not defined by obs+noise ensemble,
1522
            # use min and max for obsval from control file
1523
            pmin, pmax = _get_plotlims(ensembles, obs_g.obsval, obs_g.obsnme)
10✔
1524
        else:
1525
            # if obs defined by obs+noise use obs+noise min and max
1526
            pmin, pmax = _get_plotlims(ensembles, base_ensemble, obs_g.obsnme)
10✔
1527
            obs_gg = obs_g.sort_values(by="obsval")
10✔
1528
            for c, en in base_ensemble.items():
10✔
1529
                en_g = en.loc[:, obs_gg.obsnme]
10✔
1530
                emx = en_g.max()
10✔
1531
                emn = en_g.min()
10✔
1532
                
1533
                #exit()
1534
                # update y min and max for obs+noise ensembles
1535
                if len(obs_gg.obsval) > 1:
10✔
1536

1537
                    emx = np.zeros(obs_gg.shape[0]) + emx
10✔
1538
                    emn = np.zeros(obs_gg.shape[0]) + emn
10✔
1539
                    ax.fill_between(obs_gg.obsval.values, emn.values, emx.values,
10✔
1540
                                    facecolor=c, alpha=0.2, zorder=2)
1541
                else:
1542
                    ax.plot([obs_gg.obsval.values, obs_gg.obsval.values], [emn, emx], color=c, alpha=0.2, zorder=2)
10✔
1543
        for c, en in ensembles.items():
10✔
1544
            en_g = en.loc[:, obs_g.obsnme]
10✔
1545
            # output mins and maxs
1546
            emx = en_g.max()
10✔
1547
            emn = en_g.min()
10✔
1548
            [
10✔
1549
                ax.plot([ov, ov], [een, eex], color=c, zorder=1)
1550
                for ov, een, eex in zip(obs_g.obsval.values, emn.values, emx.values)
1551
            ]
1552
        ax.plot([pmin, pmax], [pmin, pmax], "k--", lw=1.0, zorder=3)
10✔
1553
        xlim = (pmin, pmax)
10✔
1554
        ax.set_xlim(pmin, pmax)
10✔
1555
        ax.set_ylim(pmin, pmax)
10✔
1556

1557
        if max(np.abs(xlim)) > 1.0e5:
10✔
1558
            ax.xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
1559
            ax.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
1560
        ax.grid()
10✔
1561

1562
        ax.set_xlabel("observed", labelpad=0.1)
10✔
1563
        ax.set_ylabel("simulated", labelpad=0.1)
10✔
1564
        ax.set_title(
10✔
1565
            "{0}) group:{1}, {2} observations".format(
1566
                abet[ax_count], g, obs_g.shape[0]
1567
            ),
1568
            loc="left",
1569
        )
1570

1571
        # Residual (RHS plot)
1572
        ax_count += 1
10✔
1573
        ax = axes[ax_count]
10✔
1574
        # ax.scatter(obs_g.obsval, obs_g.res, marker='.', s=10, color='b')
1575

1576
        if base_ensemble is not None:
10✔
1577
            obs_gg = obs_g.sort_values(by="obsval")
10✔
1578
            for c, en in base_ensemble.items():
10✔
1579
                en_g = en.loc[:, obs_gg.obsnme].subtract(obs_gg.obsval)
10✔
1580
                emx = en_g.max()
10✔
1581
                emn = en_g.min()
10✔
1582
                if len(obs_gg.obsval) > 1:
10✔
1583
                    emx = np.zeros(obs_gg.shape[0]) + emx
10✔
1584
                    emn = np.zeros(obs_gg.shape[0]) + emn
10✔
1585
                    ax.fill_between(obs_gg.obsval.values, emn.values, emx.values,
10✔
1586
                                    facecolor=c, alpha=0.2, zorder=2)
1587
                else:
1588
                    # [ax.plot([ov, ov], [een, eex], color=c,alpha=0.3) for ov, een, eex in zip(obs_g.obsval.values, en.values, ex.values)]
1589
                    ax.plot([obs_gg.obsval.values, obs_gg.obsval.values], [emn, emx], color=c,
10✔
1590
                            alpha=0.2, zorder=2)
1591
        omn = []
10✔
1592
        omx = []
10✔
1593
        for c, en in ensembles.items():
10✔
1594
            en_g = en.loc[:, obs_g.obsnme].subtract(obs_g.obsval, axis=1)
10✔
1595
            emx = en_g.max()
10✔
1596
            emn = en_g.min()
10✔
1597
            omn.append(emn)
10✔
1598
            omx.append(emx)
10✔
1599
            [
10✔
1600
                ax.plot([ov, ov], [een, eex], color=c, zorder=1)
1601
                for ov, een, eex in zip(obs_g.obsval.values, emn.values, emx.values)
1602
            ]
1603

1604
        omn = pd.concat(omn).min()
10✔
1605
        omx = pd.concat(omx).max()
10✔
1606
        mx = np.nanmax([np.abs(omn), np.abs(omx)])  # ensure symmetric about y=0
10✔
1607
        if obs_g.shape[0] == 1:
10✔
1608
            mx *= 1.05
10✔
1609
        else:
1610
            mx *= 1.02
10✔
1611
        if np.sign(omn) == np.sign(omx):
10✔
1612
            # allow y axis asymm if all above or below
1613
            mn = np.nanmin([0, np.sign(omn) * mx])
10✔
1614
            mx = np.nanmax([0, np.sign(omn) * mx])
10✔
1615
        else:
1616
            mn = -mx
10✔
1617
        ax.set_ylim(mn, mx)
10✔
1618
        bmin = obs_g.obsval.values.min()
10✔
1619
        bmax = obs_g.obsval.values.max()
10✔
1620
        brange = (bmax - bmin)
10✔
1621
        if brange == 0.:
10✔
1622
            brange = obs_g.obsval.values.mean()
10✔
1623
        bmin = bmin - 0.1*brange
10✔
1624
        bmax = bmax + 0.1*brange
10✔
1625
        xlim = (bmin, bmax)
10✔
1626
        # show a zero residuals line
1627
        ax.plot(xlim, [0, 0], "k--", lw=1.0, zorder=3)
10✔
1628

1629
        ax.set_xlim(xlim)
10✔
1630
        ax.set_ylabel("residual", labelpad=0.1)
10✔
1631
        ax.set_xlabel("observed", labelpad=0.1)
10✔
1632
        ax.set_title(
10✔
1633
            "{0}) group:{1}, {2} observations".format(
1634
                abet[ax_count], g, obs_g.shape[0]
1635
            ),
1636
            loc="left",
1637
        )
1638
        ax.grid()
10✔
1639
        if ax.get_xlim()[1] > 1.0e5:
10✔
1640
            ax.xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
1641

1642
        ax_count += 1
10✔
1643

1644
        logger.log("plotting 1to1 for {0}".format(g))
10✔
1645

1646
    for a in range(ax_count, nr * nc):
10✔
1647
        axes[a].set_axis_off()
10✔
1648
        axes[a].set_yticks([])
10✔
1649
        axes[a].set_xticks([])
10✔
1650

1651
    plt.tight_layout()
10✔
1652
    figs.append(fig)
10✔
1653
    if filename is not None:
10✔
1654
        # plt.tight_layout()
1655
        with PdfPages(filename) as pdf:
10✔
1656
            for fig in figs:
10✔
1657
                pdf.savefig(fig)
10✔
1658
                plt.close(fig)
10✔
1659
        logger.log("plot res_1to1")
10✔
1660
    else:
1661
        logger.log("plot res_1to1")
×
1662
        return figs
×
1663

1664

1665
@apply_custom_font()
11✔
1666
def plot_jac_test(
11✔
1667
    csvin, csvout, targetobs=None, filetype=None, maxoutputpages=1, outputdirectory=None
1668
):
1669
    """helper function to plot results of the Jacobian test performed using the pest++
1670
    program pestpp-swp.
1671

1672
    Args:
1673
        csvin (`str`): name of csv file used as input to sweep, typically developed with
1674
            static method pyemu.helpers.build_jac_test_csv()
1675
        csvout (`str`): name of csv file with output generated by sweep, both input
1676
            and output files can be specified in the pest++ control file
1677
            with pyemu using: pest_object.pestpp_options["sweep_parameter_csv_file"] = jactest_in_file.csv
1678
            pest_object.pestpp_options["sweep_output_csv_file"] = jactest_out_file.csv
1679
        targetobs ([`str`]): list of observation file names to plot, each parameter used for jactest can
1680
            have up to 32 observations plotted per page, throws a warning if more than
1681
            10 pages of output are requested per parameter. If none, all observations in
1682
            the output csv file are used.
1683
        filetype (`str`): file type to store output, if None, plt.show() is called.
1684
        maxoutputpages (`int`): maximum number of pages of output per parameter.  Each page can
1685
            hold up to 32 observation derivatives.  If value > 10, set it to
1686
            10 and throw a warning.  If observations in targetobs > 32*maxoutputpages,
1687
            then a random set is selected from the targetobs list (or all observations
1688
            in the csv file if targetobs=None).
1689
        outputdirectory (`str`):  directory to store results, if None, current working directory is used.
1690
            If string is passed, it is joined to the current working directory and
1691
            created if needed. If os.path is passed, it is used directly.
1692

1693
    Note:
1694
        Used in conjunction with pyemu.helpers.build_jac_test_csv() and sweep to perform
1695
        a Jacobian Test and then view the results. Can generate a lot of plots so easiest
1696
        to put into a separate directory and view the files.
1697

1698
    """
1699

1700
    localhome = os.getcwd()
×
1701
    # check if the output directory exists, if not make it
1702
    if outputdirectory is not None and not os.path.exists(
×
1703
        os.path.join(localhome, outputdirectory)
1704
    ):
1705
        os.mkdir(os.path.join(localhome, outputdirectory))
×
1706
    if outputdirectory is None:
×
1707
        figures_dir = localhome
×
1708
    else:
1709
        figures_dir = os.path.join(localhome, outputdirectory)
×
1710

1711
    # read the input and output files into pandas dataframes
1712
    jactest_in_df = pd.read_csv(csvin, engine="python", index_col=0)
×
1713
    jactest_in_df.index.name = "input_run_id"
×
1714
    jactest_out_df = pd.read_csv(csvout, engine="python", index_col=1)
×
1715

1716
    # subtract the base run from every row, leaves the one parameter that
1717
    # was perturbed in any row as only non-zero value. Set zeros to nan
1718
    # so round-off doesn't get us and sum across rows to get a column of
1719
    # the perturbation for each row, finally extract to a series. First
1720
    # the input csv and then the output.
1721
    base_par = jactest_in_df.loc["base"]
×
1722
    delta_par_df = jactest_in_df.subtract(base_par, axis="columns")
×
1723
    delta_par_df.replace(0, np.nan, inplace=True)
×
1724
    delta_par_df.drop("base", axis="index", inplace=True)
×
1725
    delta_par_df["change"] = delta_par_df.sum(axis="columns")
×
1726
    delta_par = pd.Series(delta_par_df["change"])
×
1727

1728
    base_obs = jactest_out_df.loc["base"]
×
1729
    delta_obs = jactest_out_df.subtract(base_obs)
×
1730
    delta_obs.drop("base", axis="index", inplace=True)
×
1731
    # if targetobs is None, then reset it to all the observations.
1732
    if targetobs is None:
×
1733
        targetobs = jactest_out_df.columns.tolist()[8:]
×
1734
    delta_obs = delta_obs[targetobs]
×
1735

1736
    # get the Jacobian by dividing the change in observation by the change in parameter
1737
    # for the perturbed parameters
1738
    jacobian = delta_obs.divide(delta_par, axis="index")
×
1739

1740
    # use the index created by build_jac_test_csv to get a column of parameter names
1741
    # and increments, then we can plot derivative vs. increment for each parameter
1742
    extr_df = pd.Series(jacobian.index.values).str.extract(r"(.+)(_\d+$)", expand=True)
×
1743
    extr_df[1] = pd.to_numeric(extr_df[1].str.replace("_", "")) + 1
×
1744
    extr_df.rename(columns={0: "parameter", 1: "increment"}, inplace=True)
×
1745
    extr_df.index = jacobian.index
×
1746

1747
    # make a dataframe for plotting the Jacobian by combining the parameter name
1748
    # and increments frame with the Jacobian frame
1749
    plotframe = pd.concat([extr_df, jacobian], axis=1, join="inner")
×
1750

1751
    # get a list of observations to keep based on maxoutputpages.
1752
    if maxoutputpages > 10:
×
1753
        print("WARNING, more than 10 pages of output requested per parameter")
×
1754
        print("maxoutputpage reset to 10.")
×
1755
        maxoutputpages = 10
×
1756
    num_obs_plotted = np.min(np.array([maxoutputpages * 32, len(targetobs)]))
×
1757
    if num_obs_plotted < len(targetobs):
×
1758
        # get random sample
1759
        index_plotted = np.random.choice(len(targetobs), num_obs_plotted, replace=False)
×
1760
        obs_plotted = [targetobs[x] for x in index_plotted]
×
1761
        real_pages = maxoutputpages
×
1762
    else:
1763
        obs_plotted = targetobs
×
1764
        real_pages = int(targetobs / 32) + 1
×
1765

1766
    # make a subplot of derivative vs. increment one plot for each of the
1767
    # observations in targetobs, and outputs grouped by parameter.
1768
    for param, group in plotframe.groupby("parameter"):
×
1769
        for page in range(0, real_pages):
×
1770
            fig, axes = plt.subplots(8, 4, sharex=True, figsize=(10, 15))
×
1771
            for row in range(0, 8):
×
1772
                for col in range(0, 4):
×
1773
                    count = 32 * page + 4 * row + col
×
1774
                    if count < num_obs_plotted:
×
1775
                        axes[row, col].scatter(
×
1776
                            group["increment"], group[obs_plotted[count]]
1777
                        )
1778
                        axes[row, col].plot(
×
1779
                            group["increment"], group[obs_plotted[count]], "r"
1780
                        )
1781
                        axes[row, col].set_title(obs_plotted[count])
×
1782
                        axes[row, col].set_xticks([1, 2, 3, 4, 5])
×
1783
                        axes[row, col].tick_params(direction="in")
×
1784
                        if row == 3:
×
1785
                            axes[row, col].set_xlabel("Increment")
×
1786
            plt.tight_layout()
×
1787

1788
            if filetype is None:
×
1789
                plt.show()
×
1790
            else:
1791
                plt.savefig(
×
1792
                    os.path.join(
1793
                        figures_dir, "{0}_jactest_{1}.{2}".format(param, page, filetype)
1794
                    )
1795
                )
1796
            plt.close()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc