Skip to content

Commit 48c77a6

Browse files
committed
Merge pull request numpy#3908 from juliantaylor/median-percentile
add extended axis and keepdims support to percentile and median
2 parents 50b60fe + 7d53c81 commit 48c77a6

File tree

2 files changed

+235
-13
lines changed

2 files changed

+235
-13
lines changed

numpy/lib/function_base.py

Lines changed: 116 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import warnings
1414
import sys
1515
import collections
16+
import operator
1617

1718
import numpy as np
1819
import numpy.core.numeric as _nx
@@ -2694,7 +2695,67 @@ def msort(a):
26942695
return b
26952696

26962697

2697-
def median(a, axis=None, out=None, overwrite_input=False):
2698+
def _ureduce(a, func, **kwargs):
2699+
"""
2700+
Internal Function.
2701+
Call `func` with `a` as first argument swapping the axes to use extended
2702+
axis on functions that don't support it natively.
2703+
2704+
Returns result and a.shape with axis dims set to 1.
2705+
2706+
Parameters
2707+
----------
2708+
a : array_like
2709+
Input array or object that can be converted to an array.
2710+
func : callable
2711+
Reduction function Kapable of receiving an axis argument.
2712+
It is is called with `a` as first argument followed by `kwargs`.
2713+
kwargs : keyword arguments
2714+
additional keyword arguments to pass to `func`.
2715+
2716+
Returns
2717+
-------
2718+
result : tuple
2719+
Result of func(a, **kwargs) and a.shape with axis dims set to 1
2720+
which can be used to reshape the result to the same shape a ufunc with
2721+
keepdims=True would produce.
2722+
2723+
"""
2724+
a = np.asanyarray(a)
2725+
axis = kwargs.get('axis', None)
2726+
if axis is not None:
2727+
keepdim = list(a.shape)
2728+
nd = a.ndim
2729+
try:
2730+
axis = operator.index(axis)
2731+
if axis >= nd or axis < -nd:
2732+
raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
2733+
keepdim[axis] = 1
2734+
except TypeError:
2735+
sax = set()
2736+
for x in axis:
2737+
if x >= nd or x < -nd:
2738+
raise IndexError("axis %d out of bounds (%d)" % (x, nd))
2739+
if x in sax:
2740+
raise ValueError("duplicate value in axis")
2741+
sax.add(x % nd)
2742+
keepdim[x] = 1
2743+
keep = sax.symmetric_difference(frozenset(range(nd)))
2744+
nkeep = len(keep)
2745+
# swap axis that should not be reduced to front
2746+
for i, s in enumerate(sorted(keep)):
2747+
a = a.swapaxes(i, s)
2748+
# merge reduced axis
2749+
a = a.reshape(a.shape[:nkeep] + (-1,))
2750+
kwargs['axis'] = -1
2751+
else:
2752+
keepdim = [1] * a.ndim
2753+
2754+
r = func(a, **kwargs)
2755+
return r, keepdim
2756+
2757+
2758+
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
26982759
"""
26992760
Compute the median along the specified axis.
27002761
@@ -2704,9 +2765,10 @@ def median(a, axis=None, out=None, overwrite_input=False):
27042765
----------
27052766
a : array_like
27062767
Input array or object that can be converted to an array.
2707-
axis : int, optional
2768+
axis : int or sequence of int, optional
27082769
Axis along which the medians are computed. The default (axis=None)
27092770
is to compute the median along a flattened version of the array.
2771+
A sequence of axes is supported since version 1.9.0.
27102772
out : ndarray, optional
27112773
Alternative output array in which to place the result. It must have
27122774
the same shape and buffer length as the expected output, but the
@@ -2719,6 +2781,13 @@ def median(a, axis=None, out=None, overwrite_input=False):
27192781
will probably be fully or partially sorted. Default is False. Note
27202782
that, if `overwrite_input` is True and the input is not already an
27212783
ndarray, an error will be raised.
2784+
keepdims : bool, optional
2785+
If this is set to True, the axes which are reduced are left
2786+
in the result as dimensions with size one. With this option,
2787+
the result will broadcast correctly against the original `arr`.
2788+
2789+
.. versionadded:: 1.9.0
2790+
27222791
27232792
Returns
27242793
-------
@@ -2768,6 +2837,16 @@ def median(a, axis=None, out=None, overwrite_input=False):
27682837
>>> assert not np.all(a==b)
27692838
27702839
"""
2840+
r, k = _ureduce(a, func=_median, axis=axis, out=out,
2841+
overwrite_input=overwrite_input)
2842+
if keepdims:
2843+
return r.reshape(k)
2844+
else:
2845+
return r
2846+
2847+
def _median(a, axis=None, out=None, overwrite_input=False):
2848+
# can't be reasonably be implemented in terms of percentile as we have to
2849+
# call mean to not break astropy
27712850
a = np.asanyarray(a)
27722851
if axis is not None and axis >= a.ndim:
27732852
raise IndexError(
@@ -2817,7 +2896,7 @@ def median(a, axis=None, out=None, overwrite_input=False):
28172896

28182897

28192898
def percentile(a, q, axis=None, out=None,
2820-
overwrite_input=False, interpolation='linear'):
2899+
overwrite_input=False, interpolation='linear', keepdims=False):
28212900
"""
28222901
Compute the qth percentile of the data along the specified axis.
28232902
@@ -2829,9 +2908,10 @@ def percentile(a, q, axis=None, out=None,
28292908
Input array or object that can be converted to an array.
28302909
q : float in range of [0,100] (or sequence of floats)
28312910
Percentile to compute which must be between 0 and 100 inclusive.
2832-
axis : int, optional
2911+
axis : int or sequence of int, optional
28332912
Axis along which the percentiles are computed. The default (None)
28342913
is to compute the percentiles along a flattened version of the array.
2914+
A sequence of axes is supported since version 1.9.0.
28352915
out : ndarray, optional
28362916
Alternative output array in which to place the result. It must
28372917
have the same shape and buffer length as the expected output,
@@ -2857,6 +2937,12 @@ def percentile(a, q, axis=None, out=None,
28572937
* midpoint: (`i` + `j`) / 2.
28582938
28592939
.. versionadded:: 1.9.0
2940+
keepdims : bool, optional
2941+
If this is set to True, the axes which are reduced are left
2942+
in the result as dimensions with size one. With this option,
2943+
the result will broadcast correctly against the original `arr`.
2944+
2945+
.. versionadded:: 1.9.0
28602946
28612947
Returns
28622948
-------
@@ -2913,19 +2999,40 @@ def percentile(a, q, axis=None, out=None,
29132999
array([ 3.5])
29143000
29153001
"""
3002+
q = asarray(q, dtype=np.float64)
3003+
r, k = _ureduce(a, func=_percentile, q=q, axis=axis, out=out,
3004+
overwrite_input=overwrite_input,
3005+
interpolation=interpolation)
3006+
if keepdims:
3007+
if q.ndim == 0:
3008+
return r.reshape(k)
3009+
else:
3010+
return r.reshape([len(q)] + k)
3011+
else:
3012+
return r
3013+
3014+
3015+
def _percentile(a, q, axis=None, out=None,
3016+
overwrite_input=False, interpolation='linear', keepdims=False):
29163017
a = asarray(a)
2917-
q = asarray(q)
29183018
if q.ndim == 0:
29193019
# Do not allow 0-d arrays because following code fails for scalar
29203020
zerod = True
29213021
q = q[None]
29223022
else:
29233023
zerod = False
29243024

2925-
q = q / 100.0
2926-
if (q < 0).any() or (q > 1).any():
2927-
raise ValueError(
2928-
"Percentiles must be in the range [0,100]")
3025+
# avoid expensive reductions, relevant for arrays with < O(1000) elements
3026+
if q.size < 10:
3027+
for i in range(q.size):
3028+
if q[i] < 0. or q[i] > 100.:
3029+
raise ValueError("Percentiles must be in the range [0,100]")
3030+
q[i] /= 100.
3031+
else:
3032+
# faster than any()
3033+
if np.count_nonzero(q < 0.) or np.count_nonzero(q > 100.):
3034+
raise ValueError("Percentiles must be in the range [0,100]")
3035+
q /= 100.
29293036

29303037
# prepare a for partioning
29313038
if overwrite_input:

numpy/lib/tests/test_function_base.py

Lines changed: 119 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,6 +1688,8 @@ def test_exception(self):
16881688
interpolation='foobar')
16891689
assert_raises(ValueError, np.percentile, [1], 101)
16901690
assert_raises(ValueError, np.percentile, [1], -1)
1691+
assert_raises(ValueError, np.percentile, [1], list(range(50)) + [101])
1692+
assert_raises(ValueError, np.percentile, [1], list(range(50)) + [-0.1])
16911693

16921694
def test_percentile_list(self):
16931695
assert_equal(np.percentile([1, 2, 3], 0), 1)
@@ -1779,26 +1781,85 @@ def test_percentile_overwrite(self):
17791781
b = np.percentile([2, 3, 4, 1], [50], overwrite_input=True)
17801782
assert_equal(b, np.array([2.5]))
17811783

1784+
def test_extended_axis(self):
1785+
o = np.random.normal(size=(71, 23))
1786+
x = np.dstack([o] * 10)
1787+
assert_equal(np.percentile(x, 30, axis=(0, 1)), np.percentile(o, 30))
1788+
x = np.rollaxis(x, -1, 0)
1789+
assert_equal(np.percentile(x, 30, axis=(-2, -1)), np.percentile(o, 30))
1790+
x = x.swapaxes(0, 1).copy()
1791+
assert_equal(np.percentile(x, 30, axis=(0, -1)), np.percentile(o, 30))
1792+
x = x.swapaxes(0, 1).copy()
1793+
1794+
assert_equal(np.percentile(x, [25, 60], axis=(0, 1, 2)),
1795+
np.percentile(x, [25, 60], axis=None))
1796+
assert_equal(np.percentile(x, [25, 60], axis=(0,)),
1797+
np.percentile(x, [25, 60], axis=0))
1798+
1799+
d = np.arange(3 * 5 * 7 * 11).reshape(3, 5, 7, 11)
1800+
np.random.shuffle(d)
1801+
assert_equal(np.percentile(d, 25, axis=(0, 1, 2))[0],
1802+
np.percentile(d[:, :, :, 0].flatten(), 25))
1803+
assert_equal(np.percentile(d, [10, 90], axis=(0, 1, 3))[:, 1],
1804+
np.percentile(d[:, :, 1, :].flatten(), [10, 90]))
1805+
assert_equal(np.percentile(d, 25, axis=(3, 1, -4))[2],
1806+
np.percentile(d[:, :, 2, :].flatten(), 25))
1807+
assert_equal(np.percentile(d, 25, axis=(3, 1, 2))[2],
1808+
np.percentile(d[2, :, :, :].flatten(), 25))
1809+
assert_equal(np.percentile(d, 25, axis=(3, 2))[2, 1],
1810+
np.percentile(d[2, 1, :, :].flatten(), 25))
1811+
assert_equal(np.percentile(d, 25, axis=(1, -2))[2, 1],
1812+
np.percentile(d[2, :, :, 1].flatten(), 25))
1813+
assert_equal(np.percentile(d, 25, axis=(1, 3))[2, 2],
1814+
np.percentile(d[2, :, 2, :].flatten(), 25))
1815+
1816+
def test_extended_axis_invalid(self):
1817+
d = np.ones((3, 5, 7, 11))
1818+
assert_raises(IndexError, np.percentile, d, axis=-5, q=25)
1819+
assert_raises(IndexError, np.percentile, d, axis=(0, -5), q=25)
1820+
assert_raises(IndexError, np.percentile, d, axis=4, q=25)
1821+
assert_raises(IndexError, np.percentile, d, axis=(0, 4), q=25)
1822+
assert_raises(ValueError, np.percentile, d, axis=(1, 1), q=25)
1823+
1824+
def test_keepdims(self):
1825+
d = np.ones((3, 5, 7, 11))
1826+
assert_equal(np.percentile(d, 7, axis=None, keepdims=True).shape,
1827+
(1, 1, 1, 1))
1828+
assert_equal(np.percentile(d, 7, axis=(0, 1), keepdims=True).shape,
1829+
(1, 1, 7, 11))
1830+
assert_equal(np.percentile(d, 7, axis=(0, 3), keepdims=True).shape,
1831+
(1, 5, 7, 1))
1832+
assert_equal(np.percentile(d, 7, axis=(1,), keepdims=True).shape,
1833+
(3, 1, 7, 11))
1834+
assert_equal(np.percentile(d, 7, (0, 1, 2, 3), keepdims=True).shape,
1835+
(1, 1, 1, 1))
1836+
assert_equal(np.percentile(d, 7, axis=(0, 1, 3), keepdims=True).shape,
1837+
(1, 1, 7, 1))
1838+
1839+
assert_equal(np.percentile(d, [1, 7], axis=(0, 1, 3),
1840+
keepdims=True).shape, (2, 1, 1, 7, 1))
1841+
assert_equal(np.percentile(d, [1, 7], axis=(0, 3),
1842+
keepdims=True).shape, (2, 1, 5, 7, 1))
17821843

17831844

17841845
class TestMedian(TestCase):
17851846
def test_basic(self):
17861847
a0 = np.array(1)
17871848
a1 = np.arange(2)
17881849
a2 = np.arange(6).reshape(2, 3)
1789-
assert_allclose(np.median(a0), 1)
1850+
assert_equal(np.median(a0), 1)
17901851
assert_allclose(np.median(a1), 0.5)
17911852
assert_allclose(np.median(a2), 2.5)
17921853
assert_allclose(np.median(a2, axis=0), [1.5, 2.5, 3.5])
1793-
assert_allclose(np.median(a2, axis=1), [1, 4])
1854+
assert_equal(np.median(a2, axis=1), [1, 4])
17941855
assert_allclose(np.median(a2, axis=None), 2.5)
17951856

17961857
a = np.array([0.0444502, 0.0463301, 0.141249, 0.0606775])
17971858
assert_almost_equal((a[1] + a[3]) / 2., np.median(a))
17981859
a = np.array([0.0463301, 0.0444502, 0.141249])
1799-
assert_almost_equal(a[0], np.median(a))
1860+
assert_equal(a[0], np.median(a))
18001861
a = np.array([0.0444502, 0.141249, 0.0463301])
1801-
assert_almost_equal(a[-1], np.median(a))
1862+
assert_equal(a[-1], np.median(a))
18021863

18031864
def test_axis_keyword(self):
18041865
a3 = np.array([[2, 3],
@@ -1872,6 +1933,60 @@ def mean(self, axis=None, dtype=None, out=None):
18721933
a = MySubClass([1,2,3])
18731934
assert_equal(np.median(a), -7)
18741935

1936+
def test_extended_axis(self):
1937+
o = np.random.normal(size=(71, 23))
1938+
x = np.dstack([o] * 10)
1939+
assert_equal(np.median(x, axis=(0, 1)), np.median(o))
1940+
x = np.rollaxis(x, -1, 0)
1941+
assert_equal(np.median(x, axis=(-2, -1)), np.median(o))
1942+
x = x.swapaxes(0, 1).copy()
1943+
assert_equal(np.median(x, axis=(0, -1)), np.median(o))
1944+
1945+
assert_equal(np.median(x, axis=(0, 1, 2)), np.median(x, axis=None))
1946+
assert_equal(np.median(x, axis=(0, )), np.median(x, axis=0))
1947+
assert_equal(np.median(x, axis=(-1, )), np.median(x, axis=-1))
1948+
1949+
d = np.arange(3 * 5 * 7 * 11).reshape(3, 5, 7, 11)
1950+
np.random.shuffle(d)
1951+
assert_equal(np.median(d, axis=(0, 1, 2))[0],
1952+
np.median(d[:, :, :, 0].flatten()))
1953+
assert_equal(np.median(d, axis=(0, 1, 3))[1],
1954+
np.median(d[:, :, 1, :].flatten()))
1955+
assert_equal(np.median(d, axis=(3, 1, -4))[2],
1956+
np.median(d[:, :, 2, :].flatten()))
1957+
assert_equal(np.median(d, axis=(3, 1, 2))[2],
1958+
np.median(d[2, :, :, :].flatten()))
1959+
assert_equal(np.median(d, axis=(3, 2))[2, 1],
1960+
np.median(d[2, 1, :, :].flatten()))
1961+
assert_equal(np.median(d, axis=(1, -2))[2, 1],
1962+
np.median(d[2, :, :, 1].flatten()))
1963+
assert_equal(np.median(d, axis=(1, 3))[2, 2],
1964+
np.median(d[2, :, 2, :].flatten()))
1965+
1966+
def test_extended_axis_invalid(self):
1967+
d = np.ones((3, 5, 7, 11))
1968+
assert_raises(IndexError, np.median, d, axis=-5)
1969+
assert_raises(IndexError, np.median, d, axis=(0, -5))
1970+
assert_raises(IndexError, np.median, d, axis=4)
1971+
assert_raises(IndexError, np.median, d, axis=(0, 4))
1972+
assert_raises(ValueError, np.median, d, axis=(1, 1))
1973+
1974+
def test_keepdims(self):
1975+
d = np.ones((3, 5, 7, 11))
1976+
assert_equal(np.median(d, axis=None, keepdims=True).shape,
1977+
(1, 1, 1, 1))
1978+
assert_equal(np.median(d, axis=(0, 1), keepdims=True).shape,
1979+
(1, 1, 7, 11))
1980+
assert_equal(np.median(d, axis=(0, 3), keepdims=True).shape,
1981+
(1, 5, 7, 1))
1982+
assert_equal(np.median(d, axis=(1,), keepdims=True).shape,
1983+
(3, 1, 7, 11))
1984+
assert_equal(np.median(d, axis=(0, 1, 2, 3), keepdims=True).shape,
1985+
(1, 1, 1, 1))
1986+
assert_equal(np.median(d, axis=(0, 1, 3), keepdims=True).shape,
1987+
(1, 1, 7, 1))
1988+
1989+
18751990

18761991
class TestAdd_newdoc_ufunc(TestCase):
18771992

0 commit comments

Comments
 (0)