• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IDSIA / bayesRecon / #13

05 Mar 2026 06:50PM UTC coverage: 76.645% (-7.8%) from 84.472%
#13

push

travis-pro

web-flow
Merge 35070ae15 into 599d7d457

671 of 891 new or added lines in 10 files covered. (75.31%)

11 existing lines in 4 files now uncovered.

955 of 1246 relevant lines covered (76.65%)

61846.67 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.52
/R/utils.R
1
################################################################################
2
# IMPLEMENTED DISTRIBUTIONS
3

4
.DISTR_TYPES <- c("continuous", "discrete")
5
.DISCR_DISTR <- c("poisson", "nbinom")
6
.CONT_DISTR <- c("gaussian")
7

8
################################################################################
9
# PARAMETERS FOR PMF CONVOLUTION AND SMOOTHING
10

11
.TOLL <- 1e-15
12
.RTOLL <- 1e-9
13
.ALPHA_SMOOTHING <- 1e-9
14
.LAP_SMOOTHING <- FALSE
15

16
################################################################################
17
# OTHER PARAMETERS
18

19
.NEGBIN_TOLL <- 1e-6 
20
# used when fitting a Negative Binomial distribution
21
.L_SHRINK_RECONC_T <- 1e-4  
22
# used for shrinking the empirical covariance matrix in reconc_t
23
.MIN_FRACTION_SAMPLES_OK <- 0.5  
24
# used in reconc_TDcond: if the fraction of reconciled upper samples that lie 
25
#in the support of the bottom-up distribution is less than this threshold, the function stops
26

27
################################################################################
28
# CHECK INPUT
29

30
# Function to check values allowed in S.
31
.check_S <- function(S) {
32
  if (!identical(sort(unique(as.vector(S))), c(0, 1))) {
2✔
33
    stop("Input error in S: S must be a matrix containing only 0s and 1s.")
1✔
34
  }
35

36
  if (!all(colSums(S) > 1)) {
1✔
UNCOV
37
    stop("Input error in S: all bottom level forecasts must aggregate into an upper.")
×
38
  }
39

40
  if (nrow(unique(S)) != nrow(S)) {
1✔
UNCOV
41
    warning("S has some repeated rows.")
×
42
  }
43

44
  # Check that each bottom has a corresponding row with with one 1 and the rest 0s.
45
  if (nrow(unique(S[rowSums(S) == 1, ])) < ncol(S)) {
1✔
46
    stop("Input error in S: there is at least one bottom that does not have a row with one 1 and the rest 0s.")
×
47
  }
48
}
49

50
# Function to check aggregation matrix A
51
.check_A <- function(A) {
52
  if (!all(A %in% c(0, 1))) {
37✔
53
    stop("Input error in A: A must be a matrix containing only 0s and 1s.")
×
54
  }
55

56
  if (any(colSums(A) == 0)) {
37✔
57
    stop("Input error in A: some columns do not have any 1.
×
58
          All bottom level forecasts must aggregate into an upper.")
×
59
  }
60

61
  if (nrow(unique(A)) != nrow(A)) {
37✔
UNCOV
62
    warning("A has some repeated rows.")
×
63
  }
64
}
65

66
# Check if it is a covariance matrix (i.e. symmetric p.d.)
67
.check_cov <- function(cov_matrix, Sigma_str, pd_check = FALSE, symm_check = FALSE) {
68
  # Check if the matrix is square
69
  if (!is.matrix(cov_matrix) || nrow(cov_matrix) != ncol(cov_matrix)) {
30✔
70
    stop(paste0(Sigma_str, " is not square"))
×
71
  }
72

73
  # Check if the matrix is positive semi-definite
74
  if (pd_check) {
30✔
75
    eigen_values <- eigen(cov_matrix, symmetric = TRUE)$values
5✔
76
    if (any(eigen_values <= 0)) {
5✔
77
      stop(paste0(Sigma_str, " is not positive semi-definite"))
×
78
    }
79
  }
80
  if (symm_check) {
30✔
81
    # Check if the matrix is symmetric
82
    if (!isSymmetric(cov_matrix)) {
13✔
83
      stop(paste0(Sigma_str, " is not symmetric"))
×
84
    }
85
  }
86
  # Check if the diagonal elements are non-negative
87
  if (any(diag(cov_matrix) < 0)) {
30✔
88
    stop(paste0(Sigma_str, ": some elements on the diagonal are negative"))
×
89
  }
90
  # If all checks pass, return TRUE
91
  return(TRUE)
30✔
92
}
93

94
# Checks if the input is a real number
95
.check_real_number <- function(x) {
96
  return(length(x) == 1 & is.numeric(x))
333✔
97
}
98

99
# Checks if the input is a positive number
100
.check_positive_number <- function(x) {
101
  return(length(x) == 1 && is.numeric(x) && x > 0)
5,656,771✔
102
}
103

104
# Check that the distr is implemented
105
.check_implemented_distr <- function(distr) {
106
  if (!(distr %in% c(.DISCR_DISTR, .CONT_DISTR))) {
5,656,654✔
107
    stop(paste(
×
108
      "Input error: the distribution must be one of {",
×
NEW
109
      paste(c(.DISCR_DISTR, .CONT_DISTR), collapse = ", "), "}"
×
110
    ))
111
  }
112
}
113

114
# Check the parameters of distr
115
.check_distr_params <- function(distr, params) {
116
  .check_implemented_distr(distr)
5,656,654✔
117
  if (!is.list(params)) {
5,656,654✔
118
    stop("Input error: the parameters of the distribution must be given as a list.")
×
119
  }
120
  switch(distr,
5,656,654✔
121
    "gaussian" = {
122
      mean <- params$mean
333✔
123
      sd <- params$sd
333✔
124
      if (!.check_real_number(mean)) {
333✔
125
        stop("Input error: mean of Gaussian must be a real number")
×
126
      }
127
      if (!.check_positive_number(sd)) {
333✔
128
        stop("Input error: sd of Gaussian must be a positive number")
×
129
      }
130
    },
131
    "poisson" = {
132
      lambda <- params$lambda
5,656,201✔
133
      if (!.check_positive_number(lambda)) {
5,656,201✔
134
        stop("Input error: lambda of Poisson must be a positive number")
×
135
      }
136
    },
137
    "nbinom" = {
138
      size <- params$size
120✔
139
      prob <- params$prob
120✔
140
      mu <- params$mu
120✔
141
      # Check that size is specified, and that is a positive number
142
      if (is.null(size)) {
120✔
143
        stop("Input error: size parameter for the nbinom distribution must be specified")
1✔
144
      }
145
      if (!.check_positive_number(size)) {
119✔
146
        stop("Input error: size of nbinom must be a positive number")
×
147
      }
148
      # Check that exactly one of prob, mu is specified
149
      if (!is.null(prob) & !is.null(mu)) {
119✔
150
        stop("Input error: prob and mu for the nbinom distribution are both specified ")
1✔
151
      } else if (is.null(prob) & is.null(mu)) {
118✔
152
        stop("Input error: either prob or mu must be specified")
×
153
      } else {
154
        if (!is.null(prob)) {
118✔
155
          if (!.check_positive_number(prob) | prob > 1) {
13✔
156
            stop("Input error: prob of nbinom must be positive and <= 1")
×
157
          }
158
        } else if (!is.null(mu)) {
105✔
159
          if (!.check_positive_number(mu)) {
105✔
160
            stop("Input error: mu of nbinom must be positive")
×
161
          }
162
        }
163
      }
164
    },
165
  )
166
}
167

168
# Check that the samples are discrete
169
.check_discrete_samples <- function(samples) {
170
  if (!isTRUE(as.vector(all.equal(as.vector(samples), as.integer(samples))))) {
221✔
171
    stop("Input error: samples are not all discrete")
×
172
  }
173
}
174

175
# Check input for BUIS (and for MH)
176
# base_fc, in_type, and distr must be list
177
.check_input_BUIS <- function(A, base_fc, in_type, distr) {
178
  .check_A(A)
11✔
179

180
  n_tot_A <- ncol(A) + nrow(A)
11✔
181

182
  # Check in_type
183
  if (!is.list(in_type)) {
11✔
184
    stop("Input error: in_type must be a list")
×
185
  }
186
  if (!(n_tot_A == length(in_type))) {
11✔
187
    stop("Input error: ncol(A)+nrow(A) != length(in_type)")
×
188
  }
189
  for (i in 1:n_tot_A) {
11✔
190
    if (!(in_type[[i]] %in% c("params", "samples"))) {
363✔
NEW
191
      stop("Input error: in_type[[", i, "]] must be either 'samples' or 'params'")
×
192
    }
193
  }
194

195
  # Check distr and base forecasts
196
  if (!is.list(distr)) {
11✔
197
    stop("Input error: distr must be a list")
×
198
  }
199
  if (!(n_tot_A == length(distr))) {
11✔
200
    stop("Input error: ncol(A)+nrow(A) != length(distr)")
×
201
  }
202
  if (!is.list(base_fc)) {
11✔
NEW
203
    stop("Input error: base_fc must be a list")
×
204
  }
205
  if (!(n_tot_A == length(base_fc))) {
11✔
NEW
206
    stop("Input error: ncol(A)+nrow(A) != length(base_fc)")
×
207
  }
208
  for (i in 1:n_tot_A) {
11✔
209
    if (in_type[[i]] == "params") {
363✔
210
      .check_distr_params(distr[[i]], base_fc[[i]])
307✔
211
    } else if (in_type[[i]] == "samples") {
56✔
212
      if (!(distr[[i]] %in% .DISTR_TYPES)) {
56✔
213
        stop(paste(
×
214
          "Input error: the distribution must be one of {",
×
NEW
215
          paste(.DISTR_TYPES, collapse = ", "), "}"
×
216
        ))
217
      }
218
      if (distr[[i]] == "discrete") {
56✔
219
        .check_discrete_samples(base_fc[[i]])
28✔
220
      }
221
      # TODO: check sample size?
222
    } else {
NEW
223
      stop("Input error: in_type[[", i, "]] must be either 'samples' or 'params'")
×
224
    }
225
  }
226
}
227

228
# Check input for TDcond
229
.check_input_TD <- function(A, base_fc_bottom, base_fc_upper,
230
                            bottom_in_type, distr,
231
                            return_type) {
232
  .check_A(A)
8✔
233

234
  n_b <- ncol(A) # number of bottom TS
8✔
235
  n_u <- nrow(A) # number of upper TS
8✔
236

237
  if (!(bottom_in_type %in% c("pmf", "samples", "params"))) {
8✔
238
    stop("Input error: bottom_in_type must be either 'pmf', 'samples', or 'params'")
×
239
  }
240
  if (!(return_type %in% c("pmf", "samples", "all"))) {
8✔
241
    stop("Input error: return_type must be either 'pmf', 'samples', or 'all'")
×
242
  }
243
  if (length(base_fc_bottom) != n_b) {
8✔
NEW
244
    stop("Input error: length of base_fc_bottom does not match with A")
×
245
  }
246
  # If cov is a number, transform into a matrix
247
  if (length(base_fc_upper$cov) == 1) {
8✔
248
    base_fc_upper$cov <- as.matrix(base_fc_upper$cov)
3✔
249
  }
250
  # Check the dimensions of mean and cov
251
  if (length(base_fc_upper$mean) != n_u | any(dim(base_fc_upper$cov) != c(n_u, n_u))) {
8✔
UNCOV
252
    stop("Input error: the dimensions of the upper parameters do not match with A")
×
253
  }
254
  # Check that cov is a covariance matrix (symmetric positive semi-definite)
255
  .check_cov(base_fc_upper$cov, "Upper covariance matrix", symm_check = TRUE)
8✔
256

257
  # If bottom_in_type is not "params" but distr is specified, throw a warning
258
  if (bottom_in_type %in% c("pmf", "samples") & !is.null(distr)) {
8✔
259
    warning(paste0("Since bottom_in_type = '", bottom_in_type, "', the input distr is ignored"))
×
260
  }
261
  # If bottom_in_type is params, distr must be one of the implemented discrete distr.
262
  # Also, check the parameters
263
  if (bottom_in_type == "params") {
8✔
264
    if (is.null(distr)) {
3✔
265
      stop("Input error: if bottom_in_type = 'params', distr must be specified")
×
266
    }
267
    if (!(distr %in% .DISCR_DISTR)) {
3✔
NEW
268
      stop(paste0(
×
NEW
269
        "Input error: distr must be one of {",
×
NEW
270
        paste(.DISCR_DISTR, collapse = ", "), "}"
×
271
      ))
272
    }
273
    for (i in 1:n_b) {
3✔
274
      .check_distr_params(distr, base_fc_bottom[[i]])
36✔
275
    }
276
  }
277
}
278

279
.check_input_t <- function(A, base_fc_mean, y_train, residuals, ...) {
280
  .check_A(A)
13✔
281

282
  n_b <- ncol(A) # number of bottom TS
13✔
283
  n_u <- nrow(A) # number of upper TS
13✔
284

285
  if (!is.vector(base_fc_mean)) {
13✔
NEW
286
    stop("Input error: base_fc_mean must be a vector")
×
287
  }
288
  n <- length(base_fc_mean)
13✔
289
  if (n_u + n_b != n) {
13✔
NEW
290
    stop("Input error: the length of base_fc_mean must be equal to nrow(A) + ncol(A)")
×
291
  }
292
  
293
  add_args <- list(...)
13✔
294
  prior <- add_args$prior
13✔
295
  posterior <- add_args$posterior
13✔
296
  freq <- add_args$freq
13✔
297
  criterion <- add_args$criterion
13✔
298

299
  ##############################################################################
300
  ### CASE 1 ###
301
  # If posterior is provided, check if is a list with entries nu and Psi and extract values
302
  if (!is.null(posterior)) {
13✔
303
    if (is.list(posterior)) {
11✔
304
      nu_post <- posterior$nu
11✔
305
      Psi_post <- posterior$Psi
11✔
306
      if (is.null(nu_post) | is.null(Psi_post)) {
11✔
NEW
307
        stop("Input error: posterior must be a list with entries nu and Psi")
×
308
      } else if (!is.numeric(nu_post) | length(nu_post) != 1 | nu_post <= n-1) {
11✔
NEW
309
        stop("Input error: nu in posterior must be a number greater than n. of series - 1")
×
310
      } else if (!is.matrix(Psi_post) | any(dim(Psi_post) != c(n, n))) {
11✔
NEW
311
        stop("Input error: Psi in posterior must be a matrix with dimensions compatible with base_fc_mean")
×
312
      }
313
      # If posterior is provided, then y_train, residuals, and prior are ignored:
314
      # if they are provided, throw a warning
315
      if (!is.null(y_train)) {
11✔
NEW
316
        warning("Input warning: posterior is provided, ignoring y_train")
×
317
      }
318
      if (!is.null(residuals)) {
11✔
NEW
319
        warning("Input warning: posterior is provided, ignoring residuals")
×
320
      }
321
      if (!is.null(prior)) {
11✔
NEW
322
        warning("Input warning: posterior is provided, ignoring prior")
×
323
      }
324
    } else {
NEW
325
      stop("Input error: posterior must be a list with entries nu and Psi")
×
326
    }
327

328
    ### CASE 2 ###
329
    # If posterior not provided, first check that residuals are provided
330
  } else {
331
    if (is.null(residuals)) {
2✔
NEW
332
      stop("Input error: either posterior or residuals must be provided")
×
333
    }
334
    if (!is.matrix(residuals)) {
2✔
NEW
335
      stop("Input error: residuals must be a matrix")
×
336
    }
337
    if (ncol(residuals) != n) {
2✔
NEW
338
      stop("Input error: number of columns of residuals must be equal to length of base_fc_mean")
×
339
    }
340

341
    L <- nrow(residuals) # number of residual samples (i.e., training length)
2✔
342
    if (L < 10) {
2✔
NEW
343
      warning("Warning: number of rows of residuals is less than 10, covariance estimation may be inaccurate")
×
344
    }
345
    # TODO: implement fallback
346

347
    ### CASE 2a ###
348
    # If prior is provided, check if is a list with entries nu and Psi and extract values
349
    if (!is.null(prior)) {
2✔
350
      if (is.list(prior)) {
2✔
351
        nu_prior <- prior$nu
2✔
352
        Psi_prior <- prior$Psi
2✔
353
        if (is.null(nu_prior) | is.null(Psi_prior)) {
2✔
NEW
354
          stop("Input error: prior must be a list with entries nu and Psi")
×
355
        } else if (!is.numeric(nu_prior) | length(nu_prior) != 1 | nu_prior <= n-1) {
2✔
NEW
356
          stop("Input error: nu in prior must be a number greater than n. of series - 1")
×
357
        } else if (!is.matrix(Psi_prior) | any(dim(Psi_prior) != c(n, n))) {
2✔
NEW
358
          stop("Input error: Psi in prior must be a matrix with dimensions compatible with base_fc_mean")
×
359
        }
360
        # If prior is provided, then y_train is ignored: if it is provided, throw a warning
361
        if (!is.null(y_train)) {
2✔
NEW
362
          warning("Input warning: prior is provided, ignoring y_train")
×
363
        }
364
      } else {
NEW
365
        stop("Input error: prior must be a list with entries nu and Psi")
×
366
      }
367
    }
368

369
    ### CASE 2b ###
370
    # If prior not provided:
371
    # - compute Psi using the (shrinked) covariance matrix of the residuals of the naive
372
    #   or seasonal naive forecasts
373
    # - set nu using LOOCV
374
    else {
NEW
375
      if (is.null(y_train)) {
×
NEW
376
        stop("Input error: y_train must be provided when neither prior nor posterior are given")
×
377
      }
NEW
378
      if (!is.matrix(y_train)) {  
×
NEW
379
        stop("Input error: y_train must be a matrix")
×
380
      }
381
      # this works also if y_train is a mts; TODO: allow for other types, e.g. data.frame
NEW
382
      if (ncol(y_train) != n) {
×
NEW
383
        stop("Input error: number of columns of y_train must be equal to length of base_fc_mean")
×
384
      }
NEW
385
      if (nrow(y_train) != L) {
×
NEW
386
        warning("Numbers of rows of y_train and of residuals are different!")
×
387
      }
388
      
NEW
389
      if (!is.null(freq)) {
×
NEW
390
        if (!is.numeric(freq) | length(freq) != 1 | freq < 1 | (freq %% 1) != 0) {
×
NEW
391
          stop("Input error: freq must be a positive integer")
×
392
        }
393
      }
394
      
395
      # Current logic: 
396
      # * if y_train is a mts object: check if freq is provided and is equal to the frequency of y_train; 
397
      #                               if they are different, throw a warning and ignore the frequency of y_train (see compute_naive_cov function)
398
      # * if y_train is not a mts object: check if freq is provided and is a positive integer; 
399
      #                                   if not, throw a warning  
NEW
400
      if (stats::is.mts(y_train)) {
×
NEW
401
        if (!is.null(freq) && freq != stats::frequency(y_train)) {
×
NEW
402
          warning("Input warning: the provided freq is different from the frequency of y_train.
×
NEW
403
                   The frequency of y_train will be ignored.")
×
404
        } 
405
      } else {
NEW
406
        if (is.null(freq)) {
×
NEW
407
          warning("Input warning: y_train is not a mts object and freq is not provided. 
×
NEW
408
                   The prior will be set assuming no seasonality.")
×
409
        }
410
      }
411
      
NEW
412
      if (!is.null(criterion)) {
×
NEW
413
        if (!(criterion %in% c("RSS", "seas-test"))) {
×
NEW
414
          stop("Input error: criterion must be either 'RSS' or 'seas-test'")
×
415
        }
416
      }
417
    }
418
  }
419
}
420

421
# Check importance sampling weights
422
.check_weights <- function(w, n_eff_min = 200, p_n_eff = 0.01) {
423
  warning <- FALSE
146✔
424
  warning_code <- c()
146✔
425
  warning_msg <- c()
146✔
426

427
  n <- length(w)
146✔
428
  n_eff <- n
146✔
429

430
  # 1. w==0
431
  if (all(w == 0)) {
146✔
432
    warning <- TRUE
2✔
433
    warning_code <- c(warning_code, 1)
2✔
434
    warning_msg <- c(
2✔
435
      warning_msg,
2✔
436
      "Importance Sampling: all the weights are zeros. This is probably caused by a strong incoherence between bottom and upper base forecasts."
2✔
437
    )
438
  } else {
439
    # Effective sample size
440
    w <- w / sum(w)
144✔
441
    n_eff <- 1 / sum(w^2)
144✔
442

443
    # 2. n_eff < threshold
444
    if (n_eff < n_eff_min) {
144✔
445
      warning <- TRUE
1✔
446
      warning_code <- c(warning_code, 2)
1✔
447
      warning_msg <- c(
1✔
448
        warning_msg,
1✔
449
        paste0("Importance Sampling: effective_sample_size= ", round(n_eff, 2), " (< ", n_eff_min, ").")
1✔
450
      )
451
    }
452

453
    # 3. n_eff < p*n, e.g. p = 0.05
454
    if (n_eff < p_n_eff * n) {
144✔
455
      warning <- TRUE
1✔
456
      warning_code <- c(warning_code, 3)
1✔
457
      warning_msg <- c(
1✔
458
        warning_msg,
1✔
459
        paste0("Importance Sampling: effective_sample_size= ", round(n_eff, 2), " (< ", round(p_n_eff * 100, 2), "%).")
1✔
460
      )
461
    }
462
  }
463
  res <- list(
146✔
464
    warning = warning,
146✔
465
    warning_code = warning_code,
146✔
466
    warning_msg = warning_msg,
146✔
467
    n_eff = n_eff
146✔
468
  )
469

470
  return(res)
146✔
471
}
472

473
################################################################################
474
# SAMPLE
475

476
# Sample from one of the implemented distributions
477
.distr_sample <- function(params, distr, n) {
478
  .check_distr_params(distr, params)
178✔
479
  switch(distr,
176✔
480
    "gaussian" = {
481
      mean <- params$mean
101✔
482
      sd <- params$sd
101✔
483
      samples <- stats::rnorm(n = n, mean = mean, sd = sd)
101✔
484
    },
485
    "poisson" = {
486
      lambda <- params$lambda
49✔
487
      samples <- stats::rpois(n = n, lambda = lambda)
49✔
488
    },
489
    "nbinom" = {
490
      size <- params$size
26✔
491
      prob <- params$prob
26✔
492
      mu <- params$mu
26✔
493
      if (!is.null(prob)) {
26✔
494
        samples <- stats::rnbinom(n = n, size = size, prob = prob)
1✔
495
      } else if (!is.null(mu)) {
25✔
496
        samples <- stats::rnbinom(n = n, size = size, mu = mu)
25✔
497
      }
498
    },
499
  )
500
  return(samples)
176✔
501
}
502

503
# Sample from a multivariate Gaussian distribution with specified mean and cov. matrix
504
.MVN_sample <- function(n_samples, mu, Sigma) {
505
  n <- length(mu)
5✔
506
  if (any(dim(Sigma) != c(n, n))) {
5✔
NEW
507
    stop("Dimensions of mean and covariance matrix are not compatible!")
×
508
  }
509
  .check_cov(Sigma, "Covariance matrix", pd_check = FALSE, symm_check = FALSE)
5✔
510

511
  Z <- matrix(stats::rnorm(n * n_samples), ncol = n)
5✔
512

513
  Ch <- tryCatch(base::chol(Sigma),
5✔
514
    error = function(e) stop(paste0(e, "check the covariance in .MVN_sample, the Cholesky fails."))
5✔
515
  )
516

517
  samples <- Z %*% Ch + matrix(mu, nrow = n_samples, ncol = n, byrow = TRUE)
5✔
518
  return(samples)
5✔
519
}
520

521
# Compute the MVN density
522
.MVN_density <- function(x, mu, Sigma, max_size_x = 5e3, suppress_warnings = TRUE) {
523
  # save dimension of mu
524
  n <- length(mu)
7✔
525

526
  # Check Sigma
527
  if (any(dim(Sigma) != c(n, n))) {
7✔
528
    stop("Dimension of mu and Sigma are not compatible!")
×
529
  }
530
  .check_cov(Sigma, "Sigma", pd_check = FALSE, symm_check = FALSE)
7✔
531

532
  # x must be a matrix with ncol = n (nrow is the number of points to evaluate)
533
  # or a vector with length n (in which case it is transformed into a matrix)
534
  if (is.vector(x)) {
7✔
NEW
535
    if (length(x) != n) stop("Length of x must be the same of mu")
×
NEW
536
    x <- matrix(x, ncol = length(x))
×
537
  } else if (is.matrix(x)) {
7✔
NEW
538
    if (ncol(x) != n) stop("The number of columns of x must be equal to the length of mu")
×
539
  } else {
540
    stop("x must be either a vector or a matrix")
×
541
  }
542

543
  # Compute Cholesky of Sigma
544
  chol_S <- tryCatch(base::chol(Sigma),
7✔
545
    error = function(e) stop(paste0(e, "check the covariance in .MVN_density, the Cholesky fails."))
7✔
546
  )
547

548
  # Constant of the loglikelihood (computed here because it is always the same)
549
  const <- -sum(log(diag(chol_S))) - 0.5 * n * log(2 * pi)
7✔
550

551
  # This part breaks down the evaluation of the density eval into batches, for memory
552
  rows_x <- nrow(x)
7✔
553

554
  if (rows_x > max_size_x) {
7✔
555
    logval <- rep(0, rows_x)
5✔
556

557
    # Compute how many batches we need
558
    num_backsolves <- rows_x %/% max_size_x
5✔
559

560
    if (!suppress_warnings) {
5✔
NEW
561
      warning_msg <- paste0("x has ", rows_x, " rows, the density evaluation is broken down into ", num_backsolves, " pieces for memory preservation.")
×
UNCOV
562
      warning(warning_msg)
×
563
    }
564

565
    for (j in seq(num_backsolves)) {
5✔
566
      idx_to_select <- (1 + (j - 1) * max_size_x):((j) * max_size_x)
18✔
567
      # Do one backsolve for each batch
568
      tmp <- backsolve(chol_S, t(x[idx_to_select, ]) - mu, transpose = TRUE)
18✔
569
      rss <- colSums(tmp^2)
18✔
570

571
      # Update the logval for those indices
572
      logval[idx_to_select] <- const - 0.5 * rss
18✔
573
    }
574

575
    # Last indices: if the number of rows of x is not exactly divided by the size of the batches
576
    remainder <- rows_x %% max_size_x
5✔
577
    if (remainder != 0) {
5✔
NEW
578
      idx_to_select <- (1 + (num_backsolves) * max_size_x):(remainder + (num_backsolves) * max_size_x)
×
579
      # Do backsolve on the remaining indices
NEW
580
      tmp <- backsolve(chol_S, t(x[idx_to_select, ]) - mu, transpose = TRUE)
×
581
      rss <- colSums(tmp^2)
×
582

583
      logval[idx_to_select] <- const - 0.5 * rss
×
584
    }
585
  } else {
586
    tmp <- backsolve(chol_S, t(x) - mu, transpose = TRUE)
2✔
587
    rss <- colSums(tmp^2)
2✔
588

589
    logval <- const - 0.5 * rss
2✔
590
  }
591

592
  return(exp(logval))
7✔
593
}
594

595
# Resample from weighted sample
596
.resample <- function(S_, weights, num_samples = NA) {
597
  if (is.na(num_samples)) {
143✔
598
    num_samples <- length(weights)
139✔
599
  }
600

601
  if (nrow(S_) != length(weights)) {
143✔
UNCOV
602
    stop("Error in .resample: nrow(S_) != length(weights)")
×
603
  }
604

605
  tmp_idx <- sample(x = 1:nrow(S_), num_samples, replace = TRUE, prob = weights)
143✔
606
  return(S_[tmp_idx, ])
143✔
607
}
608

609
################################################################################
610
# Miscellaneous
611

612
# Compute the pmf of the distribution specified by distr and params at the points x
613
.distr_pmf <- function(x, params, distr) {
614
  .check_distr_params(distr, params)
5,656,087✔
615
  switch(distr,
5,656,087✔
616
    "gaussian" = {
617
      mean <- params$mean
78✔
618
      sd <- params$sd
78✔
619
      pmf <- stats::dnorm(x = x, mean = mean, sd = sd)
78✔
620
    },
621
    "poisson" = {
622
      lambda <- params$lambda
5,655,992✔
623
      pmf <- stats::dpois(x = x, lambda = lambda)
5,655,992✔
624
    },
625
    "nbinom" = {
626
      size <- params$size
17✔
627
      prob <- params$prob
17✔
628
      mu <- params$mu
17✔
629
      if (!is.null(prob)) {
17✔
630
        pmf <- stats::dnbinom(x = x, size = size, prob = prob)
1✔
631
      } else if (!is.null(mu)) {
16✔
632
        pmf <- stats::dnbinom(x = x, size = size, mu = mu)
16✔
633
      }
634
    },
635
  )
636
  return(pmf)
5,656,087✔
637
}
638

639
.shape <- function(m) {
640
  print(paste0("(", nrow(m), ",", ncol(m), ")"))
×
641
}
642

643
################################################################################
644
# Functions for tests
645

646
.gen_gaussian <- function(params_file, seed = NULL) {
647
  if (!is.null(seed)) set.seed(seed)
1✔
648
  params <- utils::read.csv(file = params_file, header = FALSE)
1✔
649
  out <- list()
1✔
650
  for (i in 1:nrow(params)) {
1✔
651
    out[[i]] <- stats::rnorm(n = 1e6, mean = params[[1]][i], sd = params[[2]][i])
28✔
652
  }
653
  return(out)
1✔
654
}
655

656
.gen_poisson <- function(params_file, seed = NULL) {
657
  if (!is.null(seed)) set.seed(seed)
1✔
658
  params <- utils::read.csv(file = params_file, header = FALSE)
1✔
659
  out <- list()
1✔
660
  for (i in 1:nrow(params)) {
1✔
661
    out[[i]] <- stats::rpois(n = 1e6, lambda = params[[1]][i])
28✔
662
  }
663
  return(out)
1✔
664
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc