diff --git a/CHANGELOG.md b/CHANGELOG.md index eaa3c9e478b..f223ed698d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ * `Address`: optimize `functionCall` by calling `functionCallWithValue` directly. ([#3468](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3468)) * `Address`: optimize `functionCall` functions by checking contract size only if there is no returned data. ([#3469](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3469)) * `GovernorCompatibilityBravo`: remove unused `using` statements ([#3506](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3506)) + * `Arrays`: add `sort` function ## 4.7.0 (2022-06-29) diff --git a/contracts/mocks/ArraysImpl.sol b/contracts/mocks/ArraysImpl.sol index f720524b808..4edda4656e7 100644 --- a/contracts/mocks/ArraysImpl.sol +++ b/contracts/mocks/ArraysImpl.sol @@ -16,4 +16,9 @@ contract ArraysImpl { function findUpperBound(uint256 element) external view returns (uint256) { return _array.findUpperBound(element); } + + function sort(uint256[] memory array) external pure returns (uint256[] memory sorted) { + array.sort(); + return (array); + } } diff --git a/contracts/utils/Arrays.sol b/contracts/utils/Arrays.sol index 0783614cd3e..565771cbb7d 100644 --- a/contracts/utils/Arrays.sol +++ b/contracts/utils/Arrays.sol @@ -45,4 +45,84 @@ library Arrays { return low; } } + + /** + * @dev Sorts `array` of integers in an ascending order. + * + * Sorting is done in-place using the heap sort algorithm. + * Examples of gas cost with optimizer enabled for 200 runs: + * - 10 random items: ~8K gas + * - 100 random items: ~156K gas + */ + function sort(uint256[] memory array) internal pure { + unchecked { + uint256 length = array.length; + if (length < 2) return; + // Heapify the array + for (uint256 i = length / 2; i-- > 0; ) { + _siftDown(array, length, i, _arrayLoad(array, i)); + } + // Drain all elements from highest to lowest and put them at the end of the array + while (--length != 0) { + uint256 val = _arrayLoad(array, 0); + _siftDown(array, length, 0, _arrayLoad(array, length)); + _arrayStore(array, length, val); + } + } + } + + /** + * @dev Insert a `inserted` value into an empty space in a binary heap. + * Makes sure that the space and all items below it still form a valid heap. + * Index `empty` is considered empty and will be overwritten. + */ + function _siftDown( + uint256[] memory array, + uint256 length, + uint256 emptyIdx, + uint256 inserted + ) private pure { + unchecked { + while (true) { + // The first child of empty, one level deeper in the heap + uint256 childIdx = (emptyIdx << 1) + 1; + // Empty has no children + if (childIdx >= length) break; + uint256 childVal = _arrayLoad(array, childIdx); + uint256 otherChildIdx = childIdx + 1; + // Pick the larger child + if (otherChildIdx < length) { + uint256 otherChildVal = _arrayLoad(array, otherChildIdx); + if (otherChildVal > childVal) { + childIdx = otherChildIdx; + childVal = otherChildVal; + } + } + // No child is larger than the inserted value + if (childVal <= inserted) break; + // Move the larger child one level up and keep sifting down + _arrayStore(array, emptyIdx, childVal); + emptyIdx = childIdx; + } + _arrayStore(array, emptyIdx, inserted); + } + } + + function _arrayLoad(uint256[] memory array, uint256 idx) private pure returns (uint256 val) { + /// @solidity memory-safe-assembly + assembly { + val := mload(add(32, add(array, shl(5, idx)))) + } + } + + function _arrayStore( + uint256[] memory array, + uint256 idx, + uint256 val + ) private pure { + /// @solidity memory-safe-assembly + assembly { + mstore(add(32, add(array, shl(5, idx))), val) + } + } } diff --git a/test/utils/Arrays.test.js b/test/utils/Arrays.test.js index 67128fac26f..333a0deb817 100644 --- a/test/utils/Arrays.test.js +++ b/test/utils/Arrays.test.js @@ -1,6 +1,7 @@ require('@openzeppelin/test-helpers'); const { expect } = require('chai'); +const { config } = require('hardhat'); const ArraysImpl = artifacts.require('ArraysImpl'); @@ -84,4 +85,80 @@ contract('Arrays', function (accounts) { }); }); }); + + describe('sort', function () { + let arraysImpl; + + before(async function () { + arraysImpl = await ArraysImpl.new([]); + }); + + async function testSort (array) { + const sorted = await arraysImpl.sort(array); + const sortedNum = sorted.map(val => val.toNumber()); + array.sort((a, b) => a - b); + expect(sortedNum).to.deep.equal(array, 'Invalid sorting result'); + } + + function arrayOf (length, generator) { + return Array.from(Array(length), generator); + } + + const randomItems = () => Math.floor(Math.random() * 1000); + const sameItems = () => 1; + const sortedItems = (_, idx) => idx; + const reverseSortedItems = (_, idx) => 1000 - idx; + + it('accepts zero length arrays', async function () { + await testSort([]); + }); + + it('accepts one length arrays', async function () { + await testSort([1]); + }); + + it('handles sorted data', async function () { + await testSort([1, 2, 3]); + }); + + it('handles reverse sorted data', async function () { + await testSort([3, 2, 1]); + }); + + it('handles scrambled data', async function () { + await testSort([2, 1, 3]); + }); + + it('handles 10 random items', async function () { + await testSort(arrayOf(10, randomItems)); + }); + + it('handles 10 same items', async function () { + await testSort(arrayOf(10, sameItems)); + }); + + it('handles 10 sorted items', async function () { + await testSort(arrayOf(10, sortedItems)); + }); + + it('handles 10 reverse sorted items', async function () { + await testSort(arrayOf(10, reverseSortedItems)); + }); + + it('handles 100 random items', async function () { + await testSort(arrayOf(100, randomItems)); + }); + + it('handles 100 same items', async function () { + await testSort(arrayOf(100, sameItems)); + }); + + it('handles 100 sorted items', async function () { + await testSort(arrayOf(100, sortedItems)); + }); + + it('handles 100 reverse sorted items', async function () { + await testSort(arrayOf(100, reverseSortedItems)); + }); + }); });