From af03cc128eb5d809ce84d6988be8b9dc878d7cef Mon Sep 17 00:00:00 2001 From: Vladislav Volosnikov Date: Mon, 28 Aug 2023 13:43:53 +0300 Subject: [PATCH] Rely on strict increasing order in checkpoints lookups --- contracts/utils/structs/Checkpoints.sol | 24 +++++++++++++++++++---- scripts/generate/templates/Checkpoints.js | 12 ++++++++++-- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/contracts/utils/structs/Checkpoints.sol b/contracts/utils/structs/Checkpoints.sol index 383f01af8fa..79b28b64a87 100644 --- a/contracts/utils/structs/Checkpoints.sol +++ b/contracts/utils/structs/Checkpoints.sol @@ -161,10 +161,14 @@ library Checkpoints { ) private view returns (uint256) { while (low < high) { uint256 mid = Math.average(low, high); - if (_unsafeAccess(self, mid)._key > key) { + uint32 currentKey = _unsafeAccess(self, mid)._key; + if (currentKey > key) { high = mid; } else { low = mid + 1; + if (currentKey == key) { + return low; + } } } return high; @@ -184,10 +188,14 @@ library Checkpoints { ) private view returns (uint256) { while (low < high) { uint256 mid = Math.average(low, high); - if (_unsafeAccess(self, mid)._key < key) { + uint32 currentKey = _unsafeAccess(self, mid)._key; + if (currentKey < key) { low = mid + 1; } else { high = mid; + if (currentKey == key) { + return high; + } } } return high; @@ -348,10 +356,14 @@ library Checkpoints { ) private view returns (uint256) { while (low < high) { uint256 mid = Math.average(low, high); - if (_unsafeAccess(self, mid)._key > key) { + uint96 currentKey = _unsafeAccess(self, mid)._key; + if (currentKey > key) { high = mid; } else { low = mid + 1; + if (currentKey == key) { + return low; + } } } return high; @@ -371,10 +383,14 @@ library Checkpoints { ) private view returns (uint256) { while (low < high) { uint256 mid = Math.average(low, high); - if (_unsafeAccess(self, mid)._key < key) { + uint96 currentKey = _unsafeAccess(self, mid)._key; + if (currentKey < key) { low = mid + 1; } else { high = mid; + if (currentKey == key) { + return high; + } } } return high; diff --git a/scripts/generate/templates/Checkpoints.js b/scripts/generate/templates/Checkpoints.js index 73c9ab53e76..f91ce5ca33e 100644 --- a/scripts/generate/templates/Checkpoints.js +++ b/scripts/generate/templates/Checkpoints.js @@ -182,10 +182,14 @@ function _upperBinaryLookup( ) private view returns (uint256) { while (low < high) { uint256 mid = Math.average(low, high); - if (_unsafeAccess(self, mid).${opts.keyFieldName} > key) { + ${opts.keyTypeName} currentKey = _unsafeAccess(self, mid).${opts.keyFieldName}; + if (currentKey > key) { high = mid; } else { low = mid + 1; + if (currentKey == key) { + return low; + } } } return high; @@ -205,10 +209,14 @@ function _lowerBinaryLookup( ) private view returns (uint256) { while (low < high) { uint256 mid = Math.average(low, high); - if (_unsafeAccess(self, mid).${opts.keyFieldName} < key) { + ${opts.keyTypeName} currentKey = _unsafeAccess(self, mid).${opts.keyFieldName}; + if (currentKey < key) { low = mid + 1; } else { high = mid; + if (currentKey == key) { + return high; + } } } return high;