From 3f12fd9175bf72e24afc82c6b4e671de99c6eba7 Mon Sep 17 00:00:00 2001 From: Abel Braaksma Date: Mon, 18 Dec 2023 03:13:38 +0100 Subject: [PATCH] Implement `take`, `truncate`, `skip` and `drop` --- src/FSharp.Control.TaskSeq/TaskSeq.fs | 26 ++-- src/FSharp.Control.TaskSeq/TaskSeq.fsi | 59 ++++++++ src/FSharp.Control.TaskSeq/TaskSeqInternal.fs | 129 ++++++++++++++++-- 3 files changed, 193 insertions(+), 21 deletions(-) diff --git a/src/FSharp.Control.TaskSeq/TaskSeq.fs b/src/FSharp.Control.TaskSeq/TaskSeq.fs index fb8d67b8..42c00ee5 100644 --- a/src/FSharp.Control.TaskSeq/TaskSeq.fs +++ b/src/FSharp.Control.TaskSeq/TaskSeq.fs @@ -6,22 +6,15 @@ open System.Threading.Tasks #nowarn "57" +// Just for convenience +module Internal = TaskSeqInternal + [] module TaskSeqExtensions = + // these need to be in a module, not a type for proper auto-initialization of generic values module TaskSeq = + let empty<'T> = Internal.empty<'T> - let empty<'T> = - { new IAsyncEnumerable<'T> with - member _.GetAsyncEnumerator(_) = - { new IAsyncEnumerator<'T> with - member _.MoveNextAsync() = ValueTask.False - member _.Current = Unchecked.defaultof<'T> - member _.DisposeAsync() = ValueTask.CompletedTask - } - } - -// Just for convenience -module Internal = TaskSeqInternal [] type TaskSeq private () = @@ -289,18 +282,27 @@ type TaskSeq private () = static member choose chooser source = Internal.choose (TryPick chooser) source static member chooseAsync chooser source = Internal.choose (TryPickAsync chooser) source + static member filter predicate source = Internal.filter (Predicate predicate) source static member filterAsync predicate source = Internal.filter (PredicateAsync predicate) source + + static member skip count source = Internal.skipOrTake Skip count source + static member drop count source = Internal.skipOrTake Drop count source + static member take count source = Internal.skipOrTake Take count source + static member truncate count source = Internal.skipOrTake Truncate count source + static member takeWhile predicate source = Internal.takeWhile Exclusive (Predicate predicate) source static member takeWhileAsync predicate source = Internal.takeWhile Exclusive (PredicateAsync predicate) source static member takeWhileInclusive predicate source = Internal.takeWhile Inclusive (Predicate predicate) source static member takeWhileInclusiveAsync predicate source = Internal.takeWhile Inclusive (PredicateAsync predicate) source + static member tryPick chooser source = Internal.tryPick (TryPick chooser) source static member tryPickAsync chooser source = Internal.tryPick (TryPickAsync chooser) source static member tryFind predicate source = Internal.tryFind (Predicate predicate) source static member tryFindAsync predicate source = Internal.tryFind (PredicateAsync predicate) source static member tryFindIndex predicate source = Internal.tryFindIndex (Predicate predicate) source static member tryFindIndexAsync predicate source = Internal.tryFindIndex (PredicateAsync predicate) source + static member except itemsToExclude source = Internal.except itemsToExclude source static member exceptOfSeq itemsToExclude source = Internal.exceptOfSeq itemsToExclude source diff --git a/src/FSharp.Control.TaskSeq/TaskSeq.fsi b/src/FSharp.Control.TaskSeq/TaskSeq.fsi index 27d27a25..a58775df 100644 --- a/src/FSharp.Control.TaskSeq/TaskSeq.fsi +++ b/src/FSharp.Control.TaskSeq/TaskSeq.fsi @@ -725,6 +725,65 @@ type TaskSeq = /// Thrown when the input task sequence is null. static member filterAsync: predicate: ('T -> #Task) -> source: TaskSeq<'T> -> TaskSeq<'T> + /// + /// Returns a task sequence that, when iterated, skips elements of the + /// underlying sequence, and then returns the remainder of the elements. Raises an exception if there are not enough + /// elements in the sequence. See for a version that does not raise an exception. + /// See also for the inverse of this operation. + /// + /// + /// The number of items to skip. + /// The input task sequence. + /// The resulting task sequence. + /// Thrown when the input task sequence is null. + /// Thrown when is less than zero. + /// Thrown when count exceeds the number of elements in the sequence. + static member skip: count: int -> source: TaskSeq<'T> -> TaskSeq<'T> + + + /// + /// Returns a task sequence that, when iterated, drops at most elements of the + /// underlying sequence, and then returns the remainder of the elements, if any. + /// See for a version that raises an exception if there + /// are not enough elements. See also for the inverse of this operation. + /// + /// + /// The number of items to drop. + /// The input task sequence. + /// The resulting task sequence. + /// Thrown when the input task sequence is null. + /// Thrown when is less than zero. + static member drop: count: int -> source: TaskSeq<'T> -> TaskSeq<'T> + + /// + /// Returns a task sequence that, when iterated, yields elements of the + /// underlying sequence, and then returns no further elements. Raises an exception if there are not enough + /// elements in the sequence. See for a version that does not raise an exception. + /// See also for the inverse of this operation. + /// + /// + /// The number of items to take. + /// The input task sequence. + /// The resulting task sequence. + /// Thrown when the input task sequence is null. + /// Thrown when is less than zero. + /// Thrown when count exceeds the number of elements in the sequence. + static member take: count: int -> source: TaskSeq<'T> -> TaskSeq<'T> + + /// + /// Returns a task sequence that, when iterated, yields at most elements of the underlying + /// sequence, truncating the remainder, if any. + /// See for a version that raises an exception if there are not enough elements in the + /// sequence. See also for the inverse of this operation. + /// + /// + /// The maximum number of items to enumerate. + /// The input task sequence. + /// The resulting task sequence. + /// Thrown when the input task sequence is null. + /// Thrown when is less than zero. + static member truncate: count: int -> source: TaskSeq<'T> -> TaskSeq<'T> + /// /// Returns a task sequence that, when iterated, yields elements of the underlying sequence while the /// given function returns , and then returns no further elements. diff --git a/src/FSharp.Control.TaskSeq/TaskSeqInternal.fs b/src/FSharp.Control.TaskSeq/TaskSeqInternal.fs index 7ea572a5..9d963b9b 100644 --- a/src/FSharp.Control.TaskSeq/TaskSeqInternal.fs +++ b/src/FSharp.Control.TaskSeq/TaskSeqInternal.fs @@ -18,6 +18,17 @@ type internal WhileKind = /// The item under test is always excluded | Exclusive +[] +type internal TakeOrSkipKind = + /// use the Seq.take semantics, raises exception if not enough elements + | Take + /// use the Seq.skip semantics, raises exception if not enough elements + | Skip + /// use the Seq.truncate semantics, safe operation, returns all if count exceeds the seq + | Truncate + /// no Seq equiv, but like Stream.drop in Scala: safe operation, return empty if not enough elements + | Drop + [] type internal Action<'T, 'U, 'TaskU when 'TaskU :> Task<'U>> = | CountableAction of countable_action: (int -> 'T -> 'U) @@ -51,20 +62,15 @@ module internal TaskSeqInternal = if isNull arg then nullArg argName - let inline raiseEmptySeq () = - ArgumentException("The asynchronous input sequence was empty.", "source") - |> raise + let inline raiseEmptySeq () = invalidArg "source" "The input task sequence was empty." - let inline raiseCannotBeNegative (name: string) = - ArgumentException("The value cannot be negative", name) - |> raise + let inline raiseCannotBeNegative name = invalidArg name "The value must be non-negative" let inline raiseInsufficient () = - ArgumentException("The asynchronous input sequence was has an insufficient number of elements.", "source") - |> raise + invalidArg "source" "The input task sequence was has an insufficient number of elements." let inline raiseNotFound () = - KeyNotFoundException("The predicate function or index did not satisfy any item in the async sequence.") + KeyNotFoundException("The predicate function or index did not satisfy any item in the task sequence.") |> raise let isEmpty (source: TaskSeq<_>) = @@ -76,6 +82,16 @@ module internal TaskSeqInternal = return not step } + let empty<'T> = + { new IAsyncEnumerable<'T> with + member _.GetAsyncEnumerator(_) = + { new IAsyncEnumerator<'T> with + member _.MoveNextAsync() = ValueTask.False + member _.Current = Unchecked.defaultof<'T> + member _.DisposeAsync() = ValueTask.CompletedTask + } + } + let singleton (value: 'T) = { new IAsyncEnumerable<'T> with member _.GetAsyncEnumerator(_) = @@ -613,6 +629,101 @@ module internal TaskSeqInternal = | false -> () } + + let skipOrTake skipOrTake count (source: TaskSeq<_>) = + checkNonNull (nameof source) source + + if count < 0 then + raiseCannotBeNegative (nameof count) + + match skipOrTake with + | Skip -> + // don't create a new sequence if count = 0 + if count = 0 then + source + else + taskSeq { + use e = source.GetAsyncEnumerator CancellationToken.None + + for _ in 1..count do + let! step = e.MoveNextAsync() + + if not step then + raiseInsufficient () + + let mutable cont = true + + while cont do + yield e.Current + let! moveNext = e.MoveNextAsync() + cont <- moveNext + + } + | Drop -> + // don't create a new sequence if count = 0 + if count = 0 then + source + else + taskSeq { + use e = source.GetAsyncEnumerator CancellationToken.None + + let! step = e.MoveNextAsync() + let mutable cont = step + let mutable pos = 0 + + // skip, or stop looping if we reached the end + while cont do + pos <- pos + 1 + let! moveNext = e.MoveNextAsync() + cont <- moveNext && pos <= count + + // return the rest + while cont do + yield e.Current + let! moveNext = e.MoveNextAsync() + cont <- moveNext + + } + | Take -> + // don't initialize an empty task sequence + if count = 0 then + empty + else + taskSeq { + use e = source.GetAsyncEnumerator CancellationToken.None + + for _ in count .. - 1 .. 1 do + let! step = e.MoveNextAsync() + + if not step then + raiseInsufficient () + + yield e.Current + } + + | Truncate -> + // don't create a new sequence if count = 0 + if count = 0 then + empty + else + taskSeq { + use e = source.GetAsyncEnumerator CancellationToken.None + + let! step = e.MoveNextAsync() + let mutable cont = step + let mutable pos = 0 + + // return items until we've exhausted the seq + // report this line, weird error: + //while! e.MoveNextAsync() && pos < 1 do + while cont do + yield e.Current + pos <- pos + 1 + let! moveNext = e.MoveNextAsync() + cont <- moveNext && pos <= count + + } + let takeWhile whileKind predicate (source: TaskSeq<_>) = checkNonNull (nameof source) source