From 591dd8a26c4269c328c8747c62ffa39c9245d3e3 Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Fri, 26 Jul 2019 01:28:47 -0400 Subject: [PATCH 1/2] ENH: cwt runs in single precision for single precision inputs --- pywt/_cwt.py | 16 ++++++++++------ pywt/tests/test_cwt_wavelets.py | 30 +++++++++++++++++++----------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/pywt/_cwt.py b/pywt/_cwt.py index 6a1e2b18..58d1aceb 100644 --- a/pywt/_cwt.py +++ b/pywt/_cwt.py @@ -106,19 +106,23 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv'): # accept array_like input; make a copy to ensure a contiguous array dt = _check_dtype(data) - data = np.array(data, dtype=dt) + data = np.asarray(data, dtype=dt) + dt_cplx = np.result_type(dt, np.complex64) if not isinstance(wavelet, (ContinuousWavelet, Wavelet)): wavelet = DiscreteContinuousWavelet(wavelet) if np.isscalar(scales): scales = np.array([scales]) - dt_out = None # TODO: fix in/out dtype consistency in a subsequent PR if data.ndim == 1: - if wavelet.complex_cwt: - dt_out = complex + dt_out = dt_cplx if wavelet.complex_cwt else dt out = np.empty((np.size(scales), data.size), dtype=dt_out) precision = 10 int_psi, x = integrate_wavelet(wavelet, precision=precision) + # convert int_psi, x to the same precision as the data + dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt + int_psi = np.asarray(int_psi, dtype=dt_psi) + x = np.asarray(x, dtype=data.real.dtype) + if method == 'fft': size_scale0 = -1 fft_data = None @@ -150,8 +154,8 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv'): conv = conv[:data.size + int_psi_scale.size - 1] coef = - np.sqrt(scale) * np.diff(conv) - if not np.iscomplexobj(out): - coef = np.real(coef) + if out.dtype.kind != 'c': + coef = coef.real d = (coef.size - data.size) / 2. if d > 0: out[i, :] = coef[floor(d):-ceil(d)] diff --git a/pywt/tests/test_cwt_wavelets.py b/pywt/tests/test_cwt_wavelets.py index e2cafcb2..acdc8653 100644 --- a/pywt/tests/test_cwt_wavelets.py +++ b/pywt/tests/test_cwt_wavelets.py @@ -2,7 +2,7 @@ from __future__ import division, print_function, absolute_import from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal, - assert_raises) + assert_raises, assert_equal) import numpy as np import pywt @@ -345,20 +345,28 @@ def test_cwt_parameters_in_names(): def test_cwt_complex(): - for dtype in [np.float32, np.float64]: + for dtype, tol in [(np.float32, 1e-5), (np.float64, 1e-13)]: time, sst = pywt.data.nino() sst = np.asarray(sst, dtype=dtype) dt = time[1] - time[0] wavelet = 'cmor1.5-1.0' scales = np.arange(1, 32) - # real-valued tranfsorm - [cfs, f] = pywt.cwt(sst, scales, wavelet, dt) + for method in ['conv', 'fft']: + # real-valued tranfsorm as a reference + [cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method) - # complex-valued tranfsorm equals sum of the transforms of the real and - # imaginary components - [cfs_complex, f] = pywt.cwt(sst + 1j*sst, scales, wavelet, dt) - assert_almost_equal(cfs + 1j*cfs, cfs_complex) + # verify same precision + assert_equal(cfs.real.dtype, sst.dtype) + + # complex-valued transform equals sum of the transforms of the real + # and imaginary components + sst_complex = sst + 1j*sst + [cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt, + method=method) + assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol) + # verify dtype is preserved + assert_equal(cfs_complex.dtype, sst_complex.dtype) def test_cwt_small_scales(): @@ -377,12 +385,12 @@ def test_cwt_method_fft(): rstate = np.random.RandomState(1) data = rstate.randn(50) data[15] = 1. - scales = np.arange(1, 64) - wavelet = 'cmor1.5-1.0' + scales = np.arange(1, 64) + wavelet = 'cmor1.5-1.0' # build a reference cwt with the legacy np.conv() method cfs_conv, _ = pywt.cwt(data, scales, wavelet, method='conv') # compare with the fft based convolution - cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft') + cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft') assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13) From fb4b030f4a06c63bebfdf10075538c222518bcfa Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Fri, 26 Jul 2019 01:40:34 -0400 Subject: [PATCH 2/2] add float32 cases to cwt benchmarks --- benchmarks/benchmarks/cwt_benchmarks.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmarks/cwt_benchmarks.py b/benchmarks/benchmarks/cwt_benchmarks.py index 0c063b3b..f02b6333 100644 --- a/benchmarks/benchmarks/cwt_benchmarks.py +++ b/benchmarks/benchmarks/cwt_benchmarks.py @@ -9,20 +9,22 @@ class CwtTimeSuiteBase(object): params = ([32, 128, 512, 2048], ['cmor', 'cgau4', 'fbsp', 'gaus4', 'mexh', 'morl', 'shan'], [16, 64, 256], - ['conv', 'fft']) - param_names = ('n', 'wavelet', 'max_scale', 'method') + [np.float32, np.float64], + ['conv', 'fft'], + ) + param_names = ('n', 'wavelet', 'max_scale', 'dtype', 'method') - def setup(self, n, wavelet, max_scale, method): + def setup(self, n, wavelet, max_scale, dtype, method): try: from pywt import cwt except ImportError: raise NotImplementedError("cwt not available") - self.data = np.ones(n, dtype='float') - self.scales = np.arange(1, max_scale+1) + self.data = np.ones(n, dtype=dtype) + self.scales = np.arange(1, max_scale + 1) class CwtTimeSuite(CwtTimeSuiteBase): - def time_cwt(self, n, wavelet, max_scale, method): + def time_cwt(self, n, wavelet, max_scale, dtype, method): try: pywt.cwt(self.data, self.scales, wavelet, method=method) except TypeError: