From de2ee36345493bb15a5347ebbb3c01d8cd5eb10c Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Thu, 14 Feb 2019 14:32:09 -0500 Subject: [PATCH 1/2] fix iswtn axis-specific transform bug --- pywt/_swt.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pywt/_swt.py b/pywt/_swt.py index ea29f3ff..472c2ec2 100644 --- a/pywt/_swt.py +++ b/pywt/_swt.py @@ -516,16 +516,19 @@ def iswtn(coeffs, wavelet, axes=None): [dt, ] + [v.dtype for v in details.values()])) if output.dtype != common_dtype: output = output.astype(common_dtype) + # We assume all coefficient arrays are of equal size shapes = [v.shape for k, v in details.items()] - dshape = shapes[0] if len(set(shapes)) != 1: raise RuntimeError( "Mismatch in shape of intermediate coefficient arrays") + # shape of a single coefficient array, excluding non-transformed axes + coeff_trans_shape = tuple([shapes[0][ax] for ax in axes]) + # nested loop over all combinations of axis offsets at this level for firsts in product(*([range(last_index), ]*ndim_transform)): - for first, sh, ax in zip(firsts, dshape, axes): + for first, sh, ax in zip(firsts, coeff_trans_shape, axes): indices[ax] = slice(first, sh, step_size) even_indices[ax] = slice(first, sh, 2*step_size) odd_indices[ax] = slice(first+step_size, sh, 2*step_size) From ac54fc753ca0cd7feaf81207ffe5a1acf68f6380 Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Thu, 14 Feb 2019 14:55:36 -0500 Subject: [PATCH 2/2] TST: add tests for round-trip swtn/iswtn with non-uniform shape --- pywt/tests/test_swt.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/pywt/tests/test_swt.py b/pywt/tests/test_swt.py index cf9fc32e..0dfd9e01 100644 --- a/pywt/tests/test_swt.py +++ b/pywt/tests/test_swt.py @@ -4,7 +4,7 @@ import warnings from copy import deepcopy -from itertools import combinations +from itertools import combinations, permutations import numpy as np from numpy.testing import (run_module_suite, dec, assert_allclose, assert_, assert_equal, assert_raises, assert_array_equal, @@ -12,7 +12,6 @@ import pywt from pywt._extensions._swt import swt_axis -from pywt._extensions._pywt import _check_dtype # Check that float32 and complex64 are preserved. Other real types get # converted to float64. @@ -387,6 +386,21 @@ def test_iswtn_errors(): assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes) +def test_swtn_iswtn_unique_shape_per_axis(): + # test case for gh-460 + _shape = (1, 48, 32) # unique shape per axis + wav = 'sym2' + max_level = 3 + rstate = np.random.RandomState(0) + for shape in permutations(_shape): + # transform only along the non-singleton axes + axes = [ax for ax, s in enumerate(shape) if s != 1] + x = rstate.standard_normal(shape) + c = pywt.swtn(x, wav, max_level, axes=axes) + r = pywt.iswtn(c, wav, axes=axes) + assert_allclose(x, r, rtol=1e-10, atol=1e-10) + + def test_per_axis_wavelets(): # tests seperate wavelet for each axis. rstate = np.random.RandomState(1234)