Coveralls logob
Coveralls logo
  • Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

theogf / KLDivergences.jl / 1045860141

19 Jul 2021 - 16:14 coverage: 60.0% (-17.8%) from 77.778%
1045860141

Pull #8

github

GitHub
Merge 226539b8a into cdc031c57
Pull Request #8: Loosen MvNormal and add StatsBase inheritance

1 of 7 new or added lines in 2 files covered. (14.29%)

12 of 20 relevant lines covered (60.0%)

496221.45 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

66.67
/src/KLDivergences.jl
1
module KLDivergences
2

3
using Distributions: StatsBase
4
using Distributions
5
using LinearAlgebra
6
using PDMats
7
using Distances
8
using SpecialFunctions
9
using StatsBase: StatsBase, kldivergence
10

11

12
export KL, kldivergence
13

14
"""
15
    KL(p::Distribution, q::Distribution) -> T
16
    KL(p::Distribution, q::Distribution, n_samples::Int) -> T
17

18
Return the KL divergence of  KL(p||q), either by sampling or analytically
19
"""
20
KL
21

NEW
22
StatsBase.kldivergence(p::Sampleable, q::Sampleable) = KL(p, q)
!
23

24
KLbase(p, q, x) = logpdf(p, x) - logpdf(q, x)
5,412,048×
25

26
## Generic fallback for multivariate Distributions
27
function KL(p::UnivariateDistribution, q::UnivariateDistribution, n_samples = 1_000)
28
    return mean(x->KLbase(p, q, x), rand(p, n_samples))
3,612,162×
29
end
30

31
function KL(p::MultivariateDistribution, q::MultivariateDistribution, n_samples = 1_000)
32
    length(p) == length(q) ||
12×
33
        throw(DimensionMismatch("Dimensions of p and q do not match"))
34
    return mean(x->KLbase(p, q, x), eachcol(rand(p, n_samples)))
900,009×
35
end
36

37
function symmetricKL(p, q)
38
    return KL(p, q) + KL(q, p)
!
39
end
40

41
include("univariate.jl")
42
include("multivariate.jl")
43

44
end
Troubleshooting · Open an Issue · Sales · Support · ENTERPRISE · CAREERS · STATUS
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2023 Coveralls, Inc