• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

QuantEcon / BasisMatrices.jl / 147

6 Dec 2017 - 21:20 coverage: 92.635%. First build
147

Pull #46

travis-ci

web-flow
ENH: full/proper support for multi-dimensional BasisParams

closes #45
Pull Request #46: ENH: full/proper support for multi-dimensional BasisParams

124 of 132 new or added lines in 7 files covered. (93.94%)

1283 of 1385 relevant lines covered (92.64%)

506206.2 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.62
/src/basis_structure.jl
1
# ------------ #
2
# BMOrder Type #
3
# ------------ #
4

5
struct BMOrder
6
    dims::Vector{Int}
288×
7
    order::Matrix{Int}
8
end
9

10
function ==(bmo1::BMOrder, bmo2::BMOrder)
11
    bmo1.order == bmo2.order && bmo1.dims == bmo2.dims
23×
12
end
13

14
function Base.convert(::Type{BMOrder}, order::Matrix{Int})
15
    BMOrder(ones(size(order, 2)), order)
26×
16
end
17

18
function _dims_to_colspans(dims::Vector{Int})
19
    # column ranges for each entry in `bm.vals`
20
    cols = Array{typeof(1:2)}(length(dims))
457×
21
    cols[1] = 1:dims[1]
457×
22
    for i in 2:length(cols)
457×
23
        start = last(cols[i-1]+1)
180×
24
        cols[i] = (start):(start-1+dims[i])
180×
25
    end
26
    cols
457×
27
end
28

29
_dims_to_colspans(bmo::BMOrder) = _dims_to_colspans(bmo.dims)
223×
30

31
Base.size(bmo::BMOrder, i::Int) = size(bmo.order, i)
10×
32

33
# ---------------- #
34
# BasisMatrix Type #
35
# ---------------- #
36

37
abstract type AbstractBasisMatrixRep end
38
const ABSR = AbstractBasisMatrixRep
39

40
struct Tensor <: ABSR end
31×
41
struct Direct <: ABSR end
203×
42
struct Expanded <: ABSR end
8×
43

44
mutable struct BasisMatrix{BST<:ABSR, TM<:AbstractMatrix}
45
    order::BMOrder
262×
46
    vals::Matrix{TM}
47
end
48

49
Base.eltype(bm::BasisMatrix{BST,TM}) where {BST, TM} = eltype(TM)
188×
50

51
Base.show(io::IO, b::BasisMatrix{BST}) where {BST} =
4×
52
    print(io, "BasisMatrix{$BST} of order $(b.order)")
53

54
Base.ndims(bs::BasisMatrix) = size(bs.order, 2)
5×
55

56
# not the same if either type parameter is different
57
function ==(::BasisMatrix{BST1}, ::BasisMatrix{BST2}) where {BST1<:ABSR,BST2<:ABSR}
58
    false
5×
59
end
60

61
function ==(::BasisMatrix{BST,TM1},
62
            ::BasisMatrix{BST,TM2}) where {BST<:ABSR,TM1<:AbstractMatrix,TM2<:AbstractMatrix}
63
    false
1×
64
end
65

66
# if type parameters are the same, then it is the same if all fields are the
67
# same
68
function ==(b1::BasisMatrix{BST,TM},
69
            b2::BasisMatrix{BST,TM}) where {BST<:ABSR,TM<:AbstractMatrix}
70
    b1.order == b2.order && b1.vals == b2.vals
23×
71
end
72

73
# -------------- #
74
# Internal Tools #
75
# -------------- #
76

77
@inline function _checkx(N, x::AbstractMatrix)
78
    size(x, 2) != N && error("Basis is $N dimensional, x must have $N columns")
200×
79
    x
199×
80
end
81

82
@inline function _checkx(N, x::AbstractVector{T}) where T
83
    # if we have a 1d basis, we can evaluate at each point
84
    if N == 1
13×
85
        return x
11×
86
    end
87

88
    # If Basis is > 1d, one evaluation point and reshape to (1,N) if possible...
89
    if length(x) == N
2×
90
        return reshape(x, 1, N)
1×
91
    end
92

93
    # ... or throw an error
94
    error("Basis is $N dimensional, x must have $N elements")
1×
95
end
96

97
@inline function _checkx(N, x::TensorX)
98
    # for BasisMatrix{Tensor} family. Need one vector per dimension
99
    if length(x) == N
33×
100
        return x
32×
101
    end
102

103
    # otherwise throw an error
104
    error("Basis is $N dimensional, need one Vector per dimension")
1×
105
end
106

107
"""
108
Do common transformations to all constructor of `BasisMatrix`
109

110
##### Arguments
111

112
- `N::Int`: The number of dimensions in the corresponding `Basis`
113
- `x::AbstractArray`: The points for which the `BasisMatrix` should be
114
constructed
115
- `order::Array{Int}`: The order of evaluation for each dimension of the basis
116

117
##### Returns
118

119
- `m::Int`: the total number of derivative order basis functions to compute.
120
This will be the number of rows in the matrix form of `order`
121
- `order::Matrix{Int}`: A `m �� N` matrix that, for each of the `m` desired
122
specifications, gives the derivative order along all `N` dimensions
123
- `minorder::Matrix{Int}`: A `1 �� N` matrix specifying the minimum desired
124
derivative order along each dimension
125
- `numbases::Matrix{Int}`: A `1 �� N` matrix specifying the total number of
126
distinct derivative orders along each dimension
127
- `x::AbstractArray`: The properly transformed points at which to evaluate
128
the basis
129

130
"""
131
function check_basis_structure(N::Int, x, order)
132
    order = _check_order(N, order)
240×
133

134
    # initialize basis structure (66-74)
135
    m = size(order, 1)  # by this time order is a matrix
239×
136
    if m > 1
239×
137
        minorder = minimum(order, 1)
2×
138
        numbases = (maximum(order, 1) - minorder) + 1
2×
139
    else
140
        minorder = order + zeros(Int, 1, N)
237×
141
        numbases = fill(1, 1, N)
237×
142
    end
143

144
    x = _checkx(N, x)
239×
145

146
    return m, order, minorder, numbases, x
239×
147
end
148

149
function _unique_rows(mat::AbstractMatrix{T}) where {T}
150
    out = Vector{T}[]
321×
151
    for row in 1:size(mat, 1)
321×
152
        if mat[row, :] in out
329×
153
            continue
6×
154
        end
155
        push!(out, mat[row, :])
323×
156
    end
157

158
    # sort so we can leverage that fact in _extract_inds later
159
    # TODO: maybe consider this later...
160
    # sort!(collect(out), order=Base.Order.Lexicographic)
161
    out
321×
162
end
163

164
# --------------- #
165
# convert methods #
166
# --------------- #
167

168
function Base.convert(
169
        ::Type{T}, bs::BasisMatrix{T,TM}, _order=fill(0, 1, size(bs.order, 2))
170
    ) where {T,TM}
171
    order = _check_order(size(bs.order.order, 2), _order)
6×
172

173
    # unflip the inds because I don't want to do kroneckers, I just want to
174
    # extract up the basis matrices as they are
175
    inds = flipdim(_extract_inds(bs, order), 2)
3×
176

177
    nrow = size(order, 1)
3×
178
    ncol = size(bs.vals, 2)
3×
179

180
    vals = Array{TM}(nrow, ncol)
3×
181
    for row in 1:nrow
3×
182
        for col in 1:ncol
3×
183
            vals[row, col] = deepcopy(bs.vals[inds[row, col]])
7×
184
        end
185
    end
186

187
    bm_order = BMOrder(bs.order.dims, order)
3×
188
    BasisMatrix{T,TM}(order, vals)
3×
189
end
190

191
function _to_expanded(bs::BasisMatrix{T,TM}, _order, reducer::Function) where {T,TM}
192
    order = _check_order(size(bs.order.order, 2), _order)
23×
193
    inds = _extract_inds(bs, order)
23×
194

195
    nrow = size(inds, 1)
23×
196
    vals = Array{TM}(nrow, 1)
23×
197

198
    for row in 1:nrow
23×
199
        vals[row] = reduce(reducer, bs.vals[inds[row, :]])
23×
200
    end
201

202
    bm_order = BMOrder(bs.order.dims, order)
23×
203
    BasisMatrix{Expanded,TM}(order, vals)
23×
204
end
205

206

207
# funbconv from direct to expanded
208
function Base.convert(::Type{Expanded}, bs::BasisMatrix{Direct}, order=0)
209
    _to_expanded(bs, order, row_kron)
36×
210
end
211

212
# funbconv from tensor to expanded
213
function Base.convert(::Type{Expanded}, bs::BasisMatrix{Tensor}, order=0)
214
    _to_expanded(bs, order, kron)
2×
215
end
216

217
# funbconv from tensor to direct
218
# HACK: there is probably a more efficient way to do this, but since I don't
219
#       plan on doing it much, this will do for now. The basic point is that
220
#       we need to expand the rows of each element of `vals` so that all of
221
#       them have prod([size(v, 1) for v in bs.vals])) rows.
222
function Base.convert(
223
        ::Type{Direct}, bs::BasisMatrix{Tensor,TM},
224
        _order=fill(0, 1, size(bs.order, 2))
225
    ) where TM
226
    order = _check_order(size(bs.order.order, 2), _order)
4×
227
    numbas = size(order, 1)
2×
228

229
    # unflip the inds because I don't want to do kroneckers, I just want to
230
    # expand basis matrices in place
231
    inds = flipdim(_extract_inds(bs, order), 2)
2×
232

233
    N = size(bs.vals, 2)
2×
234
    vals = Array{TM}(size(inds))
2×
235

236
    for row in 1:size(inds, 1)
2×
237
        expansion_inds = gridmake(([size(x, 1) for x in bs.vals[inds[row, :]]]...))
2×
238
        for col in 1:N
2×
239
            vals[row, col] = bs.vals[row, col][expansion_inds[:, col], :]
6×
240
        end
241
    end
242

243
    bm_order = BMOrder(bs.order.dims, order)
2×
244
    BasisMatrix{Direct,TM}(bm_order, vals)
2×
245
end
246

247
# ------------ #
248
# Constructors #
249
# ------------ #
250

251
# method to construct BasisMatrix in direct or expanded form based on
252
# a matrix of `x` values  -- funbasex
253
function BasisMatrix(
254
        ::Type{T2}, basis::Basis{N,BF}, ::Direct,
255
        _x::AbstractArray=nodes(basis)[1], _order=0
256
    ) where {N,BF,T2}
257
    m, order, minorder, numbases, x = check_basis_structure(N, _x, _order)
207×
258
    Np = length(basis.params)
203×
259

260
    val_type = bmat_type(T2, basis, x)
203×
261
    vals = Array{val_type}(maximum(numbases), Np)
203×
262

263
    order_dims = collect(ndims.(basis.params))
203×
264
    colspans = _dims_to_colspans(order_dims)
203×
265
    order_vals = fill(typemax(Int), size(vals, 1), N)
203×
266

267
    for (i_params, params) in enumerate(basis.params)
203×
268
        cols = colspans[i_params]
254×
269
        orders_p = _unique_rows(order[:, cols])
254×
270

271
        if length(cols) == 1
254×
272
            _orders_1d = vcat(orders_p...)::Vector{Int}
252×
273
            rows = 1:length(_orders_1d)
252×
274
            vals[rows, i_params] = evalbase(T2, params, x[:, cols[1]], _orders_1d)
252×
275
            order_vals[rows, cols[1]] = _orders_1d
252×
276
        else  # multi-dim params
277
            for (i, ord) in enumerate(orders_p)
2×
278
                vals[i, i_params] = evalbase(T2, params, x[:, cols], ord)
2×
279
                order_vals[i, cols] = ord
2×
280
            end
281
        end
282
    end
283

284
    bm_order = BMOrder(order_dims, order_vals)
203×
285
    return BasisMatrix{Direct,val_type}(bm_order, vals)
203×
286
end
287

288
function BasisMatrix(::Type{T2}, basis::Basis, ::Expanded,
289
                     x::AbstractArray=nodes(basis)[1], order=0) where T2  # funbasex
290
    # create direct form, then convert to expanded
291
    bsd = BasisMatrix(T2, basis, Direct(), x, order)
8×
292
    convert(Expanded, bsd, bsd.order.order)
8×
293
end
294

295
function BasisMatrix(
296
        ::Type{T2}, basis::Basis{N,BT}, ::Tensor,
297
        _x::TensorX=nodes(basis)[2], _order=0
298
    ) where {N,BT,T2}
299

300
    m, order, minorder, numbases, x = check_basis_structure(N, _x, _order)
40×
301
    Np = length(basis.params)
31×
302

303
    val_type = bmat_type(T2, basis, x[1])
31×
304
    vals = Array{val_type}(maximum(numbases), Np)
31×
305

306
    order_dims = collect(ndims.(basis.params))
31×
307
    colspans = _dims_to_colspans(order_dims)
31×
308
    order_vals = fill(typemax(Int), size(vals, 1), N)
31×
309

310
    for (i_params, params) in enumerate(basis.params)
31×
311
        cols = colspans[i_params]
67×
312
        orders_p = _unique_rows(order[:, cols])
67×
313

314
        if length(cols) == 1
67×
315
            _orders_1d = vcat(orders_p...)::Vector{Int}
67×
316
            rows = 1:length(_orders_1d)
67×
317
            vals[rows, i_params] = evalbase(T2, params, x[i_params], _orders_1d)
67×
318
            order_vals[rows, cols[1]] = _orders_1d
67×
319
        else  # multi-dim params
NEW
320
            for (i, ord) in enumerate(orders_p)
!
NEW
321
                vals[i, i_params] = evalbase(T2, params, x[i_params], ord)
!
NEW
322
                order_vals[i, cols] = ord
!
323
            end
324
        end
325

326
    end
327

328
    bm_order = BMOrder(order_dims, order_vals)
31×
329
    return BasisMatrix{Tensor,val_type}(bm_order, vals)
31×
330
end
331

332
# When the user doesn't supply a ABSR, we pick one for them.
333
# for x::AbstractMatrix we pick direct
334
# for x::TensorX we pick Tensor
335
function BasisMatrix(::Type{T2}, basis::Basis, x::AbstractArray, order=0) where T2
336
    BasisMatrix(T2, basis, Direct(), x, order)
3×
337
end
338

339
function BasisMatrix(::Type{T2}, basis::Basis, x::TensorX, order=0) where T2
340
    BasisMatrix(T2, basis, Tensor(), x, order)
3×
341
end
342

343

344
# method to allow passing types instead of instances of ABSR
345
function BasisMatrix(::Type{T2}, basis, ::Type{BST},
346
                     x::Union{AbstractArray,TensorX}, order=0) where {BST<:ABSR,T2}
347
    BasisMatrix(T2, basis, BST(), x, order)
16×
348
end
349

350
function BasisMatrix(basis, ::Type{BST},
351
                     x::Union{AbstractArray,TensorX}, order=0) where BST<:ABSR
352
    BasisMatrix(basis, BST(), x, order)
10×
353
end
354

355
# method without vals eltypes
356
function BasisMatrix(basis::Basis, tbm::TBM,
357
                     x::Union{AbstractArray,TensorX}, order=0) where TBM<:ABSR
358
    BasisMatrix(Void, basis, tbm, x, order)
41×
359
end
360

361
function BasisMatrix(basis::Basis, x::Union{AbstractArray,TensorX}, order=0)
362
    BasisMatrix(Void, basis, x, order)
4×
363
end
364

365
# method without x
366
function BasisMatrix(basis::Basis, tbm::Union{Type{TBM},TBM}) where TBM<:ABSR
367
    BasisMatrix(Void, basis, tbm)
12×
368
end
Troubleshooting · Open an Issue · Sales · Support · ENTERPRISE · CAREERS · STATUS
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2023 Coveralls, Inc