• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pypest / pyemu / 5887625428

17 Aug 2023 06:23AM UTC coverage: 79.857% (+1.5%) from 78.319%
5887625428

push

github

briochh
Merge branch 'develop'

11386 of 14258 relevant lines covered (79.86%)

6.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.77
/pyemu/plot/plot_utils.py
1
"""Plotting functions for various PEST(++) and pyemu operations"""
9✔
2
import os
9✔
3
import numpy as np
9✔
4
import pandas as pd
9✔
5
import warnings
9✔
6
from datetime import datetime
9✔
7
import string
9✔
8
from pyemu.logger import Logger
9✔
9
from pyemu.pst import pst_utils
9✔
10
from ..pyemu_warnings import PyemuWarning
9✔
11

12
font = {"size": 6}
9✔
13
try:
9✔
14
    import matplotlib
9✔
15

16
    matplotlib.rc("font", **font)
9✔
17

18
    import matplotlib.pyplot as plt
9✔
19
    from matplotlib.backends.backend_pdf import PdfPages
9✔
20
    from matplotlib.gridspec import GridSpec
9✔
21
except Exception as e:
×
22
    # raise Exception("error importing matplotlib: {0}".format(str(e)))
23
    warnings.warn("error importing matplotlib: {0}".format(str(e)), PyemuWarning)
×
24

25
import pyemu
9✔
26

27
figsize = (8, 10.5)
9✔
28
nr, nc = 4, 2
9✔
29
# page_gs = GridSpec(nr,nc)
30

31
abet = string.ascii_uppercase
9✔
32

33

34
def plot_summary_distributions(
9✔
35
    df,
36
    ax=None,
37
    label_post=False,
38
    label_prior=False,
39
    subplots=False,
40
    figsize=(11, 8.5),
41
    pt_color="b",
42
):
43
    """helper function to plot gaussian distributions from prior and posterior
44
    means and standard deviations
45

46
    Args:
47
        df (`pandas.DataFrame`): a dataframe and csv file.  Must have columns named:
48
            'prior_mean','prior_stdev','post_mean','post_stdev'.  If loaded
49
            from a csv file, column 0 is assumed to tbe the index
50
        ax (`matplotlib.pyplot.axis`): If None, and not subplots, then one is created
51
            and all distributions are plotted on a single plot
52
        label_post (`bool`): flag to add text labels to the peak of the posterior
53
        label_prior (`bool`): flag to add text labels to the peak of the prior
54
        subplots (`bool`): flag to use subplots.  If True, then 6 axes per page
55
            are used and a single prior and posterior is plotted on each
56
        figsize (`tuple`): matplotlib figure size
57

58
    Returns:
59
        tuple containing:
60

61
        - **[`matplotlib.figure`]**: list of figures
62
        - **[`matplotlib.axis`]**: list of axes
63

64
    Note:
65
        This is useful for demystifying FOSM results
66

67
        if subplots is False, a single axis is returned
68

69
    Example::
70

71
        import matplotlib.pyplot as plt
72
        import pyemu
73
        pyemu.plot_utils.plot_summary_distributions("pest.par.usum.csv")
74
        plt.show()
75

76
    """
77
    import matplotlib.pyplot as plt
8✔
78

79
    if isinstance(df, str):
8✔
80
        df = pd.read_csv(df, index_col=0)
8✔
81
    if ax is None and not subplots:
8✔
82
        fig = plt.figure(figsize=figsize)
8✔
83
        ax = plt.subplot(111)
8✔
84
        ax.grid()
8✔
85

86
    if "post_stdev" not in df.columns and "post_var" in df.columns:
8✔
87
        df.loc[:, "post_stdev"] = df.post_var.apply(np.sqrt)
×
88
    if "prior_stdev" not in df.columns and "prior_var" in df.columns:
8✔
89
        df.loc[:, "prior_stdev"] = df.prior_var.apply(np.sqrt)
×
90
    if "prior_expt" not in df.columns and "prior_mean" in df.columns:
8✔
91
        df.loc[:, "prior_expt"] = df.prior_mean
8✔
92
    if "post_expt" not in df.columns and "post_mean" in df.columns:
8✔
93
        df.loc[:, "post_expt"] = df.post_mean
8✔
94

95
    if subplots:
8✔
96
        fig = plt.figure(figsize=figsize)
8✔
97
        ax = plt.subplot(2, 3, 1)
8✔
98
        ax_per_page = 6
8✔
99
        ax_count = 0
8✔
100
        axes = []
8✔
101
        figs = []
8✔
102
    for name in df.index:
8✔
103
        x, y = gaussian_distribution(
8✔
104
            df.loc[name, "post_expt"], df.loc[name, "post_stdev"]
105
        )
106
        ax.fill_between(x, 0, y, facecolor=pt_color, edgecolor="none", alpha=0.25)
8✔
107
        if label_post:
8✔
108
            mx_idx = np.argmax(y)
8✔
109
            xtxt, ytxt = x[mx_idx], y[mx_idx] * 1.001
8✔
110
            ax.text(xtxt, ytxt, name, ha="center", alpha=0.5)
8✔
111

112
        x, y = gaussian_distribution(
8✔
113
            df.loc[name, "prior_expt"], df.loc[name, "prior_stdev"]
114
        )
115
        ax.plot(x, y, color="0.5", lw=3.0, dashes=(2, 1))
8✔
116
        if label_prior:
8✔
117
            mx_idx = np.argmax(y)
×
118
            xtxt, ytxt = x[mx_idx], y[mx_idx] * 1.001
×
119
            ax.text(xtxt, ytxt, name, ha="center", alpha=0.5)
×
120
        # ylim = list(ax.get_ylim())
121
        # ylim[1] *= 1.2
122
        # ylim[0] = 0.0
123
        # ax.set_ylim(ylim)
124
        if subplots:
8✔
125
            ax.set_title(name)
8✔
126
            ax_count += 1
8✔
127
            ax.set_yticklabels([])
8✔
128
            axes.append(ax)
8✔
129
            if name == df.index[-1]:
8✔
130
                break
8✔
131
            if ax_count >= ax_per_page:
8✔
132
                figs.append(fig)
8✔
133
                fig = plt.figure(figsize=figsize)
8✔
134
                ax_count = 0
8✔
135
            ax = plt.subplot(2, 3, ax_count + 1)
8✔
136
    if subplots:
8✔
137
        figs.append(fig)
8✔
138
        return figs, axes
8✔
139
    ylim = list(ax.get_ylim())
8✔
140
    ylim[1] *= 1.2
8✔
141
    ylim[0] = 0.0
8✔
142
    ax.set_ylim(ylim)
8✔
143
    ax.set_yticklabels([])
8✔
144
    return ax
8✔
145

146

147
def gaussian_distribution(mean, stdev, num_pts=50):
9✔
148
    """get an x and y numpy.ndarray that spans the +/- 4
149
    standard deviation range of a gaussian distribution with
150
    a given mean and standard deviation. useful for plotting
151

152
    Args:
153
        mean (`float`): the mean of the distribution
154
        stdev (`float`): the standard deviation of the distribution
155
        num_pts (`int`): the number of points in the returned ndarrays.
156
            Default is 50
157

158
    Returns:
159
        tuple containing:
160

161
        - **numpy.ndarray**: the x-values of the distribution
162
        - **numpy.ndarray**: the y-values of the distribution
163

164
    Example::
165

166
        mean,std = 1.0, 2.0
167
        x,y = pyemu.plot.gaussian_distribution(mean,std)
168
        plt.fill_between(x,0,y)
169
        plt.show()
170

171

172
    """
173
    xstart = mean - (4.0 * stdev)
9✔
174
    xend = mean + (4.0 * stdev)
9✔
175
    x = np.linspace(xstart, xend, num_pts)
9✔
176
    y = (1.0 / np.sqrt(2.0 * np.pi * stdev * stdev)) * np.exp(
9✔
177
        -1.0 * ((x - mean) ** 2) / (2.0 * stdev * stdev)
178
    )
179
    return x, y
9✔
180

181

182
def pst_helper(pst, kind=None, **kwargs):
9✔
183
    """`pyemu.Pst` plot helper - takes the
184
    handoff from `pyemu.Pst.plot()`
185

186
    Args:
187
        kind (`str`): the kind of plot to make
188
        **kargs (`dict`): keyword arguments to pass to the
189
            plotting function and ultimately to `matplotlib`
190

191
    Returns:
192
        varies: usually a combination of `matplotlib.figure` (s) and/or
193
        `matplotlib.axis` (s)
194

195
    Example::
196

197
        pst = pyemu.Pst("pest.pst") #assumes pest.res or pest.rei is found
198
        pst.plot(kind="1to1")
199
        plt.show()
200
        pst.plot(kind="phipie")
201
        plt.show()
202
        pst.plot(kind="prior")
203
        plt.show()
204

205
    """
206

207
    echo = kwargs.get("echo", False)
9✔
208
    logger = pyemu.Logger("plot_pst_helper.log", echo=echo)
9✔
209
    logger.statement("plot_utils.pst_helper()")
9✔
210

211
    kinds = {
9✔
212
        "prior": pst_prior,
213
        "1to1": res_1to1,
214
        "phi_pie": res_phi_pie,
215
        "phi_progress": phi_progress,
216
    }
217

218
    if kind is None:
9✔
219
        returns = []
8✔
220
        base_filename = pst.filename
8✔
221
        if pst.new_filename is not None:
8✔
222
            base_filename = pst.new_filename
×
223
        base_filename = base_filename.replace(".pst", "")
8✔
224
        for name, func in kinds.items():
8✔
225
            plt_name = base_filename + "." + name + ".pdf"
8✔
226
            returns.append(func(pst, logger=logger, filename=plt_name))
8✔
227

228
        return returns
8✔
229
    elif kind not in kinds:
9✔
230
        logger.lraise(
×
231
            "unrecognized kind:{0}, should one of {1}".format(
232
                kind, ",".join(list(kinds.keys()))
233
            )
234
        )
235
    return kinds[kind](pst, logger, **kwargs)
9✔
236

237

238
def phi_progress(pst, logger=None, filename=None, **kwargs):
9✔
239
    """make plot of phi vs number of model runs - requires
240
    available  ".iobj" file generated by a PESTPP-GLM run.
241

242
    Args:
243
        pst (`pyemu.Pst`): a control file instance
244
        logger (`pyemu.Logger`):  if None, a generic one is created.  Default is None
245
        filename (`str`): PDF filename to save figures to.  If None, figures
246
            are returned.  Default is None
247
        kwargs (`dict`): optional keyword args to pass to plotting function
248

249
    Returns:
250
        `matplotlib.axis`: the axis the plot was made on
251

252
    Example::
253

254
        import matplotlib.pyplot as plt
255
        import pyemu
256
        pst = pyemu.Pst("my.pst")
257
        pyemu.plot_utils.phi_progress(pst)
258
        plt.show()
259

260
    """
261
    if logger is None:
8✔
262
        logger = Logger("Default_Loggger.log", echo=False)
×
263
    logger.log("plot phi_progress")
8✔
264

265
    iobj_file = pst.filename.replace(".pst", ".iobj")
8✔
266
    if not os.path.exists(iobj_file):
8✔
267
        logger.lraise("couldn't find iobj file {0}".format(iobj_file))
×
268
    df = pd.read_csv(iobj_file)
8✔
269
    if "ax" in kwargs:
8✔
270
        ax = kwargs["ax"]
×
271
    else:
272
        fig = plt.figure(figsize=figsize)
8✔
273
        ax = plt.subplot(1, 1, 1)
8✔
274
    ax.plot(df.model_runs_completed, df.total_phi, marker=".")
8✔
275
    ax.set_xlabel("model runs")
8✔
276
    ax.set_ylabel(r"$\phi$")
8✔
277
    ax.grid()
8✔
278
    if filename is not None:
8✔
279
        plt.savefig(filename)
8✔
280
    logger.log("plot phi_progress")
8✔
281
    return ax
8✔
282

283

284
def _get_page_axes(count=nr * nc):
9✔
285
    axes = [plt.subplot(nr, nc, i + 1) for i in range(min(count, nr * nc))]
9✔
286
    # [ax.set_yticks([]) for ax in axes]
287
    return axes
9✔
288

289

290
def res_1to1(
9✔
291
    pst, logger=None, filename=None, plot_hexbin=False, histogram=False, **kwargs
292
):
293
    """make 1-to-1 plots and also observed vs residual by observation group
294

295
    Args:
296
        pst (`pyemu.Pst`): a control file instance
297
        logger (`pyemu.Logger`):  if None, a generic one is created.  Default is None
298
        filename (`str`): PDF filename to save figures to.  If None, figures
299
            are returned.  Default is None
300
        hexbin (`bool`): flag to use the hexbinning for large numbers of residuals.
301
            Default is False
302
        histogram (`bool`): flag to plot residual histograms instead of obs vs residual.
303
            Default is False (use `matplotlib.pyplot.scatter` )
304
        kwargs (`dict`): optional keyword args to pass to plotting function
305

306
    Returns:
307
        `matplotlib.axis`: the axis the plot was made on
308

309
    Example::
310

311
        import matplotlib.pyplot as plt
312
        import pyemu
313
        pst = pyemu.Pst("my.pst")
314
        pyemu.plot_utils.phi_progress(pst)
315
        plt.show()
316

317
    """
318
    if logger is None:
8✔
319
        logger = Logger("Default_Loggger.log", echo=False)
8✔
320
    logger.log("plot res_1to1")
8✔
321

322
    if "ensemble" in kwargs:
8✔
323
        res = pst_utils.res_from_en(pst, kwargs["ensemble"])
8✔
324
        try:
8✔
325
            res = pst_utils.res_from_en(pst, kwargs["ensemble"])
8✔
326
        except Exception as e:
×
327
            logger.lraise("res_1to1: error loading ensemble file: {0}".format(str(e)))
×
328
    else:
329
        try:
8✔
330
            res = pst.res
8✔
331
        except:
×
332
            logger.lraise("res_phi_pie: pst.res is None, couldn't find residuals file")
×
333

334
    obs = pst.observation_data
8✔
335

336
    if "grouper" in kwargs:
8✔
337
        raise NotImplementedError()
×
338
    else:
339
        grouper = obs.groupby(obs.obgnme).groups
8✔
340

341
    fig = plt.figure(figsize=figsize)
8✔
342
    if "fig_title" in kwargs:
8✔
343
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
8✔
344
    else:
345
        plt.figtext(
8✔
346
            0.5,
347
            0.5,
348
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
349
                pst.filename, str(datetime.now())
350
            ),
351
            ha="center",
352
        )
353
    # if plot_hexbin:
354
    #    pdfname = pst.filename.replace(".pst", ".1to1.hexbin.pdf")
355
    # else:
356
    #    pdfname = pst.filename.replace(".pst", ".1to1.pdf")
357
    figs = []
8✔
358
    ax_count = 0
8✔
359
    for g, names in grouper.items():
8✔
360
        logger.log("plotting 1to1 for {0}".format(g))
8✔
361

362
        obs_g = obs.loc[names, :]
8✔
363
        obs_g.loc[:, "sim"] = res.loc[names, "modelled"]
8✔
364
        logger.statement("using control file obsvals to calculate residuals")
8✔
365
        obs_g.loc[:, "res"] = obs_g.sim - obs_g.obsval
8✔
366
        if "include_zero" not in kwargs or kwargs["include_zero"] is True:
8✔
367
            obs_g = obs_g.loc[obs_g.weight > 0, :]
8✔
368
        if obs_g.shape[0] == 0:
8✔
369
            logger.statement("no non-zero obs for group '{0}'".format(g))
8✔
370
            logger.log("plotting 1to1 for {0}".format(g))
8✔
371
            continue
8✔
372

373
        if ax_count % (nr * nc) == 0:
8✔
374
            if ax_count > 0:
8✔
375
                plt.tight_layout()
×
376
            # pdf.savefig()
377
            # plt.close(fig)
378
            figs.append(fig)
8✔
379
            fig = plt.figure(figsize=figsize)
8✔
380
            axes = _get_page_axes()
8✔
381
            ax_count = 0
8✔
382

383
        ax = axes[ax_count]
8✔
384

385
        # if obs_g.shape[0] == 1:
386
        #    ax.scatter(list(obs_g.sim),list(obs_g.obsval),marker='.',s=30,color='b')
387
        # else:
388
        mx = max(obs_g.obsval.max(), obs_g.sim.max())
8✔
389
        mn = min(obs_g.obsval.min(), obs_g.sim.min())
8✔
390

391
        # if obs_g.shape[0] == 1:
392
        mx *= 1.1
8✔
393
        mn *= 0.9
8✔
394
        ax.axis("square")
8✔
395
        if plot_hexbin:
8✔
396
            ax.hexbin(
×
397
                obs_g.obsval.values,
398
                obs_g.sim.values,
399
                mincnt=1,
400
                gridsize=(75, 75),
401
                extent=(mn, mx, mn, mx),
402
                bins="log",
403
                edgecolors=None,
404
            )
405
        #               plt.colorbar(ax=ax)
406
        else:
407
            ax.scatter([obs_g.obsval], [obs_g.sim], marker=".", s=10, color="b")
8✔
408

409
        ax.plot([mn, mx], [mn, mx], "k--", lw=1.0)
8✔
410
        xlim = (mn, mx)
8✔
411
        ax.set_xlim(mn, mx)
8✔
412
        ax.set_ylim(mn, mx)
8✔
413
        ax.grid()
8✔
414

415
        ax.set_xlabel("observed", labelpad=0.1)
8✔
416
        ax.set_ylabel("simulated", labelpad=0.1)
8✔
417
        ax.set_title(
8✔
418
            "{0}) group:{1}, {2} observations".format(
419
                abet[ax_count], g, obs_g.shape[0]
420
            ),
421
            loc="left",
422
        )
423

424
        ax_count += 1
8✔
425

426
        if histogram == False:
8✔
427
            ax = axes[ax_count]
8✔
428
            ax.scatter(obs_g.obsval, obs_g.res, marker=".", s=10, color="b")
8✔
429
            ylim = ax.get_ylim()
8✔
430
            mx = max(np.abs(ylim[0]), np.abs(ylim[1]))
8✔
431
            if obs_g.shape[0] == 1:
8✔
432
                mx *= 1.1
8✔
433
            ax.set_ylim(-mx, mx)
8✔
434
            # show a zero residuals line
435
            ax.plot(xlim, [0, 0], "k--", lw=1.0)
8✔
436
            meanres = obs_g.res.mean()
8✔
437
            # show mean residuals line
438
            ax.plot(xlim, [meanres, meanres], "r-", lw=1.0)
8✔
439
            ax.set_xlim(xlim)
8✔
440
            ax.set_ylabel("residual", labelpad=0.1)
8✔
441
            ax.set_xlabel("observed", labelpad=0.1)
8✔
442
            ax.set_title(
8✔
443
                "{0}) group:{1}, {2} observations".format(
444
                    abet[ax_count], g, obs_g.shape[0]
445
                ),
446
                loc="left",
447
            )
448
            ax.grid()
8✔
449
            ax_count += 1
8✔
450
        else:
451
            # need max and min res to set xlim, otherwise wonky figsize
452
            mxr = obs_g.res.max()
8✔
453
            mnr = obs_g.res.min()
8✔
454

455
            # if obs_g.shape[0] == 1:
456
            mxr *= 1.1
8✔
457
            mnr *= 0.9
8✔
458
            rlim = (mnr, mxr)
8✔
459

460
            ax = axes[ax_count]
8✔
461
            ax.hist(obs_g.res, bins=50, color="b")
8✔
462
            meanres = obs_g.res.mean()
8✔
463
            ax.axvline(meanres, color="r", lw=1)
8✔
464
            b, t = ax.get_ylim()
8✔
465
            ax.text(meanres + meanres / 10, t - t / 10, "Mean: {:.2f}".format(meanres))
8✔
466
            ax.set_xlim(rlim)
8✔
467
            ax.set_ylabel("count", labelpad=0.1)
8✔
468
            ax.set_xlabel("residual", labelpad=0.1)
8✔
469
            ax.set_title(
8✔
470
                "{0}) group:{1}, {2} observations".format(
471
                    abet[ax_count], g, obs_g.shape[0]
472
                ),
473
                loc="left",
474
            )
475
            ax.grid()
8✔
476
            ax_count += 1
8✔
477
        logger.log("plotting 1to1 for {0}".format(g))
8✔
478

479
    for a in range(ax_count, nr * nc):
8✔
480
        axes[a].set_axis_off()
8✔
481
        axes[a].set_yticks([])
8✔
482
        axes[a].set_xticks([])
8✔
483

484
    plt.tight_layout()
8✔
485
    # pdf.savefig()
486
    # plt.close(fig)
487
    figs.append(fig)
8✔
488
    if filename is not None:
8✔
489
        with PdfPages(filename) as pdf:
8✔
490
            for fig in figs:
8✔
491
                pdf.savefig(fig)
8✔
492
                plt.close(fig)
8✔
493
        logger.log("plot res_1to1")
8✔
494
    else:
495
        logger.log("plot res_1to1")
8✔
496
        return figs
8✔
497

498

499
def plot_id_bar(id_df, nsv=None, logger=None, **kwargs):
9✔
500
    """Plot a stacked bar chart of identifiability based on
501
    a the `pyemu.ErrVar.get_identifiability()` dataframe
502

503
    Args:
504
        id_df (`pandas.DataFrame`) : dataframe of identifiability
505
        nsv (`int`): number of singular values to consider
506
        logger (`pyemu.Logger`, optonal): a logger.  If None, a generic
507
            one is created
508
        kwargs (`dict`): a dict of keyword arguments to pass to the
509
            plotting function
510

511
    Returns:
512
        `matplotlib.Axis`: the axis with the plot
513

514
    Example::
515

516
        import pyemu
517
        pest_obj = pyemu.Pst(pest_control_file)
518
        ev = pyemu.ErrVar(jco='freyberg_jac.jcb'))
519
        id_df = ev.get_identifiability_dataframe(singular_value=48)
520
        pyemu.plot_utils.plot_id_bar(id_df, nsv=12, figsize=(12,4)
521

522
    """
523
    if logger is None:
8✔
524
        logger = Logger("Default_Loggger.log", echo=False)
8✔
525
    logger.log("plot id bar")
8✔
526

527
    df = id_df.copy()
8✔
528

529
    # drop the final `ident` column
530
    if "ident" in df.columns:
8✔
531
        df.drop("ident", inplace=True, axis=1)
8✔
532

533
    if nsv is None or nsv > len(df.columns):
8✔
534
        nsv = len(df.columns)
8✔
535
        logger.log("set number of SVs and number in the dataframe")
8✔
536

537
    df = df[df.columns[:nsv]]
8✔
538

539
    df["ident"] = df.sum(axis=1)
8✔
540
    df.sort_values(by="ident", inplace=True, ascending=False)
8✔
541
    df.drop("ident", inplace=True, axis=1)
8✔
542

543
    if "figsize" in kwargs:
8✔
544
        figsize = kwargs["figsize"]
×
545
    else:
546
        figsize = (8, 10.5)
8✔
547
    if "ax" in kwargs:
8✔
548
        ax = kwargs["ax"]
×
549
    else:
550
        fig = plt.figure(figsize=figsize)
8✔
551
        ax = plt.subplot(1, 1, 1)
8✔
552

553
    # plto the stacked bar chart (the easy part!)
554
    df.plot.bar(stacked=True, cmap="jet_r", legend=False, ax=ax)
8✔
555

556
    #
557
    # horrible shenanigans to make a colorbar rather than a legend
558
    #
559

560
    # special case colormap just dark red if one SV
561
    if nsv == 1:
8✔
562
        tcm = matplotlib.colors.LinearSegmentedColormap.from_list(
×
563
            "one_sv", [plt.get_cmap("jet_r")(0)] * 2, N=2
564
        )
565
        sm = plt.cm.ScalarMappable(
×
566
            cmap=tcm, norm=matplotlib.colors.Normalize(vmin=0, vmax=nsv + 1)
567
        )
568
    # or typically just rock the jet_r colormap over the range of SVs
569
    else:
570
        sm = plt.cm.ScalarMappable(
8✔
571
            cmap=plt.get_cmap("jet_r"),
572
            norm=matplotlib.colors.Normalize(vmin=1, vmax=nsv),
573
        )
574
    sm._A = []
8✔
575

576
    # now, if too many ticks for the colorbar, summarize them
577
    if nsv < 20:
8✔
578
        ticks = range(1, nsv + 1)
8✔
579
    else:
580
        ticks = np.arange(1, nsv + 1, int((nsv + 1) / 30))
×
581

582
    cb = plt.colorbar(sm, ax=ax)
8✔
583
    cb.set_ticks(ticks)
8✔
584

585
    logger.log("plot id bar")
8✔
586

587
    return ax
8✔
588

589

590
def res_phi_pie(pst, logger=None, **kwargs):
9✔
591
    """plot current phi components as a pie chart.
592

593
    Args:
594
        pst (`pyemu.Pst`): a control file instance with the residual datafrane
595
            instance available.
596
        logger (`pyemu.Logger`): a logger.  If None, a generic one is created
597
        kwargs (`dict`): a dict of plotting options. Accepts 'include_zero'
598
            as a flag to include phi groups with only zero-weight obs (not
599
            sure why anyone would do this, but whatevs).
600

601
            Also accepts 'label_comps': list of components for the labels. Options are
602
            ['name', 'phi_comp', 'phi_percent']. Labels will use those three components
603
            in the order of the 'label_comps' list.
604

605
            Any additional
606
            args are passed to `matplotlib`.
607

608
    Returns:
609
        `matplotlib.Axis`: the axis with the plot.
610

611
    Example::
612

613
        import pyemu
614
        pst = pyemu.Pst("my.pst")
615
        pyemu.plot_utils.res_phi_pie(pst,figsize=(12,4))
616
        pyemu.plot_utils.res_phi_pie(pst,label_comps = ['name','phi_percent'], figsize=(12,4))
617

618

619
    """
620
    if logger is None:
9✔
621
        logger = Logger("Default_Loggger.log", echo=False)
8✔
622
    logger.log("plot res_phi_pie")
9✔
623

624
    if "ensemble" in kwargs:
9✔
625
        try:
8✔
626
            res = pst_utils.res_from_en(pst, kwargs["ensemble"])
8✔
627
        except:
×
628
            logger.statement(
×
629
                "res_1to1: could not find ensemble file {0}".format(kwargs["ensemble"])
630
            )
631
    else:
632
        try:
9✔
633
            res = pst.res
9✔
634
        except:
×
635
            logger.lraise("res_phi_pie: pst.res is None, couldn't find residuals file")
×
636

637
    obs = pst.observation_data
9✔
638
    phi = pst.phi
9✔
639
    phi_comps = pst.phi_components
9✔
640
    norm_phi_comps = pst.phi_components_normalized
9✔
641
    keys = list(phi_comps.keys())
9✔
642
    if "include_zero" not in kwargs or kwargs["include_zero"] is False:
9✔
643
        phi_comps = {k: phi_comps[k] for k in keys if phi_comps[k] > 0.0}
9✔
644
        keys = list(phi_comps.keys())
9✔
645
        norm_phi_comps = {k: norm_phi_comps[k] for k in keys}
9✔
646
    if "ax" in kwargs:
9✔
647
        ax = kwargs["ax"]
8✔
648
    else:
649
        fig = plt.figure(figsize=figsize)
9✔
650
        ax = plt.subplot(1, 1, 1, aspect="equal")
9✔
651

652
    if "label_comps" not in kwargs:
9✔
653
        labels = [
9✔
654
            "{0}\n{1:4G}\n({2:3.1f}%)".format(
655
                k, phi_comps[k], 100.0 * (phi_comps[k] / phi)
656
            )
657
            for k in keys
658
        ]
659
    else:
660
        # make sure the components for the labels are in a list
661
        if not isinstance(kwargs["label_comps"], list):
×
662
            fmtchoices = list([kwargs["label_comps"]])
×
663
        else:
664
            fmtchoices = kwargs["label_comps"]
×
665
        # assemble all possible label components
666
        labfmts = {
×
667
            "name": ["{}\n", keys],
668
            "phi_comp": ["{:4G}\n", [phi_comps[k] for k in keys]],
669
            "phi_percent": ["({:3.1f}%)", [100.0 * (phi_comps[k] / phi) for k in keys]],
670
        }
671
        if fmtchoices[0] == "phi_percent":
×
672
            labfmts["phi_percent"][0] = "{}\n".format(labfmts["phi_percent"][0])
×
673
        # make the string format
674
        labfmtstr = "".join([labfmts[k][0] for k in fmtchoices])
×
675
        # pull it together
676
        labels = [
×
677
            labfmtstr.format(*k) for k in zip(*[labfmts[j][1] for j in fmtchoices])
678
        ]
679

680
    ax.pie([float(norm_phi_comps[k]) for k in keys], labels=labels)
9✔
681
    logger.log("plot res_phi_pie")
9✔
682
    if "filename" in kwargs:
9✔
683
        plt.savefig(kwargs["filename"])
8✔
684
    return ax
9✔
685

686

687
def pst_prior(pst, logger=None, filename=None, **kwargs):
9✔
688
    """helper to plot prior parameter histograms implied by
689
    parameter bounds. Saves a multipage pdf named <case>.prior.pdf
690

691
    Args:
692
        pst (`pyemu.Pst`): control file
693
        logger (`pyemu.Logger`): a logger.  If None, a generic one is created.
694
        filename (`str`):  PDF filename to save plots to.
695
            If None, return figs without saving.  Default is None.
696
        kwargs (`dict`): additional plotting options. Accepts 'grouper' as
697
            dict to group parameters on to a single axis (use
698
            parameter groups if not passed),'unqiue_only' to only show unique
699
            mean-stdev combinations within a given group.  Any additional args
700
            are passed to `matplotlib`.
701

702
    Returns:
703
        [`matplotlib.Figure`]: a list of figures created.
704

705
    Example::
706

707
        pst = pyemu.Pst("pest.pst")
708
        pyemu.pst_utils.pst_prior(pst)
709
        plt.show()
710

711
    """
712
    if logger is None:
9✔
713
        logger = Logger("Default_Loggger.log", echo=False)
×
714
    logger.log("plot pst_prior")
9✔
715
    par = pst.parameter_data
9✔
716

717
    if "parcov_filename" in pst.pestpp_options:
9✔
718
        logger.warn("ignoring parcov_filename, using parameter bounds for prior cov")
×
719
    logger.log("loading cov from parameter data")
9✔
720
    cov = pyemu.Cov.from_parameter_data(pst)
9✔
721
    logger.log("loading cov from parameter data")
9✔
722

723
    logger.log("building mean parameter values")
9✔
724
    li = par.partrans.loc[cov.names] == "log"
9✔
725
    mean = par.parval1.loc[cov.names]
9✔
726
    info = par.loc[cov.names, :].copy()
9✔
727
    info.loc[:, "mean"] = mean
9✔
728
    info.loc[li, "mean"] = mean[li].apply(np.log10)
9✔
729
    logger.log("building mean parameter values")
9✔
730

731
    logger.log("building stdev parameter values")
9✔
732
    if cov.isdiagonal:
9✔
733
        std = cov.x.flatten()
9✔
734
    else:
735
        std = np.diag(cov.x)
×
736
    std = np.sqrt(std)
9✔
737
    info.loc[:, "prior_std"] = std
9✔
738

739
    logger.log("building stdev parameter values")
9✔
740

741
    if std.shape != mean.shape:
9✔
742
        logger.lraise("mean.shape {0} != std.shape {1}".format(mean.shape, std.shape))
×
743

744
    if "grouper" in kwargs:
9✔
745
        raise NotImplementedError()
×
746
        # check for consistency here
747

748
    else:
749
        par_adj = par.loc[par.partrans.apply(lambda x: x in ["log", "none"]), :]
9✔
750
        grouper = par_adj.groupby(par_adj.pargp).groups
9✔
751
        # grouper = par.groupby(par.pargp).groups
752

753
    if len(grouper) == 0:
9✔
754
        raise Exception("no adustable parameters to plot")
×
755

756
    fig = plt.figure(figsize=figsize)
9✔
757
    if "fig_title" in kwargs:
9✔
758
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
8✔
759
    else:
760
        plt.figtext(
9✔
761
            0.5,
762
            0.5,
763
            "pyemu.Pst.plot(kind='prior')\nfrom pest control file '{0}'\n at {1}".format(
764
                pst.filename, str(datetime.now())
765
            ),
766
            ha="center",
767
        )
768
    figs = []
9✔
769
    ax_count = 0
9✔
770
    grps_names = list(grouper.keys())
9✔
771
    grps_names.sort()
9✔
772
    for g in grps_names:
9✔
773
        names = grouper[g]
9✔
774
        logger.log("plotting priors for {0}".format(",".join(list(names))))
9✔
775
        if ax_count % (nr * nc) == 0:
9✔
776
            plt.tight_layout()
9✔
777
            # pdf.savefig()
778
            # plt.close(fig)
779
            figs.append(fig)
9✔
780
            fig = plt.figure(figsize=figsize)
9✔
781
            axes = _get_page_axes()
9✔
782
            ax_count = 0
9✔
783

784
        islog = False
9✔
785
        vc = info.partrans.value_counts()
9✔
786
        if vc.shape[0] > 1:
9✔
787
            logger.warn("mixed partrans for group {0}".format(g))
×
788
        elif "log" in vc.index:
9✔
789
            islog = True
9✔
790
        ax = axes[ax_count]
9✔
791
        if "unique_only" in kwargs and kwargs["unique_only"]:
9✔
792

793
            ms = (
9✔
794
                info.loc[names, :]
795
                .apply(lambda x: (x["mean"], x["prior_std"]), axis=1)
796
                .unique()
797
            )
798
            for (m, s) in ms:
9✔
799
                x, y = gaussian_distribution(m, s)
9✔
800
                ax.fill_between(x, 0, y, facecolor="0.5", alpha=0.5, edgecolor="none")
9✔
801

802
        else:
803
            for m, s in zip(info.loc[names, "mean"], info.loc[names, "prior_std"]):
8✔
804
                x, y = gaussian_distribution(m, s)
8✔
805
                ax.fill_between(x, 0, y, facecolor="0.5", alpha=0.5, edgecolor="none")
8✔
806
        ax.set_title(
9✔
807
            "{0}) group:{1}, {2} parameters".format(abet[ax_count], g, names.shape[0]),
808
            loc="left",
809
        )
810

811
        ax.set_yticks([])
9✔
812
        if islog:
9✔
813
            ax.set_xlabel("$log_{10}$ parameter value", labelpad=0.1)
9✔
814
        else:
815
            ax.set_xlabel("parameter value", labelpad=0.1)
8✔
816
        logger.log("plotting priors for {0}".format(",".join(list(names))))
9✔
817

818
        ax_count += 1
9✔
819

820
    for a in range(ax_count, nr * nc):
9✔
821
        axes[a].set_axis_off()
9✔
822
        axes[a].set_yticks([])
9✔
823
        axes[a].set_xticks([])
9✔
824

825
    plt.tight_layout()
9✔
826
    # pdf.savefig()
827
    # plt.close(fig)
828
    figs.append(fig)
9✔
829
    if filename is not None:
9✔
830
        with PdfPages(filename) as pdf:
8✔
831
            plt.tight_layout()
8✔
832
            pdf.savefig(fig)
8✔
833
            plt.close(fig)
8✔
834
        logger.log("plot pst_prior")
8✔
835
    else:
836
        logger.log("plot pst_prior")
9✔
837
        return figs
9✔
838

839

840
def ensemble_helper(
9✔
841
    ensemble,
842
    bins=10,
843
    facecolor="0.5",
844
    plot_cols=None,
845
    filename=None,
846
    func_dict=None,
847
    sync_bins=True,
848
    deter_vals=None,
849
    std_window=None,
850
    deter_range=False,
851
    **kwargs
852
):
853
    """helper function to plot ensemble histograms
854

855
    Args:
856
        ensemble : varies
857
            the ensemble argument can be a pandas.DataFrame or derived type or a str, which
858
            is treated as a filename.  Optionally, ensemble can be a list of these types or
859
            a dict, in which case, the keys are treated as facecolor str (e.g., 'b', 'y', etc).
860
        facecolor : str
861
            the histogram facecolor.  Only applies if ensemble is a single thing
862
        plot_cols : enumerable
863
            a collection of columns (in form of a list of parameters, or a dict with keys for
864
            parsing plot axes and values of parameters) from the ensemble(s) to plot.  If None,
865
            (the union of) all cols are plotted. Default is None
866
        filename : str
867
            the name of the pdf to create.  If None, return figs without saving.  Default is None.
868
        func_dict : dict
869
            a dictionary of unary functions (e.g., `np.log10` to apply to columns.  Key is
870
            column name.  Default is None
871
        sync_bins : bool
872
            flag to use the same bin edges for all ensembles. Only applies if more than
873
            one ensemble is being plotted.  Default is True
874
        deter_vals : dict
875
            dict of deterministic values to plot as a vertical line. key is ensemble columnn name
876
        std_window : float
877
            the number of standard deviations around the mean to mark as vertical lines.  If None,
878
            nothing happens.  Default is None
879
        deter_range : bool
880
            flag to set xlims to deterministic value +/- std window.  If True, std_window must not be None.
881
            Default is False
882

883
    Example::
884

885
        # plot prior and posterior par ensembles
886
        pst = pyemu.Pst("my.pst")
887
        prior = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="prior.jcb")
888
        post = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="my.3.par.jcb")
889
        pyemu.plot_utils.ensemble_helper(ensemble={"0.5":prior, "b":post},filename="ensemble.pdf")
890

891
        #plot prior and posterior simulated equivalents to observations with obs noise and obs vals
892
        pst = pyemu.Pst("my.pst")
893
        prior = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.0.obs.jcb")
894
        post = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.3.obs.jcb")
895
        noise = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.obs+noise.jcb")
896
        pyemu.plot_utils.ensemble_helper(ensemble={"0.5":prior, "b":post,"r":noise},
897
                                         filename="ensemble.pdf",
898
                                         deter_vals=pst.observation_data.obsval.to_dict())
899

900

901
    """
902
    logger = pyemu.Logger("ensemble_helper.log")
9✔
903
    logger.log("pyemu.plot_utils.ensemble_helper()")
9✔
904
    ensembles = _process_ensemble_arg(ensemble, facecolor, logger)
9✔
905
    if len(ensembles) == 0:
9✔
906
        raise Exception("plot_uitls.ensemble_helper() error processing `ensemble` arg")
×
907
    # apply any functions
908
    if func_dict is not None:
9✔
909
        logger.log("applying functions")
8✔
910
        for col, func in func_dict.items():
8✔
911
            for fc, en in ensembles.items():
8✔
912
                if col in en.columns:
8✔
913
                    en.loc[:, col] = en.loc[:, col].apply(func)
8✔
914
        logger.log("applying functions")
8✔
915

916
    # get a list of all cols (union)
917
    all_cols = set()
9✔
918
    for fc, en in ensembles.items():
9✔
919
        cols = set(en.columns)
9✔
920
        all_cols.update(cols)
9✔
921
    if plot_cols is None:
9✔
922
        plot_cols = {i: [v] for i, v in (zip(all_cols, all_cols))}
1✔
923
    else:
924
        if isinstance(plot_cols, list):
8✔
925
            splot_cols = set(plot_cols)
8✔
926
            plot_cols = {i: [v] for i, v in (zip(plot_cols, plot_cols))}
8✔
927
        elif isinstance(plot_cols, dict):
8✔
928
            splot_cols = []
8✔
929
            for label, pcols in plot_cols.items():
8✔
930
                splot_cols.extend(list(pcols))
8✔
931
            splot_cols = set(splot_cols)
8✔
932
        else:
933
            logger.lraise(
×
934
                "unrecognized plot_cols type: {0}, should be list or dict".format(
935
                    type(plot_cols)
936
                )
937
            )
938

939
        missing = splot_cols - all_cols
8✔
940
        if len(missing) > 0:
8✔
941
            logger.lraise(
×
942
                "the following plot_cols are missing: {0}".format(",".join(missing))
943
            )
944

945
    logger.statement("plotting {0} histograms".format(len(plot_cols)))
9✔
946

947
    fig = plt.figure(figsize=figsize)
9✔
948
    if "fig_title" in kwargs:
9✔
949
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
950
    else:
951
        plt.figtext(
9✔
952
            0.5,
953
            0.5,
954
            "pyemu.plot_utils.ensemble_helper()\n at {0}".format(str(datetime.now())),
955
            ha="center",
956
        )
957
    # plot_cols = list(plot_cols)
958
    # plot_cols.sort()
959
    labels = list(plot_cols.keys())
9✔
960
    labels.sort()
9✔
961
    logger.statement("saving pdf to {0}".format(filename))
9✔
962
    figs = []
9✔
963

964
    ax_count = 0
9✔
965

966
    # for label,plot_col in plot_cols.items():
967
    for label in labels:
9✔
968
        plot_col = plot_cols[label]
9✔
969
        logger.log("plotting reals for {0}".format(label))
9✔
970
        if ax_count % (nr * nc) == 0:
9✔
971
            plt.tight_layout()
9✔
972
            # pdf.savefig()
973
            # plt.close(fig)
974
            figs.append(fig)
9✔
975
            fig = plt.figure(figsize=figsize)
9✔
976
            axes = _get_page_axes()
9✔
977
            [ax.set_yticks([]) for ax in axes]
9✔
978
            ax_count = 0
9✔
979

980
        ax = axes[ax_count]
9✔
981

982
        if sync_bins:
9✔
983
            mx, mn = -1.0e30, 1.0e30
9✔
984
            for fc, en in ensembles.items():
9✔
985
                # for pc in plot_col:
986
                #     if pc in en.columns:
987
                #         emx,emn = en.loc[:,pc].max(),en.loc[:,pc].min()
988
                #         mx = max(mx,emx)
989
                #         mn = min(mn,emn)
990
                emn = np.nanmin(en.loc[:, plot_col].values)
9✔
991
                emx = np.nanmax(en.loc[:, plot_col].values)
9✔
992
                mx = max(mx, emx)
9✔
993
                mn = min(mn, emn)
9✔
994
            if mx == -1.0e30 and mn == 1.0e30:
9✔
995
                logger.warn("all NaNs for label: {0}".format(label))
×
996
                ax.set_title(
×
997
                    "{0}) {1}, count:{2} - all NaN".format(
998
                        abet[ax_count], label, len(plot_col)
999
                    ),
1000
                    loc="left",
1001
                )
1002
                ax.set_yticks([])
×
1003
                ax.set_xticks([])
×
1004
                ax_count += 1
×
1005
                continue
×
1006
            plot_bins = np.linspace(mn, mx, num=bins)
9✔
1007
            logger.statement("{0} min:{1:5G}, max:{2:5G}".format(label, mn, mx))
9✔
1008
        else:
1009
            plot_bins = bins
8✔
1010
        for fc, en in ensembles.items():
9✔
1011
            # for pc in plot_col:
1012
            #    if pc in en.columns:
1013
            #        try:
1014
            #            en.loc[:,pc].hist(bins=plot_bins,facecolor=fc,
1015
            #                                    edgecolor="none",alpha=0.5,
1016
            #                                    density=True,ax=ax)
1017
            #        except Exception as e:
1018
            #            logger.warn("error plotting histogram for {0}:{1}".
1019
            #                        format(pc,str(e)))
1020
            vals = en.loc[:, plot_col].values.flatten()
9✔
1021
            # print(plot_bins)
1022
            # print(vals)
1023

1024
            ax.hist(
9✔
1025
                vals,
1026
                bins=plot_bins,
1027
                edgecolor="none",
1028
                alpha=0.5,
1029
                density=True,
1030
                facecolor=fc,
1031
            )
1032
            v = None
9✔
1033
            if deter_vals is not None:
9✔
1034
                for pc in plot_col:
9✔
1035
                    if pc in deter_vals:
9✔
1036
                        ylim = ax.get_ylim()
9✔
1037
                        v = deter_vals[pc]
9✔
1038
                        ax.plot([v, v], ylim, "k--", lw=1.5)
9✔
1039
                        ax.set_ylim(ylim)
9✔
1040

1041
            if std_window is not None:
9✔
1042
                try:
×
1043
                    ylim = ax.get_ylim()
×
1044
                    mn, st = (
×
1045
                        en.loc[:, pc].mean(),
1046
                        en.loc[:, pc].std() * (std_window / 2.0),
1047
                    )
1048

1049
                    ax.plot([mn - st, mn - st], ylim, color=fc, lw=1.5, ls="--")
×
1050
                    ax.plot([mn + st, mn + st], ylim, color=fc, lw=1.5, ls="--")
×
1051
                    ax.set_ylim(ylim)
×
1052
                    if deter_range and v is not None:
×
1053
                        xmn = v - st
×
1054
                        xmx = v + st
×
1055
                        ax.set_xlim(xmn, xmx)
×
1056
                except:
×
1057
                    logger.warn("error plotting std window for {0}".format(pc))
×
1058
        ax.grid()
9✔
1059
        if len(ensembles) > 1:
9✔
1060
            ax.set_title(
8✔
1061
                "{0}) {1}, count: {2}".format(abet[ax_count], label, len(plot_col)),
1062
                loc="left",
1063
            )
1064
        else:
1065
            ax.set_title(
9✔
1066
                "{0}) {1}, count:{2}\nmin:{3:3.1E}, max:{4:3.1E}".format(
1067
                    abet[ax_count],
1068
                    label,
1069
                    len(plot_col),
1070
                    np.nanmin(vals),
1071
                    np.nanmax(vals),
1072
                ),
1073
                loc="left",
1074
            )
1075
        ax_count += 1
9✔
1076

1077
    for a in range(ax_count, nr * nc):
9✔
1078
        axes[a].set_axis_off()
9✔
1079
        axes[a].set_yticks([])
9✔
1080
        axes[a].set_xticks([])
9✔
1081

1082
    plt.tight_layout()
9✔
1083
    # pdf.savefig()
1084
    # plt.close(fig)
1085
    figs.append(fig)
9✔
1086
    if filename is not None:
9✔
1087
        plt.tight_layout()
8✔
1088
        with PdfPages(filename) as pdf:
8✔
1089
            for fig in figs:
8✔
1090
                pdf.savefig(fig)
8✔
1091
                plt.close(fig)
8✔
1092
    logger.log("pyemu.plot_utils.ensemble_helper()")
9✔
1093

1094

1095
def ensemble_change_summary(
9✔
1096
    ensemble1,
1097
    ensemble2,
1098
    pst,
1099
    bins=10,
1100
    facecolor="0.5",
1101
    logger=None,
1102
    filename=None,
1103
    **kwargs
1104
):
1105
    """helper function to plot first and second moment change histograms between two
1106
    ensembles
1107

1108
    Args:
1109
        ensemble1 (varies): filename or `pandas.DataFrame` or `pyemu.Ensemble`
1110
        ensemble2 (varies): filename or `pandas.DataFrame` or `pyemu.Ensemble`
1111
        pst (`pyemu.Pst`): control file
1112
        facecolor (`str`): the histogram facecolor.
1113
        filename (`str`): the name of the multi-pdf to create. If None, return figs without saving.  Default is None.
1114

1115
    Returns:
1116
        [`matplotlib.Figure`]: a list of figures.  Returns None is
1117
        `filename` is not None
1118

1119
    Example::
1120

1121
        pst = pyemu.Pst("my.pst")
1122
        prior = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="prior.jcb")
1123
        post = pyemu.ParameterEnsemble.from_binary(pst=pst, filename="my.3.par.jcb")
1124
        pyemu.plot_utils.ensemble_change_summary(prior,post)
1125
        plt.show()
1126

1127

1128
    """
1129
    if logger is None:
8✔
1130
        logger = Logger("Default_Loggger.log", echo=False)
8✔
1131
    logger.log("plot ensemble change")
8✔
1132

1133
    if isinstance(ensemble1, str):
8✔
1134
        ensemble1 = pd.read_csv(ensemble1, index_col=0)
×
1135
    ensemble1.columns = ensemble1.columns.str.lower()
8✔
1136

1137
    if isinstance(ensemble2, str):
8✔
1138
        ensemble2 = pd.read_csv(ensemble2, index_col=0)
×
1139
    ensemble2.columns = ensemble2.columns.str.lower()
8✔
1140

1141
    # better to ensure this is caught by pestpp-ies ensemble csvs
1142
    unnamed1 = [col for col in ensemble1.columns if "unnamed:" in col]
8✔
1143
    if len(unnamed1) != 0:
8✔
1144
        ensemble1 = ensemble1.iloc[
×
1145
            :, :-1
1146
        ]  # ensure unnamed col result of poor csv read only (ie last col)
1147
    unnamed2 = [col for col in ensemble2.columns if "unnamed:" in col]
8✔
1148
    if len(unnamed2) != 0:
8✔
1149
        ensemble2 = ensemble2.iloc[
×
1150
            :, :-1
1151
        ]  # ensure unnamed col result of poor csv read only (ie last col)
1152

1153
    d = set(ensemble1.columns).symmetric_difference(set(ensemble2.columns))
8✔
1154

1155
    if len(d) != 0:
8✔
1156
        logger.lraise(
×
1157
            "ensemble1 does not have the same columns as ensemble2: {0}".format(
1158
                ",".join(d)
1159
            )
1160
        )
1161
    if "grouper" in kwargs:
8✔
1162
        raise NotImplementedError()
×
1163
    else:
1164
        en_cols = ensemble1.columns
8✔
1165
        if len(en_cols.difference(pst.par_names)) == 0:
8✔
1166
            par = pst.parameter_data.loc[en_cols, :]
8✔
1167
            grouper = par.groupby(par.pargp).groups
8✔
1168
            grouper["all"] = pst.adj_par_names
8✔
1169
            li = par.loc[par.partrans == "log", "parnme"]
8✔
1170
            ensemble1.loc[:, li] = ensemble1.loc[:, li].apply(np.log10)
8✔
1171
            ensemble2.loc[:, li] = ensemble2.loc[:, li].apply(np.log10)
8✔
1172
        elif len(en_cols.difference(pst.obs_names)) == 0:
8✔
1173
            obs = pst.observation_data.loc[en_cols, :]
8✔
1174
            grouper = obs.groupby(obs.obgnme).groups
8✔
1175
            grouper["all"] = pst.nnz_obs_names
8✔
1176
        else:
1177
            logger.lraise("could not match ensemble cols with par or obs...")
×
1178

1179
    en1_mn, en1_std = ensemble1.mean(axis=0), ensemble1.std(axis=0)
8✔
1180
    en2_mn, en2_std = ensemble2.mean(axis=0), ensemble2.std(axis=0)
8✔
1181

1182
    # mn_diff = 100.0 * ((en1_mn - en2_mn) / en1_mn)
1183
    # std_diff = 100 * ((en1_std - en2_std) / en1_std)
1184

1185
    mn_diff = -1 * (en2_mn - en1_mn)
8✔
1186
    std_diff = 100 * (((en1_std - en2_std) / en1_std))
8✔
1187
    # set en1_std==0 to nan
1188
    # std_diff[en1_std.index[en1_std==0]] = np.nan
1189

1190
    # diff = ensemble1 - ensemble2
1191
    # mn_diff = diff.mean(axis=0)
1192
    # std_diff = diff.std(axis=0)
1193

1194
    fig = plt.figure(figsize=figsize)
8✔
1195
    if "fig_title" in kwargs:
8✔
1196
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
1197
    else:
1198
        plt.figtext(
8✔
1199
            0.5,
1200
            0.5,
1201
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
1202
                pst.filename, str(datetime.now())
1203
            ),
1204
            ha="center",
1205
        )
1206
    # if plot_hexbin:
1207
    #    pdfname = pst.filename.replace(".pst", ".1to1.hexbin.pdf")
1208
    # else:
1209
    #    pdfname = pst.filename.replace(".pst", ".1to1.pdf")
1210
    figs = []
8✔
1211
    ax_count = 0
8✔
1212
    for g, names in grouper.items():
8✔
1213
        logger.log("plotting change for {0}".format(g))
8✔
1214

1215
        mn_g = mn_diff.loc[names]
8✔
1216
        std_g = std_diff.loc[names]
8✔
1217

1218
        if mn_g.shape[0] == 0:
8✔
1219
            logger.statement("no entries for group '{0}'".format(g))
×
1220
            logger.log("plotting change for {0}".format(g))
×
1221
            continue
×
1222

1223
        if ax_count % (nr * nc) == 0:
8✔
1224
            if ax_count > 0:
8✔
1225
                plt.tight_layout()
×
1226
            # pdf.savefig()
1227
            # plt.close(fig)
1228
            figs.append(fig)
8✔
1229
            fig = plt.figure(figsize=figsize)
8✔
1230
            axes = _get_page_axes()
8✔
1231
            ax_count = 0
8✔
1232

1233
        ax = axes[ax_count]
8✔
1234
        mn_g.hist(ax=ax, facecolor=facecolor, alpha=0.5, edgecolor=None, bins=bins)
8✔
1235
        # mx = max(mn_g.max(), mn_g.min(),np.abs(mn_g.max()),np.abs(mn_g.min())) * 1.2
1236
        # ax.set_xlim(-mx,mx)
1237

1238
        # std_g.hist(ax=ax,facecolor='b',alpha=0.5,edgecolor=None)
1239

1240
        # ax.set_xlim(xlim)
1241
        ax.set_yticklabels([])
8✔
1242
        ax.set_xlabel("mean change", labelpad=0.1)
8✔
1243
        ax.set_title(
8✔
1244
            "{0}) mean change group:{1}, {2} entries\nmax:{3:10G}, min:{4:10G}".format(
1245
                abet[ax_count], g, mn_g.shape[0], mn_g.max(), mn_g.min()
1246
            ),
1247
            loc="left",
1248
        )
1249
        ax.grid()
8✔
1250
        ax_count += 1
8✔
1251

1252
        ax = axes[ax_count]
8✔
1253
        std_g.hist(ax=ax, facecolor=facecolor, alpha=0.5, edgecolor=None, bins=bins)
8✔
1254
        # std_g.hist(ax=ax,facecolor='b',alpha=0.5,edgecolor=None)
1255

1256
        # ax.set_xlim(xlim)
1257
        ax.set_yticklabels([])
8✔
1258
        ax.set_xlabel("sigma percent reduction", labelpad=0.1)
8✔
1259
        ax.set_title(
8✔
1260
            "{0}) sigma change group:{1}, {2} entries\nmax:{3:10G}, min:{4:10G}".format(
1261
                abet[ax_count], g, mn_g.shape[0], std_g.max(), std_g.min()
1262
            ),
1263
            loc="left",
1264
        )
1265
        ax.grid()
8✔
1266
        ax_count += 1
8✔
1267

1268
        logger.log("plotting change for {0}".format(g))
8✔
1269

1270
    for a in range(ax_count, nr * nc):
8✔
1271
        axes[a].set_axis_off()
8✔
1272
        axes[a].set_yticks([])
8✔
1273
        axes[a].set_xticks([])
8✔
1274

1275
    plt.tight_layout()
8✔
1276
    # pdf.savefig()
1277
    # plt.close(fig)
1278
    figs.append(fig)
8✔
1279
    if filename is not None:
8✔
1280
        plt.tight_layout()
8✔
1281
        with PdfPages(filename) as pdf:
8✔
1282
            for fig in figs:
8✔
1283
                pdf.savefig(fig)
8✔
1284
                plt.close(fig)
8✔
1285
        logger.log("plot ensemble change")
8✔
1286
    else:
1287
        logger.log("plot ensemble change")
8✔
1288
        return figs
8✔
1289

1290

1291
def _process_ensemble_arg(ensemble, facecolor, logger):
9✔
1292
    """private method to work out ensemble plot args"""
1293
    ensembles = {}
9✔
1294
    if isinstance(ensemble, pd.DataFrame) or isinstance(ensemble, pyemu.Ensemble):
9✔
1295
        if not isinstance(facecolor, str):
9✔
1296
            logger.lraise("facecolor must be str")
×
1297
        ensembles[facecolor] = ensemble
9✔
1298
    elif isinstance(ensemble, str):
8✔
1299
        if not isinstance(facecolor, str):
8✔
1300
            logger.lraise("facecolor must be str")
×
1301

1302
        logger.log("loading ensemble from csv file {0}".format(ensemble))
8✔
1303
        en = pd.read_csv(ensemble, index_col=0)
8✔
1304
        logger.statement("{0} shape: {1}".format(ensemble, en.shape))
8✔
1305
        ensembles[facecolor] = en
8✔
1306
        logger.log("loading ensemble from csv file {0}".format(ensemble))
8✔
1307

1308
    elif isinstance(ensemble, list):
8✔
1309
        if isinstance(facecolor, list):
8✔
1310
            if len(ensemble) != len(facecolor):
×
1311
                logger.lraise("facecolor list len != ensemble list len")
×
1312
        else:
1313
            colors = ["m", "c", "b", "r", "g", "y"]
8✔
1314

1315
            facecolor = [colors[i] for i in range(len(ensemble))]
8✔
1316
        ensembles = {}
8✔
1317
        for fc, en_arg in zip(facecolor, ensemble):
8✔
1318
            if isinstance(en_arg, str):
8✔
1319
                logger.log("loading ensemble from csv file {0}".format(en_arg))
8✔
1320
                en = pd.read_csv(en_arg, index_col=0)
8✔
1321
                logger.log("loading ensemble from csv file {0}".format(en_arg))
8✔
1322
                logger.statement("ensemble {0} gets facecolor {1}".format(en_arg, fc))
8✔
1323

1324
            elif isinstance(en_arg, pd.DataFrame) or isinstance(en_arg, pyemu.Ensemble):
8✔
1325
                en = en_arg
8✔
1326
            else:
1327
                logger.lraise("unrecognized ensemble list arg:{0}".format(en_arg))
×
1328
            ensembles[fc] = en
8✔
1329

1330
    elif isinstance(ensemble, dict):
8✔
1331
        for fc, en_arg in ensemble.items():
8✔
1332
            if isinstance(en_arg, pd.DataFrame) or isinstance(en_arg, pyemu.Ensemble):
8✔
1333
                ensembles[fc] = en_arg
8✔
1334
            elif isinstance(en_arg, str):
8✔
1335
                logger.log("loading ensemble from csv file {0}".format(en_arg))
8✔
1336
                en = pd.read_csv(en_arg, index_col=0)
8✔
1337
                logger.log("loading ensemble from csv file {0}".format(en_arg))
8✔
1338
                ensembles[fc] = en
8✔
1339
            else:
1340
                logger.lraise("unrecognized ensemble list arg:{0}".format(en_arg))
×
1341
    try:
9✔
1342
        for fc in ensembles:
9✔
1343
            ensembles[fc].columns = ensembles[fc].columns.str.lower()
9✔
1344
    except:
×
1345
        logger.lraise("error processing ensemble")
×
1346

1347
    return ensembles
9✔
1348

1349

1350
def ensemble_res_1to1(
9✔
1351
    ensemble,
1352
    pst,
1353
    facecolor="0.5",
1354
    logger=None,
1355
    filename=None,
1356
    skip_groups=[],
1357
    base_ensemble=None,
1358
    **kwargs
1359
):
1360
    """helper function to plot ensemble 1-to-1 plots showing the simulated range
1361

1362
    Args:
1363
        ensemble (varies):  the ensemble argument can be a pandas.DataFrame or derived type or a str, which
1364
            is treated as a fileanme.  Optionally, ensemble can be a list of these types or
1365
            a dict, in which case, the keys are treated as facecolor str (e.g., 'b', 'y', etc).
1366
        pst (`pyemu.Pst`): a control file instance
1367
        facecolor (`str`): the histogram facecolor.  Only applies if `ensemble` is a single thing
1368
        filename (`str`): the name of the pdf to create. If None, return figs
1369
            without saving.  Default is None.
1370
        base_ensemble (`varies`): an optional ensemble argument for the observations + noise ensemble.
1371
            This will be plotted as a transparent red bar on the 1to1 plot.
1372

1373
    Note:
1374

1375
        the vertical bar on each plot the min-max range
1376

1377
    Example::
1378

1379

1380
        pst = pyemu.Pst("my.pst")
1381
        prior = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.0.obs.jcb")
1382
        post = pyemu.ObservationEnsemble.from_binary(pst=pst, filename="my.3.obs.jcb")
1383
        pyemu.plot_utils.ensemble_res_1to1(ensemble={"0.5":prior, "b":post})
1384
        plt.show()
1385

1386
    """
1387
    def _get_plotlims(oen, ben, obsnames):
8✔
1388
        if not isinstance(oen, dict):
8✔
1389
            oen = {'g': oen.loc[:, obsnames]}
×
1390
        if not isinstance(ben, dict):
8✔
1391
            ben = {'g': ben.get(obsnames)}
8✔
1392
        outofrange = False
8✔
1393
        # work back from crazy values
1394
        oemin = 1e32
8✔
1395
        oemeanmin = 1e32
8✔
1396
        oemax = -1e32
8✔
1397
        oemeanmax = -1e32
8✔
1398
        bemin = 1e32
8✔
1399
        bemeanmin = 1e32
8✔
1400
        bemax = -1e32
8✔
1401
        bemeanmax = -1e32
8✔
1402
        for _, oeni in oen.items():  # loop over ensembles
8✔
1403
            oeni = oeni.loc[:, obsnames]  # slice group obs
8✔
1404
            oemin = np.min([oemin, oeni.min().min()])
8✔
1405
            oemax = np.max([oemax, oeni.max().max()])
8✔
1406
            # get min and max of mean sim vals
1407
            # (incase we want plot to ignore extremes)
1408
            oemeanmin = np.min([oemeanmin, oeni.mean().min()])
8✔
1409
            oemeanmax = np.max([oemeanmax, oeni.mean().max()])
8✔
1410
        for _, beni in ben.items():  # same with base ensemble/obsval
8✔
1411
            # work with either ensemble or obsval series
1412
            beni = beni.get(obsnames)
8✔
1413
            bemin = np.min([bemin, beni.min().min()])
8✔
1414
            bemax = np.max([bemax, beni.max().max()])
8✔
1415
            bemeanmin = np.min([bemeanmin, beni.mean().min()])
8✔
1416
            bemeanmax = np.max([bemeanmax, beni.mean().max()])
8✔
1417
        # get base ensemble range
1418
        berange = bemax-bemin
8✔
1419
        if berange == 0.:  # only one obs in group (probs)
8✔
1420
            berange = bemeanmax * 1.1  # expand a little
8✔
1421
        # add buffer to obs endpoints
1422
        bemin = bemin - (berange*0.05)
8✔
1423
        bemax = bemax + (berange*0.05)
8✔
1424
        if oemax < bemin:  # sim well below obs
8✔
1425
            oemin = oemeanmin  # set min to mean min
8✔
1426
            # (sim captured but not extremes)
1427
            outofrange = True
8✔
1428
        if oemin > bemax:  # sim well above obs
8✔
1429
            oemax = oemeanmax
8✔
1430
            outofrange = True
8✔
1431
        oerange = oemax - oemin
8✔
1432
        if bemax > oemax + (0.1*oerange):  # obs max well above sim
8✔
1433
            if not outofrange:  # but sim still in range
8✔
1434
                # zoom to sim
1435
                bemax = oemax + (0.1*oerange)
8✔
1436
            else:  # use obs mean max
1437
                bemax = bemeanmax
8✔
1438
        if bemin < oemin - (0.1 * oerange):  # obs min well below sim
8✔
1439
            if not outofrange:  # but sim still in range
8✔
1440
                # zoom to sim
1441
                bemin = oemin - (0.1 * oerange)
8✔
1442
            else:
1443
                bemin = bemeanmin
8✔
1444
        pmin = np.min([oemin, bemin])
8✔
1445
        pmax = np.max([oemax, bemax])
8✔
1446
        return pmin, pmax
8✔
1447

1448

1449
    if logger is None:
8✔
1450
        logger = Logger("Default_Loggger.log", echo=False)
8✔
1451
    logger.log("plot res_1to1")
8✔
1452
    obs = pst.observation_data
8✔
1453
    ensembles = _process_ensemble_arg(ensemble, facecolor, logger)
8✔
1454

1455
    if base_ensemble is not None:
8✔
1456
        base_ensemble = _process_ensemble_arg(base_ensemble, "r", logger)
8✔
1457

1458
    if "grouper" in kwargs:
8✔
1459
        raise NotImplementedError()
×
1460
    else:
1461
        grouper = obs.groupby(obs.obgnme).groups
8✔
1462
        for skip_group in skip_groups:
8✔
1463
            grouper.pop(skip_group)
×
1464

1465
    fig = plt.figure(figsize=figsize)
8✔
1466
    if "fig_title" in kwargs:
8✔
1467
        plt.figtext(0.5, 0.5, kwargs["fig_title"])
×
1468
    else:
1469
        plt.figtext(
8✔
1470
            0.5,
1471
            0.5,
1472
            "pyemu.Pst.plot(kind='1to1')\nfrom pest control file '{0}'\n at {1}".format(
1473
                pst.filename, str(datetime.now())
1474
            ),
1475
            ha="center",
1476
        )
1477

1478
    figs = []
8✔
1479
    ax_count = 0
8✔
1480
    for g, names in grouper.items():
8✔
1481
        logger.log("plotting 1to1 for {0}".format(g))
8✔
1482
        # control file observation for group
1483
        obs_g = obs.loc[names, :]
8✔
1484
        # normally only look a non-zero weighted obs
1485
        if "include_zero" not in kwargs or kwargs["include_zero"] is False:
8✔
1486
            obs_g = obs_g.loc[obs_g.weight > 0, :]
8✔
1487
        if obs_g.shape[0] == 0:
8✔
1488
            logger.statement("no non-zero obs for group '{0}'".format(g))
8✔
1489
            logger.log("plotting 1to1 for {0}".format(g))
8✔
1490
            continue
8✔
1491
        # if the first axis in page
1492
        if ax_count % (nr * nc) == 0:
8✔
1493
            if ax_count > 0:
8✔
1494
                plt.tight_layout()
×
1495
            figs.append(fig)
8✔
1496
            fig = plt.figure(figsize=figsize)
8✔
1497
            axes = _get_page_axes()
8✔
1498
            ax_count = 0
8✔
1499
        ax = axes[ax_count]
8✔
1500

1501
        if base_ensemble is None:
8✔
1502
            # if obs not defined by obs+noise ensemble,
1503
            # use min and max for obsval from control file
1504
            pmin, pmax = _get_plotlims(ensembles, obs_g.obsval, obs_g.obsnme)
8✔
1505
        else:
1506
            # if obs defined by obs+noise use obs+noise min and max
1507
            pmin, pmax = _get_plotlims(ensembles, base_ensemble, obs_g.obsnme)
8✔
1508
            obs_gg = obs_g.sort_values(by="obsval")
8✔
1509
            for c, en in base_ensemble.items():
8✔
1510
                en_g = en.loc[:, obs_gg.obsnme]
8✔
1511
                emx = en_g.max()
8✔
1512
                emn = en_g.min()
8✔
1513
                
1514
                #exit()
1515
                # update y min and max for obs+noise ensembles
1516
                if len(obs_gg.obsval) > 1:
8✔
1517

1518
                    emx = np.zeros(obs_gg.shape[0]) + emx
8✔
1519
                    emn = np.zeros(obs_gg.shape[0]) + emn
8✔
1520
                    ax.fill_between(obs_gg.obsval.values, emn.values, emx.values,
8✔
1521
                                    facecolor=c, alpha=0.2, zorder=2)
1522
                else:
1523
                    ax.plot([obs_gg.obsval.values, obs_gg.obsval.values], [emn, emx], color=c, alpha=0.2, zorder=2)
8✔
1524
        for c, en in ensembles.items():
8✔
1525
            en_g = en.loc[:, obs_g.obsnme]
8✔
1526
            # output mins and maxs
1527
            emx = en_g.max()
8✔
1528
            emn = en_g.min()
8✔
1529
            [
8✔
1530
                ax.plot([ov, ov], [een, eex], color=c, zorder=1)
1531
                for ov, een, eex in zip(obs_g.obsval.values, emn.values, emx.values)
1532
            ]
1533
        ax.plot([pmin, pmax], [pmin, pmax], "k--", lw=1.0, zorder=3)
8✔
1534
        xlim = (pmin, pmax)
8✔
1535
        ax.set_xlim(pmin, pmax)
8✔
1536
        ax.set_ylim(pmin, pmax)
8✔
1537

1538
        if max(np.abs(xlim)) > 1.0e5:
8✔
1539
            ax.xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
1540
            ax.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
1541
        ax.grid()
8✔
1542

1543
        ax.set_xlabel("observed", labelpad=0.1)
8✔
1544
        ax.set_ylabel("simulated", labelpad=0.1)
8✔
1545
        ax.set_title(
8✔
1546
            "{0}) group:{1}, {2} observations".format(
1547
                abet[ax_count], g, obs_g.shape[0]
1548
            ),
1549
            loc="left",
1550
        )
1551

1552
        # Residual (RHS plot)
1553
        ax_count += 1
8✔
1554
        ax = axes[ax_count]
8✔
1555
        # ax.scatter(obs_g.obsval, obs_g.res, marker='.', s=10, color='b')
1556

1557
        if base_ensemble is not None:
8✔
1558
            obs_gg = obs_g.sort_values(by="obsval")
8✔
1559
            for c, en in base_ensemble.items():
8✔
1560
                en_g = en.loc[:, obs_gg.obsnme].subtract(obs_gg.obsval)
8✔
1561
                emx = en_g.max()
8✔
1562
                emn = en_g.min()
8✔
1563
                if len(obs_gg.obsval) > 1:
8✔
1564
                    emx = np.zeros(obs_gg.shape[0]) + emx
8✔
1565
                    emn = np.zeros(obs_gg.shape[0]) + emn
8✔
1566
                    ax.fill_between(obs_gg.obsval.values, emn.values, emx.values,
8✔
1567
                                    facecolor=c, alpha=0.2, zorder=2)
1568
                else:
1569
                    # [ax.plot([ov, ov], [een, eex], color=c,alpha=0.3) for ov, een, eex in zip(obs_g.obsval.values, en.values, ex.values)]
1570
                    ax.plot([obs_gg.obsval.values, obs_gg.obsval.values], [emn, emx], color=c,
8✔
1571
                            alpha=0.2, zorder=2)
1572
        omn = []
8✔
1573
        omx = []
8✔
1574
        for c, en in ensembles.items():
8✔
1575
            en_g = en.loc[:, obs_g.obsnme].subtract(obs_g.obsval, axis=1)
8✔
1576
            emx = en_g.max()
8✔
1577
            emn = en_g.min()
8✔
1578
            omn.append(emn)
8✔
1579
            omx.append(emx)
8✔
1580
            [
8✔
1581
                ax.plot([ov, ov], [een, eex], color=c, zorder=1)
1582
                for ov, een, eex in zip(obs_g.obsval.values, emn.values, emx.values)
1583
            ]
1584

1585
        omn = pd.concat(omn).min()
8✔
1586
        omx = pd.concat(omx).max()
8✔
1587
        mx = max(np.abs(omn), np.abs(omx))  # ensure symmetric about y=0
8✔
1588
        if obs_g.shape[0] == 1:
8✔
1589
            mx *= 1.05
8✔
1590
        else:
1591
            mx *= 1.02
8✔
1592
        if np.sign(omn) == np.sign(omx):
8✔
1593
            # allow y axis asymm if all above or below
1594
            mn = np.min([0, np.sign(omn) * mx])
8✔
1595
            mx = np.max([0, np.sign(omn) * mx])
8✔
1596
        else:
1597
            mn = -mx
8✔
1598
        ax.set_ylim(mn, mx)
8✔
1599
        bmin = obs_g.obsval.values.min()
8✔
1600
        bmax = obs_g.obsval.values.max()
8✔
1601
        brange = (bmax - bmin)
8✔
1602
        if brange == 0.:
8✔
1603
            brange = obs_g.obsval.values.mean()
8✔
1604
        bmin = bmin - 0.1*brange
8✔
1605
        bmax = bmax + 0.1*brange
8✔
1606
        xlim = (bmin, bmax)
8✔
1607
        # show a zero residuals line
1608
        ax.plot(xlim, [0, 0], "k--", lw=1.0, zorder=3)
8✔
1609

1610
        ax.set_xlim(xlim)
8✔
1611
        ax.set_ylabel("residual", labelpad=0.1)
8✔
1612
        ax.set_xlabel("observed", labelpad=0.1)
8✔
1613
        ax.set_title(
8✔
1614
            "{0}) group:{1}, {2} observations".format(
1615
                abet[ax_count], g, obs_g.shape[0]
1616
            ),
1617
            loc="left",
1618
        )
1619
        ax.grid()
8✔
1620
        if ax.get_xlim()[1] > 1.0e5:
8✔
1621
            ax.xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%1.0e"))
×
1622

1623
        ax_count += 1
8✔
1624

1625
        logger.log("plotting 1to1 for {0}".format(g))
8✔
1626

1627
    for a in range(ax_count, nr * nc):
8✔
1628
        axes[a].set_axis_off()
8✔
1629
        axes[a].set_yticks([])
8✔
1630
        axes[a].set_xticks([])
8✔
1631

1632
    plt.tight_layout()
8✔
1633
    figs.append(fig)
8✔
1634
    if filename is not None:
8✔
1635
        plt.tight_layout()
8✔
1636
        with PdfPages(filename) as pdf:
8✔
1637
            for fig in figs:
8✔
1638
                pdf.savefig(fig)
8✔
1639
                plt.close(fig)
8✔
1640
        logger.log("plot res_1to1")
8✔
1641
    else:
1642
        logger.log("plot res_1to1")
×
1643
        return figs
×
1644

1645

1646
def plot_jac_test(
9✔
1647
    csvin, csvout, targetobs=None, filetype=None, maxoutputpages=1, outputdirectory=None
1648
):
1649
    """helper function to plot results of the Jacobian test performed using the pest++
1650
    program pestpp-swp.
1651

1652
    Args:
1653
        csvin (`str`): name of csv file used as input to sweep, typically developed with
1654
            static method pyemu.helpers.build_jac_test_csv()
1655
        csvout (`str`): name of csv file with output generated by sweep, both input
1656
            and output files can be specified in the pest++ control file
1657
            with pyemu using: pest_object.pestpp_options["sweep_parameter_csv_file"] = jactest_in_file.csv
1658
            pest_object.pestpp_options["sweep_output_csv_file"] = jactest_out_file.csv
1659
        targetobs ([`str`]): list of observation file names to plot, each parameter used for jactest can
1660
            have up to 32 observations plotted per page, throws a warning if more tha
1661
            10 pages of output are requested per parameter. If none, all observations in
1662
            the output csv file are used.
1663
        filetype (`str`): file type to store output, if None, plt.show() is called.
1664
        maxoutputpages (`int`): maximum number of pages of output per parameter.  Each page can
1665
            hold up to 32 observation derivatives.  If value > 10, set it to
1666
            10 and throw a warning.  If observations in targetobs > 32*maxoutputpages,
1667
            then a random set is selected from the targetobs list (or all observations
1668
            in the csv file if targetobs=None).
1669
        outputdirectory (`str`):  directory to store results, if None, current working directory is used.
1670
            If string is passed, it is joined to the current working directory and
1671
            created if needed. If os.path is passed, it is used directly.
1672

1673
    Note:
1674
        Used in conjunction with pyemu.helpers.build_jac_test_csv() and sweep to perform
1675
        a Jacobian Test and then view the results. Can generate a lot of plots so easiest
1676
        to put into a separate directory and view the files.
1677

1678
    """
1679

1680
    localhome = os.getcwd()
×
1681
    # check if the output directory exists, if not make it
1682
    if outputdirectory is not None and not os.path.exists(
×
1683
        os.path.join(localhome, outputdirectory)
1684
    ):
1685
        os.mkdir(os.path.join(localhome, outputdirectory))
×
1686
    if outputdirectory is None:
×
1687
        figures_dir = localhome
×
1688
    else:
1689
        figures_dir = os.path.join(localhome, outputdirectory)
×
1690

1691
    # read the input and output files into pandas dataframes
1692
    jactest_in_df = pd.read_csv(csvin, engine="python", index_col=0)
×
1693
    jactest_in_df.index.name = "input_run_id"
×
1694
    jactest_out_df = pd.read_csv(csvout, engine="python", index_col=1)
×
1695

1696
    # subtract the base run from every row, leaves the one parameter that
1697
    # was perturbed in any row as only non-zero value. Set zeros to nan
1698
    # so round-off doesn't get us and sum across rows to get a column of
1699
    # the perturbation for each row, finally extract to a series. First
1700
    # the input csv and then the output.
1701
    base_par = jactest_in_df.loc["base"]
×
1702
    delta_par_df = jactest_in_df.subtract(base_par, axis="columns")
×
1703
    delta_par_df.replace(0, np.nan, inplace=True)
×
1704
    delta_par_df.drop("base", axis="index", inplace=True)
×
1705
    delta_par_df["change"] = delta_par_df.sum(axis="columns")
×
1706
    delta_par = pd.Series(delta_par_df["change"])
×
1707

1708
    base_obs = jactest_out_df.loc["base"]
×
1709
    delta_obs = jactest_out_df.subtract(base_obs)
×
1710
    delta_obs.drop("base", axis="index", inplace=True)
×
1711
    # if targetobs is None, then reset it to all the observations.
1712
    if targetobs is None:
×
1713
        targetobs = jactest_out_df.columns.tolist()[8:]
×
1714
    delta_obs = delta_obs[targetobs]
×
1715

1716
    # get the Jacobian by dividing the change in observation by the change in parameter
1717
    # for the perturbed parameters
1718
    jacobian = delta_obs.divide(delta_par, axis="index")
×
1719

1720
    # use the index created by build_jac_test_csv to get a column of parameter names
1721
    # and increments, then we can plot derivative vs. increment for each parameter
1722
    extr_df = pd.Series(jacobian.index.values).str.extract(r"(.+)(_\d+$)", expand=True)
×
1723
    extr_df[1] = pd.to_numeric(extr_df[1].str.replace("_", "")) + 1
×
1724
    extr_df.rename(columns={0: "parameter", 1: "increment"}, inplace=True)
×
1725
    extr_df.index = jacobian.index
×
1726

1727
    # make a dataframe for plotting the Jacobian by combining the parameter name
1728
    # and increments frame with the Jacobian frame
1729
    plotframe = pd.concat([extr_df, jacobian], axis=1, join="inner")
×
1730

1731
    # get a list of observations to keep based on maxoutputpages.
1732
    if maxoutputpages > 10:
×
1733
        print("WARNING, more than 10 pages of output requested per parameter")
×
1734
        print("maxoutputpage reset to 10.")
×
1735
        maxoutputpages = 10
×
1736
    num_obs_plotted = np.min(np.array([maxoutputpages * 32, len(targetobs)]))
×
1737
    if num_obs_plotted < len(targetobs):
×
1738
        # get random sample
1739
        index_plotted = np.random.choice(len(targetobs), num_obs_plotted, replace=False)
×
1740
        obs_plotted = [targetobs[x] for x in index_plotted]
×
1741
        real_pages = maxoutputpages
×
1742
    else:
1743
        obs_plotted = targetobs
×
1744
        real_pages = int(targetobs / 32) + 1
×
1745

1746
    # make a subplot of derivative vs. increment one plot for each of the
1747
    # observations in targetobs, and outputs grouped by parameter.
1748
    for param, group in plotframe.groupby("parameter"):
×
1749
        for page in range(0, real_pages):
×
1750
            fig, axes = plt.subplots(8, 4, sharex=True, figsize=(10, 15))
×
1751
            for row in range(0, 8):
×
1752
                for col in range(0, 4):
×
1753
                    count = 32 * page + 4 * row + col
×
1754
                    if count < num_obs_plotted:
×
1755
                        axes[row, col].scatter(
×
1756
                            group["increment"], group[obs_plotted[count]]
1757
                        )
1758
                        axes[row, col].plot(
×
1759
                            group["increment"], group[obs_plotted[count]], "r"
1760
                        )
1761
                        axes[row, col].set_title(obs_plotted[count])
×
1762
                        axes[row, col].set_xticks([1, 2, 3, 4, 5])
×
1763
                        axes[row, col].tick_params(direction="in")
×
1764
                        if row == 3:
×
1765
                            axes[row, col].set_xlabel("Increment")
×
1766
            plt.tight_layout()
×
1767

1768
            if filetype is None:
×
1769
                plt.show()
×
1770
            else:
1771
                plt.savefig(
×
1772
                    os.path.join(
1773
                        figures_dir, "{0}_jactest_{1}.{2}".format(param, page, filetype)
1774
                    )
1775
                )
1776
            plt.close()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc