Skip to content

Commit

Permalink
Merge pull request #462 from grlee77/iswtn_axis_fix
Browse files Browse the repository at this point in the history
MAINT: fix bug in iswtn for data of arbitrary shape when using user-specified axes
  • Loading branch information
rgommers committed Feb 18, 2019
2 parents dbdd4be + ac54fc7 commit 2e3a224
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
7 changes: 5 additions & 2 deletions pywt/_swt.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,16 +516,19 @@ def iswtn(coeffs, wavelet, axes=None):
[dt, ] + [v.dtype for v in details.values()]))
if output.dtype != common_dtype:
output = output.astype(common_dtype)

# We assume all coefficient arrays are of equal size
shapes = [v.shape for k, v in details.items()]
dshape = shapes[0]
if len(set(shapes)) != 1:
raise RuntimeError(
"Mismatch in shape of intermediate coefficient arrays")

# shape of a single coefficient array, excluding non-transformed axes
coeff_trans_shape = tuple([shapes[0][ax] for ax in axes])

# nested loop over all combinations of axis offsets at this level
for firsts in product(*([range(last_index), ]*ndim_transform)):
for first, sh, ax in zip(firsts, dshape, axes):
for first, sh, ax in zip(firsts, coeff_trans_shape, axes):
indices[ax] = slice(first, sh, step_size)
even_indices[ax] = slice(first, sh, 2*step_size)
odd_indices[ax] = slice(first+step_size, sh, 2*step_size)
Expand Down
18 changes: 16 additions & 2 deletions pywt/tests/test_swt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@

import warnings
from copy import deepcopy
from itertools import combinations
from itertools import combinations, permutations
import numpy as np
from numpy.testing import (run_module_suite, dec, assert_allclose, assert_,
assert_equal, assert_raises, assert_array_equal,
assert_warns)

import pywt
from pywt._extensions._swt import swt_axis
from pywt._extensions._pywt import _check_dtype

# Check that float32 and complex64 are preserved. Other real types get
# converted to float64.
Expand Down Expand Up @@ -387,6 +386,21 @@ def test_iswtn_errors():
assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)


def test_swtn_iswtn_unique_shape_per_axis():
# test case for gh-460
_shape = (1, 48, 32) # unique shape per axis
wav = 'sym2'
max_level = 3
rstate = np.random.RandomState(0)
for shape in permutations(_shape):
# transform only along the non-singleton axes
axes = [ax for ax, s in enumerate(shape) if s != 1]
x = rstate.standard_normal(shape)
c = pywt.swtn(x, wav, max_level, axes=axes)
r = pywt.iswtn(c, wav, axes=axes)
assert_allclose(x, r, rtol=1e-10, atol=1e-10)


def test_per_axis_wavelets():
# tests seperate wavelet for each axis.
rstate = np.random.RandomState(1234)
Expand Down

0 comments on commit 2e3a224

Please sign in to comment.