Skip to content

Commit

Permalink
Add stack(iterator_of_arrays) (JuliaLang#43334)
Browse files Browse the repository at this point in the history
* generalises `reduce(hcat, vector_of_vectors)` to handle more dimensions and handle iterators efficiently.

* add `stack(f, xs) = stack(f(x) for x in xs)`

* add doc and test

* disallow stack on empty iterators

* add NEWS and compat note
  • Loading branch information
mcabbott authored Aug 21, 2022
1 parent aac466f commit 696f7d3
Show file tree
Hide file tree
Showing 7 changed files with 392 additions and 1 deletion.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ New library functions
inspecting which function `f` was originally wrapped. ([#42717])
* New `pkgversion(m::Module)` function to get the version of the package that loaded
a given module, similar to `pkgdir(m::Module)`. ([#45607])
* New function `stack(x)` which generalises `reduce(hcat, x::Vector{<:Vector})` to any dimensionality,
and allows any iterators of iterators. Method `stack(f, x)` generalises `mapreduce(f, hcat, x)` and
is efficient. ([#43334])

Library changes
---------------
Expand Down
230 changes: 230 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2605,6 +2605,236 @@ end
Ai
end

"""
stack(iter; [dims])
Combine a collection of arrays (or other iterable objects) of equal size
into one larger array, by arranging them along one or more new dimensions.
By default the axes of the elements are placed first,
giving `size(result) = (size(first(iter))..., size(iter)...)`.
This has the same order of elements as [`Iterators.flatten`](@ref)`(iter)`.
With keyword `dims::Integer`, instead the `i`th element of `iter` becomes the slice
[`selectdim`](@ref)`(result, dims, i)`, so that `size(result, dims) == length(iter)`.
In this case `stack` reverses the action of [`eachslice`](@ref) with the same `dims`.
The various [`cat`](@ref) functions also combine arrays. However, these all
extend the arrays' existing (possibly trivial) dimensions, rather than placing
the arrays along new dimensions.
They also accept arrays as separate arguments, rather than a single collection.
!!! compat "Julia 1.9"
This function requires at least Julia 1.9.
# Examples
```jldoctest
julia> vecs = (1:2, [30, 40], Float32[500, 600]);
julia> mat = stack(vecs)
2×3 Matrix{Float32}:
1.0 30.0 500.0
2.0 40.0 600.0
julia> mat == hcat(vecs...) == reduce(hcat, collect(vecs))
true
julia> vec(mat) == vcat(vecs...) == reduce(vcat, collect(vecs))
true
julia> stack(zip(1:4, 10:99)) # accepts any iterators of iterators
2×4 Matrix{Int64}:
1 2 3 4
10 11 12 13
julia> vec(ans) == collect(Iterators.flatten(zip(1:4, 10:99)))
true
julia> stack(vecs; dims=1) # unlike any cat function, 1st axis of vecs[1] is 2nd axis of result
3×2 Matrix{Float32}:
1.0 2.0
30.0 40.0
500.0 600.0
julia> x = rand(3,4);
julia> x == stack(eachcol(x)) == stack(eachrow(x), dims=1) # inverse of eachslice
true
```
Higher-dimensional examples:
```jldoctest
julia> A = rand(5, 7, 11);
julia> E = eachslice(A, dims=2); # a vector of matrices
julia> (element = size(first(E)), container = size(E))
(element = (5, 11), container = (7,))
julia> stack(E) |> size
(5, 11, 7)
julia> stack(E) == stack(E; dims=3) == cat(E...; dims=3)
true
julia> A == stack(E; dims=2)
true
julia> M = (fill(10i+j, 2, 3) for i in 1:5, j in 1:7);
julia> (element = size(first(M)), container = size(M))
(element = (2, 3), container = (5, 7))
julia> stack(M) |> size # keeps all dimensions
(2, 3, 5, 7)
julia> stack(M; dims=1) |> size # vec(container) along dims=1
(35, 2, 3)
julia> hvcat(5, M...) |> size # hvcat puts matrices next to each other
(14, 15)
```
"""
stack(iter; dims=:) = _stack(dims, iter)

"""
stack(f, args...; [dims])
Apply a function to each element of a collection, and `stack` the result.
Or to several collections, [`zip`](@ref)ped together.
The function should return arrays (or tuples, or other iterators) all of the same size.
These become slices of the result, each separated along `dims` (if given) or by default
along the last dimensions.
See also [`mapslices`](@ref), [`eachcol`](@ref).
# Examples
```jldoctest
julia> stack(c -> (c, c-32), "julia")
2×5 Matrix{Char}:
'j' 'u' 'l' 'i' 'a'
'J' 'U' 'L' 'I' 'A'
julia> stack(eachrow([1 2 3; 4 5 6]), (10, 100); dims=1) do row, n
vcat(row, row .* n, row ./ n)
end
2×9 Matrix{Float64}:
1.0 2.0 3.0 10.0 20.0 30.0 0.1 0.2 0.3
4.0 5.0 6.0 400.0 500.0 600.0 0.04 0.05 0.06
```
"""
stack(f, iter; dims=:) = _stack(dims, f(x) for x in iter)
stack(f, xs, yzs...; dims=:) = _stack(dims, f(xy...) for xy in zip(xs, yzs...))

_stack(dims::Union{Integer, Colon}, iter) = _stack(dims, IteratorSize(iter), iter)

_stack(dims, ::IteratorSize, iter) = _stack(dims, collect(iter))

function _stack(dims, ::Union{HasShape, HasLength}, iter)
S = @default_eltype iter
T = S != Union{} ? eltype(S) : Any # Union{} occurs for e.g. stack(1,2), postpone the error
if isconcretetype(T)
_typed_stack(dims, T, S, iter)
else # Need to look inside, but shouldn't run an expensive iterator twice:
array = iter isa Union{Tuple, AbstractArray} ? iter : collect(iter)
isempty(array) && return _empty_stack(dims, T, S, iter)
T2 = mapreduce(eltype, promote_type, array)
_typed_stack(dims, T2, eltype(array), array)
end
end

function _typed_stack(::Colon, ::Type{T}, ::Type{S}, A, Aax=_iterator_axes(A)) where {T, S}
xit = iterate(A)
nothing === xit && return _empty_stack(:, T, S, A)
x1, _ = xit
ax1 = _iterator_axes(x1)
B = similar(_ensure_array(x1), T, ax1..., Aax...)
off = firstindex(B)
len = length(x1)
while xit !== nothing
x, state = xit
_stack_size_check(x, ax1)
copyto!(B, off, x)
off += len
xit = iterate(A, state)
end
B
end

_iterator_axes(x) = _iterator_axes(x, IteratorSize(x))
_iterator_axes(x, ::HasLength) = (OneTo(length(x)),)
_iterator_axes(x, ::IteratorSize) = axes(x)

# For some dims values, stack(A; dims) == stack(vec(A)), and the : path will be faster
_typed_stack(dims::Integer, ::Type{T}, ::Type{S}, A) where {T,S} =
_typed_stack(dims, T, S, IteratorSize(S), A)
_typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::HasLength, A) where {T,S} =
_typed_stack(dims, T, S, HasShape{1}(), A)
function _typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::HasShape{N}, A) where {T,S,N}
if dims == N+1
_typed_stack(:, T, S, A, (_vec_axis(A),))
else
_dim_stack(dims, T, S, A)
end
end
_typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::IteratorSize, A) where {T,S} =
_dim_stack(dims, T, S, A)

_vec_axis(A, ax=_iterator_axes(A)) = length(ax) == 1 ? only(ax) : OneTo(prod(length, ax; init=1))

@constprop :aggressive function _dim_stack(dims::Integer, ::Type{T}, ::Type{S}, A) where {T,S}
xit = Iterators.peel(A)
nothing === xit && return _empty_stack(dims, T, S, A)
x1, xrest = xit
ax1 = _iterator_axes(x1)
N1 = length(ax1)+1
dims in 1:N1 || throw(ArgumentError(LazyString("cannot stack slices ndims(x) = ", N1-1, " along dims = ", dims)))

newaxis = _vec_axis(A)
outax = ntuple(d -> d==dims ? newaxis : ax1[d - (d>dims)], N1)
B = similar(_ensure_array(x1), T, outax...)

if dims == 1
_dim_stack!(Val(1), B, x1, xrest)
elseif dims == 2
_dim_stack!(Val(2), B, x1, xrest)
else
_dim_stack!(Val(dims), B, x1, xrest)
end
B
end

function _dim_stack!(::Val{dims}, B::AbstractArray, x1, xrest) where {dims}
before = ntuple(d -> Colon(), dims - 1)
after = ntuple(d -> Colon(), ndims(B) - dims)

i = firstindex(B, dims)
copyto!(view(B, before..., i, after...), x1)

for x in xrest
_stack_size_check(x, _iterator_axes(x1))
i += 1
@inbounds copyto!(view(B, before..., i, after...), x)
end
end

@inline function _stack_size_check(x, ax1::Tuple)
if _iterator_axes(x) != ax1
uax1 = map(UnitRange, ax1)
uaxN = map(UnitRange, axes(x))
throw(DimensionMismatch(
LazyString("stack expects uniform slices, got axes(x) == ", uaxN, " while first had ", uax1)))
end
end

_ensure_array(x::AbstractArray) = x
_ensure_array(x) = 1:0 # passed to similar, makes stack's output an Array

_empty_stack(_...) = throw(ArgumentError("`stack` on an empty collection is not allowed"))


## Reductions and accumulates ##

function isequal(A::AbstractArray, B::AbstractArray)
Expand Down
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ export
sortperm!,
sortslices,
dropdims,
stack,
step,
stride,
strides,
Expand Down
16 changes: 15 additions & 1 deletion base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1199,7 +1199,7 @@ See also [`Iterators.flatten`](@ref), [`Iterators.map`](@ref).
# Examples
```jldoctest
julia> Iterators.flatmap(n->-n:2:n, 1:3) |> collect
julia> Iterators.flatmap(n -> -n:2:n, 1:3) |> collect
9-element Vector{Int64}:
-1
1
Expand All @@ -1210,6 +1210,20 @@ julia> Iterators.flatmap(n->-n:2:n, 1:3) |> collect
-1
1
3
julia> stack(n -> -n:2:n, 1:3)
ERROR: DimensionMismatch: stack expects uniform slices, got axes(x) == (1:3,) while first had (1:2,)
[...]
julia> Iterators.flatmap(n -> (-n, 10n), 1:2) |> collect
4-element Vector{Int64}:
-1
10
-2
20
julia> ans == vec(stack(n -> (-n, 10n), 1:2))
true
```
"""
flatmap(f, c...) = flatten(map(f, c...))
Expand Down
1 change: 1 addition & 0 deletions doc/src/base/arrays.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ Base.vcat
Base.hcat
Base.hvcat
Base.hvncat
Base.stack
Base.vect
Base.circshift
Base.circshift!
Expand Down
Loading

0 comments on commit 696f7d3

Please sign in to comment.