Skip to content

Commit

Permalink
llvm: Switch to trampoline-based jlcall annotations (JuliaLang#45088)
Browse files Browse the repository at this point in the history
As discussed extensively in JuliaLang#45057, when enabling LLVM's opaque pointer
support, we get significant miscompilations in jlcall callsites, because
calls with mismatching calling conventions are considered undefined behavior.
This implements Option D) from JuliaLang#45057, switching our jlcall callsites to use
a `julia.call` trampoline intrinsic instead. The lowering for this intrinsic
is essentially the same as the CC-based lowering before, except that the
callee is now of course the first argument rather than the actual callee.
Other than that, the changes are mostly mechanical.

Fixes JuliaLang#45057
  • Loading branch information
Keno authored Jun 10, 2022
1 parent 295d741 commit 23f39f8
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 57 deletions.
10 changes: 3 additions & 7 deletions doc/src/devdocs/llvm.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,11 @@ array. However, this would betray the SSA nature of the uses at the call site,
making optimizations (including GC root placement), significantly harder.
Instead, we emit it as follows:
```llvm
%bitcast = bitcast @any_unoptimized_call to %jl_value_t *(*)(%jl_value_t *, %jl_value_t *)
call cc 37 %jl_value_t *%bitcast(%jl_value_t *%arg1, %jl_value_t *%arg2)
call %jl_value_t *@julia.call(jl_value_t *(*)(...) @any_unoptimized_call, %jl_value_t *%arg1, %jl_value_t *%arg2)
```
The special `cc 37` annotation marks the fact that this call site is really using
the jlcall calling convention. This allows us to retain the SSA-ness of the
This allows us to retain the SSA-ness of the
uses throughout the optimizer. GC root placement will later lower this call to
the original C ABI. In the code the calling convention number is represented by
the `JLCALL_F_CC` constant. In addition, there is the `JLCALL_CC` calling
convention which functions similarly, but omits the first argument.
the original C ABI.

## GC root placement

Expand Down
4 changes: 2 additions & 2 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1802,7 +1802,7 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
ret = emit_invoke(ctx, *modifyop, argv, 3, (jl_value_t*)jl_any_type);
}
else {
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, JLCALL_F_CC);
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, julia_call);
ret = mark_julia_type(ctx, callval, true, jl_any_type);
}
if (!jl_subtype(ret.typ, jltype)) {
Expand Down Expand Up @@ -3549,7 +3549,7 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
rhs = emit_invoke(ctx, *modifyop, argv, 3, (jl_value_t*)jl_any_type);
}
else {
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, JLCALL_F_CC);
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, julia_call);
rhs = mark_julia_type(ctx, callval, true, jl_any_type);
}
if (!jl_subtype(rhs.typ, jfty)) {
Expand Down
88 changes: 62 additions & 26 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,43 @@ static const auto pointer_from_objref_func = new JuliaFunction{
None); },
};

// julia.call represents a call with julia calling convention, it is used as
//
// ptr julia.call(ptr %fptr, ptr %f, ptr %arg1, ptr %arg2, ...)
//
// In late lowering the call will then be rewritten as
//
// ptr %fptr(ptr %f, ptr args, i64 nargs)
//
// with all the spelled out args appropriately moved into the argument stack buffer.
// By representing it this way rather than allocating the stack buffer earlier, we
// allow LLVM to make more aggressive optimizations on the call arguments.
static const auto julia_call = new JuliaFunction{
"julia.call",
[](LLVMContext &C) { return FunctionType::get(JuliaType::get_prjlvalue_ty(C),
#ifdef JL_LLVM_OPAQUE_POINTERS
{PointerType::get(C, 0)},
#else
{get_func_sig(C)->getPointerTo()},
#endif
true); },
nullptr
};

// julia.call2 is like julia.call, except that %arg1 gets passed as a register
// argument at the end of the argument list.
static const auto julia_call2 = new JuliaFunction{
"julia.call2",
[](LLVMContext &C) { return FunctionType::get(JuliaType::get_prjlvalue_ty(C),
#ifdef JL_LLVM_OPAQUE_POINTERS
{PointerType::get(C, 0)},
#else
{get_func_sig(C)->getPointerTo()},
#endif
true); },
nullptr
};

static const auto jltuple_func = new JuliaFunction{XSTR(jl_f_tuple), get_func_sig, get_func_attrs};
static const auto &builtin_func_map() {
static std::map<jl_fptr_args_t, JuliaFunction*> builtins = {
Expand Down Expand Up @@ -1442,9 +1479,9 @@ static Value *get_last_age_field(jl_codectx_t &ctx);
static Value *get_current_signal_page(jl_codectx_t &ctx);
static void CreateTrap(IRBuilder<> &irbuilder, bool create_new_block = true);
static CallInst *emit_jlcall(jl_codectx_t &ctx, Function *theFptr, Value *theF,
const jl_cgval_t *args, size_t nargs, CallingConv::ID cc);
const jl_cgval_t *args, size_t nargs, JuliaFunction *trampoline);
static CallInst *emit_jlcall(jl_codectx_t &ctx, JuliaFunction *theFptr, Value *theF,
const jl_cgval_t *args, size_t nargs, CallingConv::ID cc);
const jl_cgval_t *args, size_t nargs, JuliaFunction *trampoline);
static Value *emit_f_is(jl_codectx_t &ctx, const jl_cgval_t &arg1, const jl_cgval_t &arg2,
Value *nullcheck1 = nullptr, Value *nullcheck2 = nullptr);
static jl_cgval_t emit_new_struct(jl_codectx_t &ctx, jl_value_t *ty, size_t nargs, const jl_cgval_t *argv, bool is_promotable=false);
Expand Down Expand Up @@ -3729,34 +3766,31 @@ static bool emit_builtin_call(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,

// Returns ctx.types().T_prjlvalue
static CallInst *emit_jlcall(jl_codectx_t &ctx, Function *theFptr, Value *theF,
const jl_cgval_t *argv, size_t nargs, CallingConv::ID cc)
const jl_cgval_t *argv, size_t nargs, JuliaFunction *trampoline)
{
++EmittedJLCalls;
Function *TheTrampoline = prepare_call(trampoline);
// emit arguments
SmallVector<Value*, 3> theArgs;
SmallVector<Type*, 3> argsT;
if (theF) {
theArgs.push_back(ctx.builder.CreateBitCast(theFptr,
TheTrampoline->getFunctionType()->getParamType(0)));
if (theF)
theArgs.push_back(theF);
argsT.push_back(ctx.types().T_prjlvalue);
}
for (size_t i = 0; i < nargs; i++) {
Value *arg = boxed(ctx, argv[i]);
theArgs.push_back(arg);
argsT.push_back(ctx.types().T_prjlvalue);
}
FunctionType *FTy = FunctionType::get(ctx.types().T_prjlvalue, argsT, false);
CallInst *result = ctx.builder.CreateCall(FTy,
ctx.builder.CreateBitCast(theFptr, FTy->getPointerTo()),
CallInst *result = ctx.builder.CreateCall(TheTrampoline->getFunctionType(),
TheTrampoline,
theArgs);
addRetAttr(result, Attribute::NonNull);
result->setCallingConv(cc);
return result;
}
// Returns ctx.types().T_prjlvalue
static CallInst *emit_jlcall(jl_codectx_t &ctx, JuliaFunction *theFptr, Value *theF,
const jl_cgval_t *argv, size_t nargs, CallingConv::ID cc)
const jl_cgval_t *argv, size_t nargs, JuliaFunction *trampoline)
{
return emit_jlcall(ctx, prepare_call(theFptr), theF, argv, nargs, cc);
return emit_jlcall(ctx, prepare_call(theFptr), theF, argv, nargs, trampoline);
}


Expand Down Expand Up @@ -3882,7 +3916,7 @@ static jl_cgval_t emit_call_specfun_boxed(jl_codectx_t &ctx, jl_value_t *jlretty
jl_Module->getOrInsertFunction(specFunctionObject, ctx.types().T_jlfunc).getCallee());
addRetAttr(theFptr, Attribute::NonNull);
theFptr->addFnAttr(Attribute::get(ctx.builder.getContext(), "thunk"));
Value *ret = emit_jlcall(ctx, theFptr, nullptr, argv, nargs, JLCALL_F_CC);
Value *ret = emit_jlcall(ctx, theFptr, nullptr, argv, nargs, julia_call);
return update_julia_type(ctx, mark_julia_type(ctx, ret, true, jlretty), inferred_retty);
}

Expand Down Expand Up @@ -3979,7 +4013,7 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, const
}
}
if (!handled) {
Value *r = emit_jlcall(ctx, jlinvoke_func, boxed(ctx, lival), argv, nargs, JLCALL_F2_CC);
Value *r = emit_jlcall(ctx, jlinvoke_func, boxed(ctx, lival), argv, nargs, julia_call2);
result = mark_julia_type(ctx, r, true, rt);
}
if (result.typ == jl_bottom_type)
Expand Down Expand Up @@ -4008,7 +4042,7 @@ static jl_cgval_t emit_invoke_modify(jl_codectx_t &ctx, jl_expr_t *ex, jl_value_
return ret;
auto it = builtin_func_map().find(jl_f_modifyfield_addr);
assert(it != builtin_func_map().end());
Value *oldnew = emit_jlcall(ctx, it->second, Constant::getNullValue(ctx.types().T_prjlvalue), &argv[1], nargs - 1, JLCALL_F_CC);
Value *oldnew = emit_jlcall(ctx, it->second, Constant::getNullValue(ctx.types().T_prjlvalue), &argv[1], nargs - 1, julia_call);
return mark_julia_type(ctx, oldnew, true, rt);
}
if (f.constant && jl_typeis(f.constant, jl_intrinsic_type)) {
Expand All @@ -4018,7 +4052,7 @@ static jl_cgval_t emit_invoke_modify(jl_codectx_t &ctx, jl_expr_t *ex, jl_value_
}

// emit function and arguments
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, nargs, JLCALL_F_CC);
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, nargs, julia_call);
return mark_julia_type(ctx, callval, true, rt);
}

Expand Down Expand Up @@ -4063,13 +4097,13 @@ static jl_cgval_t emit_call(jl_codectx_t &ctx, jl_expr_t *ex, jl_value_t *rt, bo
// special case for known builtin not handled by emit_builtin_call
auto it = builtin_func_map().find(jl_get_builtin_fptr(f.constant));
if (it != builtin_func_map().end()) {
Value *ret = emit_jlcall(ctx, it->second, Constant::getNullValue(ctx.types().T_prjlvalue), &argv[1], nargs - 1, JLCALL_F_CC);
Value *ret = emit_jlcall(ctx, it->second, Constant::getNullValue(ctx.types().T_prjlvalue), &argv[1], nargs - 1, julia_call);
return mark_julia_type(ctx, ret, true, rt);
}
}

// emit function and arguments
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, generic_argv, n_generic_args, JLCALL_F_CC);
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, generic_argv, n_generic_args, julia_call);
return mark_julia_type(ctx, callval, true, rt);
}

Expand Down Expand Up @@ -5087,7 +5121,7 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_
res.promotion_ssa = ssaidx_0based;
return res;
}
Value *val = emit_jlcall(ctx, jlnew_func, nullptr, argv, nargs, JLCALL_F_CC);
Value *val = emit_jlcall(ctx, jlnew_func, nullptr, argv, nargs, julia_call);
// temporarily mark as `Any`, expecting `emit_ssaval_assign` to update
// it to the inferred type.
return mark_julia_type(ctx, val, true, (jl_value_t*)jl_any_type);
Expand Down Expand Up @@ -5177,7 +5211,7 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_
}

return mark_julia_type(ctx,
emit_jlcall(ctx, jl_new_opaque_closure_jlcall_func, Constant::getNullValue(ctx.types().T_prjlvalue), argv.data(), nargs, JLCALL_F_CC),
emit_jlcall(ctx, jl_new_opaque_closure_jlcall_func, Constant::getNullValue(ctx.types().T_prjlvalue), argv.data(), nargs, julia_call),
true, jl_any_type);
}
else if (head == jl_exc_sym) {
Expand Down Expand Up @@ -5414,7 +5448,7 @@ static void emit_cfunc_invalidate(
}
}
assert(AI == gf_thunk->arg_end());
Value *gf_ret = emit_jlcall(ctx, target, nullptr, myargs, nargs, JLCALL_F_CC);
Value *gf_ret = emit_jlcall(ctx, target, nullptr, myargs, nargs, julia_call);
jl_cgval_t gf_retbox = mark_julia_type(ctx, gf_ret, true, jl_any_type);
if (cc != jl_returninfo_t::Boxed) {
emit_typecheck(ctx, gf_retbox, rettype, "cfunction");
Expand Down Expand Up @@ -5834,11 +5868,11 @@ static Function* gen_cfun_wrapper(
// for jlcall, we need to pass the function object even if it is a ghost.
Value *theF = boxed(ctx, inputargs[0]);
assert(theF);
ret_jlcall = emit_jlcall(ctx, theFptr, theF, &inputargs[1], nargs, JLCALL_F_CC);
ret_jlcall = emit_jlcall(ctx, theFptr, theF, &inputargs[1], nargs, julia_call);
ctx.builder.CreateBr(b_after);
ctx.builder.SetInsertPoint(b_generic);
}
Value *ret = emit_jlcall(ctx, jlapplygeneric_func, NULL, inputargs, nargs + 1, JLCALL_F_CC);
Value *ret = emit_jlcall(ctx, jlapplygeneric_func, NULL, inputargs, nargs + 1, julia_call);
if (age_ok) {
ctx.builder.CreateBr(b_after);
ctx.builder.SetInsertPoint(b_after);
Expand Down Expand Up @@ -7213,7 +7247,7 @@ static jl_llvm_functions_t
}
else {
restTuple = emit_jlcall(ctx, jltuple_func, Constant::getNullValue(ctx.types().T_prjlvalue),
vargs, ctx.nvargs, JLCALL_F_CC);
vargs, ctx.nvargs, julia_call);
jl_cgval_t tuple = mark_julia_type(ctx, restTuple, true, vi.value.typ);
emit_varinfo_assign(ctx, vi, tuple);
}
Expand Down Expand Up @@ -8371,6 +8405,8 @@ static void init_jit_functions(void)
add_named_global(gc_preserve_end_func, (void*)NULL);
add_named_global(pointer_from_objref_func, (void*)NULL);
add_named_global(except_enter_func, (void*)NULL);
add_named_global(julia_call, (void*)NULL);
add_named_global(julia_call2, (void*)NULL);

#ifdef _OS_WINDOWS_
#if defined(_CPU_X86_64_)
Expand Down
4 changes: 0 additions & 4 deletions src/codegen_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ namespace JuliaType {
}
}

// JLCALL with API arguments ([extra], arg0, arg1, arg2, ...) has the following ABI calling conventions defined:
#define JLCALL_F_CC (CallingConv::ID)37 // (jl_value_t *arg0, jl_value_t **argv, uint32_t nargv)
#define JLCALL_F2_CC (CallingConv::ID)38 // (jl_value_t *arg0, jl_value_t **argv, uint32_t nargv, jl_value_t *extra)

// return how many Tracked pointers are in T (count > 0),
// and if there is anything else in T (all == false)
struct CountTrackedPointers {
Expand Down
9 changes: 6 additions & 3 deletions src/llvm-gc-invariant-verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,15 @@ void GCInvariantVerifier::visitGetElementPtrInst(GetElementPtrInst &GEP) {
}

void GCInvariantVerifier::visitCallInst(CallInst &CI) {
CallingConv::ID CC = CI.getCallingConv();
if (CC == JLCALL_F_CC || CC == JLCALL_F2_CC) {
Function *Callee = CI.getCalledFunction();
if (Callee && (Callee->getName() == "julia.call" ||
Callee->getName() == "julia.call2")) {
bool First = true;
for (Value *Arg : CI.args()) {
Type *Ty = Arg->getType();
Check(Ty->isPointerTy() && cast<PointerType>(Ty)->getAddressSpace() == AddressSpace::Tracked,
Check(Ty->isPointerTy() && cast<PointerType>(Ty)->getAddressSpace() == (First ? 0 : AddressSpace::Tracked),
"Invalid derived pointer in jlcall", &CI);
First = false;
}
}
}
Expand Down
30 changes: 18 additions & 12 deletions src/llvm-late-gc-lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2285,7 +2285,6 @@ bool LateLowerGCFrame::CleanupIR(Function &F, State *S, bool *CFGModified) {
++it;
continue;
}
CallingConv::ID CC = CI->getCallingConv();
Value *callee = CI->getCalledOperand();
if (callee && (callee == gc_flush_func || callee == gc_preserve_begin_func
|| callee == gc_preserve_end_func)) {
Expand Down Expand Up @@ -2389,53 +2388,60 @@ bool LateLowerGCFrame::CleanupIR(Function &F, State *S, bool *CFGModified) {
ChangesMade = true;
++it;
continue;
} else if (CC == JLCALL_F_CC ||
CC == JLCALL_F2_CC) {
} else if ((call_func && callee == call_func) ||
(call2_func && callee == call2_func)) {
assert(T_prjlvalue);
size_t nargs = CI->arg_size();
size_t nframeargs = nargs;
if (CC == JLCALL_F_CC)
size_t nframeargs = nargs-1;
if (callee == call_func)
nframeargs -= 1;
else if (CC == JLCALL_F2_CC)
else if (callee == call2_func)
nframeargs -= 2;
SmallVector<Value*, 4> ReplacementArgs;
auto arg_it = CI->arg_begin();
assert(arg_it != CI->arg_end());
Value *new_callee = *(arg_it++);
assert(arg_it != CI->arg_end());
ReplacementArgs.push_back(*(arg_it++));
if (CC != JLCALL_F_CC) {
if (callee == call2_func) {
assert(arg_it != CI->arg_end());
ReplacementArgs.push_back(*(arg_it++));
}
maxframeargs = std::max(maxframeargs, nframeargs);
int slot = 0;
IRBuilder<> Builder (CI);
for (; arg_it != CI->arg_end(); ++arg_it) {
Builder.CreateAlignedStore(*arg_it,
// Julia emits IR with proper pointer types here, but because
// the julia.call signature is varargs, the optimizer is allowed
// to rewrite pointee types. It'll go away with opaque pointer
// types anyway.
Builder.CreateAlignedStore(Builder.CreateBitCast(*arg_it, T_prjlvalue),
Builder.CreateInBoundsGEP(T_prjlvalue, Frame, ConstantInt::get(T_int32, slot++)),
Align(sizeof(void*)));
}
ReplacementArgs.push_back(nframeargs == 0 ?
(llvm::Value*)ConstantPointerNull::get(T_pprjlvalue) :
(llvm::Value*)Frame);
ReplacementArgs.push_back(ConstantInt::get(T_int32, nframeargs));
if (CC == JLCALL_F2_CC) {
if (callee == call2_func) {
// move trailing arg to the end now
Value *front = ReplacementArgs.front();
ReplacementArgs.erase(ReplacementArgs.begin());
ReplacementArgs.push_back(front);
}
FunctionType *FTy;
if (CC == JLCALL_F_CC) // jl_fptr_args
if (callee == call_func) // jl_fptr_args
FTy = FunctionType::get(T_prjlvalue, {T_prjlvalue, T_pprjlvalue, T_int32}, false);
else // CC == JLCALL_F2_CC // jl_invoke
else // callee == call2_func // jl_invoke
FTy = FunctionType::get(T_prjlvalue, {T_prjlvalue, T_pprjlvalue, T_int32, T_prjlvalue}, false);
Value *newFptr = Builder.CreateBitCast(callee, FTy->getPointerTo());
Value *newFptr = Builder.CreateBitCast(new_callee, FTy->getPointerTo());
CallInst *NewCall = CallInst::Create(FTy, newFptr, ReplacementArgs, "", CI);
NewCall->setTailCallKind(CI->getTailCallKind());
auto old_attrs = CI->getAttributes();
NewCall->setAttributes(AttributeList::get(CI->getContext(),
getFnAttrs(old_attrs),
getRetAttrs(old_attrs), {}));
NewCall->takeName(CI);
NewCall->copyMetadata(*CI);
CI->replaceAllUsesWith(NewCall);
UpdatePtrNumbering(CI, NewCall, S);
Expand Down
2 changes: 1 addition & 1 deletion src/llvm-muladd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ static bool combineMulAdd(Function &F)
}
}
}
assert(!verifyFunction(F));
assert(!verifyFunction(F, &errs()));
return modified;
}

Expand Down
5 changes: 4 additions & 1 deletion src/llvm-pass-helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ JuliaPassContext::JuliaPassContext()
gc_preserve_begin_func(nullptr), gc_preserve_end_func(nullptr),
pointer_from_objref_func(nullptr), alloc_obj_func(nullptr),
typeof_func(nullptr), write_barrier_func(nullptr),
write_barrier_binding_func(nullptr), module(nullptr)
write_barrier_binding_func(nullptr), call_func(nullptr),
call2_func(nullptr), module(nullptr)
{
}

Expand All @@ -51,6 +52,8 @@ void JuliaPassContext::initFunctions(Module &M)
write_barrier_func = M.getFunction("julia.write_barrier");
write_barrier_binding_func = M.getFunction("julia.write_barrier_binding");
alloc_obj_func = M.getFunction("julia.gc_alloc_obj");
call_func = M.getFunction("julia.call");
call2_func = M.getFunction("julia.call2");
}

void JuliaPassContext::initAll(Module &M)
Expand Down
2 changes: 2 additions & 0 deletions src/llvm-pass-helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ struct JuliaPassContext {
llvm::Function *typeof_func;
llvm::Function *write_barrier_func;
llvm::Function *write_barrier_binding_func;
llvm::Function *call_func;
llvm::Function *call2_func;

// Creates a pass context. Type and function pointers
// are set to `nullptr`. Metadata nodes are initialized.
Expand Down
Loading

0 comments on commit 23f39f8

Please sign in to comment.