diff --git a/.coveragerc b/.coveragerc index b6fef2d7e..dd895487f 100644 --- a/.coveragerc +++ b/.coveragerc @@ -6,6 +6,8 @@ omit = */version.py */pywt/tests/* */pywt/_doc_utils.py* + */pywt/_pytesttester.py* + */pywt/_pytest.py* */pywt/data/create_dat.py *.pxd stringsource diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..2c36646e3 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +tidelift: "pypi/PyWavelets" diff --git a/.travis.yml b/.travis.yml index 0965345f5..47ec9d214 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,12 +19,15 @@ matrix: - CYTHONSPEC=cython - USE_WHEEL=1 - os: linux - python: 3.7-dev + python: 3.7 + dist: xenial # travis-ci/travis-ci/issues/9815 + sudo: true env: - NUMPYSPEC=numpy - MATPLOTLIBSPEC=matplotlib - CYTHONSPEC=cython - USE_SDIST=1 + - USE_SCIPY=1 - os: linux python: 3.5 env: @@ -59,6 +62,7 @@ before_install: - pip install pytest pytest-cov coverage codecov futures - set -o pipefail - if [ "${USE_WHEEL}" == "1" ]; then pip install wheel; fi + - if [ "${USE_SCIPY}" == "1" ]; then pip install scipy; fi - | if [ "${REFGUIDE_CHECK}" == "1" ]; then pip install sphinx numpydoc @@ -90,7 +94,8 @@ script: CFLAGS="--coverage" python setup.py build --build-lib build/lib/ --build-temp build/tmp/ CFLAGS="--coverage" pip install -e . -v pushd demo - pytest --pyargs pywt --cov=pywt + pytest --pyargs pywt --cov=pywt --cov-config=../.coveragerc + cp .coverage .. popd fi diff --git a/LICENSE b/LICENSE index 47b60f4b6..c01d7d721 100644 --- a/LICENSE +++ b/LICENSE @@ -18,15 +18,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - - -The PyWavelets repository and source distributions bundle some code that is -adapted from compatibly licensed projects. We list these here. - -Name: NumPy -Files: pywt/_pytesttester.py -License: 3-clause BSD - -Name: SciPy -Files: setup.py, util/* -License: 3-clause BSD diff --git a/LICENSES_bundled.txt b/LICENSES_bundled.txt new file mode 100644 index 000000000..6b2ab7a02 --- /dev/null +++ b/LICENSES_bundled.txt @@ -0,0 +1,10 @@ +The PyWavelets repository and source distributions bundle some code that is +adapted from compatibly licensed projects. We list these here. + +Name: NumPy +Files: pywt/_pytesttester.py +License: 3-clause BSD + +Name: SciPy +Files: setup.py, util/* +License: 3-clause BSD diff --git a/MANIFEST.in b/MANIFEST.in index 0bfae58e3..7331f6d30 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -15,7 +15,7 @@ recursive-include demo * include cythonize.dat # Add build and testing tools -include tox.ini +include tox.ini pytest.ini recursive-include util * # Exclude what we don't want to include diff --git a/README.rst b/README.rst index 056cd008a..3083036a4 100644 --- a/README.rst +++ b/README.rst @@ -65,9 +65,11 @@ For more usage examples see the `demo`_ directory in the source package. Installation ------------ -PyWavelets supports `Python`_ >=3.5, and is only dependent on `Numpy`_ +PyWavelets supports `Python`_ >=3.5, and is only dependent on `NumPy`_ (supported versions are currently ``>= 1.13.3``). To pass all of the tests, -`Matplotlib`_ is also required. +`Matplotlib`_ is also required. `SciPy`_ is also an optional dependency. When +present, FFT-based continuous wavelet transforms will use FFTs from SciPy +rather than NumPy. There are binary wheels for Intel Linux, Windows and macOS / OSX on PyPi. If you are on one of these platforms, you should get a binary (precompiled) @@ -116,12 +118,17 @@ All contributions including bug reports, bug fixes, new feature implementations and documentation improvements are welcome. Moreover, developers with an interest in PyWavelets are very welcome to join the development team! +As of 2019, PyWavelets development is supported in part by Tidelift. +`Help support PyWavelets with the Tidelift Subscription `_ + Contact ------- Use `GitHub Issues`_ or the `mailing list`_ to post your comments or questions. +**Report a security vulnerability:** https://tidelift.com/security + License ------- @@ -146,7 +153,8 @@ the link in the badge below to Zenodo: .. _Anaconda: https://www.continuum.io .. _GitHub: https://github.com/PyWavelets/pywt .. _GitHub Issues: https://github.com/PyWavelets/pywt/issues -.. _Numpy: http://www.numpy.org +.. _NumPy: https://www.numpy.org +.. _SciPy: https://www.scipy.org .. _original developer: http://en.ig.ma .. _Python: http://python.org/ .. _Python Package Index: http://pypi.python.org/pypi/PyWavelets/ diff --git a/benchmarks/README.rst b/benchmarks/README.rst index 1572fd4c4..b24123180 100644 --- a/benchmarks/README.rst +++ b/benchmarks/README.rst @@ -32,7 +32,7 @@ To record the results use: asv publish -And to see the results via a web broweser, run: +And to see the results via a web browser, run: asv preview diff --git a/benchmarks/benchmarks/cwt_benchmarks.py b/benchmarks/benchmarks/cwt_benchmarks.py index cf9cacd35..eda4e4f2e 100644 --- a/benchmarks/benchmarks/cwt_benchmarks.py +++ b/benchmarks/benchmarks/cwt_benchmarks.py @@ -6,20 +6,40 @@ class CwtTimeSuiteBase(object): """ Set-up for CWT timing. """ - params = ([32, 128, 512], + params = ([32, 128, 512, 2048], ['cmor', 'cgau4', 'fbsp', 'gaus4', 'mexh', 'morl', 'shan'], - [16, 64, 256]) - param_names = ('n', 'wavelet', 'max_scale') + [16, 64, 256], + [np.float32, np.float64], + ['conv', 'fft'], + ) + param_names = ('n', 'wavelet', 'max_scale', 'dtype', 'method') - def setup(self, n, wavelet, max_scale): + def setup(self, n, wavelet, max_scale, dtype, method): try: from pywt import cwt except ImportError: raise NotImplementedError("cwt not available") - self.data = np.ones(n, dtype='float') - self.scales = np.arange(1, max_scale+1) + self.data = np.ones(n, dtype=dtype) + self.batch_data = np.ones((5, n), dtype=dtype) + self.scales = np.arange(1, max_scale + 1) class CwtTimeSuite(CwtTimeSuiteBase): - def time_cwt(self, n, wavelet, max_scale): - pywt.cwt(self.data, self.scales, wavelet) + def time_cwt(self, n, wavelet, max_scale, dtype, method): + try: + pywt.cwt(self.data, self.scales, wavelet, method=method) + except TypeError: + # older PyWavelets does not support use of the method argument + if method == 'fft': + raise NotImplementedError( + "fft-based convolution not available.") + pywt.cwt(self.data, self.scales, wavelet) + + def time_cwt_batch(self, n, wavelet, max_scale, dtype, method): + try: + pywt.cwt(self.batch_data, self.scales, wavelet, method=method, + axis=-1) + except TypeError: + # older PyWavelets does not support the axis argument + raise NotImplementedError( + "axis argument not available.") diff --git a/codecov.yml b/codecov.yml index 915a21659..14cd53ae3 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,2 +1,6 @@ ignore: - "_doc_utils.py" # utilities only used for creating documentation figures + - "_pytest*.py" # pytest test utilities + - "create_dat.py" # raw data creation script + - "version.py" # generated by setup.py + diff --git a/demo/swt_variance.py b/demo/swt_variance.py new file mode 100644 index 000000000..738447559 --- /dev/null +++ b/demo/swt_variance.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python + +import numpy as np +import matplotlib.pyplot as plt + +import pywt +import pywt.data + +ecg = pywt.data.ecg() + +# set trim_approx to avoid keeping approximation coefficients for all levels + +# set norm=True to rescale the wavelets so that the transform partitions the +# variance of the input signal among the various coefficient arrays. + +coeffs = pywt.swt(ecg, wavelet='sym4', trim_approx=True, norm=True) +ca = coeffs[0] +details = coeffs[1:] + +print("Variance of the ecg signal = {}".format(np.var(ecg, ddof=1))) + +variances = [np.var(c, ddof=1) for c in coeffs] +detail_variances = variances[1:] +print("Sum of variance across all SWT coefficients = {}".format( + np.sum(variances))) + +# Create a plot using the same y axis limits for all coefficient arrays to +# illustrate the preservation of amplitude scale across levels when norm=True. +ylim = [ecg.min(), ecg.max()] + +fig, axes = plt.subplots(len(coeffs) + 1) +axes[0].set_title("normalized SWT decomposition") +axes[0].plot(ecg) +axes[0].set_ylabel('ECG Signal') +axes[0].set_xlim(0, len(ecg) - 1) +axes[0].set_ylim(ylim[0], ylim[1]) + +for i, x in enumerate(coeffs): + ax = axes[-i - 1] + ax.plot(coeffs[i], 'g') + if i == 0: + ax.set_ylabel("A%d" % (len(coeffs) - 1)) + else: + ax.set_ylabel("D%d" % (len(coeffs) - i)) + # Scale axes + ax.set_xlim(0, len(ecg) - 1) + ax.set_ylim(ylim[0], ylim[1]) + + +# reorder from first to last level of coefficients +level = np.arange(1, len(detail_variances) + 1) + +# create a plot of the variance as a function of level +plt.figure(figsize=(8, 6)) +fontdict = dict(fontsize=16, fontweight='bold') +plt.plot(level, detail_variances[::-1], 'k.') +plt.xlabel("Decomposition level", fontdict=fontdict) +plt.ylabel("Variance", fontdict=fontdict) +plt.title("Variances of detail coefficients", fontdict=fontdict) +plt.show() diff --git a/doc/release/1.1.0-notes.rst b/doc/release/1.1.0-notes.rst index 29f3bea3f..a3af908db 100644 --- a/doc/release/1.1.0-notes.rst +++ b/doc/release/1.1.0-notes.rst @@ -17,9 +17,25 @@ Deprecated features Backwards incompatible changes ============================== +When using complex-valued wavelets with the ``cwt``, the output will now be +the complex conjugate of the result that was produced by PyWavelets 1.0.x. +This was done to account for a bug described below. The magnitude of the +``cwt`` coefficients will still match those from previous releases. + Bugs Fixed ========== +For a ``cwt`` with complex wavelets, the results in PyWavelets 1.0.x releases +matched the output of Matlab R2012a's ``cwt``. Howveer, older Matlab releases +like R2012a had a phase that was of opposite sign to that given in textbook +definitions of the CWT (Eq. 2 of Torrence and Compo's review article, "A +Practical Guide to Wavelet Analysis"). Consequently, the wavelet coefficients +were the complex conjugates of the expected result. This was validated by +comparing the results of a transform using ``cmor1.0-1.0`` as compared to the +``cwt`` implementation available in Matlab R2017b as well as the function +``wt.m`` from the Lancaster University Physics department's +`MODA toolbox `_ + Other changes ============= diff --git a/doc/source/common_refs.rst b/doc/source/common_refs.rst index a8331b221..8fe04b957 100644 --- a/doc/source/common_refs.rst +++ b/doc/source/common_refs.rst @@ -5,7 +5,8 @@ .. _GitHub: https://github.com/PyWavelets/pywt .. _GitHub repository: https://github.com/PyWavelets/pywt .. _GitHub Issues: https://github.com/PyWavelets/pywt/issues -.. _Numpy: http://www.numpy.org +.. _NumPy: https://www.numpy.org +.. _SciPy: https://www.scipy.org .. _original developer: http://en.ig.ma .. _Python: http://python.org/ .. _Python Package Index: http://pypi.python.org/pypi/PyWavelets/ diff --git a/doc/source/conf.py b/doc/source/conf.py index 5254eb8e8..b830351c5 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -36,6 +36,7 @@ extensions = ['sphinx.ext.doctest', 'sphinx.ext.autodoc', 'sphinx.ext.todo', 'sphinx.ext.extlinks', 'sphinx.ext.mathjax', 'sphinx.ext.autosummary', 'numpydoc', + 'sphinx.ext.intersphinx', 'matplotlib.sphinxext.plot_directive'] # Add any paths that contain templates here, relative to this directory. @@ -224,3 +225,9 @@ plot_formats = [('png', 96), 'pdf'] plot_html_show_formats = False plot_html_show_source_link = False + +# -- Options for intersphinx extension --------------------------------------- + +# Intersphinx to get Numpy and other targets +intersphinx_mapping = { + 'numpy': ('https://docs.scipy.org/doc/numpy/', None)} diff --git a/doc/source/index.rst b/doc/source/index.rst index bd4f77f5b..bced5c02e 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -51,12 +51,22 @@ Citing ------ If you use PyWavelets in a scientific publication, we would appreciate -citations of the project: +citations of the project via the following +JOSS publication: - Lee G, Gommers R, Wasilewski F, Wohlfahrt K, O'Leary A, Nahrstaedt H, - and Contributors, "PyWavelets - Wavelet Transforms in Python", 2006-, - https://github.com/PyWavelets/pywt [Online; accessed 2018-MM-DD]. + Gregory R. Lee, Ralf Gommers, Filip Wasilewski, Kai Wohlfahrt, Aaron + O'Leary (2019). PyWavelets: A Python package for wavelet analysis. Journal + of Open Source Software, 4(36), 1237, https://doi.org/10.21105/joss.01237. +.. image:: http://joss.theoj.org/papers/10.21105/joss.01237/status.svg + :target: https://doi.org/10.21105/joss.01237 + +Specific releases can also be cited via Zenodo. The DOI below will correspond +to the most recent release. DOIs for past versions can be found by following +the link in the badge below to Zenodo: + +.. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.1407171.svg + :target: https://doi.org/10.5281/zenodo.1407171 Contents -------- diff --git a/doc/source/install.rst b/doc/source/install.rst index 7a82d00aa..f880a8ea8 100644 --- a/doc/source/install.rst +++ b/doc/source/install.rst @@ -39,11 +39,12 @@ PyWavelets source code directory (containing ``setup.py``) and type:: The requirements needed to build from source are: - Python_ 2.7 or >=3.4 - - Numpy_ >= 1.13.3 + - NumPy_ >= 1.13.3 - Cython_ >= 0.23.5 (if installing from git, not from a PyPI source release) To run all the tests for PyWavelets, you will also need to install the -Matplotlib_ package. +Matplotlib_ package. If SciPy_ is available, FFT-based continuous wavelet +transforms will use the FFT implementation from SciPy instead of NumPy. .. seealso:: :ref:`Development guide ` section contains more information on building and installing from source code. diff --git a/doc/source/pyplots/plot_boundary_modes.py b/doc/source/pyplots/plot_boundary_modes.py index cbb5f6eaa..940822f76 100644 --- a/doc/source/pyplots/plot_boundary_modes.py +++ b/doc/source/pyplots/plot_boundary_modes.py @@ -5,7 +5,7 @@ In practice, which signal extension mode is beneficial will depend on the signal characteristics. For this particular signal, some modes such as -"periodic", "antisymmetric" and "zeros" result in large discontinuities that +"periodic", "antisymmetric" and "zero" result in large discontinuities that would lead to large amplitude boundary coefficients in the detail coefficients of a discrete wavelet transform. """ @@ -28,5 +28,5 @@ boundary_mode_subplot(x, 'periodization', axes[5], symw=False) boundary_mode_subplot(x, 'smooth', axes[6], symw=False) boundary_mode_subplot(x, 'constant', axes[7], symw=False) -boundary_mode_subplot(x, 'zeros', axes[8], symw=False) +boundary_mode_subplot(x, 'zero', axes[8], symw=False) plt.show() diff --git a/doc/source/ref/cwt.rst b/doc/source/ref/cwt.rst index 9a94f3b6c..c7ebe8e76 100644 --- a/doc/source/ref/cwt.rst +++ b/doc/source/ref/cwt.rst @@ -87,10 +87,10 @@ correspond to the following wavelets: .. math:: \psi(t) = \sqrt{B} \frac{\sin(\pi B t)}{\pi B t} \exp^{\mathrm{j}2 \pi C t} -where :math:`B` is the bandwith and :math:`C` is the center frequency. +where :math:`B` is the bandwidth and :math:`C` is the center frequency. -Freuqency B-Spline Wavelets +Frequency B-Spline Wavelets ^^^^^^^^^^^^^^^^^^^^^^^^^^^ The frequency B-spline wavelets (``"fpspM-B-C"`` with integer M and floating point B, C) correspond to the following wavelets: diff --git a/doc/source/ref/signal-extension-modes.rst b/doc/source/ref/signal-extension-modes.rst index f8e466bb4..e05f2cce1 100644 --- a/doc/source/ref/signal-extension-modes.rst +++ b/doc/source/ref/signal-extension-modes.rst @@ -136,3 +136,15 @@ periodization per N/A antisymmetric asym, asymh N/A antireflect asymw reflect, reflect_type='odd' ================== ============= =========================== + +Padding using PyWavelets Signal Extension Modes - ``pad`` +--------------------------------------------------------- + +.. autofunction:: pad + +Pywavelets provides a function, :func:`pad`, that operate like +:func:`numpy.pad`, but supporting the PyWavelets signal extension modes +discussed above. For efficiency, the DWT routines in PyWavelets do not +expclitly create padded signals using this function. It can be used to manually +prepad signals to reduce boundary effects in functions such as :func:`cwt` and +:func:`swt` that do not currently support all of these signal extension modes. diff --git a/pywt/_cwt.py b/pywt/_cwt.py index a4e6ca536..a47cf9885 100644 --- a/pywt/_cwt.py +++ b/pywt/_cwt.py @@ -1,13 +1,40 @@ -import numpy as np +from math import floor, ceil from ._extensions._pywt import (DiscreteContinuousWavelet, ContinuousWavelet, Wavelet, _check_dtype) from ._functions import integrate_wavelet, scale2frequency + __all__ = ["cwt"] -def cwt(data, scales, wavelet, sampling_period=1.): +import numpy as np + +try: + # Prefer scipy.fft (new in SciPy 1.4) + import scipy.fft + fftmodule = scipy.fft + next_fast_len = fftmodule.next_fast_len +except ImportError: + try: + import scipy.fftpack + fftmodule = scipy.fftpack + next_fast_len = fftmodule.next_fast_len + except ImportError: + fftmodule = np.fft + + # provide a fallback so scipy is an optional requirement + def next_fast_len(n): + """Round up size to the nearest power of two. + + Given a number of samples `n`, returns the next power of two + following this number to take advantage of FFT speedup. + This fallback is less efficient than `scipy.fftpack.next_fast_len` + """ + return 2**ceil(np.log2(n)) + + +def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): """ cwt(data, scales, wavelet) @@ -19,7 +46,7 @@ def cwt(data, scales, wavelet, sampling_period=1.): Input signal scales : array_like The wavelet scales to use. One can use - ``f = scale2frequency(scale, wavelet)/sampling_period`` to determine + ``f = scale2frequency(wavelet, scale)/sampling_period`` to determine what physical frequency, ``f``. Here, ``f`` is in hertz when the ``sampling_period`` is given in seconds. wavelet : Wavelet object or name @@ -29,12 +56,27 @@ def cwt(data, scales, wavelet, sampling_period=1.): The values computed for ``coefs`` are independent of the choice of ``sampling_period`` (i.e. ``scales`` is not scaled by the sampling period). + method : {'conv', 'fft'}, optional + The method used to compute the CWT. Can be any of: + - ``conv`` uses ``numpy.convolve``. + - ``fft`` uses frequency domain convolution. + - ``auto`` uses automatic selection based on an estimate of the + computational complexity at each scale. + + The ``conv`` method complexity is ``O(len(scale) * len(data))``. + The ``fft`` method is ``O(N * log2(N))`` with + ``N = len(scale) + len(data) - 1``. It is well suited for large size + signals but slightly slower than ``conv`` on small ones. + axis: int, optional + Axis over which to compute the CWT. If not given, the last axis is + used. Returns ------- coefs : array_like Continuous wavelet transform of the input signal for the given scales - and wavelet + and wavelet. The first axis of ``coefs`` corresponds to the scales. + The remaining axes match the shape of ``data``. frequencies : array_like If the unit of sampling period are seconds and given, than frequencies are in hertz. Otherwise, a sampling period of 1 is assumed. @@ -69,39 +111,93 @@ def cwt(data, scales, wavelet, sampling_period=1.): # accept array_like input; make a copy to ensure a contiguous array dt = _check_dtype(data) - data = np.array(data, dtype=dt) + data = np.asarray(data, dtype=dt) + dt_cplx = np.result_type(dt, np.complex64) if not isinstance(wavelet, (ContinuousWavelet, Wavelet)): wavelet = DiscreteContinuousWavelet(wavelet) if np.isscalar(scales): scales = np.array([scales]) - if data.ndim == 1: - if wavelet.complex_cwt: - out = np.zeros((np.size(scales), data.size), dtype=complex) - else: - out = np.zeros((np.size(scales), data.size)) - precision = 10 - int_psi, x = integrate_wavelet(wavelet, precision=precision) - for i in np.arange(np.size(scales)): - step = x[1] - x[0] - j = np.floor( - np.arange(scales[i] * (x[-1] - x[0]) + 1) / (scales[i] * step)) - if np.max(j) >= np.size(int_psi): - j = np.delete(j, np.where((j >= np.size(int_psi)))[0]) - coef = - np.sqrt(scales[i]) * np.diff( - np.convolve(data, int_psi[j.astype(np.int)][::-1])) - d = (coef.size - data.size) / 2. - if d > 0: - out[i, :] = coef[int(np.floor(d)):int(-np.ceil(d))] - elif d == 0.: - out[i, :] = coef + if not np.isscalar(axis): + raise ValueError("axis must be a scalar.") + + dt_out = dt_cplx if wavelet.complex_cwt else dt + out = np.empty((np.size(scales),) + data.shape, dtype=dt_out) + precision = 10 + int_psi, x = integrate_wavelet(wavelet, precision=precision) + int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi + + # convert int_psi, x to the same precision as the data + dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt + int_psi = np.asarray(int_psi, dtype=dt_psi) + x = np.asarray(x, dtype=data.real.dtype) + + if method == 'fft': + size_scale0 = -1 + fft_data = None + elif not method == 'conv': + raise ValueError("method must be 'conv' or 'fft'") + + if data.ndim > 1: + # move axis to be transformed last (so it is contiguous) + data = data.swapaxes(-1, axis) + + # reshape to (n_batch, data.shape[-1]) + data_shape_pre = data.shape + data = data.reshape((-1, data.shape[-1])) + + for i, scale in enumerate(scales): + step = x[1] - x[0] + j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step) + j = j.astype(int) # floor + if j[-1] >= int_psi.size: + j = np.extract(j < int_psi.size, j) + int_psi_scale = int_psi[j][::-1] + + if method == 'conv': + if data.ndim == 1: + conv = np.convolve(data, int_psi_scale) else: - raise ValueError( - "Selected scale of {} too small.".format(scales[i])) - frequencies = scale2frequency(wavelet, scales, precision) - if np.isscalar(frequencies): - frequencies = np.array([frequencies]) - for i in np.arange(len(frequencies)): - frequencies[i] /= sampling_period - return out, frequencies - else: - raise ValueError("Only dim == 1 supported") + # batch convolution via loop + conv_shape = list(data.shape) + conv_shape[-1] += int_psi_scale.size - 1 + conv_shape = tuple(conv_shape) + conv = np.empty(conv_shape, dtype=dt_out) + for n in range(data.shape[0]): + conv[n, :] = np.convolve(data[n], int_psi_scale) + else: + # The padding is selected for: + # - optimal FFT complexity + # - to be larger than the two signals length to avoid circular + # convolution + size_scale = next_fast_len( + data.shape[-1] + int_psi_scale.size - 1 + ) + if size_scale != size_scale0: + # Must recompute fft_data when the padding size changes. + fft_data = fftmodule.fft(data, size_scale, axis=-1) + size_scale0 = size_scale + fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1) + conv = fftmodule.ifft(fft_wav * fft_data, axis=-1) + conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1] + + coef = - np.sqrt(scale) * np.diff(conv, axis=-1) + if out.dtype.kind != 'c': + coef = coef.real + # transform axis is always -1 due to the data reshape above + d = (coef.shape[-1] - data.shape[-1]) / 2. + if d > 0: + coef = coef[..., floor(d):-ceil(d)] + elif d < 0: + raise ValueError( + "Selected scale of {} too small.".format(scale)) + if data.ndim > 1: + # restore original data shape and axis position + coef = coef.reshape(data_shape_pre) + coef = coef.swapaxes(axis, -1) + out[i, ...] = coef + + frequencies = scale2frequency(wavelet, scales, precision) + if np.isscalar(frequencies): + frequencies = np.array([frequencies]) + frequencies /= sampling_period + return out, frequencies diff --git a/pywt/_doc_utils.py b/pywt/_doc_utils.py index 20c3fafbb..ee906aeab 100644 --- a/pywt/_doc_utils.py +++ b/pywt/_doc_utils.py @@ -4,8 +4,10 @@ import numpy as np from matplotlib import pyplot as plt +from ._dwt import pad + __all__ = ['wavedec_keys', 'wavedec2_keys', 'draw_2d_wp_basis', - 'draw_2d_fswavedecn_basis', 'pad', 'boundary_mode_subplot'] + 'draw_2d_fswavedecn_basis', 'boundary_mode_subplot'] def wavedec_keys(level): @@ -149,63 +151,6 @@ def draw_2d_fswavedecn_basis(shape, levels, fmt='k', plot_kwargs={}, ax=None, return fig, ax -def pad(x, pad_widths, mode): - """Extend a 1D signal using a given boundary mode. - - Like numpy.pad but supports all PyWavelets boundary modes. - """ - if np.isscalar(pad_widths): - pad_widths = (pad_widths, pad_widths) - - if x.ndim > 1: - raise ValueError("This padding function is only for 1D signals.") - - if mode in ['symmetric', 'reflect']: - xp = np.pad(x, pad_widths, mode=mode) - elif mode in ['periodic', 'periodization']: - if mode == 'periodization' and x.size % 2 == 1: - raise ValueError("periodization expects an even length signal.") - xp = np.pad(x, pad_widths, mode='wrap') - elif mode == 'zeros': - xp = np.pad(x, pad_widths, mode='constant', constant_values=0) - elif mode == 'constant': - xp = np.pad(x, pad_widths, mode='edge') - elif mode == 'smooth': - xp = np.pad(x, pad_widths, mode='linear_ramp', - end_values=(x[0] + pad_widths[0] * (x[0] - x[1]), - x[-1] + pad_widths[1] * (x[-1] - x[-2]))) - elif mode == 'antisymmetric': - # implement by flipping portions symmetric padding - npad_l, npad_r = pad_widths - xp = np.pad(x, pad_widths, mode='symmetric') - r_edge = npad_l + x.size - 1 - l_edge = npad_l - # width of each reflected segment - seg_width = x.size - # flip reflected segments on the right of the original signal - n = 1 - while r_edge <= xp.size: - segment_slice = slice(r_edge + 1, - min(r_edge + 1 + seg_width, xp.size)) - if n % 2: - xp[segment_slice] *= -1 - r_edge += seg_width - n += 1 - - # flip reflected segments on the left of the original signal - n = 1 - while l_edge >= 0: - segment_slice = slice(max(0, l_edge - seg_width), l_edge) - if n % 2: - xp[segment_slice] *= -1 - l_edge -= seg_width - n += 1 - elif mode == 'antireflect': - npad_l, npad_r = pad_widths - xp = np.pad(x, pad_widths, mode='reflect', reflect_type='odd') - return xp - - def boundary_mode_subplot(x, mode, ax, symw=True): """Plot an illustration of the boundary mode in a subplot axis.""" @@ -236,7 +181,7 @@ def boundary_mode_subplot(x, mode, ax, symw=True): left -= 0.5 step = len(x) rng = range(-2, 4) - if mode in ['smooth', 'constant', 'zeros']: + if mode in ['smooth', 'constant', 'zero']: rng = range(0, 2) for rep in rng: ax.plot((left + rep * step) * o2, [xp.min() - .5, xp.max() + .5], 'k-') diff --git a/pywt/_dwt.py b/pywt/_dwt.py index 56114566a..bf2a3bbb2 100644 --- a/pywt/_dwt.py +++ b/pywt/_dwt.py @@ -12,7 +12,7 @@ __all__ = ["dwt", "idwt", "downcoef", "upcoef", "dwt_max_level", - "dwt_coeff_len"] + "dwt_coeff_len", "pad"] def dwt_max_level(data_len, filter_len): @@ -92,8 +92,8 @@ def dwt_coeff_len(data_len, filter_len, mode): Data length. filter_len : int Filter length. - mode : str, optional (default: 'symmetric') - Signal extension mode, see Modes + mode : str, optional + Signal extension mode, see :ref:`Modes `. Returns ------- @@ -130,12 +130,11 @@ def dwt(data, wavelet, mode='symmetric', axis=-1): wavelet : Wavelet object or name Wavelet to use mode : str, optional - Signal extension mode, see Modes + Signal extension mode, see :ref:`Modes `. axis: int, optional Axis over which to compute the DWT. If not given, the last axis is used. - Returns ------- (cA, cD) : tuple @@ -199,19 +198,18 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1): ---------- cA : array_like or None Approximation coefficients. If None, will be set to array of zeros - with same shape as `cD`. + with same shape as ``cD``. cD : array_like or None Detail coefficients. If None, will be set to array of zeros - with same shape as `cA`. + with same shape as ``cA``. wavelet : Wavelet object or name Wavelet to use mode : str, optional (default: 'symmetric') - Signal extension mode, see Modes + Signal extension mode, see :ref:`Modes `. axis: int, optional Axis over which to compute the inverse DWT. If not given, the last axis is used. - Returns ------- rec: array_like @@ -224,7 +222,7 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1): >>> pywt.idwt(cA, cD, 'db2', 'smooth') array([ 1., 2., 3., 4., 5., 6.]) - One of the neat features of `idwt` is that one of the ``cA`` and ``cD`` + One of the neat features of ``idwt`` is that one of the ``cA`` and ``cD`` arguments can be set to None. In that situation the reconstruction will be performed using only the other one. Mathematically speaking, this is equivalent to passing a zero-filled array as one of the arguments. @@ -300,7 +298,7 @@ def downcoef(part, data, wavelet, mode='symmetric', level=1): Partial Discrete Wavelet Transform data decomposition. - Similar to `pywt.dwt`, but computes only one set of coefficients. + Similar to ``pywt.dwt``, but computes only one set of coefficients. Useful when you need only approximation or only details at the given level. Parameters @@ -316,7 +314,7 @@ def downcoef(part, data, wavelet, mode='symmetric', level=1): wavelet : Wavelet object or name Wavelet to use mode : str, optional - Signal extension mode, see `Modes`. Default is 'symmetric'. + Signal extension mode, see :ref:`Modes `. level : int, optional Decomposition level. Default is 1. @@ -401,3 +399,119 @@ def upcoef(part, coeffs, wavelet, level=1, take=0): if part not in 'ad': raise ValueError("Argument 1 must be 'a' or 'd', not '%s'." % part) return np.asarray(_upcoef(part == 'a', coeffs, wavelet, level, take)) + + +def pad(x, pad_widths, mode): + """Extend a 1D signal using a given boundary mode. + + This function operates like :func:`numpy.pad` but supports all signal + extension modes that can be used by PyWavelets discrete wavelet transforms. + + Parameters + ---------- + x : ndarray + The array to pad + pad_widths : {sequence, array_like, int} + Number of values padded to the edges of each axis. + ``((before_1, after_1), … (before_N, after_N))`` unique pad widths for + each axis. ``((before, after),)`` yields same before and after pad for + each axis. ``(pad,)`` or int is a shortcut for + ``before = after = pad width`` for all axes. + mode : str, optional + Signal extension mode, see :ref:`Modes `. + + Returns + ------- + pad : ndarray + Padded array of rank equal to array with shape increased according to + ``pad_widths``. + + Notes + ----- + The performance of padding in dimensions > 1 may be substantially slower + for modes ``'smooth'`` and ``'antisymmetric'`` as these modes are not + supported efficiently by the underlying :func:`numpy.pad` function. + + Note that the behavior of the ``'constant'`` mode here follows the + PyWavelets convention which is different from NumPy (it is equivalent to + ``mode='edge'`` in :func:`numpy.pad`). + """ + x = np.asanyarray(x) + + # process pad_widths exactly as in numpy.pad + pad_widths = np.array(pad_widths) + pad_widths = np.round(pad_widths).astype(np.intp, copy=False) + if pad_widths.min() < 0: + raise ValueError("pad_widths must be > 0") + pad_widths = np.broadcast_to(pad_widths, (x.ndim, 2)).tolist() + + if mode in ['symmetric', 'reflect']: + xp = np.pad(x, pad_widths, mode=mode) + elif mode in ['periodic', 'periodization']: + if mode == 'periodization': + # Promote odd-sized dimensions to even length by duplicating the + # last value. + edge_pad_widths = [(0, x.shape[ax] % 2) + for ax in range(x.ndim)] + x = np.pad(x, edge_pad_widths, mode='edge') + xp = np.pad(x, pad_widths, mode='wrap') + elif mode == 'zero': + xp = np.pad(x, pad_widths, mode='constant', constant_values=0) + elif mode == 'constant': + xp = np.pad(x, pad_widths, mode='edge') + elif mode == 'smooth': + def pad_smooth(vector, pad_width, iaxis, kwargs): + # smooth extension to left + left = vector[pad_width[0]] + slope_left = (left - vector[pad_width[0] + 1]) + vector[:pad_width[0]] = \ + left + np.arange(pad_width[0], 0, -1) * slope_left + + # smooth extension to right + right = vector[-pad_width[1] - 1] + slope_right = (right - vector[-pad_width[1] - 2]) + vector[-pad_width[1]:] = \ + right + np.arange(1, pad_width[1] + 1) * slope_right + return vector + xp = np.pad(x, pad_widths, pad_smooth) + elif mode == 'antisymmetric': + def pad_antisymmetric(vector, pad_width, iaxis, kwargs): + # smooth extension to left + # implement by flipping portions symmetric padding + npad_l, npad_r = pad_width + vsize_nonpad = vector.size - npad_l - npad_r + # Note: must modify vector in-place + vector[:] = np.pad(vector[pad_width[0]:-pad_width[-1]], + pad_width, mode='symmetric') + vp = vector + r_edge = npad_l + vsize_nonpad - 1 + l_edge = npad_l + # width of each reflected segment + seg_width = vsize_nonpad + # flip reflected segments on the right of the original signal + n = 1 + while r_edge <= vp.size: + segment_slice = slice(r_edge + 1, + min(r_edge + 1 + seg_width, vp.size)) + if n % 2: + vp[segment_slice] *= -1 + r_edge += seg_width + n += 1 + + # flip reflected segments on the left of the original signal + n = 1 + while l_edge >= 0: + segment_slice = slice(max(0, l_edge - seg_width), l_edge) + if n % 2: + vp[segment_slice] *= -1 + l_edge -= seg_width + n += 1 + return vector + xp = np.pad(x, pad_widths, pad_antisymmetric) + elif mode == 'antireflect': + xp = np.pad(x, pad_widths, mode='reflect', reflect_type='odd') + else: + raise ValueError( + ("unsupported mode: {}. The supported modes are {}").format( + mode, Modes.modes)) + return xp diff --git a/pywt/_extensions/_dwt.pyx b/pywt/_extensions/_dwt.pyx index 7af5d3920..d6a0b0b59 100644 --- a/pywt/_extensions/_dwt.pyx +++ b/pywt/_extensions/_dwt.pyx @@ -90,7 +90,7 @@ cpdef dwt_axis(np.ndarray data, Wavelet wavelet, MODE mode, unsigned int axis=0) input_shape = data.shape output_shape = input_shape.copy() - output_shape[axis] = common.dwt_buffer_length(data.shape[axis], wavelet.dec_len, mode) + output_shape[axis] = dwt_coeff_len(data.shape[axis], wavelet.dec_len, mode) cA = np.empty(output_shape, data.dtype) cD = np.empty(output_shape, data.dtype) diff --git a/pywt/_extensions/_swt.pyx b/pywt/_extensions/_swt.pyx index 718b7b0c3..0b3d82103 100644 --- a/pywt/_extensions/_swt.pyx +++ b/pywt/_extensions/_swt.pyx @@ -1,6 +1,7 @@ #cython: boundscheck=False, wraparound=False from . cimport common from . cimport c_wt +from cpython cimport bool import warnings import numpy as np @@ -9,6 +10,7 @@ cimport numpy as np from .common cimport pywt_index_t from ._pywt cimport c_wavelet_from_object, cdata_t, Wavelet, _check_dtype + include "config.pxi" def swt_max_level(size_t input_len): @@ -36,6 +38,8 @@ def swt_max_level(size_t input_len): multiple of ``2**n``. ``numpy.pad`` can be used to pad a signal up to an appropriate length as needed. """ + if input_len < 1: + raise ValueError("Cannot apply swt to a size 0 signal.") max_level = common.swt_max_level(input_len) if max_level == 0: warnings.warn( @@ -45,7 +49,8 @@ def swt_max_level(size_t input_len): return max_level -def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level): +def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level, + bool trim_approx=False): cdef cdata_t[::1] cA, cD cdef Wavelet w cdef int retval @@ -54,6 +59,8 @@ def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level): if data.size % 2: raise ValueError("Length of data must be even.") + if data.size < 1: + raise ValueError("Data must have non-zero size") if level < 1: raise ValueError("Level value must be greater than zero.") @@ -67,6 +74,7 @@ def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level): common.swt_max_level(data.size) - start_level)) raise ValueError(msg) + output_len = common.swt_buffer_length(data.size) if output_len < 1: raise RuntimeError("Invalid output length.") @@ -137,14 +145,20 @@ def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level): raise RuntimeError("C swt failed.") data = cA - ret.append((cA, cD)) + if not trim_approx: + ret.append((np.asarray(cA), np.asarray(cD))) + else: + ret.append(np.asarray(cD)) + if trim_approx: + ret.append(np.asarray(cA)) ret.reverse() return ret cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level, - size_t start_level, unsigned int axis=0): + size_t start_level, unsigned int axis=0, + bool trim_approx=False): # memory-views do not support n-dimensional arrays, use np.ndarray instead cdef common.ArrayInfo data_info, output_info cdef np.ndarray cD, cA @@ -153,8 +167,10 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level, cdef int retval = -5 cdef size_t i - if data.size % 2: - raise ValueError("Length of data must be even.") + if data.shape[axis] % 2: + raise ValueError("Length of data must be even along the transform axis.") + if data.shape[axis] < 1: + raise ValueError("Data must have non-zero size along the transform axis.") if level < 1: raise ValueError("Level value must be greater than zero.") @@ -282,7 +298,10 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level, if retval == -5: raise TypeError("Array must be floating point, not {}" .format(data.dtype)) - ret.append((cA, cD)) + if not trim_approx: + ret.append((cA, cD)) + else: + ret.append(cD) # previous approx coeffs are the data for the next level data = cA @@ -290,5 +309,8 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level, data_info.strides = data.strides data_info.shape = data.shape + if trim_approx: + ret.append(cA) + ret.reverse() return ret diff --git a/pywt/_multidim.py b/pywt/_multidim.py index 39d9dc2bf..3636d01c6 100644 --- a/pywt/_multidim.py +++ b/pywt/_multidim.py @@ -33,7 +33,7 @@ def dwt2(data, wavelet, mode='symmetric', axes=(-2, -1)): Wavelet to use. This can also be a tuple containing a wavelet to apply along each axis in ``axes``. mode : str or 2-tuple of strings, optional - Signal extension mode, see Modes (default: 'symmetric'). This can + Signal extension mode, see :ref:`Modes `. This can also be a tuple of modes specifying the mode to use on each axis in ``axes``. axes : 2-tuple of ints, optional @@ -84,13 +84,13 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)): ---------- coeffs : tuple (cA, (cH, cV, cD)) A tuple with approximation coefficients and three - details coefficients 2D arrays like from `dwt2`. If any of these + details coefficients 2D arrays like from ``dwt2``. If any of these components are set to ``None``, it will be treated as zeros. wavelet : Wavelet object or name string, or 2-tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to apply along each axis in ``axes``. mode : str or 2-tuple of strings, optional - Signal extension mode, see Modes (default: 'symmetric'). This can + Signal extension mode, see :ref:`Modes `. This can also be a tuple of modes specifying the mode to use on each axis in ``axes``. axes : 2-tuple of ints, optional @@ -131,7 +131,7 @@ def dwtn(data, wavelet, mode='symmetric', axes=None): apply along each axis in ``axes``. mode : str or tuple of string, optional Signal extension mode used in the decomposition, - see Modes (default: 'symmetric'). This can also be a tuple of modes + see :ref:`Modes `. This can also be a tuple of modes specifying the mode to use on each axis in ``axes``. axes : sequence of ints, optional Axes over which to compute the DWT. Repeated elements mean the DWT will @@ -233,7 +233,7 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None): apply along each axis in ``axes``. mode : str or list of string, optional Signal extension mode used in the decomposition, - see Modes (default: 'symmetric'). This can also be a tuple of modes + see :ref:`Modes `. This can also be a tuple of modes specifying the mode to use on each axis in ``axes``. axes : sequence of ints, optional Axes over which to compute the IDWT. Repeated elements mean the IDWT diff --git a/pywt/_multilevel.py b/pywt/_multilevel.py index 1ec785eb1..5ece58437 100644 --- a/pywt/_multilevel.py +++ b/pywt/_multilevel.py @@ -57,10 +57,10 @@ def wavedec(data, wavelet, mode='symmetric', level=None, axis=-1): wavelet : Wavelet object or name string Wavelet to use mode : str, optional - Signal extension mode, see `Modes` (default: 'symmetric') + Signal extension mode, see :ref:`Modes `. level : int, optional Decomposition level (must be >= 0). If level is None (default) then it - will be calculated using the `dwt_max_level` function. + will be calculated using the ``dwt_max_level`` function. axis: int, optional Axis over which to compute the DWT. If not given, the last axis is used. @@ -69,9 +69,10 @@ def wavedec(data, wavelet, mode='symmetric', level=None, axis=-1): ------- [cA_n, cD_n, cD_n-1, ..., cD2, cD1] : list Ordered list of coefficients arrays - where `n` denotes the level of decomposition. The first element - (`cA_n`) of the result is approximation coefficients array and the - following elements (`cD_n` - `cD_1`) are details coefficients arrays. + where ``n`` denotes the level of decomposition. The first element + (``cA_n``) of the result is approximation coefficients array and the + following elements (``cD_n`` - ``cD_1``) are details coefficients + arrays. Examples -------- @@ -119,14 +120,14 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1): wavelet : Wavelet object or name string Wavelet to use mode : str, optional - Signal extension mode, see `Modes` (default: 'symmetric') + Signal extension mode, see :ref:`Modes `. axis: int, optional Axis over which to compute the inverse DWT. If not given, the last axis is used. Notes ----- - It may sometimes be desired to run `waverec` with some sets of + It may sometimes be desired to run ``waverec`` with some sets of coefficients omitted. This can best be done by setting the corresponding arrays to zero arrays of matching shape and dtype. Explicitly removing list entries or setting them to None is not supported. @@ -156,6 +157,12 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1): a, ds = coeffs[0], coeffs[1:] for d in ds: + if d is not None and not isinstance(d, np.ndarray): + raise ValueError(( + "Unexpected detail coefficient type: {}. Detail coefficients " + "must be arrays as returned by wavedec. If you are using " + "pywt.array_to_coeffs or pywt.unravel_coeffs, please specify " + "output_format='wavedec'").format(type(d))) if (a is not None) and (d is not None): try: if a.shape[axis] == d.shape[axis] + 1: @@ -164,10 +171,6 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1): raise ValueError("coefficient shape mismatch") except IndexError: raise ValueError("Axis greater than coefficient dimensions") - except AttributeError: - raise AttributeError( - "Wrong coefficient format, if using 'array_to_coeffs' " - "please specify the 'output_format' parameter") a = idwt(a, d, wavelet, mode, axis) return a @@ -183,25 +186,26 @@ def wavedec2(data, wavelet, mode='symmetric', level=None, axes=(-2, -1)): 2D input data wavelet : Wavelet object or name string, or 2-tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to - apply along each axis in `axes`. + apply along each axis in ``axes``. mode : str or 2-tuple of str, optional - Signal extension mode, see `Modes` (default: 'symmetric'). This can - also be a tuple containing a mode to apply along each axis in `axes`. + Signal extension mode, see :ref:`Modes `. This can + also be a tuple containing a mode to apply along each axis in ``axes``. level : int, optional Decomposition level (must be >= 0). If level is None (default) then it - will be calculated using the `dwt_max_level` function. + will be calculated using the ``dwt_max_level`` function. axes : 2-tuple of ints, optional Axes over which to compute the DWT. Repeated elements are not allowed. Returns ------- [cAn, (cHn, cVn, cDn), ... (cH1, cV1, cD1)] : list - Coefficients list. For user-specified `axes`, `cH*` - corresponds to ``axes[0]`` while `cV*` corresponds to ``axes[1]``. + Coefficients list. For user-specified ``axes``, ``cH*`` + corresponds to ``axes[0]`` while ``cV*`` corresponds to ``axes[1]``. The first element returned is the approximation coefficients for the nth level of decomposition. Remaining elements are tuples of detail coefficients in descending order of decomposition level. - (i.e. `cH1` are the horizontal detail coefficients at the first level) + (i.e. ``cH1`` are the horizontal detail coefficients at the first + level) Examples -------- @@ -257,10 +261,10 @@ def waverec2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)): Coefficients list [cAn, (cHn, cVn, cDn), ... (cH1, cV1, cD1)] wavelet : Wavelet object or name string, or 2-tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to - apply along each axis in `axes`. + apply along each axis in ``axes``. mode : str or 2-tuple of str, optional - Signal extension mode, see `Modes` (default: 'symmetric'). This can - also be a tuple containing a mode to apply along each axis in `axes`. + Signal extension mode, see :ref:`Modes `. This can + also be a tuple containing a mode to apply along each axis in ``axes``. axes : 2-tuple of ints, optional Axes over which to compute the IDWT. Repeated elements are not allowed. @@ -270,7 +274,7 @@ def waverec2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)): Notes ----- - It may sometimes be desired to run `waverec2` with some sets of + It may sometimes be desired to run ``waverec2`` with some sets of coefficients omitted. This can best be done by setting the corresponding arrays to zero arrays of matching shape and dtype. Explicitly removing list or tuple entries or setting them to None is not supported. @@ -310,6 +314,12 @@ def waverec2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)): a = np.asarray(a) for d in ds: + if not isinstance(d, (list, tuple)) or len(d) != 3: + raise ValueError(( + "Unexpected detail coefficient type: {}. Detail coefficients " + "must be a 3-tuple of arrays as returned by wavedec2. If you " + "are using pywt.array_to_coeffs or pywt.unravel_coeffs, " + "please specify output_format='wavedec2'").format(type(d))) d = tuple(np.asarray(coeff) if coeff is not None else None for coeff in d) d_shapes = (coeff.shape for coeff in d if coeff is not None) @@ -357,13 +367,13 @@ def wavedecn(data, wavelet, mode='symmetric', level=None, axes=None): nD input data wavelet : Wavelet object or name string, or tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to - apply along each axis in `axes`. + apply along each axis in ``axes``. mode : str or tuple of str, optional - Signal extension mode, see `Modes` (default: 'symmetric'). This can - also be a tuple containing a mode to apply along each axis in `axes`. + Signal extension mode, see :ref:`Modes `. This can + also be a tuple containing a mode to apply along each axis in ``axes``. level : int, optional Decomposition level (must be >= 0). If level is None (default) then it - will be calculated using the `dwt_max_level` function. + will be calculated using the ``dwt_max_level`` function. axes : sequence of ints, optional Axes over which to compute the DWT. Axes may not be repeated. The default is None, which means transform all axes @@ -373,16 +383,16 @@ def wavedecn(data, wavelet, mode='symmetric', level=None, axes=None): ------- [cAn, {details_level_n}, ... {details_level_1}] : list Coefficients list. Coefficients are listed in descending order of - decomposition level. `cAn` are the approximation coefficients at - level `n`. Each `details_level_i` element is a dictionary - containing detail coefficients at level `i` of the decomposition. As + decomposition level. ``cAn`` are the approximation coefficients at + level ``n``. Each ``details_level_i`` element is a dictionary + containing detail coefficients at level ``i`` of the decomposition. As a concrete example, a 3D decomposition would have the following set of - keys in each `details_level_i` dictionary:: + keys in each ``details_level_i`` dictionary:: {'aad', 'ada', 'daa', 'add', 'dad', 'dda', 'ddd'} where the order of the characters in each key map to the specified - `axes`. + ``axes``. Examples -------- @@ -456,10 +466,10 @@ def waverecn(coeffs, wavelet, mode='symmetric', axes=None): Coefficients list [cAn, {details_level_n}, ... {details_level_1}] wavelet : Wavelet object or name string, or tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to - apply along each axis in `axes`. + apply along each axis in ``axes``. mode : str or tuple of str, optional - Signal extension mode, see `Modes` (default: 'symmetric'). This can - also be a tuple containing a mode to apply along each axis in `axes`. + Signal extension mode, see :ref:`Modes `. This can + also be a tuple containing a mode to apply along each axis in ``axes``. axes : sequence of ints, optional Axes over which to compute the IDWT. Axes may not be repeated. @@ -469,7 +479,7 @@ def waverecn(coeffs, wavelet, mode='symmetric', axes=None): Notes ----- - It may sometimes be desired to run `waverecn` with some sets of + It may sometimes be desired to run ``waverecn`` with some sets of coefficients omitted. This can best be done by setting the corresponding arrays to zero arrays of matching shape and dtype. Explicitly removing list or dictionary entries or setting them to None is not supported. @@ -511,6 +521,14 @@ def waverecn(coeffs, wavelet, mode='symmetric', axes=None): a, ds = coeffs[0], coeffs[1:] + # this dictionary check must be prior to the call to _fix_coeffs + if len(ds) > 0 and not all([isinstance(d, dict) for d in ds]): + raise ValueError(( + "Unexpected detail coefficient type: {}. Detail coefficients " + "must be a dicionary of arrays as returned by wavedecn. If " + "you are using pywt.array_to_coeffs or pywt.unravel_coeffs, " + "please specify output_format='wavedecn'").format(type(ds[0]))) + # Raise error for invalid key combinations ds = list(map(_fix_coeffs, ds)) @@ -655,7 +673,7 @@ def _prepare_coeffs_axes(coeffs, axes): def coeffs_to_array(coeffs, padding=0, axes=None): """ - Arrange a wavelet coefficient list from `wavedecn` into a single array. + Arrange a wavelet coefficient list from ``wavedecn`` into a single array. Parameters ---------- @@ -665,7 +683,7 @@ def coeffs_to_array(coeffs, padding=0, axes=None): padding : float or None, optional If None, raise an error if the coefficients cannot be tightly packed. axes : sequence of ints, optional - Axes over which the DWT that created `coeffs` was performed. The + Axes over which the DWT that created ``coeffs`` was performed. The default value of None corresponds to all axes. Returns @@ -674,8 +692,8 @@ def coeffs_to_array(coeffs, padding=0, axes=None): Wavelet transform coefficient array. coeff_slices : list List of slices corresponding to each coefficient. As a 2D example, - `coeff_arr[coeff_slices[1]['dd']]` would extract the first level detail - coefficients from `coeff_arr`. + ``coeff_arr[coeff_slices[1]['dd']]`` would extract the first level + detail coefficients from ``coeff_arr``. See Also -------- @@ -773,17 +791,17 @@ def coeffs_to_array(coeffs, padding=0, axes=None): def array_to_coeffs(arr, coeff_slices, output_format='wavedecn'): """ Convert a combined array of coefficients back to a list compatible with - `waverecn`. + ``waverecn``. Parameters ---------- arr : array-like An array containing all wavelet coefficients. This should have been - generated via `coeffs_to_array`. + generated via ``coeffs_to_array``. coeff_slices : list of tuples List of slices corresponding to each coefficient as obtained from - `array_to_coeffs`. + ``array_to_coeffs``. output_format : {'wavedec', 'wavedec2', 'wavedecn'} Make the form of the coefficients compatible with this type of multilevel transform. @@ -800,7 +818,7 @@ def array_to_coeffs(arr, coeff_slices, output_format='wavedecn'): Notes ----- A single large array containing all coefficients will have subsets stored, - into a `waverecn` list, c, as indicated below:: + into a ``waverecn`` list, c, as indicated below:: +---------------+---------------+-------------------------------+ | | | | @@ -827,7 +845,8 @@ def array_to_coeffs(arr, coeff_slices, output_format='wavedecn'): >>> cam = pywt.data.camera() >>> coeffs = pywt.wavedecn(cam, wavelet='db2', level=3) >>> arr, coeff_slices = pywt.coeffs_to_array(coeffs) - >>> coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices) + >>> coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices, + ... output_format='wavedecn') >>> cam_recon = pywt.waverecn(coeffs_from_arr, wavelet='db2') >>> assert_array_almost_equal(cam, cam_recon) @@ -867,13 +886,13 @@ def wavedecn_shapes(shape, wavelet, mode='symmetric', level=None, axes=None): The shape of the data to be transformed. wavelet : Wavelet object or name string, or tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to - apply along each axis in `axes`. + apply along each axis in ``axes``. mode : str or tuple of str, optional - Signal extension mode, see Modes (default: 'symmetric'). This can - also be a tuple containing a mode to apply along each axis in `axes`. + Signal extension mode, see :ref:`Modes `. This can + also be a tuple containing a mode to apply along each axis in ``axes``. level : int, optional Decomposition level (must be >= 0). If level is None (default) then it - will be calculated using the `dwt_max_level` function. + will be calculated using the ``dwt_max_level`` function. axes : sequence of ints, optional Axes over which to compute the DWT. Axes may not be repeated. The default is None, which means transform all axes @@ -882,7 +901,7 @@ def wavedecn_shapes(shape, wavelet, mode='symmetric', level=None, axes=None): Returns ------- shapes : [cAn, {details_level_n}, ... {details_level_1}] : list - Coefficients shape list. Mirrors the output of `wavedecn`, except + Coefficients shape list. Mirrors the output of ``wavedecn``, except it contains only the shapes of the coefficient arrays rather than the arrays themselves. @@ -922,9 +941,9 @@ def wavedecn_size(shapes): Parameters ---------- shapes : list of coefficient shapes - A set of coefficient shapes as returned by `wavedecn_shapes`. + A set of coefficient shapes as returned by ``wavedecn_shapes``. Alternatively, the user can specify a set of coefficients as returned - by `wavedecn`. + by ``wavedecn``. Returns ------- @@ -944,7 +963,7 @@ def wavedecn_size(shapes): 3087 """ def _size(x): - """Size corresponding to `x` as either a shape tuple or an ndarray.""" + """Size corresponding to ``x`` as either a shape tuple or ndarray.""" if isinstance(x, np.ndarray): return x.size else: @@ -963,7 +982,7 @@ def dwtn_max_level(shape, wavelet, axes=None): """Compute the maximum level of decomposition for n-dimensional data. This returns the maximum number of levels of decomposition suitable for use - with `wavedec`, `wavedec2` or `wavedecn`. + with ``wavedec``, ``wavedec2`` or ``wavedecn``. Parameters ---------- @@ -971,7 +990,7 @@ def dwtn_max_level(shape, wavelet, axes=None): Input data shape. wavelet : Wavelet object or name string, or tuple of wavelets Wavelet to use. This can also be a tuple containing a wavelet to - apply along each axis in `axes`. + apply along each axis in ``axes``. axes : sequence of ints, optional Axes over which to compute the DWT. Axes may not be repeated. @@ -982,7 +1001,7 @@ def dwtn_max_level(shape, wavelet, axes=None): Notes ----- - The level returned is the smallest `dwt_max_level` over all axes. + The level returned is the smallest ``dwt_max_level`` over all axes. Examples -------- @@ -1009,9 +1028,11 @@ def ravel_coeffs(coeffs, axes=None): ---------- coeffs : array-like A list of multilevel wavelet coefficients as returned by - `wavedec`, `wavedec2` or `wavedecn`. + ``wavedec``, ``wavedec2`` or ``wavedecn``. This function is also + compatible with the output of ``swt``, ``swt2`` and ``swtn`` if those + functions were called with ``trim_approx=True``. axes : sequence of ints, optional - Axes over which the DWT that created `coeffs` was performed. The + Axes over which the DWT that created ``coeffs`` was performed. The default value of None corresponds to all axes. Returns @@ -1022,7 +1043,7 @@ def ravel_coeffs(coeffs, axes=None): coeff_slices : list List of slices corresponding to each coefficient. As a 2D example, ``coeff_arr[coeff_slices[1]['dd']]`` would extract the first level - detail coefficients from `coeff_arr`. + detail coefficients from ``coeff_arr``. coeff_shapes : list List of shapes corresponding to each coefficient. For example, in 2D, ``coeff_shapes[1]['dd']`` would contain the original shape of the first @@ -1092,23 +1113,24 @@ def unravel_coeffs(arr, coeff_slices, coeff_shapes, output_format='wavedecn'): ---------- arr : array-like An array containing all wavelet coefficients. This should have been - generated by applying `ravel_coeffs` to the output of `wavedec`, - `wavedec2` or `wavedecn`. + generated by applying ``ravel_coeffs`` to the output of ``wavedec``, + ``wavedec2`` or ``wavedecn`` (or via ``swt``, ``swt2`` or ``swtn`` + with ``trim_approx=True``). coeff_slices : list of tuples List of slices corresponding to each coefficient as obtained from - `ravel_coeffs`. + ``ravel_coeffs``. coeff_shapes : list of tuples List of shapes corresponding to each coefficient as obtained from - `ravel_coeffs`. - output_format : {'wavedec', 'wavedec2', 'wavedecn'}, optional + ``ravel_coeffs``. + output_format : {'wavedec', 'wavedec2', 'wavedecn', 'swt', 'swt2', 'swtn'}, optional Make the form of the unraveled coefficients compatible with this type - of multilevel transform. The default is 'wavedecn'. + of multilevel transform. The default is ``'wavedecn'``. Returns ------- coeffs: list List of wavelet transform coefficients. The specific format of the list - elements is determined by `output_format`. + elements is determined by ``output_format``. See Also -------- @@ -1121,7 +1143,8 @@ def unravel_coeffs(arr, coeff_slices, coeff_shapes, output_format='wavedecn'): >>> cam = pywt.data.camera() >>> coeffs = pywt.wavedecn(cam, wavelet='db2', level=3) >>> arr, coeff_slices, coeff_shapes = pywt.ravel_coeffs(coeffs) - >>> coeffs_from_arr = pywt.unravel_coeffs(arr, coeff_slices, coeff_shapes) + >>> coeffs_from_arr = pywt.unravel_coeffs(arr, coeff_slices, coeff_shapes, + ... output_format='wavedecn') >>> cam_recon = pywt.waverecn(coeffs_from_arr, wavelet='db2') >>> assert_array_almost_equal(cam, cam_recon) @@ -1141,13 +1164,13 @@ def unravel_coeffs(arr, coeff_slices, coeff_shapes, output_format='wavedecn'): for n in range(1, len(coeff_slices)): slice_dict = coeff_slices[n] shape_dict = coeff_shapes[n] - if output_format == 'wavedec': + if output_format in ['wavedec', 'swt']: d = arr[slice_dict['d']].reshape(shape_dict['d']) - elif output_format == 'wavedec2': + elif output_format in ['wavedec2', 'swt2']: d = (arr[slice_dict['da']].reshape(shape_dict['da']), arr[slice_dict['ad']].reshape(shape_dict['ad']), arr[slice_dict['dd']].reshape(shape_dict['dd'])) - elif output_format == 'wavedecn': + elif output_format in ['wavedecn', 'swtn']: d = {} for k, v in coeff_slices[n].items(): d[k] = arr[v].reshape(shape_dict[k]) @@ -1362,12 +1385,12 @@ def fswavedecn(data, wavelet, mode='symmetric', levels=None, axes=None): Wavelet to use. This can also be a tuple containing a wavelet to apply along each axis in ``axes``. mode : str or tuple of str, optional - Signal extension mode, see `Modes` (default: 'symmetric'). This can + Signal extension mode, see :ref:`Modes `. This can also be a tuple containing a mode to apply along each axis in ``axes``. levels : int or sequence of ints, optional Decomposition levels along each axis (must be >= 0). If an integer is provided, the same number of levels are used for all axes. If - ``levels`` is None (default), `dwt_max_level` will be used to compute + ``levels`` is None (default), ``dwt_max_level`` will be used to compute the maximum number of levels possible for each axis. axes : sequence of ints, optional Axes over which to compute the transform. Axes may not be repeated. The @@ -1378,7 +1401,7 @@ def fswavedecn(data, wavelet, mode='symmetric', levels=None, axes=None): fswavedecn_result : FswavedecnResult object Contains the wavelet coefficients, slice objects to allow obtaining the coefficients per detail or approximation level, and more. - See `FswavedecnResult` for details. + See ``FswavedecnResult`` for details. Examples -------- diff --git a/pywt/_swt.py b/pywt/_swt.py index 472c2ec2d..575c4e803 100644 --- a/pywt/_swt.py +++ b/pywt/_swt.py @@ -6,7 +6,7 @@ from ._c99_config import _have_c99_complex from ._extensions._dwt import idwt_single from ._extensions._swt import swt_max_level, swt as _swt, swt_axis as _swt_axis -from ._extensions._pywt import Modes, _check_dtype +from ._extensions._pywt import Wavelet, Modes, _check_dtype from ._multidim import idwt2, idwtn from ._utils import _as_wavelet, _wavelets_per_axis @@ -14,7 +14,18 @@ __all__ = ["swt", "swt_max_level", 'iswt', 'swt2', 'iswt2', 'swtn', 'iswtn'] -def swt(data, wavelet, level=None, start_level=0, axis=-1): +def _rescale_wavelet_filterbank(wavelet, sf): + wav = Wavelet(wavelet.name + 'r', + [np.asarray(f) * sf for f in wavelet.filter_bank]) + + # copy attributes from the original wavelet + wav.orthogonal = wavelet.orthogonal + wav.biorthogonal = wavelet.biorthogonal + return wav + + +def swt(data, wavelet, level=None, start_level=0, axis=-1, + trim_approx=False, norm=False): """ Multilevel 1D stationary wavelet transform. @@ -33,6 +44,13 @@ def swt(data, wavelet, level=None, start_level=0, axis=-1): axis: int, optional Axis over which to compute the SWT. If not given, the last axis is used. + trim_approx : bool, optional + If True, approximation coefficients at the final level are retained. + norm : bool, optional + If True, transform is normalized so that the energy of the coefficients + will be equal to the energy of ``data``. In other words, + ``np.linalg.norm(data.ravel())`` will equal the norm of the + concatenated transform coefficients when ``trim_approx`` is True. Returns ------- @@ -49,20 +67,60 @@ def swt(data, wavelet, level=None, start_level=0, axis=-1): [(cAm+n, cDm+n), ..., (cAm+1, cDm+1), (cAm, cDm)] + If ``trim_approx`` is ``True``, then the output list is exactly as in + ``pywt.wavedec``, where the first coefficient in the list is the + approximation coefficient at the final level and the rest are the + detail coefficients:: + + [cAn, cDn, ..., cD2, cD1] + Notes ----- The implementation here follows the "algorithm a-trous" and requires that the signal length along the transformed axis be a multiple of ``2**level``. If this is not the case, the user should pad up to an appropriate size using a function such as ``numpy.pad``. + + A primary benefit of this transform in comparison to its decimated + counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes + at cost of redundancy in the transform (the size of the output coefficients + is larger than the input). + + When the following three conditions are true: + + 1. The wavelet is orthogonal + 2. ``swt`` is called with ``norm=True`` + 3. ``swt`` is called with ``trim_approx=True`` + + the transform has the following additional properties that may be + desirable in applications: + + 1. energy is conserved + 2. variance is partitioned across scales + + When used with ``norm=True``, this transform is closely related to the + multiple-overlap DWT (MODWT) as popularized for time-series analysis, + although the underlying implementation is slightly different from the one + published in [1]_. Specifically, the implementation used here requires a + signal that is a multiple of ``2**level`` in length. + + References + ---------- + .. [1] DB Percival and AT Walden. Wavelet Methods for Time Series Analysis. + Cambridge University Press, 2000. """ + if not _have_c99_complex and np.iscomplexobj(data): data = np.asarray(data) - coeffs_real = swt(data.real, wavelet, level, start_level) - coeffs_imag = swt(data.imag, wavelet, level, start_level) - coeffs_cplx = [] - for (cA_r, cD_r), (cA_i, cD_i) in zip(coeffs_real, coeffs_imag): - coeffs_cplx.append((cA_r + 1j*cA_i, cD_r + 1j*cD_i)) + coeffs_real = swt(data.real, wavelet, level, start_level, trim_approx) + coeffs_imag = swt(data.imag, wavelet, level, start_level, trim_approx) + if not trim_approx: + coeffs_cplx = [] + for (cA_r, cD_r), (cA_i, cD_i) in zip(coeffs_real, coeffs_imag): + coeffs_cplx.append((cA_r + 1j*cA_i, cD_r + 1j*cD_i)) + else: + coeffs_cplx = [cr + 1j*ci + for (cr, ci) in zip(coeffs_real, coeffs_imag)] return coeffs_cplx # accept array_like input; make a copy to ensure a contiguous array @@ -70,6 +128,12 @@ def swt(data, wavelet, level=None, start_level=0, axis=-1): data = np.array(data, dtype=dt) wavelet = _as_wavelet(wavelet) + if norm: + if not wavelet.orthogonal: + warnings.warn( + "norm=True, but the wavelet is not orthogonal: \n" + "\tThe conditions for energy preservation are not satisfied.") + wavelet = _rescale_wavelet_filterbank(wavelet, 1/np.sqrt(2)) if axis < 0: axis = axis + data.ndim @@ -80,13 +144,13 @@ def swt(data, wavelet, level=None, start_level=0, axis=-1): level = swt_max_level(data.shape[axis]) if data.ndim == 1: - ret = _swt(data, wavelet, level, start_level) + ret = _swt(data, wavelet, level, start_level, trim_approx) else: - ret = _swt_axis(data, wavelet, level, start_level, axis) - return [(np.asarray(cA), np.asarray(cD)) for cA, cD in ret] + ret = _swt_axis(data, wavelet, level, start_level, axis, trim_approx) + return ret -def iswt(coeffs, wavelet): +def iswt(coeffs, wavelet, norm=False): """ Multilevel 1D inverse discrete stationary wavelet transform. @@ -101,6 +165,10 @@ def iswt(coeffs, wavelet): ``start_level`` from ``pywt.swt``. wavelet : Wavelet object or name string Wavelet to use + norm : bool, optional + Controls the normalization used by the inverse transform. This must + be set equal to the value that was used by ``pywt.swt`` to preserve the + energy of a round-trip transform. Returns ------- @@ -114,23 +182,41 @@ def iswt(coeffs, wavelet): array([ 1., 2., 3., 4., 5., 6., 7., 8.]) """ # copy to avoid modification of input data - dt = _check_dtype(coeffs[0][0]) - output = np.array(coeffs[0][0], dtype=dt, copy=True) + # If swt was called with trim_approx=False, first element is a tuple + trim_approx = not isinstance(coeffs[0], (tuple, list)) + + if trim_approx: + cA = coeffs[0] + coeffs = coeffs[1:] + else: + cA = coeffs[0][0] + + dt = _check_dtype(cA) + output = np.array(cA, dtype=dt, copy=True) if not _have_c99_complex and np.iscomplexobj(output): # compute real and imaginary separately then combine - coeffs_real = [(cA.real, cD.real) for (cA, cD) in coeffs] - coeffs_imag = [(cA.imag, cD.imag) for (cA, cD) in coeffs] + if trim_approx: + coeffs_real = [c.real for c in coeffs] + coeffs_imag = [c.imag for c in coeffs] + else: + coeffs_real = [(cA.real, cD.real) for (cA, cD) in coeffs] + coeffs_imag = [(cA.imag, cD.imag) for (cA, cD) in coeffs] return iswt(coeffs_real, wavelet) + 1j*iswt(coeffs_imag, wavelet) # num_levels, equivalent to the decomposition level, n num_levels = len(coeffs) wavelet = _as_wavelet(wavelet) + if norm: + wavelet = _rescale_wavelet_filterbank(wavelet, np.sqrt(2)) mode = Modes.from_object('periodization') for j in range(num_levels, 0, -1): step_size = int(pow(2, j-1)) last_index = step_size - _, cD = coeffs[num_levels - j] + if trim_approx: + cD = coeffs[-j] + else: + _, cD = coeffs[-j] cD = np.asarray(cD, dtype=_check_dtype(cD)) if cD.dtype != output.dtype: # upcast to a common dtype (float64 or complex128) @@ -170,7 +256,8 @@ def iswt(coeffs, wavelet): return output -def swt2(data, wavelet, level, start_level=0, axes=(-2, -1)): +def swt2(data, wavelet, level, start_level=0, axes=(-2, -1), + trim_approx=False, norm=False): """ Multilevel 2D stationary wavelet transform. @@ -187,11 +274,20 @@ def swt2(data, wavelet, level, start_level=0, axes=(-2, -1)): The level at which the decomposition will start (default: 0) axes : 2-tuple of ints, optional Axes over which to compute the SWT. Repeated elements are not allowed. + trim_approx : bool, optional + If True, approximation coefficients at the final level are retained. + norm : bool, optional + If True, transform is normalized so that the energy of the coefficients + will be equal to the energy of ``data``. In other words, + ``np.linalg.norm(data.ravel())`` will equal the norm of the + concatenated transform coefficients when ``trim_approx`` is True. Returns ------- coeffs : list - Approximation and details coefficients (for ``start_level = m``):: + Approximation and details coefficients (for ``start_level = m``). + If ``trim_approx`` is ``True``, approximation coefficients are + retained for all levels:: [ (cA_m+level, @@ -209,12 +305,42 @@ def swt2(data, wavelet, level, start_level=0, axes=(-2, -1)): where cA is approximation, cH is horizontal details, cV is vertical details, cD is diagonal details and m is ``start_level``. + If ``trim_approx`` is ``False``, approximation coefficients are only + retained at the final level of decomposition. This matches the format + used by ``pywt.wavedec2``:: + + [ + cA_m+level, + (cH_m+level, cV_m+level, cD_m+level), + ..., + (cH_m+1, cV_m+1, cD_m+1), + (cH_m, cV_m, cD_m), + ] + Notes ----- The implementation here follows the "algorithm a-trous" and requires that the signal length along the transformed axes be a multiple of ``2**level``. If this is not the case, the user should pad up to an appropriate size using a function such as ``numpy.pad``. + + A primary benefit of this transform in comparison to its decimated + counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes + at cost of redundancy in the transform (the size of the output coefficients + is larger than the input). + + When the following three conditions are true: + + 1. The wavelet is orthogonal + 2. ``swt2`` is called with ``norm=True`` + 3. ``swt2`` is called with ``trim_approx=True`` + + the transform has the following additional properties that may be + desirable in applications: + + 1. energy is conserved + 2. variance is partitioned across scales + """ axes = tuple(axes) data = np.asarray(data) @@ -226,15 +352,20 @@ def swt2(data, wavelet, level, start_level=0, axes=(-2, -1)): raise ValueError("Input array has fewer dimensions than the specified " "axes") - coefs = swtn(data, wavelet, level, start_level, axes) + coefs = swtn(data, wavelet, level, start_level, axes, trim_approx, norm) ret = [] + if trim_approx: + ret.append(coefs[0]) + coefs = coefs[1:] for c in coefs: - ret.append((c['aa'], (c['da'], c['ad'], c['dd']))) - + if trim_approx: + ret.append((c['da'], c['ad'], c['dd'])) + else: + ret.append((c['aa'], (c['da'], c['ad'], c['dd']))) return ret -def iswt2(coeffs, wavelet): +def iswt2(coeffs, wavelet, norm=False): """ Multilevel 2D inverse discrete stationary wavelet transform. @@ -262,6 +393,10 @@ def iswt2(coeffs, wavelet): wavelet : Wavelet object or name string, or 2-tuple of wavelets Wavelet to use. This can also be a 2-tuple of wavelets to apply per axis. + norm : bool, optional + Controls the normalization used by the inverse transform. This must + be set equal to the value that was used by ``pywt.swt2`` to preserve + the energy of a round-trip transform. Returns ------- @@ -281,9 +416,17 @@ def iswt2(coeffs, wavelet): """ + # If swt was called with trim_approx=False, first element is a tuple + trim_approx = not isinstance(coeffs[0], (tuple, list)) + if trim_approx: + cA = coeffs[0] + coeffs = coeffs[1:] + else: + cA = coeffs[0][0] + # copy to avoid modification of input data - dt = _check_dtype(coeffs[0][0]) - output = np.array(coeffs[0][0], dtype=dt, copy=True) + dt = _check_dtype(cA) + output = np.array(cA, dtype=dt, copy=True) if output.ndim != 2: raise ValueError( @@ -292,11 +435,17 @@ def iswt2(coeffs, wavelet): # num_levels, equivalent to the decomposition level, n num_levels = len(coeffs) wavelets = _wavelets_per_axis(wavelet, axes=(0, 1)) + if norm: + wavelets = [_rescale_wavelet_filterbank(wav, np.sqrt(2)) + for wav in wavelets] for j in range(num_levels): step_size = int(pow(2, num_levels-j-1)) last_index = step_size - _, (cH, cV, cD) = coeffs[j] + if trim_approx: + (cH, cV, cD) = coeffs[j] + else: + _, (cH, cV, cD) = coeffs[j] # We are going to assume cH, cV, and cD are of equal size if (cH.shape != cV.shape) or (cH.shape != cD.shape): raise RuntimeError( @@ -353,7 +502,8 @@ def iswt2(coeffs, wavelet): return output -def swtn(data, wavelet, level, start_level=0, axes=None): +def swtn(data, wavelet, level, start_level=0, axes=None, trim_approx=False, + norm=False): """ n-dimensional stationary wavelet transform. @@ -371,6 +521,13 @@ def swtn(data, wavelet, level, start_level=0, axes=None): axes : sequence of ints, optional Axes over which to compute the SWT. A value of ``None`` (the default) selects all axes. Axes may not be repeated. + trim_approx : bool, optional + If True, approximation coefficients at the final level are retained. + norm : bool, optional + If True, transform is normalized so that the energy of the coefficients + will be equal to the energy of ``data``. In other words, + ``np.linalg.norm(data.ravel())`` will equal the norm of the + concatenated transform coefficients when ``trim_approx`` is True. Returns ------- @@ -391,19 +548,47 @@ def swtn(data, wavelet, level, start_level=0, axes=None): For user-specified ``axes``, the order of the characters in the dictionary keys map to the specified ``axes``. + If ``trim_approx`` is ``True``, the first element of the list contains + the array of approximation coefficients from the final level of + decomposition, while the remaining coefficient dictionaries contain + only detail coefficients. This matches the behavior of `pywt.wavedecn`. + Notes ----- The implementation here follows the "algorithm a-trous" and requires that the signal length along the transformed axes be a multiple of ``2**level``. If this is not the case, the user should pad up to an appropriate size using a function such as ``numpy.pad``. + + A primary benefit of this transform in comparison to its decimated + counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes + at cost of redundancy in the transform (the size of the output coefficients + is larger than the input). + + When the following three conditions are true: + + 1. The wavelet is orthogonal + 2. ``swtn`` is called with ``norm=True`` + 3. ``swtn`` is called with ``trim_approx=True`` + + the transform has the following additional properties that may be + desirable in applications: + + 1. energy is conserved + 2. variance is partitioned across scales + """ data = np.asarray(data) if not _have_c99_complex and np.iscomplexobj(data): - real = swtn(data.real, wavelet, level, start_level, axes) - imag = swtn(data.imag, wavelet, level, start_level, axes) - cplx = [] - for rdict, idict in zip(real, imag): + real = swtn(data.real, wavelet, level, start_level, axes, trim_approx) + imag = swtn(data.imag, wavelet, level, start_level, axes, trim_approx) + if trim_approx: + cplx = [real[0] + 1j * imag[0]] + offset = 1 + else: + cplx = [] + offset = 0 + for rdict, idict in zip(real[offset:], imag[offset:]): cplx.append( dict((k, rdict[k] + 1j * idict[k]) for k in rdict.keys())) return cplx @@ -421,7 +606,13 @@ def swtn(data, wavelet, level, start_level=0, axes=None): num_axes = len(axes) wavelets = _wavelets_per_axis(wavelet, axes) - + if norm: + if not np.all([wav.orthogonal for wav in wavelets]): + warnings.warn( + "norm=True, but the wavelets used are not orthogonal: \n" + "\tThe conditions for energy preservation are not satisfied.") + wavelets = [_rescale_wavelet_filterbank(wav, 1/np.sqrt(2)) + for wav in wavelets] ret = [] for i in range(start_level, start_level + level): coeffs = [('', data)] @@ -439,12 +630,15 @@ def swtn(data, wavelet, level, start_level=0, axes=None): # data for the next level is the approximation coeffs from this level data = coeffs['a' * num_axes] - + if trim_approx: + coeffs.pop('a' * num_axes) + if trim_approx: + ret.append(data) ret.reverse() return ret -def iswtn(coeffs, wavelet, axes=None): +def iswtn(coeffs, wavelet, axes=None, norm=False): """ Multilevel nD inverse discrete stationary wavelet transform. @@ -459,6 +653,10 @@ def iswtn(coeffs, wavelet, axes=None): Axes over which to compute the inverse SWT. Axes may not be repeated. The default is ``None``, which means transform all axes (``axes = range(data.ndim)``). + norm : bool, optional + Controls the normalization used by the inverse transform. This must + be set equal to the value that was used by ``pywt.swtn`` to preserve + the energy of a round-trip transform. Returns ------- @@ -479,11 +677,18 @@ def iswtn(coeffs, wavelet, axes=None): """ # key length matches the number of axes transformed - ndim_transform = max(len(key) for key in coeffs[0].keys()) + ndim_transform = max(len(key) for key in coeffs[-1].keys()) + + trim_approx = not isinstance(coeffs[0], dict) + if trim_approx: + cA = coeffs[0] + coeffs = coeffs[1:] + else: + cA = coeffs[0]['a'*ndim_transform] # copy to avoid modification of input data - dt = _check_dtype(coeffs[0]['a'*ndim_transform]) - output = np.array(coeffs[0]['a'*ndim_transform], dtype=dt, copy=True) + dt = _check_dtype(cA) + output = np.array(cA, dtype=dt, copy=True) ndim = output.ndim if axes is None: @@ -498,6 +703,9 @@ def iswtn(coeffs, wavelet, axes=None): # num_levels, equivalent to the decomposition level, n num_levels = len(coeffs) wavelets = _wavelets_per_axis(wavelet, axes) + if norm: + wavelets = [_rescale_wavelet_filterbank(wav, np.sqrt(2)) + for wav in wavelets] # initialize various slice objects used in the loops below # these will remain slice(None) only on axes that aren't transformed @@ -509,7 +717,8 @@ def iswtn(coeffs, wavelet, axes=None): for j in range(num_levels): step_size = int(pow(2, num_levels-j-1)) last_index = step_size - a = coeffs[j].pop('a'*ndim_transform) # will restore later + if not trim_approx: + a = coeffs[j].pop('a'*ndim_transform) # will restore later details = coeffs[j] # make sure dtype matches the coarsest level approximation coefficients common_dtype = np.result_type(*( @@ -560,5 +769,6 @@ def iswtn(coeffs, wavelet, axes=None): output[tuple(indices)] += x ntransforms += 1 output[tuple(indices)] /= ntransforms # normalize - coeffs[j]['a'*ndim_transform] = a # restore approx coeffs to dict + if not trim_approx: + coeffs[j]['a'*ndim_transform] = a # restore approx coeffs to dict return output diff --git a/pywt/_utils.py b/pywt/_utils.py index 48f814e21..291e85070 100644 --- a/pywt/_utils.py +++ b/pywt/_utils.py @@ -2,6 +2,7 @@ # # See COPYING for license details. import inspect +import numpy as np import sys from collections.abc import Iterable @@ -17,7 +18,7 @@ def _as_wavelet(wavelet): - """Convert wavelet name to a Wavelet object""" + """Convert wavelet name to a Wavelet object.""" if not isinstance(wavelet, (ContinuousWavelet, Wavelet)): wavelet = DiscreteContinuousWavelet(wavelet) if isinstance(wavelet, ContinuousWavelet): diff --git a/pywt/conftest.py b/pywt/conftest.py new file mode 100644 index 000000000..da8ae6455 --- /dev/null +++ b/pywt/conftest.py @@ -0,0 +1,6 @@ +import pytest + + +def pytest_configure(config): + config.addinivalue_line("markers", + "slow: Tests that are slow.") diff --git a/pywt/tests/data/generate_matlab_data_cwt.py b/pywt/tests/data/generate_matlab_data_cwt.py index d1f771e2b..05b6e42a0 100644 --- a/pywt/tests/data/generate_matlab_data_cwt.py +++ b/pywt/tests/data/generate_matlab_data_cwt.py @@ -37,7 +37,7 @@ try: all_matlab_results = {} for wavelet in wavelets: - w = pywt.Wavelet(wavelet) + w = pywt.ContinuousWavelet(wavelet) if np.any((wavelet == np.array(['shan', 'cmor'])),axis=0): mlab.set_variable('wavelet', wavelet+str(w.bandwidth_frequency)+'-'+str(w.center_frequency)) elif wavelet == 'fbsp': diff --git a/pywt/tests/test_cwt_wavelets.py b/pywt/tests/test_cwt_wavelets.py index 4372efc6a..9dcb65162 100644 --- a/pywt/tests/test_cwt_wavelets.py +++ b/pywt/tests/test_cwt_wavelets.py @@ -1,8 +1,10 @@ #!/usr/bin/env python from __future__ import division, print_function, absolute_import +from itertools import product from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal, - assert_raises) + assert_raises, assert_equal) +import pytest import numpy as np import pywt @@ -344,21 +346,65 @@ def test_cwt_parameters_in_names(): assert_raises(ValueError, func, 'fbsp1-1-1-1') -def test_cwt_complex(): - for dtype in [np.float32, np.float64]: - time, sst = pywt.data.nino() - sst = np.asarray(sst, dtype=dtype) - dt = time[1] - time[0] - wavelet = 'cmor1.5-1.0' - scales = np.arange(1, 32) +@pytest.mark.parametrize('dtype, tol, method', + [(np.float32, 1e-5, 'conv'), + (np.float32, 1e-5, 'fft'), + (np.float64, 1e-13, 'conv'), + (np.float64, 1e-13, 'fft')]) +def test_cwt_complex(dtype, tol, method): + time, sst = pywt.data.nino() + sst = np.asarray(sst, dtype=dtype) + dt = time[1] - time[0] + wavelet = 'cmor1.5-1.0' + scales = np.arange(1, 32) - # real-valued tranfsorm - [cfs, f] = pywt.cwt(sst, scales, wavelet, dt) + # real-valued tranfsorm as a reference + [cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method) - # complex-valued tranfsorm equals sum of the transforms of the real and - # imaginary components - [cfs_complex, f] = pywt.cwt(sst + 1j*sst, scales, wavelet, dt) - assert_almost_equal(cfs + 1j*cfs, cfs_complex) + # verify same precision + assert_equal(cfs.real.dtype, sst.dtype) + + # complex-valued transform equals sum of the transforms of the real + # and imaginary components + sst_complex = sst + 1j*sst + [cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt, + method=method) + assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol) + # verify dtype is preserved + assert_equal(cfs_complex.dtype, sst_complex.dtype) + + +@pytest.mark.parametrize('axis, method', product([0, 1], ['conv', 'fft'])) +def test_cwt_batch(axis, method): + dtype = np.float64 + time, sst = pywt.data.nino() + n_batch = 8 + batch_axis = 1 - axis + sst1 = np.asarray(sst, dtype=dtype) + sst = np.stack((sst1, ) * n_batch, axis=batch_axis) + dt = time[1] - time[0] + wavelet = 'cmor1.5-1.0' + scales = np.arange(1, 32) + + # non-batch transform as reference + [cfs1, f] = pywt.cwt(sst1, scales, wavelet, dt, method=method, axis=axis) + + shape_in = sst.shape + [cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method, axis=axis) + + # shape of input is not modified + assert_equal(shape_in, sst.shape) + + # verify same precision + assert_equal(cfs.real.dtype, sst.dtype) + + # verify expected shape + assert_equal(cfs.shape[0], len(scales)) + assert_equal(cfs.shape[1 + batch_axis], n_batch) + assert_equal(cfs.shape[1 + axis], sst.shape[axis]) + + # batch result on stacked input is the same as stacked 1d result + assert_equal(cfs, np.stack((cfs1,) * n_batch, axis=batch_axis + 1)) def test_cwt_small_scales(): @@ -371,3 +417,18 @@ def test_cwt_small_scales(): # extremely short scale factors raise a ValueError assert_raises(ValueError, pywt.cwt, data, scales=0.01, wavelet='mexh') + + +def test_cwt_method_fft(): + rstate = np.random.RandomState(1) + data = rstate.randn(50) + data[15] = 1. + scales = np.arange(1, 64) + wavelet = 'cmor1.5-1.0' + + # build a reference cwt with the legacy np.conv() method + cfs_conv, _ = pywt.cwt(data, scales, wavelet, method='conv') + + # compare with the fft based convolution + cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft') + assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13) diff --git a/pywt/tests/test_dwt_idwt.py b/pywt/tests/test_dwt_idwt.py index 1fd17e042..7e9e206bf 100644 --- a/pywt/tests/test_dwt_idwt.py +++ b/pywt/tests/test_dwt_idwt.py @@ -2,8 +2,8 @@ from __future__ import division, print_function, absolute_import import numpy as np -from numpy.testing import assert_allclose, assert_, assert_raises - +from numpy.testing import (assert_allclose, assert_, assert_raises, + assert_array_equal) import pywt # Check that float32, float64, complex64, complex128 are preserved. @@ -223,3 +223,77 @@ def test_error_on_continuous_wavelet(): cA, cD = pywt.dwt(data, 'db1') assert_raises(ValueError, pywt.idwt, cA, cD, cwave) + + +def test_dwt_zero_size_axes(): + # raise on empty input array + assert_raises(ValueError, pywt.dwt, [], 'db2') + + # >1D case uses a different code path so check there as well + x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis + assert_raises(ValueError, pywt.dwt, x, 'db2', axis=0) + + +def test_pad_1d(): + x = [1, 2, 3] + assert_array_equal(pywt.pad(x, (4, 6), 'periodization'), + [1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3, 1, 2]) + assert_array_equal(pywt.pad(x, (4, 6), 'periodic'), + [3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]) + assert_array_equal(pywt.pad(x, (4, 6), 'constant'), + [1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3]) + assert_array_equal(pywt.pad(x, (4, 6), 'zero'), + [0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0]) + assert_array_equal(pywt.pad(x, (4, 6), 'smooth'), + [-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + assert_array_equal(pywt.pad(x, (4, 6), 'symmetric'), + [3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1, 2, 3]) + assert_array_equal(pywt.pad(x, (4, 6), 'antisymmetric'), + [3, -3, -2, -1, 1, 2, 3, -3, -2, -1, 1, 2, 3]) + assert_array_equal(pywt.pad(x, (4, 6), 'reflect'), + [1, 2, 3, 2, 1, 2, 3, 2, 1, 2, 3, 2, 1]) + assert_array_equal(pywt.pad(x, (4, 6), 'antireflect'), + [-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + + # equivalence of various pad_width formats + assert_array_equal(pywt.pad(x, 4, 'periodic'), + pywt.pad(x, (4, 4), 'periodic')) + + assert_array_equal(pywt.pad(x, (4, ), 'periodic'), + pywt.pad(x, (4, 4), 'periodic')) + + assert_array_equal(pywt.pad(x, [(4, 4)], 'periodic'), + pywt.pad(x, (4, 4), 'periodic')) + + +def test_pad_errors(): + # negative pad width + x = [1, 2, 3] + assert_raises(ValueError, pywt.pad, x, -2, 'periodic') + + # wrong length pad width + assert_raises(ValueError, pywt.pad, x, (1, 1, 1), 'periodic') + + # invalid mode name + assert_raises(ValueError, pywt.pad, x, 2, 'bad_mode') + + +def test_pad_nd(): + for ndim in [2, 3]: + x = np.arange(4**ndim).reshape((4, ) * ndim) + if ndim == 2: + pad_widths = [(2, 1), (2, 3)] + else: + pad_widths = [(2, 1), ] * ndim + for mode in pywt.Modes.modes: + xp = pywt.pad(x, pad_widths, mode) + + # expected result is the same as applying along axes separably + xp_expected = x.copy() + for ax in range(ndim): + xp_expected = np.apply_along_axis(pywt.pad, + ax, + xp_expected, + pad_widths=[pad_widths[ax]], + mode=mode) + assert_array_equal(xp, xp_expected) diff --git a/pywt/tests/test_matlab_compatibility_cwt.py b/pywt/tests/test_matlab_compatibility_cwt.py index 9dc9e35bb..c2121fe44 100644 --- a/pywt/tests/test_matlab_compatibility_cwt.py +++ b/pywt/tests/test_matlab_compatibility_cwt.py @@ -147,6 +147,11 @@ def _check_accuracy(data, w, scales, coefs, wavelet, epsilon): # PyWavelets result coefs_pywt, freq = pywt.cwt(data, scales, w) + # coefs from Matlab are from R2012a which is missing the complex conjugate + # as shown in Eq. 2 of Torrence and Compo. We take the complex conjugate of + # the precomputed Matlab result to account for this. + coefs = np.conj(coefs) + # calculate error measures err = coefs_pywt - coefs rms = np.real(np.sqrt(np.mean(np.conj(err) * err))) diff --git a/pywt/tests/test_multilevel.py b/pywt/tests/test_multilevel.py index 4223a56a7..6233cb71e 100644 --- a/pywt/tests/test_multilevel.py +++ b/pywt/tests/test_multilevel.py @@ -80,8 +80,9 @@ def test_waverec_invalid_inputs(): coeffs = pywt.wavedec(x, 'db1') arr, coeff_slices = pywt.coeffs_to_array(coeffs) coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices) - message = "Wrong coefficient format, if using 'array_to_coeffs' please specify the 'output_format' parameter" - assert_raises_regex(AttributeError, message, pywt.waverec, coeffs_from_arr, 'haar') + message = "Unexpected detail coefficient type" + assert_raises_regex(ValueError, message, pywt.waverec, coeffs_from_arr, + 'haar') def test_waverec_accuracies(): @@ -208,6 +209,13 @@ def test_waverec2_invalid_inputs(): # input list cannot be empty assert_raises(ValueError, pywt.waverec2, [], 'haar') + # coefficients from a difference decomposition used as input + for dec_func in [pywt.wavedec, pywt.wavedecn]: + coeffs = dec_func(np.ones((8, 8)), 'haar') + message = "Unexpected detail coefficient type" + assert_raises_regex(ValueError, message, pywt.waverec2, coeffs, + 'haar') + def test_waverec2_coeff_shape_mismatch(): x = np.ones((8, 8)) @@ -285,6 +293,16 @@ def test_waverecn_invalid_coeffs(): assert_raises(ValueError, pywt.waverecn, [], 'haar') +def test_waverecn_invalid_inputs(): + + # coefficients from a difference decomposition used as input + for dec_func in [pywt.wavedec, pywt.wavedec2]: + coeffs = dec_func(np.ones((8, 8)), 'haar') + message = "Unexpected detail coefficient type" + assert_raises_regex(ValueError, message, pywt.waverecn, coeffs, + 'haar') + + def test_waverecn_lists(): # support coefficient arrays specified as lists instead of arrays coeffs = [[[1.0]], {'ad': [[0.0]], 'da': [[0.0]], 'dd': [[0.0]]}] diff --git a/pywt/tests/test_swt.py b/pywt/tests/test_swt.py index a0e4d2956..bb70d3949 100644 --- a/pywt/tests/test_swt.py +++ b/pywt/tests/test_swt.py @@ -153,9 +153,15 @@ def test_swt_iswt_integration(): current_wavelet.rec_len)))) input_length = 2**(input_length_power + max_level - 1) X = np.arange(input_length) - coeffs = pywt.swt(X, current_wavelet, max_level) - Y = pywt.iswt(coeffs, current_wavelet) - assert_allclose(Y, X, rtol=1e-5, atol=1e-7) + for norm in [True, False]: + if norm and not current_wavelet.orthogonal: + # non-orthogonal wavelets to avoid warnings when norm=True + continue + for trim_approx in [True, False]: + coeffs = pywt.swt(X, current_wavelet, max_level, + trim_approx=trim_approx, norm=norm) + Y = pywt.iswt(coeffs, current_wavelet, norm=norm) + assert_allclose(Y, X, rtol=1e-5, atol=1e-7) def test_swt_dtypes(): @@ -235,9 +241,15 @@ def test_swt2_iswt2_integration(wavelets=None): input_length = 2**(input_length_power + max_level - 1) X = np.arange(input_length**2).reshape(input_length, input_length) - coeffs = pywt.swt2(X, current_wavelet, max_level) - Y = pywt.iswt2(coeffs, current_wavelet) - assert_allclose(Y, X, rtol=1e-5, atol=1e-5) + for norm in [True, False]: + if norm and not current_wavelet.orthogonal: + # non-orthogonal wavelets to avoid warnings when norm=True + continue + for trim_approx in [True, False]: + coeffs = pywt.swt2(X, current_wavelet, max_level, + trim_approx=trim_approx, norm=norm) + Y = pywt.iswt2(coeffs, current_wavelet, norm=norm) + assert_allclose(Y, X, rtol=1e-5, atol=1e-5) def test_swt2_iswt2_quick(): @@ -355,10 +367,16 @@ def test_swtn_iswtn_integration(wavelets=None): N = 2**(input_length_power + max_level - 1) X = np.arange(N**ndim).reshape((N, )*ndim) - coeffs = pywt.swtn(X, wav, max_level, axes=axes) - coeffs_copy = deepcopy(coeffs) - Y = pywt.iswtn(coeffs, wav, axes=axes) - assert_allclose(Y, X, rtol=1e-5, atol=1e-5) + for norm in [True, False]: + if norm and not wav.orthogonal: + # non-orthogonal wavelets to avoid warnings + continue + for trim_approx in [True, False]: + coeffs = pywt.swtn(X, wav, max_level, axes=axes, + trim_approx=trim_approx, norm=norm) + coeffs_copy = deepcopy(coeffs) + Y = pywt.iswtn(coeffs, wav, axes=axes, norm=norm) + assert_allclose(Y, X, rtol=1e-5, atol=1e-5) # verify the inverse transform didn't modify any coeffs for c, c2 in zip(coeffs, coeffs_copy): @@ -525,3 +543,91 @@ def test_iswtn_mixed_dtypes(): y = pywt.iswtn(coeffs, wav) assert_equal(output_dtype, y.dtype) assert_allclose(y, x, rtol=1e-3, atol=1e-3) + + +def test_swt_zero_size_axes(): + # raise on empty input array + assert_raises(ValueError, pywt.swt, [], 'db2') + + # >1D case uses a different code path so check there as well + x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis + assert_raises(ValueError, pywt.swtn, x, 'db2', level=1, axes=(0,)) + + +def test_swt_variance_and_energy_preservation(): + """Verify that the 1D SWT partitions variance among the coefficients.""" + # When norm is True and the wavelet is orthogonal, the sum of the + # variances of the coefficients should equal the variance of the signal. + wav = 'db2' + rstate = np.random.RandomState(5) + x = rstate.randn(256) + coeffs = pywt.swt(x, wav, trim_approx=True, norm=True) + variances = [np.var(c) for c in coeffs] + assert_allclose(np.sum(variances), np.var(x)) + + # also verify L2-norm energy preservation property + assert_allclose(np.linalg.norm(x), + np.linalg.norm(np.concatenate(coeffs))) + + # non-orthogonal wavelet with norm=True raises a warning + assert_warns(UserWarning, pywt.swt, x, 'bior2.2', norm=True) + + +def test_swt2_variance_and_energy_preservation(): + """Verify that the 2D SWT partitions variance among the coefficients.""" + # When norm is True and the wavelet is orthogonal, the sum of the + # variances of the coefficients should equal the variance of the signal. + wav = 'db2' + rstate = np.random.RandomState(5) + x = rstate.randn(64, 64) + coeffs = pywt.swt2(x, wav, level=4, trim_approx=True, norm=True) + coeff_list = [coeffs[0].ravel()] + for d in coeffs[1:]: + for v in d: + coeff_list.append(v.ravel()) + variances = [np.var(v) for v in coeff_list] + assert_allclose(np.sum(variances), np.var(x)) + + # also verify L2-norm energy preservation property + assert_allclose(np.linalg.norm(x), + np.linalg.norm(np.concatenate(coeff_list))) + + # non-orthogonal wavelet with norm=True raises a warning + assert_warns(UserWarning, pywt.swt2, x, 'bior2.2', level=4, norm=True) + + +def test_swtn_variance_and_energy_preservation(): + """Verify that the nD SWT partitions variance among the coefficients.""" + # When norm is True and the wavelet is orthogonal, the sum of the + # variances of the coefficients should equal the variance of the signal. + wav = 'db2' + rstate = np.random.RandomState(5) + x = rstate.randn(64, 64) + coeffs = pywt.swtn(x, wav, level=4, trim_approx=True, norm=True) + coeff_list = [coeffs[0].ravel()] + for d in coeffs[1:]: + for k, v in d.items(): + coeff_list.append(v.ravel()) + variances = [np.var(v) for v in coeff_list] + assert_allclose(np.sum(variances), np.var(x)) + + # also verify L2-norm energy preservation property + assert_allclose(np.linalg.norm(x), + np.linalg.norm(np.concatenate(coeff_list))) + + # non-orthogonal wavelet with norm=True raises a warning + assert_warns(UserWarning, pywt.swtn, x, 'bior2.2', level=4, norm=True) + + +def test_swt_ravel_and_unravel(): + # When trim_approx=True, all swt functions can user pywt.ravel_coeffs + for ndim, _swt, _iswt, ravel_type in [ + (1, pywt.swt, pywt.iswt, 'swt'), + (2, pywt.swt2, pywt.iswt2, 'swt2'), + (3, pywt.swtn, pywt.iswtn, 'swtn')]: + x = np.ones((16, ) * ndim) + c = _swt(x, 'sym2', level=3, trim_approx=True) + arr, slices, shapes = pywt.ravel_coeffs(c) + c = pywt.unravel_coeffs(arr, slices, shapes, output_format=ravel_type) + r = _iswt(c, 'sym2') + assert_allclose(x, r) diff --git a/util/authors.py b/util/authors.py index 9b3fe31b5..7bfbc8241 100755 --- a/util/authors.py +++ b/util/authors.py @@ -30,6 +30,7 @@ u('Helder'): u('Helder Oliveira'), u('Kai'): u('Kai Wohlfahrt'), u('asnt'): u('Alexandre Saint'), + u('pavleb'): u('Pavle Boškoski'), }