• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

keflavich / image_registration / 193

pending completion
193

push

travis-ci

web-flow
Merge pull request #20 from keflavich/refactor_convolve

Refactor to use astropy's convolution

17 of 17 new or added lines in 3 files covered. (100.0%)

475 of 1329 relevant lines covered (35.74%)

0.36 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

56.76
/image_registration/fft_tools/convolve_nd.py
1
import numpy as np
1✔
2
import warnings
1✔
3
import itertools
1✔
4
from astropy.tests.helper import pytest
1✔
5

6
try:
1✔
7
    import fftw3
1✔
8
    has_fftw = True
×
9

10
    def fftwn(array, nthreads=1):
×
11
        array = array.astype('complex').copy()
×
12
        outarray = array.copy()
×
13
        fft_forward = fftw3.Plan(array, outarray, direction='forward',
×
14
                                 flags=['estimate'], nthreads=nthreads)
15
        fft_forward.execute()
×
16
        return outarray
×
17

18
    def ifftwn(array, nthreads=1):
×
19
        array = array.astype('complex').copy()
×
20
        outarray = array.copy()
×
21
        fft_backward = fftw3.Plan(array, outarray, direction='backward',
×
22
                                  flags=['estimate'], nthreads=nthreads)
23
        fft_backward.execute()
×
24
        return outarray / np.size(array)
×
25
except ImportError:
26
    fftn = np.fft.fftn
27
    ifftn = np.fft.ifftn
28
    has_fftw = False
29
# I performed some fft speed tests and found that scipy is slower than numpy
30
# http://code.google.com/p/agpy/source/browse/trunk/tests/test_ffts.py However,
31
# the speed varied on machines - YMMV.  If someone finds that scipy's fft is
32
# faster, we should add that as an option here... not sure how exactly
33

34

35
__all__ = ['convolvend']
1✔
36

37

38
def convolvend(array, kernel, boundary='fill', fill_value=0, crop=True,
1✔
39
               return_fft=False, fftshift=True, fft_pad=True, psf_pad=False,
40
               interpolate_nan=False, quiet=False, ignore_edge_zeros=False,
41
               min_wt=0.0, normalize_kernel=False, use_numpy_fft=not has_fftw,
42
               nthreads=1):
43
    """
44
    Convolve an ndarray with an nd-kernel.  Returns a convolved image with shape =
45
    array.shape.  Assumes image & kernel are centered.
46

47
    Also note that the astropy.convolution convolver is a more up-to-date
48
    version of this one.
49

50
    Parameters
51
    ----------
52
    array: `numpy.ndarray`
53
          Array to be convolved with *kernel*
54
    kernel: `numpy.ndarray`
55
          Will be normalized if *normalize_kernel* is set.  Assumed to be
56
          centered (i.e., shifts may result if your kernel is asymmetric)
57
    boundary: str, optional
58
        A flag indicating how to handle boundaries:
59
            * 'fill' : set values outside the array boundary to fill_value
60
                       (default)
61
            * 'wrap' : periodic boundary
62
    interpolate_nan: bool
63
        attempts to re-weight assuming NAN values are meant to be ignored, not
64
        treated as zero.  If this is off, all NaN values will be treated as
65
        zero.
66
    ignore_edge_zeros: bool
67
        Ignore the zero-pad-created zeros.  This will effectively decrease
68
        the kernel area on the edges but will not re-normalize the kernel.
69
        This parameter may result in 'edge-brightening' effects if you're using
70
        a normalized kernel
71
    min_wt: float
72
        If ignoring NANs/zeros, force all grid points with a weight less than
73
        this value to NAN (the weight of a grid point with *no* ignored
74
        neighbors is 1.0).
75
        If `min_wt` == 0.0, then all zero-weight points will be set to zero
76
        instead of NAN (which they would be otherwise, because 1/0 = nan).
77
        See the examples below
78
    normalize_kernel: function or boolean
79
        if specified, function to divide kernel by to normalize it.  e.g.,
80
        normalize_kernel=np.sum means that kernel will be modified to be:
81
        kernel = kernel / np.sum(kernel).  If True, defaults to
82
        normalize_kernel = np.sum
83
    fft_pad: bool
84
        Default on.  Zero-pad image to the nearest 2^n
85
    psf_pad: bool
86
        Default off.  Zero-pad image to be at least the sum of the image sizes
87
        (in order to avoid edge-wrapping when smoothing)
88
    crop: bool
89
        Default on.  Return an image of the size of the largest input image.
90
        If the images are asymmetric in opposite directions, will return the
91
        largest image in both directions.
92
        For example, if an input image has shape [100,3] but a kernel with shape
93
        [6,6] is used, the output will be [100,6].
94
    return_fft: bool
95
        Return the fft(image)*fft(kernel) instead of the convolution (which is
96
        ifft(fft(image)*fft(kernel))).  Useful for making PSDs.
97
    fftshift: bool
98
        If return_fft on, will shift & crop image to appropriate dimensions
99
    nthreads: int
100
        if fftw3 is installed, can specify the number of threads to allow FFTs
101
        to use.  Probably only helpful for large arrays
102
    use_numpy_fft: bool
103
        Force the code to use the numpy FFTs instead of FFTW even if FFTW is
104
        installed
105

106
    Returns
107
    -------
108
    default: `array` convolved with `kernel`
109
    if return_fft: fft(`array`) * fft(`kernel`)
110
      * if fftshift: Determines whether the fft will be shifted before
111
        returning
112
    if not(`crop`) : Returns the image, but with the fft-padded size
113
        instead of the input size
114

115
    Examples
116
    --------
117
    >>> convolvend([1,0,3],[1,1,1])
118
    array([ 1.,  4.,  3.])
119

120
    >>> convolvend([1,np.nan,3],[1,1,1],quiet=True)
121
    array([ 1.,  4.,  3.])
122

123
    >>> convolvend([1,0,3],[0,1,0])
124
    array([ 1.,  0.,  3.])
125

126
    >>> convolvend([1,2,3],[1])
127
    array([ 1.,  2.,  3.])
128

129
    >>> convolvend([1,np.nan,3],[0,1,0], interpolate_nan=True)
130
    array([ 1.,  0.,  3.])
131

132
    >>> convolvend([1,np.nan,3],[0,1,0], interpolate_nan=True, min_wt=1e-8)
133
    array([  1.,  nan,   3.])
134

135
    >>> convolvend([1,np.nan,3],[1,1,1], interpolate_nan=True)
136
    array([ 1.,  4.,  3.])
137

138
    >>> convolvend([1,np.nan,3],[1,1,1], interpolate_nan=True, normalize_kernel=True, ignore_edge_zeros=True)
139
    array([ 1.,  2.,  3.])
140

141
    """
142

143

144
    # Checking copied from convolve.py - however, since FFTs have real &
145
    # complex components, we change the types.  Only the real part will be
146
    # returned!
147
    # Check that the arguments are lists or Numpy arrays
148
    array = np.asarray(array, dtype=np.complex)
1✔
149
    kernel = np.asarray(kernel, dtype=np.complex)
1✔
150

151
    # Check that the number of dimensions is compatible
152
    if array.ndim != kernel.ndim:
1✔
153
        raise Exception('array and kernel have differing number of'
×
154
                        'dimensions')
155

156
    # store the dtype for conversion back later
157
    array_dtype = array.dtype
1✔
158
    # turn the arrays into 'complex' arrays
159
    if array.dtype.kind != 'c':
1✔
160
        array = array.astype(np.complex)
×
161
    if kernel.dtype.kind != 'c':
1✔
162
        kernel = kernel.astype(np.complex)
×
163

164
    # mask catching - masks must be turned into NaNs for use later
165
    if np.ma.is_masked(array):
1✔
166
        mask = array.mask
×
167
        array = np.array(array)
×
168
        array[mask] = np.nan
×
169
    if np.ma.is_masked(kernel):
1✔
170
        mask = kernel.mask
×
171
        kernel = np.array(kernel)
×
172
        kernel[mask] = np.nan
×
173

174
    # replace fftn if has_fftw so that nthreads can be passed
175
    global fftn, ifftn
176
    if has_fftw and not use_numpy_fft:
1✔
177
        def fftn(*args, **kwargs):
×
178
            return fftwn(*args, nthreads=nthreads, **kwargs)
×
179

180
        def ifftn(*args, **kwargs):
×
181
            return ifftwn(*args, nthreads=nthreads, **kwargs)
×
182
    elif use_numpy_fft:
1✔
183
        fftn = np.fft.fftn
1✔
184
        ifftn = np.fft.ifftn
1✔
185

186

187
    # NAN catching
188
    nanmaskarray = (array != array)
1✔
189
    array[nanmaskarray] = 0
1✔
190
    nanmaskkernel = (kernel != kernel)
1✔
191
    kernel[nanmaskkernel] = 0
1✔
192
    if (((nanmaskarray.sum() > 0 or nanmaskkernel.sum() > 0) and not
1✔
193
         interpolate_nan and not quiet)):
194
        warnings.warn("NOT ignoring nan values even though they are present" +
×
195
                      " (they are treated as 0)")
196

197
    if normalize_kernel is True:
1✔
198
        kernel = kernel / kernel.sum()
1✔
199
        kernel_is_normalized = True
1✔
200
    elif normalize_kernel:
1✔
201
        # try this.  If a function is not passed, the code will just crash... I
202
        # think type checking would be better but PEPs say otherwise...
203
        kernel = kernel / normalize_kernel(kernel)
×
204
        kernel_is_normalized = True
×
205
    else:
206
        if np.abs(kernel.sum() - 1) < 1e-8:
1✔
207
            kernel_is_normalized = True
1✔
208
        else:
209
            kernel_is_normalized = False
1✔
210

211

212
    if boundary is None:
1✔
213
        WARNING = ("The convolvend version of boundary=None is equivalent" +
×
214
                   " to the convolve boundary='fill'.  There is no FFT " +
215
                   " equivalent to convolve's zero-if-kernel-leaves-boundary")
216
        warnings.warn(WARNING)
×
217
        psf_pad = True
×
218
    elif boundary == 'fill':
1✔
219
        # create a boundary region at least as large as the kernel
220
        psf_pad = True
1✔
221
    elif boundary == 'wrap':
×
222
        psf_pad = False
×
223
        fft_pad = False
×
224
        fill_value = 0 # force zero; it should not be used
×
225
    elif boundary == 'extend':
×
226
        raise NotImplementedError("The 'extend' option is not implemented " +
227
                "for fft-based convolution")
228

229
    arrayshape = array.shape
1✔
230
    kernshape = kernel.shape
1✔
231
    ndim = len(array.shape)
1✔
232
    if ndim != len(kernshape):
1✔
233
        raise ValueError("Image and kernel must " +
×
234
            "have same number of dimensions")
235
    # find ideal size (power of 2) for fft.
236
    # Can add shapes because they are tuples
237
    if fft_pad:
1✔
238
        if psf_pad:
1✔
239
            # add the dimensions and then take the max (bigger)
240
            fsize = 2**np.ceil(np.log2(
1✔
241
                np.max(np.array(arrayshape) + np.array(kernshape))))
242
        else:
243
            # add the shape lists (max of a list of length 4) (smaller)
244
            # also makes the shapes square
245
            fsize = 2**np.ceil(np.log2(np.max(arrayshape+kernshape)))
×
246
        newshape = np.array([fsize for ii in range(ndim)], dtype='int')
1✔
247
    else:
248
        if psf_pad:
×
249
            # just add the biggest dimensions
250
            newshape = np.array(arrayshape, dtype='int')+np.array(kernshape, dtype='int')
×
251
        else:
252
            newshape = np.array([np.max([imsh, kernsh])
×
253
                                 for imsh, kernsh in zip(arrayshape, kernshape)], dtype='int')
254

255

256
    # separate each dimension by the padding size...  this is to determine the
257
    # appropriate slice size to get back to the input dimensions
258
    arrayslices = []
1✔
259
    kernslices = []
1✔
260
    for ii, (newdimsize, arraydimsize, kerndimsize) in enumerate(zip(newshape, arrayshape, kernshape)):
1✔
261
        center = newdimsize - (newdimsize+1)//2
1✔
262
        arrayslices += [slice(center - arraydimsize//2,
1✔
263
                              center + (arraydimsize+1)//2)]
264
        kernslices += [slice(center - kerndimsize//2,
1✔
265
                             center + (kerndimsize+1)//2)]
266

267
    bigarray = np.ones(newshape, dtype=np.complex128) * fill_value
1✔
268
    bigkernel = np.zeros(newshape, dtype=np.complex128)
1✔
269
    bigarray[arrayslices] = array
1✔
270
    bigkernel[kernslices] = kernel
1✔
271
    arrayfft = fftn(bigarray)
1✔
272
    # need to shift the kernel so that, e.g., [0,0,1,0] -> [1,0,0,0] = unity
273
    kernfft = fftn(np.fft.ifftshift(bigkernel))
1✔
274
    fftmult = arrayfft*kernfft
1✔
275
    if (interpolate_nan or ignore_edge_zeros) and kernel_is_normalized:
1✔
276
        if ignore_edge_zeros:
1✔
277
            bigimwt = np.zeros(newshape, dtype=np.complex128)
1✔
278
        else:
279
            bigimwt = np.ones(newshape, dtype=np.complex128)
1✔
280
        bigimwt[arrayslices] = 1.0-nanmaskarray*interpolate_nan
1✔
281
        wtfft = fftn(bigimwt)
1✔
282
        # I think this one HAS to be normalized (i.e., the weights can't be
283
        # computed with a non-normalized kernel)
284
        wtfftmult = wtfft*kernfft/kernel.sum()
1✔
285
        wtsm = ifftn(wtfftmult)
1✔
286
        # need to re-zero weights outside of the image (if it is padded, we
287
        # still don't weight those regions)
288
        bigimwt[arrayslices] = wtsm.real[arrayslices]
1✔
289
        # curiously, at the floating-point limit, can get slightly negative numbers
290
        # they break the min_wt=0 "flag" and must therefore be removed
291
        bigimwt[bigimwt<0] = 0
1✔
292
    else:
293
        bigimwt = 1
1✔
294

295

296
    if np.isnan(fftmult).any():
1✔
297
        # this check should be unnecessary; call it an insanity check
298
        raise ValueError("Encountered NaNs in convolve.  This is disallowed.")
×
299

300
    # restore nans in original image (they were modified inplace earlier)
301
    # We don't have to worry about masked arrays - if input was masked, it was
302
    # copied
303
    array[nanmaskarray] = np.nan
1✔
304
    kernel[nanmaskkernel] = np.nan
1✔
305

306
    if return_fft:
1✔
307
        if fftshift: # default on
×
308
            if crop:
×
309
                return np.fft.fftshift(fftmult)[arrayslices]
×
310
            else:
311
                return np.fft.fftshift(fftmult)
×
312
        else:
313
            return fftmult
×
314

315
    if interpolate_nan or ignore_edge_zeros:
1✔
316
        rifft = (ifftn(fftmult)) / bigimwt
1✔
317
        if not np.isscalar(bigimwt):
1✔
318
            rifft[bigimwt < min_wt] = np.nan
1✔
319
            if min_wt == 0.0:
1✔
320
                rifft[bigimwt == 0.0] = 0.0
1✔
321
    else:
322
        rifft = (ifftn(fftmult))
1✔
323

324
    if crop:
1✔
325
        result = rifft[arrayslices].real
1✔
326
        return result
1✔
327
    else:
328
        return rifft.real
×
329

330

331
params = list(itertools.product((True,False),(True,False),(True,False)))
1✔
332
@pytest.mark.parametrize(('psf_pad','use_numpy_fft','force_ignore_zeros_off'),params)
1✔
333
def test_3d(psf_pad, use_numpy_fft, force_ignore_zeros_off, debug=False, tolerance=1e-17):
1✔
334
    array = np.zeros([32,32,32])
×
335
    array[15,15,15]=1
×
336
    array[15,0,15]=1
×
337
    kern = np.zeros([32,32,32])
×
338
    kern[14:19,14:19,14:19] = 1
×
339

340
    conv1 = convolvend(array, kern, psf_pad=psf_pad, force_ignore_zeros_off=force_ignore_zeros_off, debug=debug)
×
341

342
    print("psf_pad=%s use_numpy=%s force_ignore_zeros_off=%s" % (psf_pad, use_numpy_fft, force_ignore_zeros_off))
×
343
    print("side,center: %g,%g" % (conv1[15,0,15],conv1[15,15,15]))
×
344
    if force_ignore_zeros_off or not psf_pad:
×
345
        assert(np.abs(conv1[15,0,15] - 1./125.) < tolerance)
×
346
        assert(np.abs(conv1[15,1,15] - 1./125.) < tolerance)
×
347
        assert(np.abs(conv1[15,15,15] - 1./125.) < tolerance)
×
348
    else:
349
        assert(np.abs(conv1[15,0,15] - 1./75.) < tolerance)
×
350
        assert(np.abs(conv1[15,1,15] - 1./100.) < tolerance)
×
351
        assert(np.abs(conv1[15,15,15] - 1./125.) < tolerance)
×
352

353

STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2024 Coveralls, Inc