• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pypest / pyemu / 9494064744

13 Jun 2024 04:54AM UTC coverage: 79.302% (+0.1%) from 79.195%
9494064744

push

github

briochh
Merge branch 'develop'

248 of 268 new or added lines in 6 files covered. (92.54%)

17 existing lines in 6 files now uncovered.

11812 of 14895 relevant lines covered (79.3%)

8.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.69
/pyemu/plot/plot_utils.py
1
"""Plotting functions for various PEST(++) and pyemu operations"""
11✔
2
import os
11✔
3
import numpy as np
11✔
4
import pandas as pd
11✔
5
import warnings
11✔
6
from datetime import datetime
11✔
7
import string
11✔
8
from pyemu.logger import Logger
11✔
9
from pyemu.pst import pst_utils
11✔
10
from ..pyemu_warnings import PyemuWarning
11✔
11

12
font = {"size": 6}
11✔
13
try:
11✔
14
    import matplotlib
11✔
15

16
    matplotlib.rc("font", **font)
11✔
17

18
    import matplotlib.pyplot as plt
11✔
19
    from matplotlib.backends.backend_pdf import PdfPages
11✔
20
    from matplotlib.gridspec import GridSpec
11✔
21
except Exception as e:
×
22
    # raise Exception("error importing matplotlib: {0}".format(str(e)))
23
    warnings.warn("error importing matplotlib: {0}".format(str(e)), PyemuWarning)
×
24

25
import pyemu
11✔
26

27
figsize = (8, 10.5)
11✔
28
nr, nc = 4, 2
11✔
29
# page_gs = GridSpec(nr,nc)
30

31
abet = string.ascii_uppercase
11✔
32

33

34
def plot_summary_distributions(
11✔
35
    df,
36
    ax=None,
37
    label_post=False,
38
    label_prior=False,
39
    subplots=False,
40
    figsize=(11, 8.5),
41
    pt_color="b",
42
):
43
    """helper function to plot gaussian distributions from prior and posterior
44
    means and standard deviations
45

46
    Args:
47
        df (`pandas.DataFrame`): a dataframe and csv file.  Must have columns named:
48
            'prior_mean','prior_stdev','post_mean','post_stdev'.  If loaded
49
            from a csv file, column 0 is assumed to tbe the index
50
        ax (`matplotlib.pyplot.axis`): If None, and not subplots, then one is created
51
            and all distributions are plotted on a single plot
52
        label_post (`bool`): flag to add text labels to the peak of the posterior
53
        label_prior (`bool`): flag to add text labels to the peak of the prior
54
        subplots (`bool`): flag to use subplots.  If True, then 6 axes per page
55
            are used and a single prior and posterior is plotted on each
56
        figsize (`tuple`): matplotlib figure size
57

58
    Returns:
59
        tuple containing:
60

61
        - **[`matplotlib.figure`]**: list of figures
62
        - **[`matplotlib.axis`]**: list of axes
63

64
    Note:
65
        This is useful for demystifying FOSM results
66

67
        if subplots is False, a single axis is returned
68

69
    Example::
70

71
        import matplotlib.pyplot as plt
72
        import pyemu
73
        pyemu.plot_utils.plot_summary_distributions("pest.par.usum.csv")
74
        plt.show()
75

76
    """
77
    import matplotlib.pyplot as plt
10✔
78

79
    if isinstance(df, str):
10✔
80
        df = pd.read_csv(df, index_col=0)
10✔
81
    if ax is None and not subplots:
10✔
82
        fig = plt.figure(figsize=figsize)
10✔
83
        ax = plt.subplot(111)
10✔
84
        ax.grid()
10✔
85

86
    if "post_stdev" not in df.columns and "post_var" in df.columns:
10✔
87
        df.loc[:, "post_stdev"] = df.post_var.apply(np.sqrt)
×
88
    if "prior_stdev" not in df.columns and "prior_var" in df.columns:
10✔
89
        df.loc[:, "prior_stdev"] = df.prior_var.apply(np.sqrt)
×
90
    if "prior_expt" not in df.columns and "prior_mean" in df.columns:
10✔
91
        df.loc[:, "prior_expt"] = df.prior_mean
10✔
92
    if "post_expt" not in df.columns and "post_mean" in df.columns:
10✔
93
        df.loc[:, "post_expt"] = df.post_mean
10✔
94

95
    if subplots:
10✔
96
        fig = plt.figure(figsize=figsize)
10✔
97
        ax = plt.subplot(2, 3, 1)
10✔
98
        ax_per_page = 6
10✔
99
        ax_count = 0
10✔
100
        axes = []
10✔
101
        figs = []
10✔
102
    for name in df.index:
10✔
103
        x, y = gaussian_distribution(
10✔
104
            df.loc[name, "post_expt"], df.loc[name, "post_stdev"]
105
        )
106
        ax.fill_between(x, 0, y, facecolor=pt_color, edgecolor="none", alpha=0.25)
10✔
107
        if label_post:
10✔
108
            mx_idx = np.argmax(y)
10✔
109
            xtxt, ytxt = x[mx_idx], y[mx_idx] * 1.001
10✔
110
            ax.text(xtxt, ytxt, name, ha="center", alpha=0.5)
10✔
111

112
        x, y = gaussian_distribution(
10✔
113
            df.loc[name, "prior_expt"], df.loc[name, "prior_stdev"]
114
        )
115
        ax.plot(x, y, color="0.5", lw=3.0, dashes=(2, 1))
10✔
116
        if label_prior:
10✔
117
            mx_idx = np.argmax(y)
×
118
            xtxt, ytxt = x[mx_idx], y[mx_idx] * 1.001
×
119
            ax.text(xtxt, ytxt, name, ha="center", alpha=0.5)
×
120
        # ylim = list(ax.get_ylim())
121
        # ylim[1] *= 1.2
122
        # ylim[0] = 0.0
123
        # ax.set_ylim(ylim)
124
        if subplots:
10✔
125
            ax.set_title(name)
10✔
126
            ax_count += 1
10✔
127
            ax.set_yticklabels([])
10✔
128
            axes.append(ax)
10✔
129
            if name == df.index[-1]:
10✔
130
                break
10✔
131
            if ax_count >= ax_per_page:
10✔
132
                figs.append(fig)
10✔
133
                fig = plt.figure(figsize=figsize)
10✔
134
                ax_count = 0
10✔
135
            ax = plt.subplot(2, 3, ax_count + 1)
10✔
136
    if subplots:
10✔
137
        figs.append(fig)
10✔
138
        return figs, axes
10✔
139
    ylim = list(ax.get_ylim())
10✔
140
    ylim[1] *= 1.2
10✔
141
    ylim[0] = 0.0
10✔
142
    ax.set_ylim(ylim)
10✔
143
    ax.set_yticklabels([])
10✔
144
    return ax
10✔
145

146

147
def gaussian_distribution(mean, stdev, num_pts=50):
11✔
148
    """get an x and y numpy.ndarray that spans the +/- 4
149
    standard deviation range of a gaussian distribution with
150
    a given mean and standard deviation. useful for plotting
151

152
    Args:
153
        mean (`float`): the mean of the distribution
154
        stdev (`float`): the standard deviation of the distribution
155
        num_pts (`int`): the number of points in the returned ndarrays.
156
            Default is 50
157

158
    Returns:
159
        tuple containing:
160

161
        - **numpy.ndarray**: the x-values of the distribution
162
        - **numpy.ndarray**: the y-values of the distribution
163

164
    Example::
165

166
        mean,std = 1.0, 2.0
167
        x,y = pyemu.plot.gaussian_distribution(mean,std)
168
        plt.fill_between(x,0,y)
169
        plt.show()
170

171

172
    """
173
    xstart = mean - (4.0 * stdev)
10✔
174
    xend = mean + (4.0 * stdev)
10✔
175
    x = np.linspace(xstart, xend, num_pts)
10✔
176
    y = (1.0 / np.sqrt(2.0 * np.pi * stdev * stdev)) * np.exp(
10✔
177
        -1.0 * ((x - mean) ** 2) / (2.0 * stdev * stdev)
178
    )
179
    return x, y
10✔
180

181

182
def pst_helper(pst, kind=None, **kwargs):
11✔
183
    """`pyemu.Pst` plot helper - takes the
184
    handoff from `pyemu.Pst.plot()`
185

186
    Args:
187
        kind (`str`): the kind of plot to make
188
        **kargs (`dict`): keyword arguments to pass to the
189
            plotting function and ultimately to `matplotlib`
190

191
    Returns:
192
        varies: usually a combination of `matplotlib.figure` (s) and/or
193
        `matplotlib.axis` (s)
194

195
    Example::
196

197
        pst = pyemu.Pst("pest.pst") #assumes pest.res or pest.rei is found
198
        pst.plot(kind="1to1")
199
        plt.show()
200
        pst.plot(kind="phipie")
201
        plt.show()
202
        pst.plot(kind="prior")
203
        plt.show()
204

205
    """
206

207
    echo = kwargs.get("echo", False)
10✔
208
    logger = pyemu.Logger("plot_pst_helper.log", echo=echo)
10✔
209
    logger.statement("plot_utils.pst_helper()")
10✔
210

211
    kinds = {
10✔
212
        "prior": pst_prior,
213
        "1to1": res_1to1,
214
        "phi_pie": res_phi_pie,
215
        "phi_progress": phi_progress,
216
    }
217

218
    if kind is None:
10✔
219
        returns = []
10✔
220
        base_filename = pst.filename
10✔
221
        if pst.new_filename is not None:
10✔
222
            base_filename = pst.new_filename
×
223
        base_filename = base_filename.replace(".pst", "")
10✔
224
        for name, func in kinds.items():
10✔
225
            plt_name = base_filename + "." + name + ".pdf"
10✔
226
            returns.append(func(pst, logger=logger, filename=plt_name))
10✔
227

228
        return returns
10✔
229
    elif kind not in kinds:
10✔
230
        logger.lraise(
×
231
            "unrecognized kind:{0}, should one of {1}".format(
232
                kind, ",".join(list(kinds.keys()))
233
            )
234
        )
235
    return kinds[kind](pst, logger, **kwargs)
10✔
236

237

238
def phi_progress(pst, logger=None, filename=None, **kwargs):
11✔
239
    """make plot of phi vs number of model runs - requires
240
    available  ".iobj" file generated by a PESTPP-GLM run.
241

242
    Args:
243
        pst (`pyemu.Pst`): a control file instance
244
        logger (`pyemu.Logger`):  if None, a generic one is created.  Default is None
245
        filename (`str`): PDF filename to save figures to.  If None, figures
246
            are returned.  Default is None
247
        kwargs (`dict`): optional keyword args to pass to plotting function
248

249
    Returns:
250
        `matplotlib.axis`: the axis the plot was made on
251

252
    Example::
253

254
        import matplotlib.pyplot as plt
255
        import pyemu
256
        pst = pyemu.Pst("my.pst")
257
        pyemu.plot_utils.phi_progress(pst)
258
        plt.show()
259

260
    """
261
    if logger is None:
10✔
262
        logger = Logger("Default_Loggger.log", echo=False)
×
263
    logger.log("plot phi_progress")
10✔
264

265
    iobj_file = pst.filename.replace(".pst", ".iobj")
10✔
266
    if not os.path.exists(iobj_file):
10✔
267
        logger.lraise("couldn't find iobj file {0}".format(iobj_file))
×
268
    df = pd.read_csv(iobj_file)
10✔
269
    if "ax" in kwargs:
10✔
270
        ax = kwargs["ax"]
×
271
    else:
272
        fig = plt.figure(figsize=figsize)
10✔
273
        ax = plt.subplot(1, 1, 1)
10✔
274
    ax.plot(df.model_runs_completed, df.total_phi, marker=".")
10✔
275
    ax.set_xlabel("model runs")
10✔
276
    ax.set_ylabel(r"$\phi$")
10✔
277
    ax.grid()
10✔
278
    if filename is not None:
10✔
279
        plt.savefig(filename)
10✔
280
    logger.log("plot phi_progress")
10✔
281
    return ax
10✔
282

283

284
def _get_page_axes(count=nr * nc):
11✔
285
    axes = [plt.subplot(nr, nc, i + 1) for i in range(min(count, nr * nc))]
10✔
286
    # [ax.set_yticks([]) for ax in axes]
287
    return axes
10✔
288

289

290
def res_1to1(
11✔
291
    pst, logger=None, filename=None, plot_hexbin=False, histogram=False, **kwargs
292
):
293
    """make 1-to-1 plots and also observed vs residual by observation group
294

295
    Args:
296
        pst (`pyemu.Pst`): a control file instance
297
        logger (`pyemu.Logger`):  if None, a generic one is created.  Default is None
298
        filename (`str`): PDF filename to save figures to.  If None, figures
299
            are returned.  Default is None
300
        hexbin (`bool`): flag to use the hexbinning for large numbers of residuals.
301
            Default is False
302
        histogram (`bool`): flag to plot residual histograms instead of obs vs residual.
303
            Default is False (use `matplotlib.pyplot.scatter` )
304
        kwargs (`dict`): optional keyword args to pass to plotting function
305

306
    Returns:
307
        `matplotlib.axis`: the axis the plot was made on
308

309
    Example::
310

311
        import matplotlib.pyplot as plt
312
        import pyemu
313
        pst = pyemu.Pst("my.pst")
314
        pyemu.plot_utils.phi_progress(pst)
315
        plt.show()
316

317
    """
318
    if logger is None:
10✔
319
        logger = Logger("Default_Loggger.log", echo=False)
10✔
320
    logger.log("plot res_1to1")
10✔
321

322
    if "ensemble" in kwargs:
10✔
323
        res = pst_utils.res_from_en(pst, kwargs["ensemble"])
10✔
324
        try:
10✔
325
            res = pst_utils.res_from_en(pst, kwargs["ensemble"])
10✔
326
        except Exception as e:
×
327
            logger.lraise("res_1to1: error loading ensemble file: {0}".format(str(e)))
×
328
    else:
329
        try:
10✔
330
            res = pst.res
10✔
331
        except:
×
332
            logger.lraise("res_phi_pie: pst.res is None, couldn't find residuals file")
×
333

334
    obs = pst.observation_data
10✔
335

336
    if "grouper" in kwargs:
10✔
337
        raise NotImplementedError()
×
338
    else:
339
        grouper = obs.groupby(obs.obgnme).groups
10✔
340

341
    fig = plt.figure(figsize=figsize)
10✔
342
    if "fig_title" in kwargs:
10✔
343
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
10✔
344
    else:
345
        plt.figtext(
10✔
346
            0.5,
347
            0.5,
348
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
349
                pst.filename, str(datetime.now())
350
            ),
351
            ha="center",
352
        )
353
    # if plot_hexbin:
354
    #    pdfname = pst.filename.replace(".pst", ".1to1.hexbin.pdf")
355
    # else:
356
    #    pdfname = pst.filename.replace(".pst", ".1to1.pdf")
357
    figs = []
10✔
358
    ax_count = 0
10✔
359
    for g, names in grouper.items():
10✔
360
        logger.log("plotting 1to1 for {0}".format(g))
10✔
361

362
        obs_g = obs.loc[names, :]
10✔
363
        obs_g.loc[:, "sim"] = res.loc[names, "modelled"]
10✔
364
        logger.statement("using control file obsvals to calculate residuals")
10✔
365
        obs_g.loc[:, "res"] = obs_g.sim - obs_g.obsval
10✔
366
        if "include_zero" not in kwargs or kwargs["include_zero"] is False:
10✔
367
            obs_g = obs_g.loc[obs_g.weight > 0, :]
10✔
368
        if obs_g.shape[0] == 0:
10✔
369
            logger.statement("no non-zero obs for group '{0}'".format(g))
10✔
370
            logger.log("plotting 1to1 for {0}".format(g))
10✔
371
            continue
10✔
372

373
        if ax_count % (nr * nc) == 0:
10✔
374
            if ax_count > 0:
10✔
375
                plt.tight_layout()
10✔
376
            # pdf.savefig()
377
            # plt.close(fig)
378
            figs.append(fig)
10✔
379
            fig = plt.figure(figsize=figsize)
10✔
380
            axes = _get_page_axes()
10✔
381
            ax_count = 0
10✔
382

383
        ax = axes[ax_count]
10✔
384

385
        # if obs_g.shape[0] == 1:
386
        #    ax.scatter(list(obs_g.sim),list(obs_g.obsval),marker='.',s=30,color='b')
387
        # else:
388
        mx = max(obs_g.obsval.max(), obs_g.sim.max())
10✔
389
        mn = min(obs_g.obsval.min(), obs_g.sim.min())
10✔
390

391
        # if obs_g.shape[0] == 1:
392
        mx *= 1.1
10✔
393
        mn *= 0.9
10✔
394
        ax.axis("square")
10✔
395
        if plot_hexbin:
10✔
396
            ax.hexbin(
×
397
                obs_g.obsval.values,
398
                obs_g.sim.values,
399
                mincnt=1,
400
                gridsize=(75, 75),
401
                extent=(mn, mx, mn, mx),
402
                bins="log",
403
                edgecolors=None,
404
            )
405
        #               plt.colorbar(ax=ax)
406
        else:
407
            ax.scatter(obs_g.obsval.values, obs_g.sim.values,
10✔
408
                       marker=".", s=10, color="b")
409

410
        ax.plot([mn, mx], [mn, mx], "k--", lw=1.0)
10✔
411
        xlim = (mn, mx)
10✔
412
        ax.set_xlim(mn, mx)
10✔
413
        ax.set_ylim(mn, mx)
10✔
414
        ax.grid()
10✔
415

416
        ax.set_xlabel("observed", labelpad=0.1)
10✔
417
        ax.set_ylabel("simulated", labelpad=0.1)
10✔
418
        ax.set_title(
10✔
419
            "{0}) group:{1}, {2} observations".format(
420
                abet[ax_count], g, obs_g.shape[0]
421
            ),
422
            loc="left",
423
        )
424

425
        ax_count += 1
10✔
426

427
        if histogram == False:
10✔
428
            ax = axes[ax_count]
10✔
429
            ax.scatter(obs_g.obsval.values, obs_g.res.values,
10✔
430
                       marker=".", s=10, color="b")
431
            ylim = ax.get_ylim()
10✔
432
            mx = max(np.abs(ylim[0]), np.abs(ylim[1]))
10✔
433
            if obs_g.shape[0] == 1:
10✔
434
                mx *= 1.1
10✔
435
            ax.set_ylim(-mx, mx)
10✔
436
            # show a zero residuals line
437
            ax.plot(xlim, [0, 0], "k--", lw=1.0)
10✔
438
            meanres = obs_g.res.mean()
10✔
439
            # show mean residuals line
440
            ax.plot(xlim, [meanres, meanres], "r-", lw=1.0)
10✔
441
            ax.set_xlim(xlim)
10✔
442
            ax.set_ylabel("residual", labelpad=0.1)
10✔
443
            ax.set_xlabel("observed", labelpad=0.1)
10✔
444
            ax.set_title(
10✔
445
                "{0}) group:{1}, {2} observations".format(
446
                    abet[ax_count], g, obs_g.shape[0]
447
                ),
448
                loc="left",
449
            )
450
            ax.grid()
10✔
451
            ax_count += 1
10✔
452
        else:
453
            # need max and min res to set xlim, otherwise wonky figsize
454
            mxr = obs_g.res.max()
10✔
455
            mnr = obs_g.res.min()
10✔
456

457
            # if obs_g.shape[0] == 1:
458
            mxr *= 1.1
10✔
459
            mnr *= 0.9
10✔
460
            rlim = (mnr, mxr)
10✔
461

462
            ax = axes[ax_count]
10✔
463
            ax.hist(obs_g.res, bins=50, color="b")
10✔
464
            meanres = obs_g.res.mean()
10✔
465
            ax.axvline(meanres, color="r", lw=1)
10✔
466
            b, t = ax.get_ylim()
10✔
467
            ax.text(meanres + meanres / 10, t - t / 10, "Mean: {:.2f}".format(meanres))
10✔
468
            ax.set_xlim(rlim)
10✔
469
            ax.set_ylabel("count", labelpad=0.1)
10✔
470
            ax.set_xlabel("residual", labelpad=0.1)
10✔
471
            ax.set_title(
10✔
472
                "{0}) group:{1}, {2} observations".format(
473
                    abet[ax_count], g, obs_g.shape[0]
474
                ),
475
                loc="left",
476
            )
477
            ax.grid()
10✔
478
            ax_count += 1
10✔
479
        logger.log("plotting 1to1 for {0}".format(g))
10✔
480

481
    for a in range(ax_count, nr * nc):
10✔
482
        axes[a].set_axis_off()
10✔
483
        axes[a].set_yticks([])
10✔
484
        axes[a].set_xticks([])
10✔
485

486
    plt.tight_layout()
10✔
487
    # pdf.savefig()
488
    # plt.close(fig)
489
    figs.append(fig)
10✔
490
    if filename is not None:
10✔
491
        with PdfPages(filename) as pdf:
10✔
492
            for fig in figs:
10✔
493
                pdf.savefig(fig)
10✔
494
                plt.close(fig)
10✔
495
        logger.log("plot res_1to1")
10✔
496
    else:
497
        logger.log("plot res_1to1")
10✔
498
        return figs
10✔
499

500

501
def plot_id_bar(id_df, nsv=None, logger=None, **kwargs):
11✔
502
    """Plot a stacked bar chart of identifiability based on
503
    a the `pyemu.ErrVar.get_identifiability()` dataframe
504

505
    Args:
506
        id_df (`pandas.DataFrame`) : dataframe of identifiability
507
        nsv (`int`): number of singular values to consider
508
        logger (`pyemu.Logger`, optonal): a logger.  If None, a generic
509
            one is created
510
        kwargs (`dict`): a dict of keyword arguments to pass to the
511
            plotting function
512

513
    Returns:
514
        `matplotlib.Axis`: the axis with the plot
515

516
    Example::
517

518
        import pyemu
519
        pest_obj = pyemu.Pst(pest_control_file)
520
        ev = pyemu.ErrVar(jco='freyberg_jac.jcb'))
521
        id_df = ev.get_identifiability_dataframe(singular_value=48)
522
        pyemu.plot_utils.plot_id_bar(id_df, nsv=12, figsize=(12,4)
523

524
    """
525
    if logger is None:
10✔
526
        logger = Logger("Default_Loggger.log", echo=False)
10✔
527
    logger.log("plot id bar")
10✔
528

529
    df = id_df.copy()
10✔
530

531
    # drop the final `ident` column
532
    if "ident" in df.columns:
10✔
533
        df.drop("ident", inplace=True, axis=1)
10✔
534

535
    if nsv is None or nsv > len(df.columns):
10✔
536
        nsv = len(df.columns)
10✔
537
        logger.log("set number of SVs and number in the dataframe")
10✔
538

539
    df = df[df.columns[:nsv]]
10✔
540

541
    df["ident"] = df.sum(axis=1)
10✔
542
    df.sort_values(by="ident", inplace=True, ascending=False)
10✔
543
    df.drop("ident", inplace=True, axis=1)
10✔
544

545
    if "figsize" in kwargs:
10✔
546
        figsize = kwargs["figsize"]
×
547
    else:
548
        figsize = (8, 10.5)
10✔
549
    if "ax" in kwargs:
10✔
550
        ax = kwargs["ax"]
×
551
    else:
552
        fig = plt.figure(figsize=figsize)
10✔
553
        ax = plt.subplot(1, 1, 1)
10✔
554

555
    # plto the stacked bar chart (the easy part!)
556
    df.plot.bar(stacked=True, cmap="jet_r", legend=False, ax=ax)
10✔
557

558
    #
559
    # horrible shenanigans to make a colorbar rather than a legend
560
    #
561

562
    # special case colormap just dark red if one SV
563
    if nsv == 1:
10✔
564
        tcm = matplotlib.colors.LinearSegmentedColormap.from_list(
×
565
            "one_sv", [plt.get_cmap("jet_r")(0)] * 2, N=2
566
        )
567
        sm = plt.cm.ScalarMappable(
×
568
            cmap=tcm, norm=matplotlib.colors.Normalize(vmin=0, vmax=nsv + 1)
569
        )
570
    # or typically just rock the jet_r colormap over the range of SVs
571
    else:
572
        sm = plt.cm.ScalarMappable(
10✔
573
            cmap=plt.get_cmap("jet_r"),
574
            norm=matplotlib.colors.Normalize(vmin=1, vmax=nsv),
575
        )
576
    sm._A = []
10✔
577

578
    # now, if too many ticks for the colorbar, summarize them
579
    if nsv < 20:
10✔
580
        ticks = range(1, nsv + 1)
10✔
581
    else:
582
        ticks = np.arange(1, nsv + 1, int((nsv + 1) / 30))
×
583

584
    cb = plt.colorbar(sm, ax=ax)
10✔
585
    cb.set_ticks(ticks)
10✔
586

587
    logger.log("plot id bar")
10✔
588

589
    return ax
10✔
590

591

592
def res_phi_pie(pst, logger=None, **kwargs):
11✔
593
    """plot current phi components as a pie chart.
594

595
    Args:
596
        pst (`pyemu.Pst`): a control file instance with the residual datafrane
597
            instance available.
598
        logger (`pyemu.Logger`): a logger.  If None, a generic one is created
599
        kwargs (`dict`): a dict of plotting options. Accepts 'include_zero'
600
            as a flag to include phi groups with only zero-weight obs (not
601
            sure why anyone would do this, but whatevs).
602

603
            Also accepts 'label_comps': list of components for the labels. Options are
604
            ['name', 'phi_comp', 'phi_percent']. Labels will use those three components
605
            in the order of the 'label_comps' list.
606

607
            Any additional
608
            args are passed to `matplotlib`.
609

610
    Returns:
611
        `matplotlib.Axis`: the axis with the plot.
612

613
    Example::
614

615
        import pyemu
616
        pst = pyemu.Pst("my.pst")
617
        pyemu.plot_utils.res_phi_pie(pst,figsize=(12,4))
618
        pyemu.plot_utils.res_phi_pie(pst,label_comps = ['name','phi_percent'], figsize=(12,4))
619

620

621
    """
622
    if logger is None:
10✔
623
        logger = Logger("Default_Loggger.log", echo=False)
10✔
624
    logger.log("plot res_phi_pie")
10✔
625

626
    if "ensemble" in kwargs:
10✔
627
        try:
10✔
628
            res = pst_utils.res_from_en(pst, kwargs["ensemble"])
10✔
629
        except:
×
630
            logger.statement(
×
631
                "res_1to1: could not find ensemble file {0}".format(kwargs["ensemble"])
632
            )
633
    else:
634
        try:
10✔
635
            res = pst.res
10✔
636
        except:
×
637
            logger.lraise("res_phi_pie: pst.res is None, couldn't find residuals file")
×
638

639
    obs = pst.observation_data
10✔
640
    phi = pst.phi
10✔
641
    phi_comps = pst.phi_components
10✔
642
    norm_phi_comps = pst.phi_components_normalized
10✔
643
    keys = list(phi_comps.keys())
10✔
644
    if "include_zero" not in kwargs or kwargs["include_zero"] is False:
10✔
645
        phi_comps = {k: phi_comps[k] for k in keys if phi_comps[k] > 0.0}
10✔
646
        keys = list(phi_comps.keys())
10✔
647
        norm_phi_comps = {k: norm_phi_comps[k] for k in keys}
10✔
648
    if "ax" in kwargs:
10✔
649
        ax = kwargs["ax"]
10✔
650
    else:
651
        fig = plt.figure(figsize=figsize)
10✔
652
        ax = plt.subplot(1, 1, 1, aspect="equal")
10✔
653

654
    if "label_comps" not in kwargs:
10✔
655
        labels = [
10✔
656
            "{0}\n{1:4G}\n({2:3.1f}%)".format(
657
                k, phi_comps[k], 100.0 * (phi_comps[k] / phi)
658
            )
659
            for k in keys
660
        ]
661
    else:
662
        # make sure the components for the labels are in a list
663
        if not isinstance(kwargs["label_comps"], list):
×
664
            fmtchoices = list([kwargs["label_comps"]])
×
665
        else:
666
            fmtchoices = kwargs["label_comps"]
×
667
        # assemble all possible label components
668
        labfmts = {
×
669
            "name": ["{}\n", keys],
670
            "phi_comp": ["{:4G}\n", [phi_comps[k] for k in keys]],
671
            "phi_percent": ["({:3.1f}%)", [100.0 * (phi_comps[k] / phi) for k in keys]],
672
        }
673
        if fmtchoices[0] == "phi_percent":
×
674
            labfmts["phi_percent"][0] = "{}\n".format(labfmts["phi_percent"][0])
×
675
        # make the string format
676
        labfmtstr = "".join([labfmts[k][0] for k in fmtchoices])
×
677
        # pull it together
678
        labels = [
×
679
            labfmtstr.format(*k) for k in zip(*[labfmts[j][1] for j in fmtchoices])
680
        ]
681

682
    ax.pie([float(norm_phi_comps[k]) for k in keys], labels=labels)
10✔
683
    logger.log("plot res_phi_pie")
10✔
684
    if "filename" in kwargs:
10✔
685
        plt.savefig(kwargs["filename"])
10✔
686
    return ax
10✔
687

688

689
def pst_prior(pst, logger=None, filename=None, **kwargs):
11✔
690
    """helper to plot prior parameter histograms implied by
691
    parameter bounds. Saves a multipage pdf named <case>.prior.pdf
692

693
    Args:
694
        pst (`pyemu.Pst`): control file
695
        logger (`pyemu.Logger`): a logger.  If None, a generic one is created.
696
        filename (`str`):  PDF filename to save plots to.
697
            If None, return figs without saving.  Default is None.
698
        kwargs (`dict`): additional plotting options. Accepts 'grouper' as
699
            dict to group parameters on to a single axis (use
700
            parameter groups if not passed),'unqiue_only' to only show unique
701
            mean-stdev combinations within a given group.  Any additional args
702
            are passed to `matplotlib`.
703

704
    Returns:
705
        [`matplotlib.Figure`]: a list of figures created.
706

707
    Example::
708

709
        pst = pyemu.Pst("pest.pst")
710
        pyemu.pst_utils.pst_prior(pst)
711
        plt.show()
712

713
    """
714
    if logger is None:
10✔
715
        logger = Logger("Default_Loggger.log", echo=False)
×
716
    logger.log("plot pst_prior")
10✔
717
    par = pst.parameter_data
10✔
718

719
    if "parcov_filename" in pst.pestpp_options:
10✔
720
        logger.warn("ignoring parcov_filename, using parameter bounds for prior cov")
×
721
    logger.log("loading cov from parameter data")
10✔
722
    cov = pyemu.Cov.from_parameter_data(pst)
10✔
723
    logger.log("loading cov from parameter data")
10✔
724

725
    logger.log("building mean parameter values")
10✔
726
    li = par.partrans.loc[cov.names] == "log"
10✔
727
    mean = par.parval1.loc[cov.names]
10✔
728
    info = par.loc[cov.names, :].copy()
10✔
729
    info.loc[:, "mean"] = mean
10✔
730
    info.loc[li, "mean"] = mean[li].apply(np.log10)
10✔
731
    logger.log("building mean parameter values")
10✔
732

733
    logger.log("building stdev parameter values")
10✔
734
    if cov.isdiagonal:
10✔
735
        std = cov.x.flatten()
10✔
736
    else:
737
        std = np.diag(cov.x)
×
738
    std = np.sqrt(std)
10✔
739
    info.loc[:, "prior_std"] = std
10✔
740

741
    logger.log("building stdev parameter values")
10✔
742

743
    if std.shape != mean.shape:
10✔
744
        logger.lraise("mean.shape {0} != std.shape {1}".format(mean.shape, std.shape))
×
745

746
    if "grouper" in kwargs:
10✔
747
        raise NotImplementedError()
×
748
        # check for consistency here
749

750
    else:
751
        par_adj = par.loc[par.partrans.apply(lambda x: x in ["log", "none"]), :]
10✔
752
        grouper = par_adj.groupby(par_adj.pargp).groups
10✔
753
        # grouper = par.groupby(par.pargp).groups
754

755
    if len(grouper) == 0:
10✔
756
        raise Exception("no adustable parameters to plot")
×
757

758
    fig = plt.figure(figsize=figsize)
10✔
759
    if "fig_title" in kwargs:
10✔
760
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
10✔
761
    else:
762
        plt.figtext(
10✔
763
            0.5,
764
            0.5,
765
            "pyemu.Pst.plot(kind='prior')\nfrom pest control file '{0}'\n at {1}".format(
766
                pst.filename, str(datetime.now())
767
            ),
768
            ha="center",
769
        )
770
    figs = []
10✔
771
    ax_count = 0
10✔
772
    grps_names = list(grouper.keys())
10✔
773
    grps_names.sort()
10✔
774
    for g in grps_names:
10✔
775
        names = grouper[g]
10✔
776
        logger.log("plotting priors for {0}".format(",".join(list(names))))
10✔
777
        if ax_count % (nr * nc) == 0:
10✔
778
            plt.tight_layout()
10✔
779
            # pdf.savefig()
780
            # plt.close(fig)
781
            figs.append(fig)
10✔
782
            fig = plt.figure(figsize=figsize)
10✔
783
            axes = _get_page_axes()
10✔
784
            ax_count = 0
10✔
785

786
        islog = False
10✔
787
        vc = info.partrans.value_counts()
10✔
788
        if vc.shape[0] > 1:
10✔
789
            logger.warn("mixed partrans for group {0}".format(g))
×
790
        elif "log" in vc.index:
10✔
791
            islog = True
10✔
792
        ax = axes[ax_count]
10✔
793
        if "unique_only" in kwargs and kwargs["unique_only"]:
10✔
794

795
            ms = (
10✔
796
                info.loc[names, :]
797
                .apply(lambda x: (x["mean"], x["prior_std"]), axis=1)
798
                .unique()
799
            )
800
            for (m, s) in ms:
10✔
801
                x, y = gaussian_distribution(m, s)
10✔
802
                ax.fill_between(x, 0, y, facecolor="0.5", alpha=0.5, edgecolor="none")
10✔
803

804
        else:
805
            for m, s in zip(info.loc[names, "mean"], info.loc[names, "prior_std"]):
10✔
806
                x, y = gaussian_distribution(m, s)
10✔
807
                ax.fill_between(x, 0, y, facecolor="0.5", alpha=0.5, edgecolor="none")
10✔
808
        ax.set_title(
10✔
809
            "{0}) group:{1}, {2} parameters".format(abet[ax_count], g, names.shape[0]),
810
            loc="left",
811
        )
812

813
        ax.set_yticks([])
10✔
814
        if islog:
10✔
815
            ax.set_xlabel("$log_{10}$ parameter value", labelpad=0.1)
10✔
816
        else:
817
            ax.set_xlabel("parameter value", labelpad=0.1)
10✔
818
        logger.log("plotting priors for {0}".format(",".join(list(names))))
10✔
819

820
        ax_count += 1
10✔
821

822
    for a in range(ax_count, nr * nc):
10✔
823
        axes[a].set_axis_off()
10✔
824
        axes[a].set_yticks([])
10✔
825
        axes[a].set_xticks([])
10✔
826

827
    plt.tight_layout()
10✔
828
    # pdf.savefig()
829
    # plt.close(fig)
830
    figs.append(fig)
10✔
831
    if filename is not None:
10✔
832
        with PdfPages(filename) as pdf:
10✔
833
            # plt.tight_layout()
834
            pdf.savefig(fig)
10✔
835
            plt.close(fig)
10✔
836
        logger.log("plot pst_prior")
10✔
837
    else:
838
        logger.log("plot pst_prior")
10✔
839
        return figs
10✔
840

841

842
def ensemble_helper(
11✔
843
    ensemble,
844
    bins=10,
845
    facecolor="0.5",
846
    plot_cols=None,
847
    filename=None,
848
    func_dict=None,
849
    sync_bins=True,
850
    deter_vals=None,
851
    std_window=None,
852
    deter_range=False,
853
    **kwargs
854
):
855
    """helper function to plot ensemble histograms
856

857
    Args:
858
        ensemble : varies
859
            the ensemble argument can be a pandas.DataFrame or derived type or a str, which
860
            is treated as a filename.  Optionally, ensemble can be a list of these types or
861
            a dict, in which case, the keys are treated as facecolor str (e.g., 'b', 'y', etc).
862
        facecolor : str
863
            the histogram facecolor.  Only applies if ensemble is a single thing
864
        plot_cols : enumerable
865
            a collection of columns (in form of a list of parameters, or a dict with keys for
866
            parsing plot axes and values of parameters) from the ensemble(s) to plot.  If None,
867
            (the union of) all cols are plotted. Default is None
868
        filename : str
869
            the name of the pdf to create.  If None, return figs without saving.  Default is None.
870
        func_dict : dict
871
            a dictionary of unary functions (e.g., `np.log10` to apply to columns.  Key is
872
            column name.  Default is None
873
        sync_bins : bool
874
            flag to use the same bin edges for all ensembles. Only applies if more than
875
            one ensemble is being plotted.  Default is True
876
        deter_vals : dict
877
            dict of deterministic values to plot as a vertical line. key is ensemble columnn name
878
        std_window : float
879
            the number of standard deviations around the mean to mark as vertical lines.  If None,
880
            nothing happens.  Default is None
881
        deter_range : bool
882
            flag to set xlims to deterministic value +/- std window.  If True, std_window must not be None.
883
            Default is False
884

885
    Example::
886

887
        # plot prior and posterior par ensembles
888
        pst = pyemu.Pst("my.pst")
889
        prior = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="prior.jcb")
890
        post = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="my.3.par.jcb")
891
        pyemu.plot_utils.ensemble_helper(ensemble={"0.5":prior, "b":post},filename="ensemble.pdf")
892

893
        #plot prior and posterior simulated equivalents to observations with obs noise and obs vals
894
        pst = pyemu.Pst("my.pst")
895
        prior = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.0.obs.jcb")
896
        post = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.3.obs.jcb")
897
        noise = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.obs+noise.jcb")
898
        pyemu.plot_utils.ensemble_helper(ensemble={"0.5":prior, "b":post,"r":noise},
899
                                         filename="ensemble.pdf",
900
                                         deter_vals=pst.observation_data.obsval.to_dict())
901

902

903
    """
904
    logger = pyemu.Logger("ensemble_helper.log")
10✔
905
    logger.log("pyemu.plot_utils.ensemble_helper()")
10✔
906
    ensembles = _process_ensemble_arg(ensemble, facecolor, logger)
10✔
907
    if len(ensembles) == 0:
10✔
908
        raise Exception("plot_uitls.ensemble_helper() error processing `ensemble` arg")
×
909
    # apply any functions
910
    if func_dict is not None:
10✔
911
        logger.log("applying functions")
10✔
912
        for col, func in func_dict.items():
10✔
913
            for fc, en in ensembles.items():
10✔
914
                if col in en.columns:
10✔
915
                    en.loc[:, col] = en.loc[:, col].apply(func)
10✔
916
        logger.log("applying functions")
10✔
917

918
    # get a list of all cols (union)
919
    all_cols = set()
10✔
920
    for fc, en in ensembles.items():
10✔
921
        cols = set(en.columns)
10✔
922
        all_cols.update(cols)
10✔
923
    if plot_cols is None:
10✔
UNCOV
924
        plot_cols = {i: [v] for i, v in (zip(all_cols, all_cols))}
×
925
    else:
926
        if isinstance(plot_cols, list):
10✔
927
            splot_cols = set(plot_cols)
10✔
928
            plot_cols = {i: [v] for i, v in (zip(plot_cols, plot_cols))}
10✔
929
        elif isinstance(plot_cols, dict):
10✔
930
            splot_cols = []
10✔
931
            for label, pcols in plot_cols.items():
10✔
932
                splot_cols.extend(list(pcols))
10✔
933
            splot_cols = set(splot_cols)
10✔
934
        else:
935
            logger.lraise(
×
936
                "unrecognized plot_cols type: {0}, should be list or dict".format(
937
                    type(plot_cols)
938
                )
939
            )
940

941
        missing = splot_cols - all_cols
10✔
942
        if len(missing) > 0:
10✔
943
            logger.lraise(
×
944
                "the following plot_cols are missing: {0}".format(",".join(missing))
945
            )
946

947
    logger.statement("plotting {0} histograms".format(len(plot_cols)))
10✔
948

949
    fig = plt.figure(figsize=figsize)
10✔
950
    if "fig_title" in kwargs:
10✔
951
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
952
    else:
953
        plt.figtext(
10✔
954
            0.5,
955
            0.5,
956
            "pyemu.plot_utils.ensemble_helper()\n at {0}".format(str(datetime.now())),
957
            ha="center",
958
        )
959
    # plot_cols = list(plot_cols)
960
    # plot_cols.sort()
961
    labels = list(plot_cols.keys())
10✔
962
    labels.sort()
10✔
963
    logger.statement("saving pdf to {0}".format(filename))
10✔
964
    figs = []
10✔
965

966
    ax_count = 0
10✔
967

968
    # for label,plot_col in plot_cols.items():
969
    for label in labels:
10✔
970
        plot_col = plot_cols[label]
10✔
971
        logger.log("plotting reals for {0}".format(label))
10✔
972
        if ax_count % (nr * nc) == 0:
10✔
973
            plt.tight_layout()
10✔
974
            # pdf.savefig()
975
            # plt.close(fig)
976
            figs.append(fig)
10✔
977
            fig = plt.figure(figsize=figsize)
10✔
978
            axes = _get_page_axes()
10✔
979
            [ax.set_yticks([]) for ax in axes]
10✔
980
            ax_count = 0
10✔
981

982
        ax = axes[ax_count]
10✔
983

984
        if sync_bins:
10✔
985
            mx, mn = -1.0e30, 1.0e30
10✔
986
            for fc, en in ensembles.items():
10✔
987
                # for pc in plot_col:
988
                #     if pc in en.columns:
989
                #         emx,emn = en.loc[:,pc].max(),en.loc[:,pc].min()
990
                #         mx = max(mx,emx)
991
                #         mn = min(mn,emn)
992
                emn = np.nanmin(en.loc[:, plot_col].values)
10✔
993
                emx = np.nanmax(en.loc[:, plot_col].values)
10✔
994
                mx = max(mx, emx)
10✔
995
                mn = min(mn, emn)
10✔
996
            if mx == -1.0e30 and mn == 1.0e30:
10✔
997
                logger.warn("all NaNs for label: {0}".format(label))
×
998
                ax.set_title(
×
999
                    "{0}) {1}, count:{2} - all NaN".format(
1000
                        abet[ax_count], label, len(plot_col)
1001
                    ),
1002
                    loc="left",
1003
                )
1004
                ax.set_yticks([])
×
1005
                ax.set_xticks([])
×
1006
                ax_count += 1
×
1007
                continue
×
1008
            plot_bins = np.linspace(mn, mx, num=bins)
10✔
1009
            logger.statement("{0} min:{1:5G}, max:{2:5G}".format(label, mn, mx))
10✔
1010
        else:
1011
            plot_bins = bins
10✔
1012
        for fc, en in ensembles.items():
10✔
1013
            # for pc in plot_col:
1014
            #    if pc in en.columns:
1015
            #        try:
1016
            #            en.loc[:,pc].hist(bins=plot_bins,facecolor=fc,
1017
            #                                    edgecolor="none",alpha=0.5,
1018
            #                                    density=True,ax=ax)
1019
            #        except Exception as e:
1020
            #            logger.warn("error plotting histogram for {0}:{1}".
1021
            #                        format(pc,str(e)))
1022
            vals = en.loc[:, plot_col].values.flatten()
10✔
1023
            # print(plot_bins)
1024
            # print(vals)
1025

1026
            ax.hist(
10✔
1027
                vals,
1028
                bins=plot_bins,
1029
                edgecolor="none",
1030
                alpha=0.5,
1031
                density=True,
1032
                facecolor=fc,
1033
            )
1034
            v = None
10✔
1035
            if deter_vals is not None:
10✔
1036
                for pc in plot_col:
10✔
1037
                    if pc in deter_vals:
10✔
1038
                        ylim = ax.get_ylim()
10✔
1039
                        v = deter_vals[pc]
10✔
1040
                        ax.plot([v, v], ylim, "k--", lw=1.5)
10✔
1041
                        ax.set_ylim(ylim)
10✔
1042

1043
            if std_window is not None:
10✔
1044
                try:
×
1045
                    ylim = ax.get_ylim()
×
1046
                    mn, st = (
×
1047
                        en.loc[:, pc].mean(),
1048
                        en.loc[:, pc].std() * (std_window / 2.0),
1049
                    )
1050

1051
                    ax.plot([mn - st, mn - st], ylim, color=fc, lw=1.5, ls="--")
×
1052
                    ax.plot([mn + st, mn + st], ylim, color=fc, lw=1.5, ls="--")
×
1053
                    ax.set_ylim(ylim)
×
1054
                    if deter_range and v is not None:
×
1055
                        xmn = v - st
×
1056
                        xmx = v + st
×
1057
                        ax.set_xlim(xmn, xmx)
×
1058
                except:
×
1059
                    logger.warn("error plotting std window for {0}".format(pc))
×
1060
        ax.grid()
10✔
1061
        if len(ensembles) > 1:
10✔
1062
            ax.set_title(
10✔
1063
                "{0}) {1}, count: {2}".format(abet[ax_count], label, len(plot_col)),
1064
                loc="left",
1065
            )
1066
        else:
1067
            ax.set_title(
10✔
1068
                "{0}) {1}, count:{2}\nmin:{3:3.1E}, max:{4:3.1E}".format(
1069
                    abet[ax_count],
1070
                    label,
1071
                    len(plot_col),
1072
                    np.nanmin(vals),
1073
                    np.nanmax(vals),
1074
                ),
1075
                loc="left",
1076
            )
1077
        ax_count += 1
10✔
1078

1079
    for a in range(ax_count, nr * nc):
10✔
1080
        axes[a].set_axis_off()
10✔
1081
        axes[a].set_yticks([])
10✔
1082
        axes[a].set_xticks([])
10✔
1083

1084
    plt.tight_layout()
10✔
1085
    # pdf.savefig()
1086
    # plt.close(fig)
1087
    figs.append(fig)
10✔
1088
    if filename is not None:
10✔
1089
        # plt.tight_layout()
1090
        with PdfPages(filename) as pdf:
10✔
1091
            for fig in figs:
10✔
1092
                pdf.savefig(fig)
10✔
1093
                plt.close(fig)
10✔
1094
    logger.log("pyemu.plot_utils.ensemble_helper()")
10✔
1095

1096

1097
def ensemble_change_summary(
11✔
1098
    ensemble1,
1099
    ensemble2,
1100
    pst,
1101
    bins=10,
1102
    facecolor="0.5",
1103
    logger=None,
1104
    filename=None,
1105
    **kwargs
1106
):
1107
    """helper function to plot first and second moment change histograms between two
1108
    ensembles
1109

1110
    Args:
1111
        ensemble1 (varies): filename or `pandas.DataFrame` or `pyemu.Ensemble`
1112
        ensemble2 (varies): filename or `pandas.DataFrame` or `pyemu.Ensemble`
1113
        pst (`pyemu.Pst`): control file
1114
        facecolor (`str`): the histogram facecolor.
1115
        filename (`str`): the name of the multi-pdf to create. If None, return figs without saving.  Default is None.
1116

1117
    Returns:
1118
        [`matplotlib.Figure`]: a list of figures.  Returns None is
1119
        `filename` is not None
1120

1121
    Example::
1122

1123
        pst = pyemu.Pst("my.pst")
1124
        prior = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="prior.jcb")
1125
        post = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="my.3.par.jcb")
1126
        pyemu.plot_utils.ensemble_change_summary(prior,post)
1127
        plt.show()
1128

1129

1130
    """
1131
    if logger is None:
10✔
1132
        logger = Logger("Default_Loggger.log", echo=False)
10✔
1133
    logger.log("plot ensemble change")
10✔
1134

1135
    if isinstance(ensemble1, str):
10✔
1136
        ensemble1 = pd.read_csv(ensemble1, index_col=0)
×
1137
    ensemble1.columns = ensemble1.columns.str.lower()
10✔
1138

1139
    if isinstance(ensemble2, str):
10✔
1140
        ensemble2 = pd.read_csv(ensemble2, index_col=0)
×
1141
    ensemble2.columns = ensemble2.columns.str.lower()
10✔
1142

1143
    # better to ensure this is caught by pestpp-ies ensemble csvs
1144
    unnamed1 = [col for col in ensemble1.columns if "unnamed:" in col]
10✔
1145
    if len(unnamed1) != 0:
10✔
1146
        ensemble1 = ensemble1.iloc[
×
1147
            :, :-1
1148
        ]  # ensure unnamed col result of poor csv read only (ie last col)
1149
    unnamed2 = [col for col in ensemble2.columns if "unnamed:" in col]
10✔
1150
    if len(unnamed2) != 0:
10✔
1151
        ensemble2 = ensemble2.iloc[
×
1152
            :, :-1
1153
        ]  # ensure unnamed col result of poor csv read only (ie last col)
1154

1155
    d = set(ensemble1.columns).symmetric_difference(set(ensemble2.columns))
10✔
1156

1157
    if len(d) != 0:
10✔
1158
        logger.lraise(
×
1159
            "ensemble1 does not have the same columns as ensemble2: {0}".format(
1160
                ",".join(d)
1161
            )
1162
        )
1163
    if "grouper" in kwargs:
10✔
1164
        raise NotImplementedError()
×
1165
    else:
1166
        en_cols = ensemble1.columns
10✔
1167
        if len(en_cols.difference(pst.par_names)) == 0:
10✔
1168
            par = pst.parameter_data.loc[en_cols, :]
10✔
1169
            grouper = par.groupby(par.pargp).groups
10✔
1170
            grouper["all"] = pst.adj_par_names
10✔
1171
            li = par.loc[par.partrans == "log", "parnme"]
10✔
1172
            ensemble1.loc[:, li] = ensemble1.loc[:, li].apply(np.log10)
10✔
1173
            ensemble2.loc[:, li] = ensemble2.loc[:, li].apply(np.log10)
10✔
1174
        elif len(en_cols.difference(pst.obs_names)) == 0:
10✔
1175
            obs = pst.observation_data.loc[en_cols, :]
10✔
1176
            grouper = obs.groupby(obs.obgnme).groups
10✔
1177
            grouper["all"] = pst.nnz_obs_names
10✔
1178
        else:
1179
            logger.lraise("could not match ensemble cols with par or obs...")
×
1180

1181
    en1_mn, en1_std = ensemble1.mean(axis=0), ensemble1.std(axis=0)
10✔
1182
    en2_mn, en2_std = ensemble2.mean(axis=0), ensemble2.std(axis=0)
10✔
1183

1184
    # mn_diff = 100.0 * ((en1_mn - en2_mn) / en1_mn)
1185
    # std_diff = 100 * ((en1_std - en2_std) / en1_std)
1186

1187
    mn_diff = -1 * (en2_mn - en1_mn)
10✔
1188
    std_diff = 100 * (((en1_std - en2_std) / en1_std))
10✔
1189
    # set en1_std==0 to nan
1190
    # std_diff[en1_std.index[en1_std==0]] = np.nan
1191

1192
    # diff = ensemble1 - ensemble2
1193
    # mn_diff = diff.mean(axis=0)
1194
    # std_diff = diff.std(axis=0)
1195

1196
    fig = plt.figure(figsize=figsize)
10✔
1197
    if "fig_title" in kwargs:
10✔
1198
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
1199
    else:
1200
        plt.figtext(
10✔
1201
            0.5,
1202
            0.5,
1203
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
1204
                pst.filename, str(datetime.now())
1205
            ),
1206
            ha="center",
1207
        )
1208
    # if plot_hexbin:
1209
    #    pdfname = pst.filename.replace(".pst", ".1to1.hexbin.pdf")
1210
    # else:
1211
    #    pdfname = pst.filename.replace(".pst", ".1to1.pdf")
1212
    figs = []
10✔
1213
    ax_count = 0
10✔
1214
    for g, names in grouper.items():
10✔
1215
        logger.log("plotting change for {0}".format(g))
10✔
1216

1217
        mn_g = mn_diff.loc[names]
10✔
1218
        std_g = std_diff.loc[names]
10✔
1219

1220
        if mn_g.shape[0] == 0:
10✔
1221
            logger.statement("no entries for group '{0}'".format(g))
×
1222
            logger.log("plotting change for {0}".format(g))
×
1223
            continue
×
1224

1225
        if ax_count % (nr * nc) == 0:
10✔
1226
            if ax_count > 0:
10✔
1227
                plt.tight_layout()
×
1228
            # pdf.savefig()
1229
            # plt.close(fig)
1230
            figs.append(fig)
10✔
1231
            fig = plt.figure(figsize=figsize)
10✔
1232
            axes = _get_page_axes()
10✔
1233
            ax_count = 0
10✔
1234

1235
        ax = axes[ax_count]
10✔
1236
        mn_g.hist(ax=ax, facecolor=facecolor, alpha=0.5, edgecolor=None, bins=bins)
10✔
1237
        # mx = max(mn_g.max(), mn_g.min(),np.abs(mn_g.max()),np.abs(mn_g.min())) * 1.2
1238
        # ax.set_xlim(-mx,mx)
1239

1240
        # std_g.hist(ax=ax,facecolor='b',alpha=0.5,edgecolor=None)
1241

1242
        # ax.set_xlim(xlim)
1243
        ax.set_yticklabels([])
10✔
1244
        ax.set_xlabel("mean change", labelpad=0.1)
10✔
1245
        ax.set_title(
10✔
1246
            "{0}) mean change group:{1}, {2} entries\nmax:{3:10G}, min:{4:10G}".format(
1247
                abet[ax_count], g, mn_g.shape[0], mn_g.max(), mn_g.min()
1248
            ),
1249
            loc="left",
1250
        )
1251
        ax.grid()
10✔
1252
        ax_count += 1
10✔
1253

1254
        ax = axes[ax_count]
10✔
1255
        std_g.hist(ax=ax, facecolor=facecolor, alpha=0.5, edgecolor=None, bins=bins)
10✔
1256
        # std_g.hist(ax=ax,facecolor='b',alpha=0.5,edgecolor=None)
1257

1258
        # ax.set_xlim(xlim)
1259
        ax.set_yticklabels([])
10✔
1260
        ax.set_xlabel("sigma percent reduction", labelpad=0.1)
10✔
1261
        ax.set_title(
10✔
1262
            "{0}) sigma change group:{1}, {2} entries\nmax:{3:10G}, min:{4:10G}".format(
1263
                abet[ax_count], g, mn_g.shape[0], std_g.max(), std_g.min()
1264
            ),
1265
            loc="left",
1266
        )
1267
        ax.grid()
10✔
1268
        ax_count += 1
10✔
1269

1270
        logger.log("plotting change for {0}".format(g))
10✔
1271

1272
    for a in range(ax_count, nr * nc):
10✔
1273
        axes[a].set_axis_off()
10✔
1274
        axes[a].set_yticks([])
10✔
1275
        axes[a].set_xticks([])
10✔
1276

1277
    plt.tight_layout()
10✔
1278
    # pdf.savefig()
1279
    # plt.close(fig)
1280
    figs.append(fig)
10✔
1281
    if filename is not None:
10✔
1282
        # plt.tight_layout()
1283
        with PdfPages(filename) as pdf:
10✔
1284
            for fig in figs:
10✔
1285
                pdf.savefig(fig)
10✔
1286
                plt.close(fig)
10✔
1287
        logger.log("plot ensemble change")
10✔
1288
    else:
1289
        logger.log("plot ensemble change")
10✔
1290
        return figs
10✔
1291

1292

1293
def _process_ensemble_arg(ensemble, facecolor, logger):
11✔
1294
    """private method to work out ensemble plot args"""
1295
    ensembles = {}
10✔
1296
    if isinstance(ensemble, pd.DataFrame) or isinstance(ensemble, pyemu.Ensemble):
10✔
1297
        if not isinstance(facecolor, str):
10✔
1298
            logger.lraise("facecolor must be str")
×
1299
        ensembles[facecolor] = ensemble
10✔
1300
    elif isinstance(ensemble, str):
10✔
1301
        if not isinstance(facecolor, str):
10✔
1302
            logger.lraise("facecolor must be str")
×
1303

1304
        logger.log("loading ensemble from csv file {0}".format(ensemble))
10✔
1305
        en = pd.read_csv(ensemble, index_col=0)
10✔
1306
        logger.statement("{0} shape: {1}".format(ensemble, en.shape))
10✔
1307
        ensembles[facecolor] = en
10✔
1308
        logger.log("loading ensemble from csv file {0}".format(ensemble))
10✔
1309

1310
    elif isinstance(ensemble, list):
10✔
1311
        if isinstance(facecolor, list):
10✔
1312
            if len(ensemble) != len(facecolor):
×
1313
                logger.lraise("facecolor list len != ensemble list len")
×
1314
        else:
1315
            colors = ["m", "c", "b", "r", "g", "y"]
10✔
1316

1317
            facecolor = [colors[i] for i in range(len(ensemble))]
10✔
1318
        ensembles = {}
10✔
1319
        for fc, en_arg in zip(facecolor, ensemble):
10✔
1320
            if isinstance(en_arg, str):
10✔
1321
                logger.log("loading ensemble from csv file {0}".format(en_arg))
10✔
1322
                en = pd.read_csv(en_arg, index_col=0)
10✔
1323
                logger.log("loading ensemble from csv file {0}".format(en_arg))
10✔
1324
                logger.statement("ensemble {0} gets facecolor {1}".format(en_arg, fc))
10✔
1325

1326
            elif isinstance(en_arg, pd.DataFrame) or isinstance(en_arg, pyemu.Ensemble):
10✔
1327
                en = en_arg
10✔
1328
            else:
1329
                logger.lraise("unrecognized ensemble list arg:{0}".format(en_arg))
×
1330
            ensembles[fc] = en
10✔
1331

1332
    elif isinstance(ensemble, dict):
10✔
1333
        for fc, en_arg in ensemble.items():
10✔
1334
            if isinstance(en_arg, pd.DataFrame) or isinstance(en_arg, pyemu.Ensemble):
10✔
1335
                ensembles[fc] = en_arg
10✔
1336
            elif isinstance(en_arg, str):
10✔
1337
                logger.log("loading ensemble from csv file {0}".format(en_arg))
10✔
1338
                en = pd.read_csv(en_arg, index_col=0)
10✔
1339
                logger.log("loading ensemble from csv file {0}".format(en_arg))
10✔
1340
                ensembles[fc] = en
10✔
1341
            else:
1342
                logger.lraise("unrecognized ensemble list arg:{0}".format(en_arg))
×
1343
    try:
10✔
1344
        for fc in ensembles:
10✔
1345
            ensembles[fc].columns = ensembles[fc].columns.str.lower()
10✔
1346
    except:
×
1347
        logger.lraise("error processing ensemble")
×
1348

1349
    return ensembles
10✔
1350

1351

1352
def ensemble_res_1to1(
11✔
1353
    ensemble,
1354
    pst,
1355
    facecolor="0.5",
1356
    logger=None,
1357
    filename=None,
1358
    skip_groups=[],
1359
    base_ensemble=None,
1360
    **kwargs
1361
):
1362
    """helper function to plot ensemble 1-to-1 plots showing the simulated range
1363

1364
    Args:
1365
        ensemble (varies):  the ensemble argument can be a pandas.DataFrame or derived type or a str, which
1366
            is treated as a fileanme.  Optionally, ensemble can be a list of these types or
1367
            a dict, in which case, the keys are treated as facecolor str (e.g., 'b', 'y', etc).
1368
        pst (`pyemu.Pst`): a control file instance
1369
        facecolor (`str`): the histogram facecolor.  Only applies if `ensemble` is a single thing
1370
        filename (`str`): the name of the pdf to create. If None, return figs
1371
            without saving.  Default is None.
1372
        base_ensemble (`varies`): an optional ensemble argument for the observations + noise ensemble.
1373
            This will be plotted as a transparent red bar on the 1to1 plot.
1374

1375
    Note:
1376

1377
        the vertical bar on each plot the min-max range
1378

1379
    Example::
1380

1381

1382
        pst = pyemu.Pst("my.pst")
1383
        prior = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.0.obs.jcb")
1384
        post = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.3.obs.jcb")
1385
        pyemu.plot_utils.ensemble_res_1to1(ensemble={"0.5":prior, "b":post})
1386
        plt.show()
1387

1388
    """
1389
    def _get_plotlims(oen, ben, obsnames):
10✔
1390
        if not isinstance(oen, dict):
10✔
1391
            oen = {'g': oen.loc[:, obsnames]}
×
1392
        if not isinstance(ben, dict):
10✔
1393
            ben = {'g': ben.get(obsnames)}
10✔
1394
        outofrange = False
10✔
1395
        # work back from crazy values
1396
        oemin = 1e32
10✔
1397
        oemeanmin = 1e32
10✔
1398
        oemax = -1e32
10✔
1399
        oemeanmax = -1e32
10✔
1400
        bemin = 1e32
10✔
1401
        bemeanmin = 1e32
10✔
1402
        bemax = -1e32
10✔
1403
        bemeanmax = -1e32
10✔
1404
        for _, oeni in oen.items():  # loop over ensembles
10✔
1405
            oeni = oeni.loc[:, obsnames]  # slice group obs
10✔
1406
            oemin = np.nanmin([oemin, oeni.min().min()])
10✔
1407
            oemax = np.nanmax([oemax, oeni.max().max()])
10✔
1408
            # get min and max of mean sim vals
1409
            # (incase we want plot to ignore extremes)
1410
            oemeanmin = np.nanmin([oemeanmin, oeni.mean().min()])
10✔
1411
            oemeanmax = np.nanmax([oemeanmax, oeni.mean().max()])
10✔
1412
        for _, beni in ben.items():  # same with base ensemble/obsval
10✔
1413
            # work with either ensemble or obsval series
1414
            beni = beni.get(obsnames)
10✔
1415
            bemin = np.nanmin([bemin, beni.min().min()])
10✔
1416
            bemax = np.nanmax([bemax, beni.max().max()])
10✔
1417
            bemeanmin = np.nanmin([bemeanmin, beni.mean().min()])
10✔
1418
            bemeanmax = np.nanmax([bemeanmax, beni.mean().max()])
10✔
1419
        # get base ensemble range
1420
        berange = bemax-bemin
10✔
1421
        if berange == 0.:  # only one obs in group (probs)
10✔
1422
            berange = bemeanmax * 1.1  # expand a little
10✔
1423
        # add buffer to obs endpoints
1424
        bemin = bemin - (berange*0.05)
10✔
1425
        bemax = bemax + (berange*0.05)
10✔
1426
        if oemax < bemin:  # sim well below obs
10✔
1427
            oemin = oemeanmin  # set min to mean min
10✔
1428
            # (sim captured but not extremes)
1429
            outofrange = True
10✔
1430
        if oemin > bemax:  # sim well above obs
10✔
1431
            oemax = oemeanmax
10✔
1432
            outofrange = True
10✔
1433
        oerange = oemax - oemin
10✔
1434
        if bemax > oemax + (0.1*oerange):  # obs max well above sim
10✔
1435
            if not outofrange:  # but sim still in range
10✔
1436
                # zoom to sim
1437
                bemax = oemax + (0.1*oerange)
10✔
1438
            else:  # use obs mean max
1439
                bemax = bemeanmax
10✔
1440
        if bemin < oemin - (0.1 * oerange):  # obs min well below sim
10✔
1441
            if not outofrange:  # but sim still in range
10✔
1442
                # zoom to sim
1443
                bemin = oemin - (0.1 * oerange)
10✔
1444
            else:
1445
                bemin = bemeanmin
10✔
1446
        pmin = np.nanmin([oemin, bemin])
10✔
1447
        pmax = np.nanmax([oemax, bemax])
10✔
1448
        return pmin, pmax
10✔
1449

1450

1451
    if logger is None:
10✔
1452
        logger = Logger("Default_Loggger.log", echo=False)
10✔
1453
    logger.log("plot res_1to1")
10✔
1454
    obs = pst.observation_data
10✔
1455
    ensembles = _process_ensemble_arg(ensemble, facecolor, logger)
10✔
1456

1457
    if base_ensemble is not None:
10✔
1458
        base_ensemble = _process_ensemble_arg(base_ensemble, "r", logger)
10✔
1459

1460
    if "grouper" in kwargs:
10✔
1461
        raise NotImplementedError()
×
1462
    else:
1463
        grouper = obs.groupby(obs.obgnme).groups
10✔
1464
        for skip_group in skip_groups:
10✔
1465
            grouper.pop(skip_group)
×
1466

1467
    fig = plt.figure(figsize=figsize)
10✔
1468
    if "fig_title" in kwargs:
10✔
1469
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
1470
    else:
1471
        plt.figtext(
10✔
1472
            0.5,
1473
            0.5,
1474
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
1475
                pst.filename, str(datetime.now())
1476
            ),
1477
            ha="center",
1478
        )
1479

1480
    figs = []
10✔
1481
    ax_count = 0
10✔
1482
    for g, names in grouper.items():
10✔
1483
        logger.log("plotting 1to1 for {0}".format(g))
10✔
1484
        # control file observation for group
1485
        obs_g = obs.loc[names, :]
10✔
1486
        # normally only look a non-zero weighted obs
1487
        if "include_zero" not in kwargs or kwargs["include_zero"] is False:
10✔
1488
            obs_g = obs_g.loc[obs_g.weight > 0, :]
10✔
1489
        if obs_g.shape[0] == 0:
10✔
1490
            logger.statement("no non-zero obs for group '{0}'".format(g))
10✔
1491
            logger.log("plotting 1to1 for {0}".format(g))
10✔
1492
            continue
10✔
1493
        # if the first axis in page
1494
        if ax_count % (nr * nc) == 0:
10✔
1495
            if ax_count > 0:
10✔
1496
                plt.tight_layout()
×
1497
            figs.append(fig)
10✔
1498
            fig = plt.figure(figsize=figsize)
10✔
1499
            axes = _get_page_axes()
10✔
1500
            ax_count = 0
10✔
1501
        ax = axes[ax_count]
10✔
1502

1503
        if base_ensemble is None:
10✔
1504
            # if obs not defined by obs+noise ensemble,
1505
            # use min and max for obsval from control file
1506
            pmin, pmax = _get_plotlims(ensembles, obs_g.obsval, obs_g.obsnme)
10✔
1507
        else:
1508
            # if obs defined by obs+noise use obs+noise min and max
1509
            pmin, pmax = _get_plotlims(ensembles, base_ensemble, obs_g.obsnme)
10✔
1510
            obs_gg = obs_g.sort_values(by="obsval")
10✔
1511
            for c, en in base_ensemble.items():
10✔
1512
                en_g = en.loc[:, obs_gg.obsnme]
10✔
1513
                emx = en_g.max()
10✔
1514
                emn = en_g.min()
10✔
1515
                
1516
                #exit()
1517
                # update y min and max for obs+noise ensembles
1518
                if len(obs_gg.obsval) > 1:
10✔
1519

1520
                    emx = np.zeros(obs_gg.shape[0]) + emx
10✔
1521
                    emn = np.zeros(obs_gg.shape[0]) + emn
10✔
1522
                    ax.fill_between(obs_gg.obsval.values, emn.values, emx.values,
10✔
1523
                                    facecolor=c, alpha=0.2, zorder=2)
1524
                else:
1525
                    ax.plot([obs_gg.obsval.values, obs_gg.obsval.values], [emn, emx], color=c, alpha=0.2, zorder=2)
10✔
1526
        for c, en in ensembles.items():
10✔
1527
            en_g = en.loc[:, obs_g.obsnme]
10✔
1528
            # output mins and maxs
1529
            emx = en_g.max()
10✔
1530
            emn = en_g.min()
10✔
1531
            [
10✔
1532
                ax.plot([ov, ov], [een, eex], color=c, zorder=1)
1533
                for ov, een, eex in zip(obs_g.obsval.values, emn.values, emx.values)
1534
            ]
1535
        ax.plot([pmin, pmax], [pmin, pmax], "k--", lw=1.0, zorder=3)
10✔
1536
        xlim = (pmin, pmax)
10✔
1537
        ax.set_xlim(pmin, pmax)
10✔
1538
        ax.set_ylim(pmin, pmax)
10✔
1539

1540
        if max(np.abs(xlim)) > 1.0e5:
10✔
1541
            ax.xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
1542
            ax.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
1543
        ax.grid()
10✔
1544

1545
        ax.set_xlabel("observed", labelpad=0.1)
10✔
1546
        ax.set_ylabel("simulated", labelpad=0.1)
10✔
1547
        ax.set_title(
10✔
1548
            "{0}) group:{1}, {2} observations".format(
1549
                abet[ax_count], g, obs_g.shape[0]
1550
            ),
1551
            loc="left",
1552
        )
1553

1554
        # Residual (RHS plot)
1555
        ax_count += 1
10✔
1556
        ax = axes[ax_count]
10✔
1557
        # ax.scatter(obs_g.obsval, obs_g.res, marker='.', s=10, color='b')
1558

1559
        if base_ensemble is not None:
10✔
1560
            obs_gg = obs_g.sort_values(by="obsval")
10✔
1561
            for c, en in base_ensemble.items():
10✔
1562
                en_g = en.loc[:, obs_gg.obsnme].subtract(obs_gg.obsval)
10✔
1563
                emx = en_g.max()
10✔
1564
                emn = en_g.min()
10✔
1565
                if len(obs_gg.obsval) > 1:
10✔
1566
                    emx = np.zeros(obs_gg.shape[0]) + emx
10✔
1567
                    emn = np.zeros(obs_gg.shape[0]) + emn
10✔
1568
                    ax.fill_between(obs_gg.obsval.values, emn.values, emx.values,
10✔
1569
                                    facecolor=c, alpha=0.2, zorder=2)
1570
                else:
1571
                    # [ax.plot([ov, ov], [een, eex], color=c,alpha=0.3) for ov, een, eex in zip(obs_g.obsval.values, en.values, ex.values)]
1572
                    ax.plot([obs_gg.obsval.values, obs_gg.obsval.values], [emn, emx], color=c,
10✔
1573
                            alpha=0.2, zorder=2)
1574
        omn = []
10✔
1575
        omx = []
10✔
1576
        for c, en in ensembles.items():
10✔
1577
            en_g = en.loc[:, obs_g.obsnme].subtract(obs_g.obsval, axis=1)
10✔
1578
            emx = en_g.max()
10✔
1579
            emn = en_g.min()
10✔
1580
            omn.append(emn)
10✔
1581
            omx.append(emx)
10✔
1582
            [
10✔
1583
                ax.plot([ov, ov], [een, eex], color=c, zorder=1)
1584
                for ov, een, eex in zip(obs_g.obsval.values, emn.values, emx.values)
1585
            ]
1586

1587
        omn = pd.concat(omn).min()
10✔
1588
        omx = pd.concat(omx).max()
10✔
1589
        mx = np.nanmax([np.abs(omn), np.abs(omx)])  # ensure symmetric about y=0
10✔
1590
        if obs_g.shape[0] == 1:
10✔
1591
            mx *= 1.05
10✔
1592
        else:
1593
            mx *= 1.02
10✔
1594
        if np.sign(omn) == np.sign(omx):
10✔
1595
            # allow y axis asymm if all above or below
1596
            mn = np.nanmin([0, np.sign(omn) * mx])
10✔
1597
            mx = np.nanmax([0, np.sign(omn) * mx])
10✔
1598
        else:
1599
            mn = -mx
10✔
1600
        ax.set_ylim(mn, mx)
10✔
1601
        bmin = obs_g.obsval.values.min()
10✔
1602
        bmax = obs_g.obsval.values.max()
10✔
1603
        brange = (bmax - bmin)
10✔
1604
        if brange == 0.:
10✔
1605
            brange = obs_g.obsval.values.mean()
10✔
1606
        bmin = bmin - 0.1*brange
10✔
1607
        bmax = bmax + 0.1*brange
10✔
1608
        xlim = (bmin, bmax)
10✔
1609
        # show a zero residuals line
1610
        ax.plot(xlim, [0, 0], "k--", lw=1.0, zorder=3)
10✔
1611

1612
        ax.set_xlim(xlim)
10✔
1613
        ax.set_ylabel("residual", labelpad=0.1)
10✔
1614
        ax.set_xlabel("observed", labelpad=0.1)
10✔
1615
        ax.set_title(
10✔
1616
            "{0}) group:{1}, {2} observations".format(
1617
                abet[ax_count], g, obs_g.shape[0]
1618
            ),
1619
            loc="left",
1620
        )
1621
        ax.grid()
10✔
1622
        if ax.get_xlim()[1] > 1.0e5:
10✔
1623
            ax.xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
1624

1625
        ax_count += 1
10✔
1626

1627
        logger.log("plotting 1to1 for {0}".format(g))
10✔
1628

1629
    for a in range(ax_count, nr * nc):
10✔
1630
        axes[a].set_axis_off()
10✔
1631
        axes[a].set_yticks([])
10✔
1632
        axes[a].set_xticks([])
10✔
1633

1634
    plt.tight_layout()
10✔
1635
    figs.append(fig)
10✔
1636
    if filename is not None:
10✔
1637
        # plt.tight_layout()
1638
        with PdfPages(filename) as pdf:
10✔
1639
            for fig in figs:
10✔
1640
                pdf.savefig(fig)
10✔
1641
                plt.close(fig)
10✔
1642
        logger.log("plot res_1to1")
10✔
1643
    else:
1644
        logger.log("plot res_1to1")
×
1645
        return figs
×
1646

1647

1648
def plot_jac_test(
11✔
1649
    csvin, csvout, targetobs=None, filetype=None, maxoutputpages=1, outputdirectory=None
1650
):
1651
    """helper function to plot results of the Jacobian test performed using the pest++
1652
    program pestpp-swp.
1653

1654
    Args:
1655
        csvin (`str`): name of csv file used as input to sweep, typically developed with
1656
            static method pyemu.helpers.build_jac_test_csv()
1657
        csvout (`str`): name of csv file with output generated by sweep, both input
1658
            and output files can be specified in the pest++ control file
1659
            with pyemu using: pest_object.pestpp_options["sweep_parameter_csv_file"] = jactest_in_file.csv
1660
            pest_object.pestpp_options["sweep_output_csv_file"] = jactest_out_file.csv
1661
        targetobs ([`str`]): list of observation file names to plot, each parameter used for jactest can
1662
            have up to 32 observations plotted per page, throws a warning if more tha
1663
            10 pages of output are requested per parameter. If none, all observations in
1664
            the output csv file are used.
1665
        filetype (`str`): file type to store output, if None, plt.show() is called.
1666
        maxoutputpages (`int`): maximum number of pages of output per parameter.  Each page can
1667
            hold up to 32 observation derivatives.  If value > 10, set it to
1668
            10 and throw a warning.  If observations in targetobs > 32*maxoutputpages,
1669
            then a random set is selected from the targetobs list (or all observations
1670
            in the csv file if targetobs=None).
1671
        outputdirectory (`str`):  directory to store results, if None, current working directory is used.
1672
            If string is passed, it is joined to the current working directory and
1673
            created if needed. If os.path is passed, it is used directly.
1674

1675
    Note:
1676
        Used in conjunction with pyemu.helpers.build_jac_test_csv() and sweep to perform
1677
        a Jacobian Test and then view the results. Can generate a lot of plots so easiest
1678
        to put into a separate directory and view the files.
1679

1680
    """
1681

1682
    localhome = os.getcwd()
×
1683
    # check if the output directory exists, if not make it
1684
    if outputdirectory is not None and not os.path.exists(
×
1685
        os.path.join(localhome, outputdirectory)
1686
    ):
1687
        os.mkdir(os.path.join(localhome, outputdirectory))
×
1688
    if outputdirectory is None:
×
1689
        figures_dir = localhome
×
1690
    else:
1691
        figures_dir = os.path.join(localhome, outputdirectory)
×
1692

1693
    # read the input and output files into pandas dataframes
1694
    jactest_in_df = pd.read_csv(csvin, engine="python", index_col=0)
×
1695
    jactest_in_df.index.name = "input_run_id"
×
1696
    jactest_out_df = pd.read_csv(csvout, engine="python", index_col=1)
×
1697

1698
    # subtract the base run from every row, leaves the one parameter that
1699
    # was perturbed in any row as only non-zero value. Set zeros to nan
1700
    # so round-off doesn't get us and sum across rows to get a column of
1701
    # the perturbation for each row, finally extract to a series. First
1702
    # the input csv and then the output.
1703
    base_par = jactest_in_df.loc["base"]
×
1704
    delta_par_df = jactest_in_df.subtract(base_par, axis="columns")
×
1705
    delta_par_df.replace(0, np.nan, inplace=True)
×
1706
    delta_par_df.drop("base", axis="index", inplace=True)
×
1707
    delta_par_df["change"] = delta_par_df.sum(axis="columns")
×
1708
    delta_par = pd.Series(delta_par_df["change"])
×
1709

1710
    base_obs = jactest_out_df.loc["base"]
×
1711
    delta_obs = jactest_out_df.subtract(base_obs)
×
1712
    delta_obs.drop("base", axis="index", inplace=True)
×
1713
    # if targetobs is None, then reset it to all the observations.
1714
    if targetobs is None:
×
1715
        targetobs = jactest_out_df.columns.tolist()[8:]
×
1716
    delta_obs = delta_obs[targetobs]
×
1717

1718
    # get the Jacobian by dividing the change in observation by the change in parameter
1719
    # for the perturbed parameters
1720
    jacobian = delta_obs.divide(delta_par, axis="index")
×
1721

1722
    # use the index created by build_jac_test_csv to get a column of parameter names
1723
    # and increments, then we can plot derivative vs. increment for each parameter
1724
    extr_df = pd.Series(jacobian.index.values).str.extract(r"(.+)(_\d+$)", expand=True)
×
1725
    extr_df[1] = pd.to_numeric(extr_df[1].str.replace("_", "")) + 1
×
1726
    extr_df.rename(columns={0: "parameter", 1: "increment"}, inplace=True)
×
1727
    extr_df.index = jacobian.index
×
1728

1729
    # make a dataframe for plotting the Jacobian by combining the parameter name
1730
    # and increments frame with the Jacobian frame
1731
    plotframe = pd.concat([extr_df, jacobian], axis=1, join="inner")
×
1732

1733
    # get a list of observations to keep based on maxoutputpages.
1734
    if maxoutputpages > 10:
×
1735
        print("WARNING, more than 10 pages of output requested per parameter")
×
1736
        print("maxoutputpage reset to 10.")
×
1737
        maxoutputpages = 10
×
1738
    num_obs_plotted = np.min(np.array([maxoutputpages * 32, len(targetobs)]))
×
1739
    if num_obs_plotted < len(targetobs):
×
1740
        # get random sample
1741
        index_plotted = np.random.choice(len(targetobs), num_obs_plotted, replace=False)
×
1742
        obs_plotted = [targetobs[x] for x in index_plotted]
×
1743
        real_pages = maxoutputpages
×
1744
    else:
1745
        obs_plotted = targetobs
×
1746
        real_pages = int(targetobs / 32) + 1
×
1747

1748
    # make a subplot of derivative vs. increment one plot for each of the
1749
    # observations in targetobs, and outputs grouped by parameter.
1750
    for param, group in plotframe.groupby("parameter"):
×
1751
        for page in range(0, real_pages):
×
1752
            fig, axes = plt.subplots(8, 4, sharex=True, figsize=(10, 15))
×
1753
            for row in range(0, 8):
×
1754
                for col in range(0, 4):
×
1755
                    count = 32 * page + 4 * row + col
×
1756
                    if count < num_obs_plotted:
×
1757
                        axes[row, col].scatter(
×
1758
                            group["increment"], group[obs_plotted[count]]
1759
                        )
1760
                        axes[row, col].plot(
×
1761
                            group["increment"], group[obs_plotted[count]], "r"
1762
                        )
1763
                        axes[row, col].set_title(obs_plotted[count])
×
1764
                        axes[row, col].set_xticks([1, 2, 3, 4, 5])
×
1765
                        axes[row, col].tick_params(direction="in")
×
1766
                        if row == 3:
×
1767
                            axes[row, col].set_xlabel("Increment")
×
1768
            plt.tight_layout()
×
1769

1770
            if filetype is None:
×
1771
                plt.show()
×
1772
            else:
1773
                plt.savefig(
×
1774
                    os.path.join(
1775
                        figures_dir, "{0}_jactest_{1}.{2}".format(param, page, filetype)
1776
                    )
1777
                )
1778
            plt.close()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc