Skip to content

Commit

Permalink
follow up the inlining unmatched type param PR (JuliaLang#46484)
Browse files Browse the repository at this point in the history
This commit follows up JuliaLang#45062:
- eliminate closure capturing, and improve type stability a bit
- refactor the test structure so that they are more aligned with
  the other parts of tests
  • Loading branch information
aviatesk authored Aug 26, 2022
1 parent 19f44b6 commit bea7b6f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 35 deletions.
34 changes: 19 additions & 15 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ function is_getfield_captures(@nospecialize(def), compact::IncrementalCompact)
end

struct LiftedValue
x
LiftedValue(@nospecialize x) = new(x)
val
LiftedValue(@nospecialize val) = new(val)
end
const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}}

Expand Down Expand Up @@ -578,7 +578,7 @@ function lift_comparison_leaves!(@specialize(tfunc),
visited_phinodes, cmp, lifting_cache, Bool,
lifted_leaves::LiftedLeaves, val, nothing)::LiftedValue

compact[idx] = lifted_val.x
compact[idx] = lifted_val.val
end

struct LiftedPhi
Expand Down Expand Up @@ -626,7 +626,7 @@ function perform_lifting!(compact::IncrementalCompact,
end
end

the_leaf_val = isa(the_leaf, LiftedValue) ? the_leaf.x : nothing
the_leaf_val = isa(the_leaf, LiftedValue) ? the_leaf.val : nothing
if !isa(the_leaf_val, SSAValue)
all_same = false
end
Expand Down Expand Up @@ -690,7 +690,7 @@ function perform_lifting!(compact::IncrementalCompact,
resize!(new_node.values, length(new_node.values)+1)
continue
end
val = lifted_val.x
val = lifted_val.val
if isa(val, AnySSAValue)
callback = (@nospecialize(pi), @nospecialize(idx)) -> true
val = simple_walk(compact, val, callback)
Expand Down Expand Up @@ -750,18 +750,18 @@ function lift_svec_ref!(compact::IncrementalCompact, idx::Int, stmt::Expr)
elseif is_known_call(def, Core._compute_sparams, compact)
res = _lift_svec_ref(def, compact)
if res !== nothing
compact[idx] = res
compact[idx] = res.val
end
return
end
end
end

function _lift_svec_ref(def::Expr, compact::IncrementalCompact)
# TODO: We could do the whole lifing machinery here, but really all
# we want to do is clean this up when it got inserted by inlining,
# which always targets simple `svec` call or `_compute_sparams`,
# so this specialized lifting would be enough
# TODO: We could do the whole lifing machinery here, but really all
# we want to do is clean this up when it got inserted by inlining,
# which always targets simple `svec` call or `_compute_sparams`,
# so this specialized lifting would be enough
@inline function _lift_svec_ref(def::Expr, compact::IncrementalCompact)
m = argextype(def.args[2], compact)
isa(m, Const) || return nothing
m = m.val
Expand All @@ -776,9 +776,13 @@ function _lift_svec_ref(def::Expr, compact::IncrementalCompact)
sig.name === Tuple.name || return nothing
length(sig.parameters) >= 1 || return nothing

i = findfirst(j->has_typevar(sig.parameters[j], tvar), 1:length(sig.parameters))
i = let sig=sig
findfirst(j->has_typevar(sig.parameters[j], tvar), 1:length(sig.parameters))
end
i === nothing && return nothing
_any(j->has_typevar(sig.parameters[j], tvar), i+1:length(sig.parameters)) && return nothing
let sig=sig
any(j->has_typevar(sig.parameters[j], tvar), i+1:length(sig.parameters))
end && return nothing

arg = sig.parameters[i]
isa(arg, DataType) || return nothing
Expand Down Expand Up @@ -808,7 +812,7 @@ function _lift_svec_ref(def::Expr, compact::IncrementalCompact)
length(applyTbody.parameters) == length(arg.parameters) == 1 || return nothing
applyTbody.parameters[1] === applyTvar || return nothing
arg.parameters[1] === tvar || return nothing
return argdef.args[3]
return LiftedValue(argdef.args[3])
end

# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
Expand Down Expand Up @@ -1017,7 +1021,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing, InliningState} = nothin
@assert val !== nothing
end

compact[idx] = val === nothing ? nothing : val.x
compact[idx] = val === nothing ? nothing : val.val
end

non_dce_finish!(compact)
Expand Down
43 changes: 23 additions & 20 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1518,26 +1518,29 @@ function oc_capture_oc(z)
end
@test fully_eliminated(oc_capture_oc, (Int,))

@eval struct OldVal{T}
x::T
(OV::Type{OldVal{T}})() where T = $(Expr(:new, :OV))
end
with_unmatched_typeparam1(x::OldVal{i}) where {i} = i
with_unmatched_typeparam2() = [ Base.donotdelete(OldVal{i}()) for i in 1:10000 ]
function with_unmatched_typeparam3()
f(x::OldVal{i}) where {i} = i
r = 0
for i = 1:10000
r += f(OldVal{i}())
end
return r
end

@testset "Inlining with unmatched type parameters" begin
@eval struct OldVal{T}
x::T
(OV::Type{OldVal{T}})() where T = $(Expr(:new, :OV))
end
let f(x) = OldVal{x}()
g() = [ Base.donotdelete(OldVal{i}()) for i in 1:10000 ]
h() = begin
f(x::OldVal{i}) where {i} = i
r = 0
for i = 1:10000
r += f(OldVal{i}())
end
return r
end
srcs = (code_typed1(f, (Any,)),
code_typed1(g),
code_typed1(h))
for src in srcs
@test !any(@nospecialize(x) -> isexpr(x, :call) && length(x.args) == 1, src.code)
end
let src = code_typed1(with_unmatched_typeparam1, (Any,))
@test !any(@nospecialize(x) -> isexpr(x, :call) && length(x.args) == 1, src.code)
end
let src = code_typed1(with_unmatched_typeparam2)
@test !any(@nospecialize(x) -> isexpr(x, :call) && length(x.args) == 1, src.code)
end
let src = code_typed1(with_unmatched_typeparam3)
@test !any(@nospecialize(x) -> isexpr(x, :call) && length(x.args) == 1, src.code)
end
end

0 comments on commit bea7b6f

Please sign in to comment.