Coveralls logob
Coveralls logo
  • Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

QuantEcon / BasisMatrices.jl / 147

6 Dec 2017 - 13:20 coverage: 92.635%. First build
147

Pull #46

travis-ci

9181eb84f9c35729a3bad740fb7f9d93?size=18&default=identiconweb-flow
ENH: full/proper support for multi-dimensional BasisParams

closes #45
Pull Request #46: ENH: full/proper support for multi-dimensional BasisParams

124 of 132 new or added lines in 7 files covered. (93.94%)

1283 of 1385 relevant lines covered (92.64%)

506206.2 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.71
/src/interp.jl
1
# ---------------- #
2
# Fitting routines #
3
# ---------------- #
4

5
# Tensor representation, single function method; calls function that computes coefficients below below
6
function get_coefs(basis::Basis, bs::BasisMatrix{Tensor}, y::Vector{T}) where T
7
    _get_coefs_deep(basis, bs, y)[:,1]
20×
8
end
9

10
# Tensor representation, multiple function method; calls function that computes coefficients below
11
function get_coefs(basis::Basis, bs::BasisMatrix{Tensor}, y::Matrix{T}) where T
12
    _get_coefs_deep(basis, bs, y)
!
13
end
14

15
function _get_coefs_deep(basis::Basis, bs::BasisMatrix{Tensor}, y)
16
    if any(bs.order.order[1, :] .!= 0)
20×
17
        error("invalid basis structure - first elements must be order 0")
!
18
    end
19
    to_kron = bs.vals[1, :]  # 68
20×
20
    ckronxi(to_kron, y, ndims(basis):-1:1)  # 66
20×
21
end
22

23
# convert to expanded and call the method below
24
function get_coefs(basis::Basis, bs::BasisMatrix{Direct}, y)
25
    get_coefs(basis, convert(Expanded, bs), y)
6×
26
end
27

28
get_coefs(basis::Basis, bs::BasisMatrix{Expanded}, y) = bs.vals[1] \ y
6×
29

30
# common checks to be run at the top of each funfit
31
function check_funfit(basis::Basis, x, y)
32
    m = size(y, 1)
14×
33
    length(basis) > m && error("Can't be more basis funcs than points in y")
14×
34
    return m
14×
35
end
36

37
# get_coefs(::Basis, ::BasisMatrix, ::Array) does almost all the work for
38
# these methods
39
function funfitxy(basis::Basis, bs::BasisMatrix, y)
40
    check_funfit(basis, bs, y)
!
41
    c = get_coefs(basis, bs, y)
!
42
    c, bs
!
43
end
44

45
# use tensor form
46
function funfitxy(basis::Basis, x::TensorX, y)
47
    m = check_funfit(basis, x, y)
8×
48

49
    bs = BasisMatrix(basis, Tensor(), x, 0)
8×
50
    c = get_coefs(basis, bs, y)
8×
51
    c, bs
8×
52
end
53

54
function funfitxy(basis::Basis, x, y)
55
    # check input sizes
56
    m = check_funfit(basis, x, y)
6×
57

58
    # additional check
59
    size(x, 1) != m && error("x and y are incompatible")
6×
60

61
    bs = BasisMatrix(basis, Direct(), x, 0)
6×
62
    c = get_coefs(basis, bs, y)
6×
63
    c, bs
6×
64
end
65

66
function funfitf(basis::Basis, f::Function, args...)
67
    X, xn = nodes(basis)
4×
68
    y = f(X, args...)
4×
69
    funfitxy(basis, xn, y)[1]
4×
70
end
71

72
function Base.:\(b::Basis, y::AbstractArray)
73
    x123 = nodes(b)[2]
!
74
    funfitxy(b, x123, y)[1]
!
75
end
76

77
Base.:\(b::Basis, f::Function) = funfitf(b, f)
!
78

79
# ---------- #
80
# Evaluation #
81
# ---------- #
82
function _extract_inds(bm::BasisMatrix, order::AbstractMatrix{Int})
83
    d = size(order, 2)
223×
84

85
    # column ranges for each entry in `bm.vals`
86
    cols = _dims_to_colspans(bm.order)
223×
87

88
    # allocate the output
89
    out = Array{Int}(size(order, 1), length(cols))
223×
90
    val_size = size(bm.vals)
223×
91

92
    for row in 1:size(out, 1)
223×
93
        for (chunk_ix, chunk) in enumerate(cols)
233×
94
            success = false
356×
95
            for row_have in 1:size(bm.order.order, 1)
356×
96
                if order[row, chunk] == bm.order.order[row_have, chunk]
373×
97
                    out[row, chunk_ix] = sub2ind(val_size, row_have, chunk_ix)
356×
98
                    success = true
356×
99
                    break
356×
100
                end
101
            end
102
            if !success
356×
NEW
103
                m = "Couldn't find $(order[row, chunk]) in BasisMatrix"
!
NEW
104
                error(m)
!
105
            end
106
        end
107
    end
108
    flipdim(out, 2)
223×
109
end
110

111
function _funeval(c, bs::BasisMatrix{Tensor}, order::AbstractMatrix{Int})  # funeval1
112
    kk, d = size(order)  # 95
6×
113

114
    # 98 reverse the order of evaluation: B(d) �� B(d-1) �� ��� �� B(1)
115
    inds = _extract_inds(bs, order)
6×
116

117
    # 99
118
    nx = prod([size(bs.vals[1, j], 1) for j=1:d])
6×
119

120
    _T = promote_type(eltype(c), eltype(bs))
6×
121
    f = Array{_T,3}(nx, size(c, 2), kk)  # 100
6×
122

123
    for i in 1:kk
6×
124
        f[:, :, i] = ckronx(bs.vals, c, inds[i, :])  # 102
6×
125
    end
126
    f
6×
127
end
128

129
function _funeval(c, bs::BasisMatrix{Direct}, order::AbstractMatrix{Int})  # funeval2
130
    kk, d = size(order)  # 95
176×
131
    # 114 reverse the order of evaluation: B(d)xB(d-1)x...xB(1)
132
    inds = _extract_inds(bs, order)
176×
133

134
    _T = promote_type(eltype(c), eltype(bs))
176×
135
    f = Array{_T,3}(size(bs.vals[1], 1), size(c, 2), kk)  # 116
176×
136

137
    for i in 1:kk
176×
138
        f[:, :, i] = cdprodx(bs.vals, c, inds[i, :])  # 118
176×
139
    end
140
    f
176×
141
end
142

143
function _funeval(c, bs::BasisMatrix{Expanded}, order::AbstractMatrix{Int})  # funeval3
144
    nx = size(bs.vals[1], 1)
6×
145
    kk = size(order, 1)
6×
146

147
    _T = promote_type(eltype(c), eltype(bs))
6×
148
    f = Array{_T,3}(nx, size(c, 2), kk)
6×
149
    for i=1:kk
6×
150
        this_order = order[i, :]
6×
151
        ind = findfirst(x->bs.order.order[x, :] == this_order, 1:kk)
12×
152
        if ind == 0
6×
153
            msg = string("Requested order $(this_order) not in BasisMatrix ",
3×
154
                         "with order $(bs.order)")
155
            error(msg)
3×
156
        end
157
        f[:, :, i] = bs.vals[ind]*c  # 154
3×
158
    end
159

160
    f
3×
161
end
162

163
# 1d basis + x::Number + c::Mat => 1 point, many func ==> out 1d
164
funeval(c::AbstractMatrix, basis::Basis{1}, x::Real, order=0) =
150×
165
    vec(funeval(c, basis, fill(x, 1, 1), order))
166

167
# 1d basis + x::Number + c::Vec => 1 point, 1 func ==> out scalar
168
funeval(c::AbstractVector, basis::Basis{1}, x::Real, order=0) =
150×
169
    funeval(c, basis, fill(x, 1, 1), order)[1]
170

171
# 1d basis + x::Vec + c::Mat => manypoints, many func ==> out 2d
172
funeval(c::AbstractMatrix, basis::Basis{1}, x::AbstractVector{T}, order=0) where {T<:Number} =
6×
173
    funeval(c, basis, x[:, :], order)
174

175
# 1d basis + x::Vec + c::Vec => manypoints, 1 func ==> out 1d
176
funeval(c::AbstractVector, basis::Basis{1}, x::AbstractVector{T}, order=0) where {T<:Number} =
6×
177
    vec(funeval(c, basis, reshape(x, length(x), 1), order))
178

179
# N(>1)d basis + x::Vec + c::Vec ==> 1 point, 1 func ==> out scalar
180
funeval(c::AbstractVector, basis::Basis{N}, x::AbstractVector{T}, order=0) where {N,T<:Number} =
181
    funeval(c, basis, reshape(x, 1, N), order)[1]
182

183
# N(>1)d basis + x::Vec + c::Mat ==> 1 point, many func ==> out vec
184
funeval(c::AbstractMatrix, basis::Basis{N}, x::AbstractVector{T}, order=0) where {N,T<:Number} =
185
    vec(funeval(c, basis, reshape(x, 1, N), order))
186

187
function funeval(c, basis::Basis{N}, x::TensorX, order::Int=0) where N
188
    # check inputs
189
    size(x, 1) == N ||  error("x must have d=$N elements")
9×
190

191
    if order != 0
6×
192
        msg = string("passing order as integer only allowed for $(order=0).",
3×
193
                     " Try calling the version where `order` is a matrix")
194
        error(msg)
3×
195
    end
196

197
    _order = fill(0, 1, N)
3×
198
    bs = BasisMatrix(SplineSparse, basis, Tensor(), x, _order)  # 67
3×
199
    funeval(c, bs, _order)
3×
200
end
201

202
function funeval(c, basis::Basis{N}, x::TensorX, _order::AbstractMatrix) where N
203
    # check inputs
204
    size(x, 1) == N ||  error("x must have d=$N elements")
2×
205
    order = _check_order(N, _order)
2×
206

207
    # construct tensor form
208
    bs = BasisMatrix(SparseMatrixCSC, basis, Tensor(), x, order)  # 67
2×
209

210
    # pass to specialized method below
211
    return funeval(c, bs, order)
2×
212
end
213

214
function funeval(c, basis::Basis{N}, x::AbstractMatrix, order::Int=0) where N
215
    # check inputs
216
    @boundscheck size(x, 2) == N || error("x must have d=$(N) columns")
180×
217

218
    if order != 0
174×
219
        msg = string("passing order as integer only allowed for $(order=0).",
3×
220
                     " Try calling the version where `order` is a matrix")
221
        error(msg)
3×
222
    end
223

224
    _order = fill(0, 1, N)
171×
225
    bs = BasisMatrix(SplineSparse, basis, Direct(), x, _order)  # 67
171×
226
    _out = funeval(c, bs, _order)
171×
227

228
    # we only had one order, so we want to collapse the third dimension of _out
229
    return _out[:, :, 1]
171×
230
end
231

232
function funeval(c, basis::Basis{N}, x::AbstractMatrix, _order::AbstractMatrix) where N
233
    # check that inputs are conformable
234
    @boundscheck size(x, 2) == N || error("x must have d=$(N) columns")  # 62
2×
235
    order = _check_order(N, _order)
2×
236

237
    # construct BasisMatrix in Direct for
238
    bs = BasisMatrix(SplineSparse, basis, Direct(), x, order)  # 67
2×
239

240
    # pass of to specialized method below
241
    funeval(c, bs, order)
2×
242
end
243

244
function funeval(c::AbstractVector, bs::BasisMatrix, order::AbstractMatrix{Int})
245
    _funeval(c, bs, order)[:, 1, :]
104×
246
end
247

248
function funeval(c::AbstractMatrix, bs::BasisMatrix, order::AbstractMatrix{Int})
249
    _funeval(c, bs, order)
78×
250
end
251

252
# default method
253
function funeval(c::AbstractVector, bs::BasisMatrix,
254
                 order::Vector{Int}=fill(0, size(bs.order.order, 2)))
255
    _funeval(c, bs, reshape(order, 1, length(order)))[:, 1, 1]
12×
256
end
257

258
function funeval(c::AbstractMatrix, bs::BasisMatrix,
259
                 order::Vector{Int}=fill(1, size(bs.order.order, 2)))
260
    _funeval(c, bs, reshape(order, 1, length(order)))[:, :, 1]
!
261
end
262

263
# ------------------------------ #
264
# Convenience `Interpoland` type #
265
# ------------------------------ #
266

267
mutable struct Interpoland{TB<:Basis,TC<:AbstractArray,TBM<:BasisMatrix{Tensor}}
268
    basis::TB  # the basis -- can't change
24×
269
    coefs::TC  # coefficients -- might change
270
    bmat::TBM  # BasisMatrix at nodes of `b` -- can't change
271
end
272

273
function Interpoland(basis::Basis, bs::BasisMatrix{Tensor}, y::AbstractArray)
274
    c = get_coefs(basis, bs, y)
12×
275
    Interpoland(basis, c, bs)
12×
276
end
277

278
# compute Tensor form and hand off to method above
279
function Interpoland(basis::Basis, y::AbstractArray)
280
    bs = BasisMatrix(basis, Tensor())
6×
281
    Interpoland(basis, bs, y)
6×
282
end
283

284
"""
285
Construct an Interpoland from a function.
286

287
The function must have the signature `f(::AbstractMatrix)::AbstractArray`
288
where each column of the input matrix is a vector of values along a single
289
dimension
290
"""
291
function Interpoland(basis::Basis, f::Function)
292
    x, xd = nodes(basis)
3×
293
    y = f(x)
3×
294
    bs = BasisMatrix(basis, Tensor(), xd)
3×
295
    Interpoland(basis, bs, y)
3×
296
end
297

298
Interpoland(p::BasisParams, f::Function) = Interpoland(Basis(p), f)
!
299

300
# let funeval take care of order and such. This just exists to make it so the
301
# user doesn't have to keep track of the coefficient vector
302
(itp::Interpoland)(x, order=0) = funeval(itp.coefs, itp.basis, x, order)
18×
303

304
# now, given a new vector of `y` data we construct a new coefficient vector
305
function update_coefs!(interp::Interpoland, y::AbstractArray)
306
    # leverage the BasisMatrix we kept around
307
    c = funfitxy(interp.basis, interp.bmat, y)[1]
!
308
    copy!(interp.coefs, c)  # update c inplace b/c Interpoland is immutable
!
309
end
310

311
# similar for a function -- just hand off to above
312
update_coefs!(interp::Interpoland, f::Function) =
!
313
    update_coefs!(interp, f(nodes(interp.basis)[1]))
314

315
# alias update_coefs! to fit!
316
fit!(interp::Interpoland, y::AbstractArray) = update_coefs!(interp, y)
!
317
fit!(interp::Interpoland, f::Function) = update_coefs!(interp, f)
!
318

319
Base.show(io::IO, ::Interpoland{T,N,BST}) where {T,N,BST<:ABSR} =
320
    print(io, "$N dimensional interpoland")
Troubleshooting · Open an Issue · Sales · Support · ENTERPRISE · CAREERS · STATUS
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2023 Coveralls, Inc