diff --git a/contracts/utils/structs/Checkpoints.sol b/contracts/utils/structs/Checkpoints.sol index 5ba6ad92d5e..ba2ac3ef41f 100644 --- a/contracts/utils/structs/Checkpoints.sol +++ b/contracts/utils/structs/Checkpoints.sol @@ -168,10 +168,14 @@ library Checkpoints { ) private view returns (uint256) { while (low < high) { uint256 mid = Math.average(low, high); - if (_unsafeAccess(self, mid)._key > key) { + uint32 currentKey = _unsafeAccess(self, mid)._key; + if (currentKey > key) { high = mid; } else { low = mid + 1; + if (currentKey == key) { + return low; + } } } return high; @@ -192,10 +196,14 @@ library Checkpoints { ) private view returns (uint256) { while (low < high) { uint256 mid = Math.average(low, high); - if (_unsafeAccess(self, mid)._key < key) { + uint32 currentKey = _unsafeAccess(self, mid)._key; + if (currentKey < key) { low = mid + 1; } else { high = mid; + if (currentKey == key) { + return high; + } } } return high; @@ -558,10 +566,14 @@ library Checkpoints { ) private view returns (uint256) { while (low < high) { uint256 mid = Math.average(low, high); - if (_unsafeAccess(self, mid)._key > key) { + uint96 currentKey = _unsafeAccess(self, mid)._key; + if (currentKey > key) { high = mid; } else { low = mid + 1; + if (currentKey == key) { + return low; + } } } return high; @@ -582,10 +594,14 @@ library Checkpoints { ) private view returns (uint256) { while (low < high) { uint256 mid = Math.average(low, high); - if (_unsafeAccess(self, mid)._key < key) { + uint96 currentKey = _unsafeAccess(self, mid)._key; + if (currentKey < key) { low = mid + 1; } else { high = mid; + if (currentKey == key) { + return high; + } } } return high; diff --git a/scripts/generate/templates/Checkpoints.js b/scripts/generate/templates/Checkpoints.js index 0a677e27f07..d9d26eb18d5 100644 --- a/scripts/generate/templates/Checkpoints.js +++ b/scripts/generate/templates/Checkpoints.js @@ -189,10 +189,14 @@ function _upperBinaryLookup( ) private view returns (uint256) { while (low < high) { uint256 mid = Math.average(low, high); - if (_unsafeAccess(self, mid).${opts.keyFieldName} > key) { + ${opts.keyTypeName} currentKey = _unsafeAccess(self, mid).${opts.keyFieldName}; + if (currentKey > key) { high = mid; } else { low = mid + 1; + if (currentKey == key) { + return low; + } } } return high; @@ -213,10 +217,14 @@ function _lowerBinaryLookup( ) private view returns (uint256) { while (low < high) { uint256 mid = Math.average(low, high); - if (_unsafeAccess(self, mid).${opts.keyFieldName} < key) { + ${opts.keyTypeName} currentKey = _unsafeAccess(self, mid).${opts.keyFieldName}; + if (currentKey < key) { low = mid + 1; } else { high = mid; + if (currentKey == key) { + return high; + } } } return high;