• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jungtaekkim / bayeso / 8255597116

12 Mar 2024 08:56PM UTC coverage: 0.0% (-100.0%) from 100.0%
8255597116

push

github

0 of 2492 relevant lines covered (0.0%)

0.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/bayeso/utils/utils_plotting.py
1
#
2
# author: Jungtaek Kim (jtkim@postech.ac.kr)
3
# last updated: September 24, 2020
4
#
5
"""It is utilities for plotting figures."""
×
6

7
import os
×
8
import numpy as np
×
9
try:
×
10
    import matplotlib.pyplot as plt
×
11
except:
×
12
    plt = None
×
13
try:
×
14
    import pylab
×
15
except:
×
16
    pylab = None
×
17

18
from bayeso.utils import utils_common
×
19
from bayeso.utils import utils_logger
×
20
from bayeso import constants
×
21

22
logger = utils_logger.get_logger('utils_plotting')
×
23

24

25
@utils_common.validate_types
×
26
def _set_font_config(use_tex: bool) -> constants.TYPE_NONE: # pragma: no cover
27
    """
28
    It sets a font configuration.
29

30
    :param use_tex: flag for using latex.
31
    :type use_tex: bool.
32

33
    :returns: None.
34
    :rtype: NoneType
35

36
    """
37

38
    if use_tex:
39
        plt.rc('text', usetex=True)
40
    else:
41
        plt.rc('pdf', fonttype=42)
42

43
@utils_common.validate_types
×
44
def _set_ax_config(ax: 'matplotlib.axes._subplots.AxesSubplot', str_x_axis: str, str_y_axis: str,
45
    size_labels: int=32,
46
    size_ticks: int=22,
47
    xlim_min: constants.TYPING_UNION_FLOAT_NONE=None,
48
    xlim_max: constants.TYPING_UNION_FLOAT_NONE=None,
49
    draw_box: bool=True,
50
    draw_zero_axis: bool=False,
51
    draw_grid: bool=True,
52
) -> constants.TYPE_NONE: # pragma: no cover
53
    """
54
    It sets an axis configuration.
55

56
    :param ax: inputs for acquisition function. Shape: (n, d).
57
    :type ax: matplotlib.axes._subplots.AxesSubplot
58
    :param str_x_axis: the name of x axis.
59
    :type str_x_axis: str.
60
    :param str_y_axis: the name of y axis.
61
    :type str_y_axis: str.
62
    :param size_labels: label size.
63
    :type size_labels: int., optional
64
    :param size_ticks: tick size.
65
    :type size_ticks: int., optional
66
    :param xlim_min: None, or minimum for x limit.
67
    :type xlim_min: NoneType or float, optional
68
    :param xlim_max: None, or maximum for x limit.
69
    :type xlim_max: NoneType or float, optional
70
    :param draw_box: flag for drawing a box.
71
    :type draw_box: bool., optional
72
    :param draw_zero_axis: flag for drawing a zero axis.
73
    :type draw_zero_axis: bool., optional
74
    :param draw_grid: flag for drawing grids.
75
    :type draw_grid: bool., optional
76

77
    :returns: None.
78
    :rtype: NoneType
79

80
    """
81

82
    if str_x_axis is not None:
83
        ax.set_xlabel(str_x_axis, fontsize=size_labels)
84
    ax.set_ylabel(str_y_axis, fontsize=size_labels)
85
    ax.tick_params(labelsize=size_ticks)
86

87
    if not draw_box:
88
        ax.spines['top'].set_color('none')
89
        ax.spines['right'].set_color('none')
90
    if xlim_min is not None and xlim_max is not None:
91
        ax.set_xlim([xlim_min, xlim_max])
92
    if draw_zero_axis:
93
        ax.spines['bottom'].set_position('zero')
94
    if draw_grid:
95
        ax.grid()
96

97
@utils_common.validate_types
×
98
def _save_figure(path_save: str, str_postfix: str,
99
    str_prefix: str=''
100
) -> constants.TYPE_NONE: # pragma: no cover
101
    """
102
    It saves a figure.
103

104
    :param path_save: path for saving a figure.
105
    :type path_save: str.
106
    :param str_postfix: the name of postfix.
107
    :type str_postfix: str.
108
    :param str_prefix: the name of prefix.
109
    :type str_prefix: str., optional
110

111
    :returns: None.
112
    :rtype: NoneType
113

114
    """
115

116
    if path_save is not None and str_postfix is not None:
117
        str_figure = str_prefix + str_postfix
118
        plt.savefig(os.path.join(path_save, str_figure + '.pdf'),
119
            format='pdf', transparent=True, bbox_inches='tight')
120

121
@utils_common.validate_types
×
122
def _show_figure(pause_figure: bool, time_pause: constants.TYPING_UNION_INT_FLOAT
123
) -> constants.TYPE_NONE: # pragma: no cover
124
    """
125
    It shows a figure.
126

127
    :param pause_figure: flag for pausing before closing a figure.
128
    :type pause_figure: bool.
129
    :param time_pause: pausing time.
130
    :type time_pause: int. or float
131

132
    :returns: None.
133
    :rtype: NoneType
134

135
    """
136

137
    if pause_figure:
138
        if time_pause < np.inf:
139
            plt.ion()
140
            plt.pause(time_pause)
141
            plt.close('all')
142
        else:
143
            plt.show()
144

145
@utils_common.validate_types
×
146
def plot_gp_via_sample(X: np.ndarray, Ys: np.ndarray,
147
    path_save: constants.TYPING_UNION_STR_NONE=None,
148
    str_postfix: constants.TYPING_UNION_STR_NONE=None,
149
    str_x_axis: str='x',
150
    str_y_axis: str='y',
151
    use_tex: bool=False,
152
    draw_zero_axis: bool=False,
153
    pause_figure: bool=True,
154
    time_pause: constants.TYPING_UNION_INT_FLOAT=constants.TIME_PAUSE,
155
    colors: np.ndarray=constants.COLORS,
156
) -> constants.TYPE_NONE: # pragma: no cover
157
    """
158
    It is for plotting sampled functions from multivariate distributions.
159

160
    :param X: training inputs. Shape: (n, 1).
161
    :type X: numpy.ndarray
162
    :param Ys: training outputs. Shape: (m, n).
163
    :type Ys: numpy.ndarray
164
    :param path_save: None, or path for saving a figure.
165
    :type path_save: NoneType or str., optional
166
    :param str_postfix: None, or the name of postfix.
167
    :type str_postfix: NoneType or str., optional
168
    :param str_x_axis: the name of x axis.
169
    :type str_x_axis: str., optional
170
    :param str_y_axis: the name of y axis.
171
    :type str_y_axis: str., optional
172
    :param use_tex: flag for using latex.
173
    :type use_tex: bool., optional
174
    :param draw_zero_axis: flag for drawing a zero axis.
175
    :type draw_zero_axis: bool., optional
176
    :param pause_figure: flag for pausing before closing a figure.
177
    :type pause_figure: bool., optional
178
    :param time_pause: pausing time.
179
    :type time_pause: int. or float, optional
180
    :param colors: array of colors.
181
    :type colors: np.ndarray, optional
182

183
    :returns: None.
184
    :rtype: NoneType
185

186
    :raises: AssertionError
187

188
    """
189

190
    assert isinstance(X, np.ndarray)
191
    assert isinstance(Ys, np.ndarray)
192
    assert isinstance(path_save, (str, type(None)))
193
    assert isinstance(str_postfix, (str, type(None)))
194
    assert isinstance(str_x_axis, str)
195
    assert isinstance(str_y_axis, str)
196
    assert isinstance(use_tex, bool)
197
    assert isinstance(draw_zero_axis, bool)
198
    assert isinstance(pause_figure, bool)
199
    assert isinstance(time_pause, (int, float))
200
    assert isinstance(colors, np.ndarray)
201
    assert len(X.shape) == 2
202
    assert len(Ys.shape) == 2
203
    assert X.shape[1] == 1
204
    assert X.shape[0] == Ys.shape[1]
205

206
    if plt is None or pylab is None:
207
        logger.info('matplotlib or pylab is not installed.')
208
        return
209
    _set_font_config(use_tex)
210

211
    _ = plt.figure(figsize=(8, 6))
212
    ax = plt.gca()
213

214
    for Y in Ys:
215
        ax.plot(X.flatten(), Y,
216
            c=colors[0],
217
            lw=4,
218
            alpha=0.3,
219
        )
220

221
    _set_ax_config(ax, str_x_axis, str_y_axis, xlim_min=np.min(X), xlim_max=np.max(X),
222
        draw_zero_axis=draw_zero_axis)
223

224
    plt.tight_layout()
225

226
    if path_save is not None and str_postfix is not None:
227
        _save_figure(path_save, str_postfix, str_prefix='gp_sampled_')
228
    _show_figure(pause_figure, time_pause)
229

230
@utils_common.validate_types
×
231
def plot_gp_via_distribution(X_train: np.ndarray, Y_train: np.ndarray,
232
    X_test: np.ndarray, mean_test: np.ndarray, std_test: np.ndarray,
233
    Y_test: constants.TYPING_UNION_ARRAY_NONE=None,
234
    path_save: constants.TYPING_UNION_STR_NONE=None,
235
    str_postfix: constants.TYPING_UNION_STR_NONE=None,
236
    str_x_axis: str='x',
237
    str_y_axis: str='y',
238
    use_tex: bool=False,
239
    draw_zero_axis: bool=False,
240
    pause_figure: bool=True,
241
    time_pause: constants.TYPING_UNION_INT_FLOAT=constants.TIME_PAUSE,
242
    range_shade: float=constants.RANGE_SHADE,
243
    colors: np.ndarray=constants.COLORS,
244
) -> constants.TYPE_NONE: # pragma: no cover
245
    """
246
    It is for plotting Gaussian process regression.
247

248
    :param X_train: training inputs. Shape: (n, 1).
249
    :type X_train: numpy.ndarray
250
    :param Y_train: training outputs. Shape: (n, 1).
251
    :type Y_train: numpy.ndarray
252
    :param X_test: test inputs. Shape: (m, 1).
253
    :type X_test: numpy.ndarray
254
    :param mean_test: posterior predictive mean function values over `X_test`.
255
        Shape: (m, 1).
256
    :type mean_test: numpy.ndarray
257
    :param std_test: posterior predictive standard deviation function values
258
        over `X_test`. Shape: (m, 1).
259
    :type std_test: numpy.ndarray
260
    :param Y_test: None, or true test outputs. Shape: (m, 1).
261
    :type Y_test: NoneType or numpy.ndarray, optional
262
    :param path_save: None, or path for saving a figure.
263
    :type path_save: NoneType or str., optional
264
    :param str_postfix: None, or the name of postfix.
265
    :type str_postfix: NoneType or str., optional
266
    :param str_x_axis: the name of x axis.
267
    :type str_x_axis: str., optional
268
    :param str_y_axis: the name of y axis.
269
    :type str_y_axis: str., optional
270
    :param use_tex: flag for using latex.
271
    :type use_tex: bool., optional
272
    :param draw_zero_axis: flag for drawing a zero axis.
273
    :type draw_zero_axis: bool., optional
274
    :param pause_figure: flag for pausing before closing a figure.
275
    :type pause_figure: bool., optional
276
    :param time_pause: pausing time.
277
    :type time_pause: int. or float, optional
278
    :param range_shade: shade range for standard deviation.
279
    :type range_shade: float, optional
280
    :param colors: array of colors.
281
    :type colors: np.ndarray, optional
282

283
    :returns: None.
284
    :rtype: NoneType
285

286
    :raises: AssertionError
287

288
    """
289

290
    assert isinstance(X_train, np.ndarray)
291
    assert isinstance(Y_train, np.ndarray)
292
    assert isinstance(X_test, np.ndarray)
293
    assert isinstance(mean_test, np.ndarray)
294
    assert isinstance(std_test, np.ndarray)
295
    assert isinstance(Y_test, (np.ndarray, type(None)))
296
    assert isinstance(path_save, (str, type(None)))
297
    assert isinstance(str_postfix, (str, type(None)))
298
    assert isinstance(str_x_axis, str)
299
    assert isinstance(str_y_axis, str)
300
    assert isinstance(use_tex, bool)
301
    assert isinstance(draw_zero_axis, bool)
302
    assert isinstance(pause_figure, bool)
303
    assert isinstance(time_pause, (int, float))
304
    assert isinstance(range_shade, float)
305
    assert isinstance(colors, np.ndarray)
306
    assert len(X_train.shape) == 2
307
    assert len(X_test.shape) == 2
308
    assert len(Y_train.shape) == 2
309
    assert len(mean_test.shape) == 2
310
    assert len(std_test.shape) == 2
311
    assert X_train.shape[1] == X_test.shape[1] == 1
312
    assert Y_train.shape[1] == 1
313
    assert X_train.shape[0] == Y_train.shape[0]
314
    assert mean_test.shape[1] == 1
315
    assert std_test.shape[1] == 1
316
    assert X_test.shape[0] == mean_test.shape[0] == std_test.shape[0]
317
    if Y_test is not None:
318
        assert len(Y_test.shape) == 2
319
        assert Y_test.shape[1] == 1
320
        assert X_test.shape[0] == Y_test.shape[0]
321

322
    if plt is None or pylab is None:
323
        logger.info('matplotlib or pylab is not installed.')
324
        return
325
    _set_font_config(use_tex)
326

327
    _ = plt.figure(figsize=(8, 6))
328
    ax = plt.gca()
329

330
    if Y_test is not None:
331
        ax.plot(X_test.flatten(), Y_test.flatten(),
332
            c=colors[1],
333
            linewidth=4,
334
            marker='None')
335
    ax.plot(X_test.flatten(), mean_test.flatten(),
336
        c=colors[2],
337
        linewidth=4,
338
        marker='None')
339
    ax.fill_between(X_test.flatten(),
340
        mean_test.flatten() - range_shade * std_test.flatten(),
341
        mean_test.flatten() + range_shade * std_test.flatten(),
342
        color=colors[2],
343
        alpha=0.3)
344
    ax.plot(X_train.flatten(), Y_train.flatten(),
345
        'x',
346
        c=colors[0],
347
        markersize=10,
348
        mew=4)
349

350
    _set_ax_config(ax, str_x_axis, str_y_axis, xlim_min=np.min(X_test),
351
        xlim_max=np.max(X_test), draw_zero_axis=draw_zero_axis)
352

353
    plt.tight_layout()
354

355
    if path_save is not None and str_postfix is not None:
356
        _save_figure(path_save, str_postfix, str_prefix='gp_')
357
    _show_figure(pause_figure, time_pause)
358

359
@utils_common.validate_types
×
360
def plot_minimum_vs_iter(minima: np.ndarray, list_str_label: constants.TYPING_LIST[str],
361
    num_init: int, draw_std: bool,
362
    include_marker: bool=True,
363
    include_legend: bool=False,
364
    use_tex: bool=False,
365
    path_save: constants.TYPING_UNION_STR_NONE=None,
366
    str_postfix: constants.TYPING_UNION_STR_NONE=None,
367
    str_x_axis: str='Iteration',
368
    str_y_axis: str='Minimum function value',
369
    pause_figure: bool=True,
370
    time_pause: constants.TYPING_UNION_INT_FLOAT=constants.TIME_PAUSE,
371
    range_shade: float=constants.RANGE_SHADE,
372
    markers: np.ndarray=constants.MARKERS,
373
    colors: np.ndarray=constants.COLORS,
374
) -> constants.TYPE_NONE: # pragma: no cover
375
    """
376
    It is for plotting optimization results of Bayesian optimization, in
377
    terms of iterations.
378

379
    :param minima: function values over acquired examples. Shape: (b, r, n)
380
        where b is the number of experiments, r is the number of rounds,
381
        and n is the number of iterations per round.
382
    :type minima: numpy.ndarray
383
    :param list_str_label: list of label strings. Shape: (b, ).
384
    :type list_str_label: list
385
    :param num_init: the number of initial examples < n.
386
    :type num_init: int.
387
    :param draw_std: flag for drawing standard deviations.
388
    :type draw_std: bool.
389
    :param include_marker: flag for drawing markers.
390
    :type include_marker: bool., optional
391
    :param include_legend: flag for drawing a legend.
392
    :type include_legend: bool., optional
393
    :param use_tex: flag for using latex.
394
    :type use_tex: bool., optional
395
    :param path_save: None, or path for saving a figure.
396
    :type path_save: NoneType or str., optional
397
    :param str_postfix: None, or the name of postfix.
398
    :type str_postfix: NoneType or str., optional
399
    :param str_x_axis: the name of x axis.
400
    :type str_x_axis: str., optional
401
    :param str_y_axis: the name of y axis.
402
    :type str_y_axis: str., optional
403
    :param pause_figure: flag for pausing before closing a figure.
404
    :type pause_figure: bool., optional
405
    :param time_pause: pausing time.
406
    :type time_pause: int. or float, optional
407
    :param range_shade: shade range for standard deviation.
408
    :type range_shade: float, optional
409
    :param markers: array of markers.
410
    :type markers: np.ndarray, optional
411
    :param colors: array of colors.
412
    :type colors: np.ndarray, optional
413

414
    :returns: None.
415
    :rtype: NoneType
416

417
    :raises: AssertionError
418

419
    """
420

421
    assert isinstance(minima, np.ndarray)
422
    assert isinstance(list_str_label, list)
423
    assert isinstance(num_init, int)
424
    assert isinstance(draw_std, bool)
425
    assert isinstance(include_marker, bool)
426
    assert isinstance(include_legend, bool)
427
    assert isinstance(use_tex, bool)
428
    assert isinstance(path_save, (str, type(None)))
429
    assert isinstance(str_postfix, (str, type(None)))
430
    assert isinstance(str_x_axis, str)
431
    assert isinstance(str_y_axis, str)
432
    assert isinstance(pause_figure, bool)
433
    assert isinstance(time_pause, (int, float))
434
    assert isinstance(range_shade, float)
435
    assert isinstance(markers, np.ndarray)
436
    assert isinstance(colors, np.ndarray)
437
    assert len(minima.shape) == 3
438
    assert minima.shape[0] == len(list_str_label)
439
    assert minima.shape[2] >= num_init
440

441
    if plt is None or pylab is None:
442
        logger.info('matplotlib or pylab is not installed.')
443
        return
444
    _set_font_config(use_tex)
445

446
    _ = plt.figure(figsize=(8, 6))
447
    ax = plt.gca()
448

449
    for ind_minimum, arr_minimum in enumerate(minima):
450
        ind_color = ind_minimum % len(colors)
451
        ind_marker = ind_minimum % len(markers)
452

453
        _, mean_min, std_min, _ = utils_common.get_minimum(arr_minimum, num_init)
454
        x_data = range(0, mean_min.shape[0])
455
        y_data = mean_min
456
        std_data = std_min
457

458
        if include_marker:
459
            ax.plot(x_data, y_data,
460
                label=list_str_label[ind_minimum],
461
                c=colors[ind_color],
462
                linewidth=4,
463
                marker=markers[ind_marker],
464
                markersize=10,
465
                mew=3,
466
            )
467
        else:
468
            ax.plot(x_data, y_data,
469
                label=list_str_label[ind_minimum],
470
                c=colors[ind_color],
471
                linewidth=4,
472
                marker='None')
473

474
        if draw_std:
475
            ax.fill_between(x_data,
476
                y_data - range_shade * std_data,
477
                y_data + range_shade * std_data,
478
                color=colors[ind_color],
479
                alpha=0.3)
480
    lines, _ = ax.get_legend_handles_labels()
481

482
    _set_ax_config(ax, str_x_axis, str_y_axis, xlim_min=0, xlim_max=mean_min.shape[0]-1)
483

484
    if include_legend:
485
        plt.legend(loc='upper right', fancybox=False, edgecolor='black', fontsize=24)
486

487
    plt.tight_layout()
488

489
    if path_save is not None and str_postfix is not None:
490
        if draw_std:
491
            str_figure = 'minimum_mean_std_' + str_postfix
492
        else:
493
            str_figure = 'minimum_mean_only_' + str_postfix
494
        _save_figure(path_save, str_figure)
495

496
        fig_legend = pylab.figure(figsize=(3, 2))
497
        fig_legend.legend(lines, list_str_label, 'center', fancybox=False,
498
            edgecolor='black', fontsize=32)
499
        fig_legend.savefig(os.path.join(path_save, f'legend_{str_postfix}.pdf'),
500
            format='pdf', transparent=True, bbox_inches='tight')
501

502
    _show_figure(pause_figure, time_pause)
503

504
@utils_common.validate_types
×
505
def plot_minimum_vs_time(times: np.ndarray, minima: np.ndarray,
506
    list_str_label: constants.TYPING_LIST[str], num_init: int, draw_std: bool,
507
    include_marker: bool=True,
508
    include_legend: bool=False,
509
    use_tex: bool=False,
510
    path_save: constants.TYPING_UNION_STR_NONE=None,
511
    str_postfix: constants.TYPING_UNION_STR_NONE=None,
512
    str_x_axis: str='Time (sec.)',
513
    str_y_axis: str='Minimum function value',
514
    pause_figure: bool=True,
515
    time_pause: constants.TYPING_UNION_INT_FLOAT=constants.TIME_PAUSE,
516
    range_shade: float=constants.RANGE_SHADE,
517
    markers: np.ndarray=constants.MARKERS,
518
    colors: np.ndarray=constants.COLORS,
519
) -> constants.TYPE_NONE: # pragma: no cover
520
    """
521
    It is for plotting optimization results of Bayesian optimization, in terms of execution time.
522

523
    :param times: execution times. Shape: (b, r, n), or (b, r, `num_init` + n)
524
        where b is the number of experiments, r is the number of rounds,
525
        and n is the number of iterations per round.
526
    :type times: numpy.ndarray
527
    :param minima: function values over acquired examples. Shape: (b, r, `num_init` + n)
528
        where b is the number of experiments, r is the number of rounds,
529
        and n is the number of iterations per round.
530
    :type minima: numpy.ndarray
531
    :param list_str_label: list of label strings. Shape: (b, ).
532
    :type list_str_label: list
533
    :param num_init: the number of initial examples.
534
    :type num_init: int.
535
    :param draw_std: flag for drawing standard deviations.
536
    :type draw_std: bool.
537
    :param include_marker: flag for drawing markers.
538
    :type include_marker: bool., optional
539
    :param include_legend: flag for drawing a legend.
540
    :type include_legend: bool., optional
541
    :param use_tex: flag for using latex.
542
    :type use_tex: bool., optional
543
    :param path_save: None, or path for saving a figure.
544
    :type path_save: NoneType or str., optional
545
    :param str_postfix: None, or the name of postfix.
546
    :type str_postfix: NoneType or str., optional
547
    :param str_x_axis: the name of x axis.
548
    :type str_x_axis: str., optional
549
    :param str_y_axis: the name of y axis.
550
    :type str_y_axis: str., optional
551
    :param pause_figure: flag for pausing before closing a figure.
552
    :type pause_figure: bool., optional
553
    :param time_pause: pausing time.
554
    :type time_pause: int. or float, optional
555
    :param range_shade: shade range for standard deviation.
556
    :type range_shade: float, optional
557
    :param markers: array of markers.
558
    :type markers: np.ndarray, optional
559
    :param colors: array of colors.
560
    :type colors: np.ndarray, optional
561

562
    :returns: None.
563
    :rtype: NoneType
564

565
    :raises: AssertionError
566

567
    """
568

569
    assert isinstance(times, np.ndarray)
570
    assert isinstance(minima, np.ndarray)
571
    assert isinstance(list_str_label, list)
572
    assert isinstance(num_init, int)
573
    assert isinstance(draw_std, bool)
574
    assert isinstance(include_marker, bool)
575
    assert isinstance(include_legend, bool)
576
    assert isinstance(use_tex, bool)
577
    assert isinstance(path_save, (str, type(None)))
578
    assert isinstance(str_postfix, (str, type(None)))
579
    assert isinstance(str_x_axis, str)
580
    assert isinstance(str_y_axis, str)
581
    assert isinstance(pause_figure, bool)
582
    assert isinstance(time_pause, (int, float))
583
    assert isinstance(range_shade, float)
584
    assert isinstance(markers, np.ndarray)
585
    assert isinstance(colors, np.ndarray)
586
    assert len(times.shape) == 3
587
    assert len(minima.shape) == 3
588
    assert times.shape[0] == minima.shape[0] == len(list_str_label)
589
    assert times.shape[1] == minima.shape[1]
590
    assert minima.shape[2] >= num_init
591
    assert times.shape[2] == minima.shape[2] or times.shape[2] + num_init == minima.shape[2]
592

593
    if plt is None or pylab is None:
594
        logger.info('matplotlib or pylab is not installed.')
595
        return
596
    _set_font_config(use_tex)
597

598
    _ = plt.figure(figsize=(8, 6))
599
    ax = plt.gca()
600

601
    list_x_data = []
602
    for ind_minimum, arr_minimum in enumerate(minima):
603
        ind_color = ind_minimum % len(colors)
604
        ind_marker = ind_minimum % len(markers)
605

606
        _, mean_min, std_min, _ = utils_common.get_minimum(arr_minimum, num_init)
607
        x_data = utils_common.get_time(times[ind_minimum], num_init,
608
            times.shape[2] == minima.shape[2])
609
        y_data = mean_min
610
        std_data = std_min
611

612
        if include_marker:
613
            ax.plot(x_data, y_data,
614
                label=list_str_label[ind_minimum],
615
                c=colors[ind_color],
616
                linewidth=4,
617
                marker=markers[ind_marker],
618
                markersize=10,
619
                mew=3,
620
            )
621
        else:
622
            ax.plot(x_data, y_data,
623
                label=list_str_label[ind_minimum],
624
                c=colors[ind_color],
625
                linewidth=4,
626
                marker='None')
627

628
        if draw_std:
629
            ax.fill_between(x_data,
630
                y_data - range_shade * std_data,
631
                y_data + range_shade * std_data,
632
                color=colors[ind_color],
633
                alpha=0.3)
634
        list_x_data.append(x_data)
635
    lines, _ = ax.get_legend_handles_labels()
636

637
    _set_ax_config(ax, str_x_axis, str_y_axis, xlim_min=np.min(list_x_data),
638
        xlim_max=np.max(list_x_data))
639

640
    if include_legend:
641
        plt.legend(loc='upper right', fancybox=False, edgecolor='black', fontsize=24)
642

643
    plt.tight_layout()
644

645
    if path_save is not None and str_postfix is not None:
646
        if draw_std:
647
            str_figure = 'minimum_time_mean_std_' + str_postfix
648
        else:
649
            str_figure = 'minimum_time_mean_only_' + str_postfix
650
        _save_figure(path_save, str_figure)
651

652
        fig_legend = pylab.figure(figsize=(3, 2))
653
        fig_legend.legend(lines, list_str_label, 'center', fancybox=False,
654
            edgecolor='black', fontsize=32)
655
        fig_legend.savefig(os.path.join(path_save, f'legend_{str_postfix}.pdf'),
656
            format='pdf', transparent=True, bbox_inches='tight')
657

658
    _show_figure(pause_figure, time_pause)
659

660
@utils_common.validate_types
×
661
def plot_bo_step(X_train: np.ndarray, Y_train: np.ndarray,
662
    X_test: np.ndarray, Y_test: np.ndarray,
663
    mean_test: np.ndarray, std_test: np.ndarray,
664
    path_save: constants.TYPING_UNION_STR_NONE=None,
665
    str_postfix: constants.TYPING_UNION_STR_NONE=None,
666
    str_x_axis: str='x',
667
    str_y_axis: str='y',
668
    num_init: constants.TYPING_UNION_INT_NONE=None,
669
    use_tex: bool=False,
670
    draw_zero_axis: bool=False,
671
    pause_figure: bool=True,
672
    time_pause: constants.TYPING_UNION_INT_FLOAT=constants.TIME_PAUSE,
673
    range_shade: float=constants.RANGE_SHADE,
674
) -> constants.TYPE_NONE: # pragma: no cover
675
    """
676
    It is for plotting Bayesian optimization results step by step.
677

678
    :param X_train: training inputs. Shape: (n, 1).
679
    :type X_train: numpy.ndarray
680
    :param Y_train: training outputs. Shape: (n, 1).
681
    :type Y_train: numpy.ndarray
682
    :param X_test: test inputs. Shape: (m, 1).
683
    :type X_test: numpy.ndarray
684
    :param Y_test: true test outputs. Shape: (m, 1).
685
    :type Y_test: numpy.ndarray
686
    :param mean_test: posterior predictive mean function values over `X_test`.
687
        Shape: (m, 1).
688
    :type mean_test: numpy.ndarray
689
    :param std_test: posterior predictive standard deviation function values
690
        over `X_test`. Shape: (m, 1).
691
    :type std_test: numpy.ndarray
692
    :param path_save: None, or path for saving a figure.
693
    :type path_save: NoneType or str., optional
694
    :param str_postfix: None, or the name of postfix.
695
    :type str_postfix: NoneType or str., optional
696
    :param str_x_axis: the name of x axis.
697
    :type str_x_axis: str., optional
698
    :param str_y_axis: the name of y axis.
699
    :type str_y_axis: str., optional
700
    :param num_init: None, or the number of initial examples.
701
    :type num_init: NoneType or int., optional
702
    :param use_tex: flag for using latex.
703
    :type use_tex: bool., optional
704
    :param draw_zero_axis: flag for drawing a zero axis.
705
    :type draw_zero_axis: bool., optional
706
    :param pause_figure: flag for pausing before closing a figure.
707
    :type pause_figure: bool., optional
708
    :param time_pause: pausing time.
709
    :type time_pause: int. or float, optional
710
    :param range_shade: shade range for standard deviation.
711
    :type range_shade: float, optional
712

713
    :returns: None.
714
    :rtype: NoneType
715

716
    :raises: AssertionError
717

718
    """
719

720
    assert isinstance(X_train, np.ndarray)
721
    assert isinstance(Y_train, np.ndarray)
722
    assert isinstance(X_test, np.ndarray)
723
    assert isinstance(Y_test, np.ndarray)
724
    assert isinstance(mean_test, np.ndarray)
725
    assert isinstance(std_test, np.ndarray)
726
    assert isinstance(path_save, (str, type(None)))
727
    assert isinstance(str_postfix, (str, type(None)))
728
    assert isinstance(str_x_axis, str)
729
    assert isinstance(str_y_axis, str)
730
    assert isinstance(num_init, (int, type(None)))
731
    assert isinstance(use_tex, bool)
732
    assert isinstance(draw_zero_axis, bool)
733
    assert isinstance(pause_figure, bool)
734
    assert isinstance(time_pause, (int, float))
735
    assert isinstance(range_shade, float)
736
    assert len(X_train.shape) == 2
737
    assert len(X_test.shape) == 2
738
    assert len(Y_train.shape) == 2
739
    assert len(Y_test.shape) == 2
740
    assert len(mean_test.shape) == 2
741
    assert len(std_test.shape) == 2
742
    assert X_train.shape[1] == X_test.shape[1] == 1
743
    assert Y_train.shape[1] == Y_test.shape[1] == 1
744
    assert X_train.shape[0] == Y_train.shape[0]
745
    assert mean_test.shape[1] == 1
746
    assert std_test.shape[1] == 1
747
    assert X_test.shape[0] == Y_test.shape[0] == mean_test.shape[0] == std_test.shape[0]
748
    if num_init is not None:
749
        assert X_train.shape[0] >= num_init
750

751
    if plt is None or pylab is None:
752
        logger.info('matplotlib or pylab is not installed.')
753
        return
754
    _set_font_config(use_tex)
755

756
    _ = plt.figure(figsize=(8, 6))
757
    ax = plt.gca()
758

759
    ax.plot(X_test, Y_test, 'g', linewidth=4)
760
    ax.plot(X_test, mean_test, 'b', linewidth=4)
761
    ax.fill_between(X_test.flatten(),
762
        mean_test.flatten() - range_shade * std_test.flatten(),
763
        mean_test.flatten() + range_shade * std_test.flatten(),
764
        color='b',
765
        alpha=0.3)
766

767
    if num_init is not None:
768
        if X_train.shape[0] > num_init:
769
            ax.plot(X_train[:num_init, :], Y_train[:num_init, :], 'x',
770
                c='saddlebrown', ms=14, markeredgewidth=6)
771
            ax.plot(X_train[num_init:-1, :], Y_train[num_init:-1, :], 'rx',
772
                ms=14, markeredgewidth=6)
773
            ax.plot(X_train[-1, :], Y_train[-1, :],
774
                c='orange', marker='+', ms=18, markeredgewidth=6)
775
        else:
776
            ax.plot(X_train, Y_train, 'x', c='saddlebrown', ms=14,
777
                markeredgewidth=6)
778
    else:
779
        ax.plot(X_train[:-1, :], Y_train[:-1, :],
780
            'rx', ms=14, markeredgewidth=6)
781
        ax.plot(X_train[-1, :], Y_train[-1, :],
782
            c='orange', marker='+', ms=18, markeredgewidth=6)
783

784
    _set_ax_config(ax, str_x_axis, str_y_axis, xlim_min=np.min(X_test),
785
        xlim_max=np.max(X_test), draw_zero_axis=draw_zero_axis)
786

787
    plt.tight_layout()
788

789
    if path_save is not None and str_postfix is not None:
790
        _save_figure(path_save, str_postfix, str_prefix='bo_step_')
791
    _show_figure(pause_figure, time_pause)
792

793
@utils_common.validate_types
×
794
def plot_bo_step_with_acq(X_train: np.ndarray, Y_train: np.ndarray,
795
    X_test: np.ndarray, Y_test: np.ndarray, mean_test: np.ndarray,
796
    std_test: np.ndarray, acq_test: np.ndarray,
797
    path_save: constants.TYPING_UNION_STR_NONE=None,
798
    str_postfix: constants.TYPING_UNION_STR_NONE=None,
799
    str_x_axis: str='x',
800
    str_y_axis: str='y',
801
    str_acq_axis: str='acq.',
802
    num_init: constants.TYPING_UNION_INT_NONE=None,
803
    use_tex: bool=False,
804
    draw_zero_axis: bool=False,
805
    pause_figure: bool=True,
806
    time_pause: constants.TYPING_UNION_INT_FLOAT=constants.TIME_PAUSE,
807
    range_shade: float=constants.RANGE_SHADE,
808
) -> constants.TYPE_NONE: # pragma: no cover
809
    """
810
    It is for plotting Bayesian optimization results step by step.
811

812
    :param X_train: training inputs. Shape: (n, 1).
813
    :type X_train: numpy.ndarray
814
    :param Y_train: training outputs. Shape: (n, 1).
815
    :type Y_train: numpy.ndarray
816
    :param X_test: test inputs. Shape: (m, 1).
817
    :type X_test: numpy.ndarray
818
    :param Y_test: true test outputs. Shape: (m, 1).
819
    :type Y_test: numpy.ndarray
820
    :param mean_test: posterior predictive mean function values over `X_test`.
821
        Shape: (m, 1).
822
    :type mean_test: numpy.ndarray
823
    :param std_test: posterior predictive standard deviation function values
824
        over `X_test`. Shape: (m, 1).
825
    :type std_test: numpy.ndarray
826
    :param acq_test: acquisition funcion values over `X_test`. Shape: (m, 1).
827
    :type acq_test: numpy.ndarray
828
    :param path_save: None, or path for saving a figure.
829
    :type path_save: NoneType or str., optional
830
    :param str_postfix: None, or the name of postfix.
831
    :type str_postfix: NoneType or str., optional
832
    :param str_x_axis: the name of x axis.
833
    :type str_x_axis: str., optional
834
    :param str_y_axis: the name of y axis.
835
    :type str_y_axis: str., optional
836
    :param str_acq_axis: the name of acquisition function axis.
837
    :type str_acq_axis: str., optional
838
    :param num_init: None, or the number of initial examples.
839
    :type num_init: NoneType or int., optional
840
    :param use_tex: flag for using latex.
841
    :type use_tex: bool., optional
842
    :param draw_zero_axis: flag for drawing a zero axis.
843
    :type draw_zero_axis: bool., optional
844
    :param pause_figure: flag for pausing before closing a figure.
845
    :type pause_figure: bool., optional
846
    :param time_pause: pausing time.
847
    :type time_pause: int. or float, optional
848
    :param range_shade: shade range for standard deviation.
849
    :type range_shade: float, optional
850

851
    :returns: None.
852
    :rtype: NoneType
853

854
    :raises: AssertionError
855

856
    """
857

858
    assert isinstance(X_train, np.ndarray)
859
    assert isinstance(Y_train, np.ndarray)
860
    assert isinstance(X_test, np.ndarray)
861
    assert isinstance(Y_test, np.ndarray)
862
    assert isinstance(mean_test, np.ndarray)
863
    assert isinstance(std_test, np.ndarray)
864
    assert isinstance(path_save, (str, type(None)))
865
    assert isinstance(str_postfix, (str, type(None)))
866
    assert isinstance(str_x_axis, str)
867
    assert isinstance(str_y_axis, str)
868
    assert isinstance(str_acq_axis, str)
869
    assert isinstance(num_init, (int, type(None)))
870
    assert isinstance(use_tex, bool)
871
    assert isinstance(draw_zero_axis, bool)
872
    assert isinstance(pause_figure, bool)
873
    assert isinstance(time_pause, (int, float))
874
    assert isinstance(range_shade, float)
875
    assert len(X_train.shape) == 2
876
    assert len(X_test.shape) == 2
877
    assert len(Y_train.shape) == 2
878
    assert len(Y_test.shape) == 2
879
    assert len(mean_test.shape) == 2
880
    assert len(std_test.shape) == 2
881
    assert len(acq_test.shape) == 2
882
    assert X_train.shape[1] == X_test.shape[1] == 1
883
    assert Y_train.shape[1] == Y_test.shape[1] == 1
884
    assert X_train.shape[0] == Y_train.shape[0]
885
    assert mean_test.shape[1] == 1
886
    assert std_test.shape[1] == 1
887
    assert acq_test.shape[1] == 1
888
    assert X_test.shape[0] == Y_test.shape[0] == mean_test.shape[0] \
889
        == std_test.shape[0] == acq_test.shape[0]
890
    if num_init is not None:
891
        assert X_train.shape[0] >= num_init
892

893
    if plt is None or pylab is None:
894
        logger.info('matplotlib or pylab is not installed.')
895
        return
896
    _set_font_config(use_tex)
897

898
    _, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), gridspec_kw={'height_ratios':[3, 1]})
899

900
    ax1.plot(X_test, Y_test, 'g', linewidth=4)
901
    ax1.plot(X_test, mean_test, 'b', linewidth=4)
902
    ax1.fill_between(X_test.flatten(),
903
        mean_test.flatten() - range_shade * std_test.flatten(),
904
        mean_test.flatten() + range_shade * std_test.flatten(),
905
        color='b',
906
        alpha=0.3)
907
    if num_init is not None:
908
        if X_train.shape[0] > num_init:
909
            ax1.plot(X_train[:num_init, :], Y_train[:num_init, :], 'x',
910
                c='saddlebrown', ms=14, markeredgewidth=6)
911
            ax1.plot(X_train[num_init:-1, :], Y_train[num_init:-1, :], 'rx',
912
                ms=14, markeredgewidth=6)
913
            ax1.plot(X_train[-1, :], Y_train[-1, :],
914
                c='orange', marker='+', ms=18, markeredgewidth=6)
915
        else:
916
            ax1.plot(X_train, Y_train, 'x', c='saddlebrown', ms=14, markeredgewidth=6)
917
    else:
918
        ax1.plot(X_train[:-1, :], Y_train[:-1, :],
919
            'rx', ms=14, markeredgewidth=6)
920
        ax1.plot(X_train[-1, :], Y_train[-1, :],
921
            c='orange', marker='+', ms=18, markeredgewidth=6)
922

923
    _set_ax_config(ax1, None, str_y_axis, xlim_min=np.min(X_test),
924
        xlim_max=np.max(X_test), draw_zero_axis=draw_zero_axis)
925

926
    ax2.plot(X_test, acq_test, 'b', linewidth=4)
927
    _set_ax_config(ax2, str_x_axis, str_acq_axis, xlim_min=np.min(X_test),
928
        xlim_max=np.max(X_test), draw_zero_axis=draw_zero_axis)
929

930
    plt.tight_layout()
931

932
    if path_save is not None and str_postfix is not None:
933
        _save_figure(path_save, str_postfix, str_prefix='bo_step_acq_')
934
    _show_figure(pause_figure, time_pause)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc