diff --git a/src/System.Linq/src/System/Linq/Partition.cs b/src/System.Linq/src/System/Linq/Partition.cs index 12127982a7ae..9b9ae1a14a8d 100644 --- a/src/System.Linq/src/System/Linq/Partition.cs +++ b/src/System.Linq/src/System/Linq/Partition.cs @@ -140,19 +140,19 @@ public int GetCount(bool onlyIfCheap) internal sealed class OrderedPartition : IPartition { private readonly OrderedEnumerable _source; - private readonly int _minIndex; - private readonly int _maxIndex; + private readonly int _minIndexInclusive; + private readonly int _maxIndexInclusive; - public OrderedPartition(OrderedEnumerable source, int minIdx, int maxIdx) + public OrderedPartition(OrderedEnumerable source, int minIdxInclusive, int maxIdxInclusive) { _source = source; - _minIndex = minIdx; - _maxIndex = maxIdx; + _minIndexInclusive = minIdxInclusive; + _maxIndexInclusive = maxIdxInclusive; } public IEnumerator GetEnumerator() { - return _source.GetEnumerator(_minIndex, _maxIndex); + return _source.GetEnumerator(_minIndexInclusive, _maxIndexInclusive); } IEnumerator IEnumerable.GetEnumerator() @@ -162,26 +162,26 @@ IEnumerator IEnumerable.GetEnumerator() public IPartition Skip(int count) { - int minIndex = _minIndex + count; - return (uint)minIndex > (uint)_maxIndex ? EmptyPartition.Instance : new OrderedPartition(_source, minIndex, _maxIndex); + int minIndex = _minIndexInclusive + count; + return (uint)minIndex > (uint)_maxIndexInclusive ? EmptyPartition.Instance : new OrderedPartition(_source, minIndex, _maxIndexInclusive); } public IPartition Take(int count) { - int maxIndex = _minIndex + count - 1; - if ((uint)maxIndex >= (uint)_maxIndex) + int maxIndex = _minIndexInclusive + count - 1; + if ((uint)maxIndex >= (uint)_maxIndexInclusive) { return this; } - return new OrderedPartition(_source, _minIndex, maxIndex); + return new OrderedPartition(_source, _minIndexInclusive, maxIndex); } public TElement TryGetElementAt(int index, out bool found) { - if ((uint)index <= (uint)(_maxIndex - _minIndex)) + if ((uint)index <= (uint)(_maxIndexInclusive - _minIndexInclusive)) { - return _source.TryGetElementAt(index + _minIndex, out found); + return _source.TryGetElementAt(index + _minIndexInclusive, out found); } found = false; @@ -190,27 +190,27 @@ public TElement TryGetElementAt(int index, out bool found) public TElement TryGetFirst(out bool found) { - return _source.TryGetElementAt(_minIndex, out found); + return _source.TryGetElementAt(_minIndexInclusive, out found); } public TElement TryGetLast(out bool found) { - return _source.TryGetLast(_minIndex, _maxIndex, out found); + return _source.TryGetLast(_minIndexInclusive, _maxIndexInclusive, out found); } public TElement[] ToArray() { - return _source.ToArray(_minIndex, _maxIndex); + return _source.ToArray(_minIndexInclusive, _maxIndexInclusive); } public List ToList() { - return _source.ToList(_minIndex, _maxIndex); + return _source.ToList(_minIndexInclusive, _maxIndexInclusive); } public int GetCount(bool onlyIfCheap) { - return _source.GetCount(_minIndex, _maxIndex, onlyIfCheap); + return _source.GetCount(_minIndexInclusive, _maxIndexInclusive, onlyIfCheap); } } @@ -219,8 +219,8 @@ public static partial class Enumerable private sealed class ListPartition : Iterator, IPartition { private readonly IList _source; - private readonly int _minIndex; - private readonly int _maxIndex; + private readonly int _minIndexInclusive; + private readonly int _maxIndexInclusive; private int _index; public ListPartition(IList source, int minIndexInclusive, int maxIndexInclusive) @@ -229,19 +229,19 @@ public ListPartition(IList source, int minIndexInclusive, int maxIndexI Debug.Assert(minIndexInclusive >= 0); Debug.Assert(minIndexInclusive <= maxIndexInclusive); _source = source; - _minIndex = minIndexInclusive; - _maxIndex = maxIndexInclusive; + _minIndexInclusive = minIndexInclusive; + _maxIndexInclusive = maxIndexInclusive; _index = minIndexInclusive; } public override Iterator Clone() { - return new ListPartition(_source, _minIndex, _maxIndex); + return new ListPartition(_source, _minIndexInclusive, _maxIndexInclusive); } public override bool MoveNext() { - if ((_state == 1 & _index <= _maxIndex) && _index < _source.Count) + if ((_state == 1 & _index <= _maxIndexInclusive) && _index < _source.Count) { _current = _source[_index]; ++_index; @@ -254,27 +254,27 @@ public override bool MoveNext() public override IEnumerable Select(Func selector) { - return new SelectListPartitionIterator(_source, selector, _minIndex, _maxIndex); + return new SelectListPartitionIterator(_source, selector, _minIndexInclusive, _maxIndexInclusive); } public IPartition Skip(int count) { - int minIndex = _minIndex + count; - return minIndex >= _maxIndex ? EmptyPartition.Instance : new ListPartition(_source, minIndex, _maxIndex); + int minIndex = _minIndexInclusive + count; + return (uint)minIndex > (uint)_maxIndexInclusive ? EmptyPartition.Instance : new ListPartition(_source, minIndex, _maxIndexInclusive); } public IPartition Take(int count) { - int maxIndex = _minIndex + count - 1; - return (uint)maxIndex >= (uint)_maxIndex ? this : new ListPartition(_source, _minIndex, maxIndex); + int maxIndex = _minIndexInclusive + count - 1; + return (uint)maxIndex >= (uint)_maxIndexInclusive ? this : new ListPartition(_source, _minIndexInclusive, maxIndex); } public TSource TryGetElementAt(int index, out bool found) { - if ((uint)index <= (uint)(_maxIndex - _minIndex) && index < _source.Count - _minIndex) + if ((uint)index <= (uint)(_maxIndexInclusive - _minIndexInclusive) && index < _source.Count - _minIndexInclusive) { found = true; - return _source[_minIndex + index]; + return _source[_minIndexInclusive + index]; } found = false; @@ -283,10 +283,10 @@ public TSource TryGetElementAt(int index, out bool found) public TSource TryGetFirst(out bool found) { - if (_source.Count > _minIndex) + if (_source.Count > _minIndexInclusive) { found = true; - return _source[_minIndex]; + return _source[_minIndexInclusive]; } found = false; @@ -296,10 +296,10 @@ public TSource TryGetFirst(out bool found) public TSource TryGetLast(out bool found) { int lastIndex = _source.Count - 1; - if (lastIndex >= _minIndex) + if (lastIndex >= _minIndexInclusive) { found = true; - return _source[Math.Min(lastIndex, _maxIndex)]; + return _source[Math.Min(lastIndex, _maxIndexInclusive)]; } found = false; @@ -311,12 +311,12 @@ private int Count get { int count = _source.Count; - if (count <= _minIndex) + if (count <= _minIndexInclusive) { return 0; } - return Math.Min(count - 1, _maxIndex) - _minIndex + 1; + return Math.Min(count - 1, _maxIndexInclusive) - _minIndexInclusive + 1; } } @@ -329,7 +329,7 @@ public TSource[] ToArray() } TSource[] array = new TSource[count]; - for (int i = 0, curIdx = _minIndex; i != array.Length; ++i, ++curIdx) + for (int i = 0, curIdx = _minIndexInclusive; i != array.Length; ++i, ++curIdx) { array[i] = _source[curIdx]; } @@ -346,8 +346,8 @@ public List ToList() } List list = new List(count); - int end = _minIndex + count; - for (int i = _minIndex; i != end; ++i) + int end = _minIndexInclusive + count; + for (int i = _minIndexInclusive; i != end; ++i) { list.Add(_source[i]); } diff --git a/src/System.Linq/src/System/Linq/Select.cs b/src/System.Linq/src/System/Linq/Select.cs index 1189efbab8d9..9e3940631437 100644 --- a/src/System.Linq/src/System/Linq/Select.cs +++ b/src/System.Linq/src/System/Linq/Select.cs @@ -686,8 +686,8 @@ private sealed class SelectListPartitionIterator : Iterator _source; private readonly Func _selector; - private readonly int _minIndex; - private readonly int _maxIndex; + private readonly int _minIndexInclusive; + private readonly int _maxIndexInclusive; private int _index; public SelectListPartitionIterator(IList source, Func selector, int minIndexInclusive, int maxIndexInclusive) @@ -698,19 +698,19 @@ public SelectListPartitionIterator(IList source, Func Debug.Assert(minIndexInclusive <= maxIndexInclusive); _source = source; _selector = selector; - _minIndex = minIndexInclusive; - _maxIndex = maxIndexInclusive; + _minIndexInclusive = minIndexInclusive; + _maxIndexInclusive = maxIndexInclusive; _index = minIndexInclusive; } public override Iterator Clone() { - return new SelectListPartitionIterator(_source, _selector, _minIndex, _maxIndex); + return new SelectListPartitionIterator(_source, _selector, _minIndexInclusive, _maxIndexInclusive); } public override bool MoveNext() { - if ((_state == 1 & _index <= _maxIndex) && _index < _source.Count) + if ((_state == 1 & _index <= _maxIndexInclusive) && _index < _source.Count) { _current = _selector(_source[_index]); ++_index; @@ -723,28 +723,28 @@ public override bool MoveNext() public override IEnumerable Select(Func selector) { - return new SelectListPartitionIterator(_source, CombineSelectors(_selector, selector), _minIndex, _maxIndex); + return new SelectListPartitionIterator(_source, CombineSelectors(_selector, selector), _minIndexInclusive, _maxIndexInclusive); } public IPartition Skip(int count) { Debug.Assert(count > 0); - int minIndex = _minIndex + count; - return minIndex >= _maxIndex ? EmptyPartition.Instance : new SelectListPartitionIterator(_source, _selector, minIndex, _maxIndex); + int minIndex = _minIndexInclusive + count; + return (uint)minIndex > (uint)_maxIndexInclusive ? EmptyPartition.Instance : new SelectListPartitionIterator(_source, _selector, minIndex, _maxIndexInclusive); } public IPartition Take(int count) { - int maxIndex = _minIndex + count - 1; - return (uint)maxIndex >= (uint)_maxIndex ? this : new SelectListPartitionIterator(_source, _selector, _minIndex, maxIndex); + int maxIndex = _minIndexInclusive + count - 1; + return (uint)maxIndex >= (uint)_maxIndexInclusive ? this : new SelectListPartitionIterator(_source, _selector, _minIndexInclusive, maxIndex); } public TResult TryGetElementAt(int index, out bool found) { - if ((uint)index <= (uint)(_maxIndex - _minIndex) && index < _source.Count - _minIndex) + if ((uint)index <= (uint)(_maxIndexInclusive - _minIndexInclusive) && index < _source.Count - _minIndexInclusive) { found = true; - return _selector(_source[_minIndex + index]); + return _selector(_source[_minIndexInclusive + index]); } found = false; @@ -753,10 +753,10 @@ public TResult TryGetElementAt(int index, out bool found) public TResult TryGetFirst(out bool found) { - if (_source.Count > _minIndex) + if (_source.Count > _minIndexInclusive) { found = true; - return _selector(_source[_minIndex]); + return _selector(_source[_minIndexInclusive]); } found = false; @@ -766,10 +766,10 @@ public TResult TryGetFirst(out bool found) public TResult TryGetLast(out bool found) { int lastIndex = _source.Count - 1; - if (lastIndex >= _minIndex) + if (lastIndex >= _minIndexInclusive) { found = true; - return _selector(_source[Math.Min(lastIndex, _maxIndex)]); + return _selector(_source[Math.Min(lastIndex, _maxIndexInclusive)]); } found = false; @@ -781,12 +781,12 @@ private int Count get { int count = _source.Count; - if (count <= _minIndex) + if (count <= _minIndexInclusive) { return 0; } - return Math.Min(count - 1, _maxIndex) - _minIndex + 1; + return Math.Min(count - 1, _maxIndexInclusive) - _minIndexInclusive + 1; } } @@ -799,7 +799,7 @@ public TResult[] ToArray() } TResult[] array = new TResult[count]; - for (int i = 0, curIdx = _minIndex; i != array.Length; ++i, ++curIdx) + for (int i = 0, curIdx = _minIndexInclusive; i != array.Length; ++i, ++curIdx) { array[i] = _selector(_source[curIdx]); } @@ -816,8 +816,8 @@ public List ToList() } List list = new List(count); - int end = _minIndex + count; - for (int i = _minIndex; i != end; ++i) + int end = _minIndexInclusive + count; + for (int i = _minIndexInclusive; i != end; ++i) { list.Add(_selector(_source[i])); } diff --git a/src/System.Linq/tests/OrderedSubsetting.cs b/src/System.Linq/tests/OrderedSubsetting.cs index d190fa7c7975..3835a37c0e80 100644 --- a/src/System.Linq/tests/OrderedSubsetting.cs +++ b/src/System.Linq/tests/OrderedSubsetting.cs @@ -329,6 +329,15 @@ public void Count() Assert.Equal(1, Enumerable.Range(0, 100).Shuffle().OrderBy(i => i).Take(2).Skip(1).Count()); } + [Fact] + public void SkipTakesOnlyOne() + { + Assert.Equal(new[] { 1 }, Enumerable.Range(1, 10).Shuffle().OrderBy(i => i).Take(1)); + Assert.Equal(new[] { 2 }, Enumerable.Range(1, 10).Shuffle().OrderBy(i => i).Skip(1).Take(1)); + Assert.Equal(new[] { 3 }, Enumerable.Range(1, 10).Shuffle().OrderBy(i => i).Take(3).Skip(2)); + Assert.Equal(new[] { 1 }, Enumerable.Range(1, 10).Shuffle().OrderBy(i => i).Take(3).Take(1)); + } + [Fact] public void EmptyToArray() { diff --git a/src/System.Linq/tests/RangeTests.cs b/src/System.Linq/tests/RangeTests.cs index 3b3544905b2a..ee6aee6b2fa5 100644 --- a/src/System.Linq/tests/RangeTests.cs +++ b/src/System.Linq/tests/RangeTests.cs @@ -167,6 +167,15 @@ public void SkipExcessive() Assert.Empty(Enumerable.Range(10, 10).Skip(20)); } + [Fact] + public void SkipTakeCanOnlyBeOne() + { + Assert.Equal(new[] { 1 }, Enumerable.Range(1, 10).Take(1)); + Assert.Equal(new[] { 2 }, Enumerable.Range(1, 10).Skip(1).Take(1)); + Assert.Equal(new[] { 3 }, Enumerable.Range(1, 10).Take(3).Skip(2)); + Assert.Equal(new[] { 1 }, Enumerable.Range(1, 10).Take(3).Take(1)); + } + [Fact] public void ElementAt() { diff --git a/src/System.Linq/tests/RepeatTests.cs b/src/System.Linq/tests/RepeatTests.cs index 3f6ac825c3fe..e24403edcaa6 100644 --- a/src/System.Linq/tests/RepeatTests.cs +++ b/src/System.Linq/tests/RepeatTests.cs @@ -172,6 +172,15 @@ public void SkipExcessive() Assert.Empty(Enumerable.Repeat(12, 8).Skip(22)); } + [Fact] + public void TakeCanOnlyBeOne() + { + Assert.Equal(new[] { 1 }, Enumerable.Repeat(1, 10).Take(1)); + Assert.Equal(new[] { 1 }, Enumerable.Repeat(1, 10).Skip(1).Take(1)); + Assert.Equal(new[] { 1 }, Enumerable.Repeat(1, 10).Take(3).Skip(2)); + Assert.Equal(new[] { 1 }, Enumerable.Repeat(1, 10).Take(3).Take(1)); + } + [Fact] public void SkipNone() { diff --git a/src/System.Linq/tests/SelectTests.cs b/src/System.Linq/tests/SelectTests.cs index e45820ca85c3..5f6568245d17 100644 --- a/src/System.Linq/tests/SelectTests.cs +++ b/src/System.Linq/tests/SelectTests.cs @@ -842,6 +842,10 @@ public void Select_SourceIsArray_Take() Assert.Empty(source.Take(-1)); Assert.Equal(new[] { 2, 4, 6, 8 }, source.Take(4)); Assert.Equal(new[] { 2, 4, 6, 8 }, source.Take(40)); + Assert.Equal(new[] { 2 }, source.Take(1)); + Assert.Equal(new[] { 4 }, source.Skip(1).Take(1)); + Assert.Equal(new[] { 6 }, source.Take(3).Skip(2)); + Assert.Equal(new[] { 2 }, source.Take(3).Take(1)); } [Fact] @@ -853,6 +857,10 @@ public void Select_SourceIsList_Take() Assert.Empty(source.Take(-1)); Assert.Equal(new[] { 2, 4, 6, 8 }, source.Take(4)); Assert.Equal(new[] { 2, 4, 6, 8 }, source.Take(40)); + Assert.Equal(new[] { 2 }, source.Take(1)); + Assert.Equal(new[] { 4 }, source.Skip(1).Take(1)); + Assert.Equal(new[] { 6 }, source.Take(3).Skip(2)); + Assert.Equal(new[] { 2 }, source.Take(3).Take(1)); } [Fact] @@ -864,6 +872,10 @@ public void Select_SourceIsIList_Take() Assert.Empty(source.Take(-1)); Assert.Equal(new[] { 2, 4, 6, 8 }, source.Take(4)); Assert.Equal(new[] { 2, 4, 6, 8 }, source.Take(40)); + Assert.Equal(new[] { 2 }, source.Take(1)); + Assert.Equal(new[] { 4 }, source.Skip(1).Take(1)); + Assert.Equal(new[] { 6 }, source.Take(3).Skip(2)); + Assert.Equal(new[] { 2 }, source.Take(3).Take(1)); } [Fact] diff --git a/src/System.Linq/tests/TakeTests.cs b/src/System.Linq/tests/TakeTests.cs index c7fbeda6ea79..bb9de6b1d746 100644 --- a/src/System.Linq/tests/TakeTests.cs +++ b/src/System.Linq/tests/TakeTests.cs @@ -415,6 +415,26 @@ public void ToListNotList() Assert.Empty(source.Take(-10).ToList()); } + [Fact] + public void TakeCanOnlyBeOneList() + { + var source = new[] { 2, 4, 6, 8, 10 }; + Assert.Equal(new[] { 2 }, source.Take(1)); + Assert.Equal(new[] { 4 }, source.Skip(1).Take(1)); + Assert.Equal(new[] { 6 }, source.Take(3).Skip(2)); + Assert.Equal(new[] { 2 }, source.Take(3).Take(1)); + } + + [Fact] + public void TakeCanOnlyBeOneNotList() + { + var source = GuaranteeNotIList(new[] { 2, 4, 6, 8, 10 }); + Assert.Equal(new[] { 2 }, source.Take(1)); + Assert.Equal(new[] { 4 }, source.Skip(1).Take(1)); + Assert.Equal(new[] { 6 }, source.Take(3).Skip(2)); + Assert.Equal(new[] { 2 }, source.Take(3).Take(1)); + } + [Fact] public void RepeatEnumerating() {