Skip to content

Commit

Permalink
AbstractInterpreter: implement findsup for OverlayMethodTable
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Mar 7, 2022
1 parent 3bcab39 commit c86d43a
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 44 deletions.
3 changes: 2 additions & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,8 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
argtype = Tuple{ft, argtype.parameters...}
result = findsup(types, method_table(interp))
result === nothing && return CallMeta(Any, false)
method, valid_worlds = result
match, valid_worlds = result
method = match.method
update_valid_age!(sv, valid_worlds)
(ti, env::SimpleVector) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector
(; rt, edge) = result = abstract_call_method(interp, method, ti, env, false, sv)
Expand Down
61 changes: 38 additions & 23 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,45 +40,51 @@ end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch

"""
findall(sig::Type, view::MethodTableView; limit=typemax(Int))
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) -> MethodLookupResult or missing
Find all methods in the given method table `view` that are applicable to the
given signature `sig`. If no applicable methods are found, an empty result is
returned. If the number of applicable methods exceeded the specified limit,
`missing` is returned.
"""
function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=typemax(Int))
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig)
if ms === false
return missing
end
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
return _findall(sig, nothing, table.world, limit)
end

function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=typemax(Int))
result = _findall(sig, table.mt, table.world, limit)
result === missing && return missing
if !isempty(result)
if all(match->match.fully_covers, result)
# no need to fall back to the internal method table
return result
else
# merge the match results with the internal method table
fallback_result = _findall(sig, nothing, table.world, limit)
return MethodLookupResult(
vcat(result.matches, fallback_result.matches),
WorldRange(min(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
max(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
result.ambig | fallback_result.ambig)
end
end
# fall back to the internal method table
return _findall(sig, nothing, table.world, limit)
end

function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
ms = _methods_by_ftype(sig, table.mt, limit, table.world, false, _min_val, _max_val, _ambig)
ms = _methods_by_ftype(sig, mt, limit, world, false, _min_val, _max_val, _ambig)
if ms === false
return missing
elseif isempty(ms)
# fall back to the internal method table
_min_val[] = typemin(UInt)
_max_val[] = typemax(UInt)
ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig)
if ms === false
return missing
end
end
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
end

"""
findsup(sig::Type, view::MethodTableView)::Union{Tuple{MethodMatch, WorldRange}, Nothing}
findsup(sig::Type, view::MethodTableView) -> Tuple{MethodMatch, WorldRange} or nothing
Find the (unique) method `m` such that `sig <: m.sig`, while being more
specific than any other method with the same property. In other words, find
Expand All @@ -92,12 +98,21 @@ upper bound of `sig`, or it is possible that among the upper bounds, there
is no least element. In both cases `nothing` is returned.
"""
function findsup(@nospecialize(sig::Type), table::InternalMethodTable)
return _findsup(sig, nothing, table.world)
end

function findsup(@nospecialize(sig::Type), table::OverlayMethodTable)
result = _findsup(sig, table.mt, table.world)
result === nothing || return result
return _findsup(sig, nothing, table.world) # fall back to the internal method table
end

function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt)
min_valid = RefValue{UInt}(typemin(UInt))
max_valid = RefValue{UInt}(typemax(UInt))
result = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}),
sig, table.world, min_valid, max_valid)::Union{MethodMatch, Nothing}
result === nothing && return nothing
(result.method, WorldRange(min_valid[], max_valid[]))
result = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}),
sig, mt, world, min_valid, max_valid)::Union{MethodMatch, Nothing}
return result === nothing ? result : (result, WorldRange(min_valid[], max_valid[]))
end

isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
Expand Down
12 changes: 4 additions & 8 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1347,15 +1347,11 @@ end
print_statement_costs(args...; kwargs...) = print_statement_costs(stdout, args...; kwargs...)

function _which(@nospecialize(tt::Type), world=get_world_counter())
min_valid = RefValue{UInt}(typemin(UInt))
max_valid = RefValue{UInt}(typemax(UInt))
match = ccall(:jl_gf_invoke_lookup_worlds, Any,
(Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}),
tt, world, min_valid, max_valid)
if match === nothing
result = Core.Compiler._findsup(tt, nothing, world)
if result === nothing
error("no unique matching method found for the specified argument types")
end
return match::Core.MethodMatch
return first(result)::Core.MethodMatch
end

"""
Expand Down Expand Up @@ -1478,7 +1474,7 @@ true
function hasmethod(@nospecialize(f), @nospecialize(t); world::UInt=get_world_counter())
t = to_tuple_type(t)
t = signature_type(f, t)
return ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), t, world) !== nothing
return ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), t, nothing, world) !== nothing
end

function hasmethod(@nospecialize(f), @nospecialize(t), kwnames::Tuple{Vararg{Symbol}}; world::UInt=get_world_counter())
Expand Down
23 changes: 12 additions & 11 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,7 @@ static jl_method_instance_t *cache_method(
return newmeth;
}

static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid);
static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT, jl_value_t *mt, size_t world, size_t *min_valid, size_t *max_valid);

static jl_method_instance_t *jl_mt_assoc_by_type(jl_methtable_t *mt JL_PROPAGATES_ROOT, jl_datatype_t *tt, size_t world)
{
Expand All @@ -1237,7 +1237,7 @@ static jl_method_instance_t *jl_mt_assoc_by_type(jl_methtable_t *mt JL_PROPAGATE

size_t min_valid = 0;
size_t max_valid = ~(size_t)0;
jl_method_match_t *matc = _gf_invoke_lookup((jl_value_t*)tt, world, &min_valid, &max_valid);
jl_method_match_t *matc = _gf_invoke_lookup((jl_value_t*)tt, jl_nothing, world, &min_valid, &max_valid);
jl_method_instance_t *nf = NULL;
if (matc) {
JL_GC_PUSH1(&matc);
Expand Down Expand Up @@ -2549,36 +2549,37 @@ JL_DLLEXPORT jl_value_t *jl_apply_generic(jl_value_t *F, jl_value_t **args, uint
return _jl_invoke(F, args, nargs, mfunc, world);
}

static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid)
static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT, jl_value_t *mt, size_t world, size_t *min_valid, size_t *max_valid)
{
jl_value_t *unw = jl_unwrap_unionall((jl_value_t*)types);
if (jl_is_tuple_type(unw) && jl_tparam0(unw) == jl_bottom_type)
return NULL;
jl_methtable_t *mt = jl_method_table_for(unw);
if ((jl_value_t*)mt == jl_nothing)
if (mt == jl_nothing)
mt = (jl_value_t*)jl_method_table_for(unw);
if (mt == jl_nothing)
mt = NULL;
jl_value_t *matches = ml_matches(mt, (jl_tupletype_t*)types, 1, 0, 0, world, 1, min_valid, max_valid, NULL);
jl_value_t *matches = ml_matches((jl_methtable_t*)mt, (jl_tupletype_t*)types, 1, 0, 0, world, 1, min_valid, max_valid, NULL);
if (matches == jl_false || jl_array_len(matches) != 1)
return NULL;
jl_method_match_t *matc = (jl_method_match_t*)jl_array_ptr_ref(matches, 0);
return matc;
}

JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, size_t world)
JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, jl_value_t *mt, size_t world)
{
// Deprecated: Use jl_gf_invoke_lookup_worlds for future development
size_t min_valid = 0;
size_t max_valid = ~(size_t)0;
jl_method_match_t *matc = _gf_invoke_lookup(types, world, &min_valid, &max_valid);
jl_method_match_t *matc = _gf_invoke_lookup(types, mt, world, &min_valid, &max_valid);
if (matc == NULL)
return jl_nothing;
return (jl_value_t*)matc->method;
}


JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup_worlds(jl_value_t *types, size_t world, size_t *min_world, size_t *max_world)
JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup_worlds(jl_value_t *types, jl_value_t *mt, size_t world, size_t *min_world, size_t *max_world)
{
jl_method_match_t *matc = _gf_invoke_lookup(types, world, min_world, max_world);
jl_method_match_t *matc = _gf_invoke_lookup(types, mt, world, min_world, max_world);
if (matc == NULL)
return jl_nothing;
return (jl_value_t*)matc;
Expand All @@ -2599,7 +2600,7 @@ jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t *gf, jl_value_t **args,
jl_value_t *types = NULL;
JL_GC_PUSH1(&types);
types = jl_argtype_with_function(gf, types0);
jl_method_t *method = (jl_method_t*)jl_gf_invoke_lookup(types, world);
jl_method_t *method = (jl_method_t*)jl_gf_invoke_lookup(types, jl_nothing, world);
JL_GC_PROMISE_ROOTED(method);

if ((jl_value_t*)method == jl_nothing) {
Expand Down
2 changes: 1 addition & 1 deletion stdlib/Test/src/Test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1774,7 +1774,7 @@ function detect_unbound_args(mods...;
params = tuple_sig.parameters[1:(end - 1)]
tuple_sig = Base.rewrap_unionall(Tuple{params...}, m.sig)
world = Base.get_world_counter()
mf = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), tuple_sig, world)
mf = ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), tuple_sig, nothing, world)
if mf !== nothing && mf !== m && mf.sig <: tuple_sig
continue
end
Expand Down
18 changes: 18 additions & 0 deletions test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,21 @@ CC.method_table(interp::MTOverlayInterp) = CC.OverlayMethodTable(CC.get_world_co
@test Base.return_types((Int,), MTOverlayInterp()) do x
sin(x)
end == Any[Int]
@test Base.return_types((Any,), MTOverlayInterp()) do x
Base.@invoke sin(x::Float64)
end == Any[Int]

# fallback to the internal method table
@test Base.return_types((Int,), MTOverlayInterp()) do x
cos(x)
end == Any[Float64]
@test Base.return_types((Any,), MTOverlayInterp()) do x
Base.@invoke cos(x::Float64)
end == Any[Float64]

# not fully covered overlay method match
overlay_match(::Any) = nothing
@overlay OverlayedMT overlay_match(::Int) = missing
@test Base.return_types((Any,), MTOverlayInterp()) do x
overlay_match(x)
end == Any[Union{Nothing,Missing}]

0 comments on commit c86d43a

Please sign in to comment.