Skip to content

Commit 9ce883e

Browse files
authored
Merge pull request #506 from grlee77/guard_against_size0_axes
Guard against trying to transform along size 0 axes
2 parents 13a4725 + 3a9969e commit 9ce883e

File tree

4 files changed

+30
-3
lines changed

4 files changed

+30
-3
lines changed

pywt/_extensions/_dwt.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ cpdef dwt_axis(np.ndarray data, Wavelet wavelet, MODE mode, unsigned int axis=0)
9090

9191
input_shape = <size_t [:data.ndim]> <size_t *> data.shape
9292
output_shape = input_shape.copy()
93-
output_shape[axis] = common.dwt_buffer_length(data.shape[axis], wavelet.dec_len, mode)
93+
output_shape[axis] = dwt_coeff_len(data.shape[axis], wavelet.dec_len, mode)
9494

9595
cA = np.empty(output_shape, data.dtype)
9696
cD = np.empty(output_shape, data.dtype)

pywt/_extensions/_swt.pyx

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def swt_max_level(size_t input_len):
3636
multiple of ``2**n``. ``numpy.pad`` can be used to pad a signal up to an
3737
appropriate length as needed.
3838
"""
39+
if input_len < 1:
40+
raise ValueError("Cannot apply swt to a size 0 signal.")
3941
max_level = common.swt_max_level(input_len)
4042
if max_level == 0:
4143
warnings.warn(
@@ -54,6 +56,8 @@ def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level):
5456

5557
if data.size % 2:
5658
raise ValueError("Length of data must be even.")
59+
if data.size < 1:
60+
raise ValueError("Data must have non-zero size")
5761

5862
if level < 1:
5963
raise ValueError("Level value must be greater than zero.")
@@ -67,6 +71,7 @@ def swt(cdata_t[::1] data, Wavelet wavelet, size_t level, size_t start_level):
6771
common.swt_max_level(data.size) - start_level))
6872
raise ValueError(msg)
6973

74+
7075
output_len = common.swt_buffer_length(data.size)
7176
if output_len < 1:
7277
raise RuntimeError("Invalid output length.")
@@ -153,8 +158,10 @@ cpdef swt_axis(np.ndarray data, Wavelet wavelet, size_t level,
153158
cdef int retval = -5
154159
cdef size_t i
155160

156-
if data.size % 2:
157-
raise ValueError("Length of data must be even.")
161+
if data.shape[axis] % 2:
162+
raise ValueError("Length of data must be even along the transform axis.")
163+
if data.shape[axis] < 1:
164+
raise ValueError("Data must have non-zero size along the transform axis.")
158165

159166
if level < 1:
160167
raise ValueError("Level value must be greater than zero.")

pywt/tests/test_dwt_idwt.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,13 @@ def test_error_on_continuous_wavelet():
223223

224224
cA, cD = pywt.dwt(data, 'db1')
225225
assert_raises(ValueError, pywt.idwt, cA, cD, cwave)
226+
227+
228+
def test_dwt_zero_size_axes():
229+
# raise on empty input array
230+
assert_raises(ValueError, pywt.dwt, [], 'db2')
231+
232+
# >1D case uses a different code path so check there as well
233+
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
234+
assert_raises(ValueError, pywt.dwt, x, 'db2', axis=0)
235+

pywt/tests/test_swt.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,3 +525,13 @@ def test_iswtn_mixed_dtypes():
525525
y = pywt.iswtn(coeffs, wav)
526526
assert_equal(output_dtype, y.dtype)
527527
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
528+
529+
530+
def test_swt_zero_size_axes():
531+
# raise on empty input array
532+
assert_raises(ValueError, pywt.swt, [], 'db2')
533+
534+
# >1D case uses a different code path so check there as well
535+
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
536+
assert_raises(ValueError, pywt.swtn, x, 'db2', level=1, axes=(0,))
537+

0 commit comments

Comments
 (0)