@@ -27,6 +27,7 @@ def quicksort_inplace(array, beg, end): # 注意这里我们都用左闭右
2727
2828
2929def partition (array , beg , end ):
30+ """对给定数组执行 partition 操作,返回新的 pivot 位置"""
3031 pivot_index = beg
3132 pivot = array [pivot_index ]
3233 left = pivot_index + 1
@@ -55,7 +56,7 @@ def test_partition():
5556 l = [1 , 2 , 3 , 4 ]
5657 assert partition (l , 0 , len (l )) == 0
5758 l = [4 , 3 , 2 , 1 ]
58- assert partition (l , 0 , len (l ))
59+ assert partition (l , 0 , len (l )) == 3
5960
6061
6162def test_quicksort_inplace ():
@@ -65,3 +66,31 @@ def test_quicksort_inplace():
6566 sorted_seq = sorted (seq )
6667 quicksort_inplace (seq , 0 , len (seq ))
6768 assert seq == sorted_seq
69+
70+
71+ def nth_element (array , beg , end , nth ):
72+ """查找一个数组第 n 大元素"""
73+ if beg < end :
74+ pivot_idx = partition (array , beg , end )
75+ if pivot_idx == nth - 1 : # 数组小标从 0 开始
76+ return array [pivot_idx ]
77+ elif pivot_idx > nth - 1 :
78+ return nth_element (array , beg , pivot_idx , nth )
79+ else :
80+ return nth_element (array , pivot_idx + 1 , end , nth )
81+
82+
83+ def test_nth_element ():
84+ l1 = [3 , 5 , 4 , 2 , 1 ]
85+ assert nth_element (l1 , 0 , len (l1 ), 3 ) == 3
86+ assert nth_element (l1 , 0 , len (l1 ), 2 ) == 2
87+
88+ l = [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]
89+ for i in l :
90+ assert nth_element (l , 0 , len (l ), i ) == i
91+ for i in reversed (l ):
92+ assert nth_element (l , 0 , len (l ), i ) == i
93+
94+
95+ if __name__ == '__main__' :
96+ test_nth_element ()
0 commit comments