Skip to content

Commit

Permalink
Add waitany and waitall functions to wait multiple tasks at once (#53341
Browse files Browse the repository at this point in the history
)

This adds two functions: `waitany` and `waitall`, as discussed in the
issue #53226. These functions wait for multiple tasks at once. The
`waitany` function blocks until one task finishes. The `waitall`
function blocks until all tasks finish.

Co-authored-by: Shuhei Kadowaki <[email protected]>
Co-authored-by: Jameson Nash <[email protected]>
  • Loading branch information
3 people authored Mar 11, 2024
1 parent f882c00 commit 8413b97
Show file tree
Hide file tree
Showing 5 changed files with 270 additions and 0 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ New library functions

* `logrange(start, stop; length)` makes a range of constant ratio, instead of constant step ([#39071])
* The new `isfull(c::Channel)` function can be used to check if `put!(c, some_value)` will block. ([#53159])
* `waitany(tasks; throw=false)` and `waitall(tasks; failfast=false, throw=false)` which wait multiple tasks at once ([#53341]).

New library features
--------------------
Expand Down
2 changes: 2 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,8 @@ export
yield,
yieldto,
wait,
waitany,
waitall,
timedwait,
asyncmap,
asyncmap!,
Expand Down
142 changes: 142 additions & 0 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,148 @@ function wait(t::Task)
nothing
end

# Wait multiple tasks

"""
waitany(tasks; throw=true) -> (done_tasks, remaining_tasks)
Wait until at least one of the given tasks have been completed.
If `throw` is `true`, throw `CompositeException` when one of the
completed tasks completes with an exception.
The return value consists of two task vectors. The first one consists of
completed tasks, and the other consists of uncompleted tasks.
!!! warning
This may scale poorly compared to writing code that uses multiple individual tasks that
each runs serially, since this needs to scan the list of `tasks` each time and
synchronize with each one every time this is called. Or consider using
[`waitall(tasks; failfast=true)`](@ref waitall) instead.
"""
waitany(tasks; throw=true) = _wait_multiple(tasks, throw)

"""
waitall(tasks; failfast=true, throw=true) -> (done_tasks, remaining_tasks)
Wait until all the given tasks have been completed.
If `failfast` is `true`, the function will return when at least one of the
given tasks is finished by exception. If `throw` is `true`, throw
`CompositeException` when one of the completed tasks has failed.
`failfast` and `throw` keyword arguments work independently; when only
`throw=true` is specified, this function waits for all the tasks to complete.
The return value consists of two task vectors. The first one consists of
completed tasks, and the other consists of uncompleted tasks.
"""
waitall(tasks; failfast=true, throw=true) = _wait_multiple(tasks, throw, true, failfast)

function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false)
tasks = Task[]

for t in waiting_tasks
t isa Task || error("Expected an iterator of `Task` object")
push!(tasks, t)
end

if (all && !failfast) || length(tasks) <= 1
exception = false
# Force everything to finish synchronously for the case of waitall
# with failfast=false
for t in tasks
_wait(t)
exception |= istaskfailed(t)
end
if exception && throwexc
exceptions = [TaskFailedException(t) for t in tasks if istaskfailed(t)]
throw(CompositeException(exceptions))
else
return tasks, Task[]
end
end

exception = false
nremaining::Int = length(tasks)
done_mask = falses(nremaining)
for (i, t) in enumerate(tasks)
if istaskdone(t)
done_mask[i] = true
exception |= istaskfailed(t)
nremaining -= 1
else
done_mask[i] = false
end
end

if nremaining == 0
return tasks, Task[]
elseif any(done_mask) && (!all || (failfast && exception))
if throwexc && (!all || failfast) && exception
exceptions = [TaskFailedException(t) for t in tasks[done_mask] if istaskfailed(t)]
throw(CompositeException(exceptions))
else
return tasks[done_mask], tasks[.~done_mask]
end
end

chan = Channel{Int}(Inf)
sentinel = current_task()
waiter_tasks = fill(sentinel, length(tasks))

for (i, done) in enumerate(done_mask)
done && continue
t = tasks[i]
if istaskdone(t)
done_mask[i] = true
exception |= istaskfailed(t)
nremaining -= 1
exception && failfast && break
else
waiter = @task put!(chan, i)
waiter.sticky = false
_wait2(t, waiter)
waiter_tasks[i] = waiter
end
end

while nremaining > 0
i = take!(chan)
t = tasks[i]
waiter_tasks[i] = sentinel
done_mask[i] = true
exception |= istaskfailed(t)
nremaining -= 1

# stop early if requested, unless there is something immediately
# ready to consume from the channel (using a race-y check)
if (!all || (failfast && exception)) && !isready(chan)
break
end
end

close(chan)

if nremaining == 0
return tasks, Task[]
else
remaining_mask = .~done_mask
for i in findall(remaining_mask)
waiter = waiter_tasks[i]
donenotify = tasks[i].donenotify::ThreadSynchronizer
@lock donenotify Base.list_deletefirst!(donenotify.waitq, waiter)
end
done_tasks = tasks[done_mask]
if throwexc && exception
exceptions = [TaskFailedException(t) for t in done_tasks if istaskfailed(t)]
throw(CompositeException(exceptions))
else
return done_tasks, tasks[remaining_mask]
end
end
end

"""
fetch(x::Any)
Expand Down
2 changes: 2 additions & 0 deletions doc/src/base/parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Base.schedule
Base.errormonitor
Base.@sync
Base.wait
Base.waitany
Base.waitall
Base.fetch(t::Task)
Base.fetch(x::Any)
Base.timedwait
Expand Down
123 changes: 123 additions & 0 deletions test/threads_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1186,4 +1186,127 @@ end
@testset "threadcall + threads" begin
threadcall_threads() #Shouldn't crash!
end

@testset "Wait multiple tasks" begin
convert_tasks(t, x) = x
convert_tasks(::Set{Task}, x::Vector{Task}) = Set{Task}(x)
convert_tasks(::Tuple{Task}, x::Vector{Task}) = tuple(x...)

function create_tasks()
tasks = Task[]
event = Threads.Event()
push!(tasks,
Threads.@spawn begin
sleep(0.01)
end)
push!(tasks,
Threads.@spawn begin
sleep(0.02)
end)
push!(tasks,
Threads.@spawn begin
wait(event)
end)
return tasks, event
end

function teardown(tasks, event)
notify(event)
waitall(resize!(tasks, 3), throw=true)
end

for tasks_type in (Vector{Task}, Set{Task}, Tuple{Task})
@testset "waitany" begin
@testset "throw=false" begin
tasks, event = create_tasks()
wait(tasks[1])
wait(tasks[2])
done, pending = waitany(convert_tasks(tasks_type, tasks); throw=false)
@test length(done) == 2
@test tasks[1] done
@test tasks[2] done
@test length(pending) == 1
@test tasks[3] pending
teardown(tasks, event)
end

@testset "throw=true" begin
tasks, event = create_tasks()
push!(tasks, Threads.@spawn error("Error"))

@test_throws CompositeException begin
waitany(convert_tasks(tasks_type, tasks); throw=true)
end

teardown(tasks, event)
end
end

@testset "waitall" begin
@testset "All tasks succeed" begin
tasks, event = create_tasks()

wait(tasks[1])
wait(tasks[2])
waiter = Threads.@spawn waitall(convert_tasks(tasks_type, tasks))
@test !istaskdone(waiter)

notify(event)
done, pending = fetch(waiter)
@test length(done) == 3
@test tasks[1] done
@test tasks[2] done
@test tasks[3] done
@test length(pending) == 0
end

@testset "failfast=true, throw=false" begin
tasks, event = create_tasks()
push!(tasks, Threads.@spawn error("Error"))

wait(tasks[1])
wait(tasks[2])
waiter = Threads.@spawn waitall(convert_tasks(tasks_type, tasks); failfast=true, throw=false)

done, pending = fetch(waiter)
@test length(done) == 3
@test tasks[1] done
@test tasks[2] done
@test tasks[4] done
@test length(pending) == 1
@test tasks[3] pending

teardown(tasks, event)
end

@testset "failfast=false, throw=true" begin
tasks, event = create_tasks()
push!(tasks, Threads.@spawn error("Error"))

notify(event)

@test_throws CompositeException begin
waitall(convert_tasks(tasks_type, tasks); failfast=false, throw=true)
end

@test all(istaskdone.(tasks))

teardown(tasks, event)
end

@testset "failfast=true, throw=true" begin
tasks, event = create_tasks()
push!(tasks, Threads.@spawn error("Error"))

@test_throws CompositeException begin
waitall(convert_tasks(tasks_type, tasks); failfast=true, throw=true)
end

@test !istaskdone(tasks[3])

teardown(tasks, event)
end
end
end
end
end # main testset

0 comments on commit 8413b97

Please sign in to comment.