Skip to content

Commit e0545a0

Browse files
authored
ZJIT: Bail out of HIR translation if we can't handle a send flag (#13182)
Bail out of HIR translation if we can't handle a send flag
1 parent 871d07a commit e0545a0

File tree

2 files changed

+134
-2
lines changed

2 files changed

+134
-2
lines changed

zjit/src/cruby.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,12 +887,15 @@ mod manual_defs {
887887
// From vm_callinfo.h - uses calculation that seems to confuse bindgen
888888
pub const VM_CALL_ARGS_SIMPLE: u32 = 1 << VM_CALL_ARGS_SIMPLE_bit;
889889
pub const VM_CALL_ARGS_SPLAT: u32 = 1 << VM_CALL_ARGS_SPLAT_bit;
890+
pub const VM_CALL_ARGS_SPLAT_MUT: u32 = 1 << VM_CALL_ARGS_SPLAT_MUT_bit;
890891
pub const VM_CALL_ARGS_BLOCKARG: u32 = 1 << VM_CALL_ARGS_BLOCKARG_bit;
891892
pub const VM_CALL_FORWARDING: u32 = 1 << VM_CALL_FORWARDING_bit;
892893
pub const VM_CALL_FCALL: u32 = 1 << VM_CALL_FCALL_bit;
893894
pub const VM_CALL_KWARG: u32 = 1 << VM_CALL_KWARG_bit;
894895
pub const VM_CALL_KW_SPLAT: u32 = 1 << VM_CALL_KW_SPLAT_bit;
896+
pub const VM_CALL_KW_SPLAT_MUT: u32 = 1 << VM_CALL_KW_SPLAT_MUT_bit;
895897
pub const VM_CALL_TAILCALL: u32 = 1 << VM_CALL_TAILCALL_bit;
898+
pub const VM_CALL_SUPER : u32 = 1 << VM_CALL_SUPER_bit;
896899
pub const VM_CALL_ZSUPER : u32 = 1 << VM_CALL_ZSUPER_bit;
897900
pub const VM_CALL_OPT_SEND : u32 = 1 << VM_CALL_OPT_SEND_bit;
898901

zjit/src/hir.rs

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,7 +1410,7 @@ impl<'a> std::fmt::Display for FunctionPrinter<'a> {
14101410
}
14111411
}
14121412

1413-
#[derive(Debug, Clone)]
1413+
#[derive(Debug, Clone, PartialEq)]
14141414
pub struct FrameState {
14151415
iseq: IseqPtr,
14161416
insn_idx: usize,
@@ -1556,10 +1556,26 @@ fn compute_jump_targets(iseq: *const rb_iseq_t) -> Vec<u32> {
15561556
result
15571557
}
15581558

1559-
#[derive(Debug)]
1559+
#[derive(Debug, PartialEq)]
1560+
pub enum CallType {
1561+
Splat,
1562+
BlockArg,
1563+
Kwarg,
1564+
KwSplat,
1565+
Tailcall,
1566+
Super,
1567+
Zsuper,
1568+
OptSend,
1569+
KwSplatMut,
1570+
SplatMut,
1571+
Forwarding,
1572+
}
1573+
1574+
#[derive(Debug, PartialEq)]
15601575
pub enum ParseError {
15611576
StackUnderflow(FrameState),
15621577
UnknownOpcode(String),
1578+
UnhandledCallType(CallType),
15631579
}
15641580

15651581
fn num_lead_params(iseq: *const rb_iseq_t) -> usize {
@@ -1573,6 +1589,22 @@ fn num_locals(iseq: *const rb_iseq_t) -> usize {
15731589
(unsafe { get_iseq_body_local_table_size(iseq) }) as usize
15741590
}
15751591

1592+
/// If we can't handle the type of send (yet), bail out.
1593+
fn filter_translatable_calls(flag: u32) -> Result<(), ParseError> {
1594+
if (flag & VM_CALL_KW_SPLAT_MUT) != 0 { return Err(ParseError::UnhandledCallType(CallType::KwSplatMut)); }
1595+
if (flag & VM_CALL_ARGS_SPLAT_MUT) != 0 { return Err(ParseError::UnhandledCallType(CallType::SplatMut)); }
1596+
if (flag & VM_CALL_ARGS_SPLAT) != 0 { return Err(ParseError::UnhandledCallType(CallType::Splat)); }
1597+
if (flag & VM_CALL_KW_SPLAT) != 0 { return Err(ParseError::UnhandledCallType(CallType::KwSplat)); }
1598+
if (flag & VM_CALL_ARGS_BLOCKARG) != 0 { return Err(ParseError::UnhandledCallType(CallType::BlockArg)); }
1599+
if (flag & VM_CALL_KWARG) != 0 { return Err(ParseError::UnhandledCallType(CallType::Kwarg)); }
1600+
if (flag & VM_CALL_TAILCALL) != 0 { return Err(ParseError::UnhandledCallType(CallType::Tailcall)); }
1601+
if (flag & VM_CALL_SUPER) != 0 { return Err(ParseError::UnhandledCallType(CallType::Super)); }
1602+
if (flag & VM_CALL_ZSUPER) != 0 { return Err(ParseError::UnhandledCallType(CallType::Zsuper)); }
1603+
if (flag & VM_CALL_OPT_SEND) != 0 { return Err(ParseError::UnhandledCallType(CallType::OptSend)); }
1604+
if (flag & VM_CALL_FORWARDING) != 0 { return Err(ParseError::UnhandledCallType(CallType::Forwarding)); }
1605+
Ok(())
1606+
}
1607+
15761608
/// Compile ISEQ into High-level IR
15771609
pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
15781610
let mut fun = Function::new(iseq);
@@ -1757,6 +1789,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
17571789
// NB: opt_neq has two cd; get_arg(0) is for eq and get_arg(1) is for neq
17581790
let cd: *const rb_call_data = get_arg(pc, 1).as_ptr();
17591791
let call_info = unsafe { rb_get_call_data_ci(cd) };
1792+
filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?;
17601793
let argc = unsafe { vm_ci_argc((*cd).ci) };
17611794

17621795

@@ -1797,6 +1830,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
17971830
YARVINSN_opt_send_without_block => {
17981831
let cd: *const rb_call_data = get_arg(pc, 0).as_ptr();
17991832
let call_info = unsafe { rb_get_call_data_ci(cd) };
1833+
filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?;
18001834
let argc = unsafe { vm_ci_argc((*cd).ci) };
18011835

18021836

@@ -1819,6 +1853,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
18191853
let cd: *const rb_call_data = get_arg(pc, 0).as_ptr();
18201854
let blockiseq: IseqPtr = get_arg(pc, 1).as_iseq();
18211855
let call_info = unsafe { rb_get_call_data_ci(cd) };
1856+
filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?;
18221857
let argc = unsafe { vm_ci_argc((*cd).ci) };
18231858

18241859
let method_name = unsafe {
@@ -2134,6 +2169,16 @@ mod tests {
21342169
expected_hir.assert_eq(&actual_hir);
21352170
}
21362171

2172+
#[track_caller]
2173+
fn assert_compile_fails(method: &str, reason: ParseError) {
2174+
let iseq = crate::cruby::with_rubyvm(|| get_method_iseq(method));
2175+
unsafe { crate::cruby::rb_zjit_profile_disable(iseq) };
2176+
let result = iseq_to_hir(iseq);
2177+
assert!(result.is_err(), "Expected an error but succesfully compiled to HIR");
2178+
assert_eq!(result.unwrap_err(), reason);
2179+
}
2180+
2181+
21372182
#[test]
21382183
fn test_putobject() {
21392184
eval("def test = 123");
@@ -2610,6 +2655,90 @@ mod tests {
26102655
Return v13
26112656
"#]]);
26122657
}
2658+
2659+
#[test]
2660+
fn test_cant_compile_splat() {
2661+
eval("
2662+
def test(a) = foo(*a)
2663+
");
2664+
assert_compile_fails("test", ParseError::UnknownOpcode("splatarray".into()))
2665+
}
2666+
2667+
#[test]
2668+
fn test_cant_compile_block_arg() {
2669+
eval("
2670+
def test(a) = foo(&a)
2671+
");
2672+
assert_compile_fails("test", ParseError::UnhandledCallType(CallType::BlockArg))
2673+
}
2674+
2675+
#[test]
2676+
fn test_cant_compile_kwarg() {
2677+
eval("
2678+
def test(a) = foo(a: 1)
2679+
");
2680+
assert_compile_fails("test", ParseError::UnhandledCallType(CallType::Kwarg))
2681+
}
2682+
2683+
#[test]
2684+
fn test_cant_compile_kw_splat() {
2685+
eval("
2686+
def test(a) = foo(**a)
2687+
");
2688+
assert_compile_fails("test", ParseError::UnhandledCallType(CallType::KwSplat))
2689+
}
2690+
2691+
// TODO(max): Figure out how to generate a call with TAILCALL flag
2692+
2693+
#[test]
2694+
fn test_cant_compile_super() {
2695+
eval("
2696+
def test = super()
2697+
");
2698+
assert_compile_fails("test", ParseError::UnknownOpcode("invokesuper".into()))
2699+
}
2700+
2701+
#[test]
2702+
fn test_cant_compile_zsuper() {
2703+
eval("
2704+
def test = super
2705+
");
2706+
assert_compile_fails("test", ParseError::UnknownOpcode("invokesuper".into()))
2707+
}
2708+
2709+
#[test]
2710+
fn test_cant_compile_super_forward() {
2711+
eval("
2712+
def test(...) = super(...)
2713+
");
2714+
assert_compile_fails("test", ParseError::UnknownOpcode("invokesuperforward".into()))
2715+
}
2716+
2717+
// TODO(max): Figure out how to generate a call with OPT_SEND flag
2718+
2719+
#[test]
2720+
fn test_cant_compile_kw_splat_mut() {
2721+
eval("
2722+
def test(a) = foo **a, b: 1
2723+
");
2724+
assert_compile_fails("test", ParseError::UnknownOpcode("putspecialobject".into()))
2725+
}
2726+
2727+
#[test]
2728+
fn test_cant_compile_splat_mut() {
2729+
eval("
2730+
def test(*) = foo *, 1
2731+
");
2732+
assert_compile_fails("test", ParseError::UnknownOpcode("splatarray".into()))
2733+
}
2734+
2735+
#[test]
2736+
fn test_cant_compile_forwarding() {
2737+
eval("
2738+
def test(...) = foo(...)
2739+
");
2740+
assert_compile_fails("test", ParseError::UnknownOpcode("sendforward".into()))
2741+
}
26132742
}
26142743

26152744
#[cfg(test)]

0 commit comments

Comments
 (0)