Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2281,7 +2281,6 @@ class Ints(enum.IntEnum):
self.assertEqual(Union[Literal[1], Literal[Ints.B], Literal[True]].__args__,
(Literal[1], Literal[Ints.B], Literal[True]))

@unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: types.UnionType[int, str] | float != types.UnionType[int, str, float]
def test_allow_non_types_in_or(self):
# gh-140348: Test that using | with a Union object allows things that are
# not allowed by is_unionable().
Expand Down
7 changes: 1 addition & 6 deletions crates/vm/src/builtins/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2048,12 +2048,7 @@ pub(crate) fn call_slot_new(
}

pub(crate) fn or_(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
if !union_::is_unionable(zelf.clone(), vm) || !union_::is_unionable(other.clone(), vm) {
return Ok(vm.ctx.not_implemented());
}

let tuple = PyTuple::new_ref(vec![zelf, other], &vm.ctx);
union_::make_union(&tuple, vm)
union_::or_op(zelf, other, vm)
}

fn take_next_base(bases: &mut [Vec<PyTypeRef>]) -> Option<PyTypeRef> {
Expand Down
34 changes: 32 additions & 2 deletions crates/vm/src/builtins/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
convert::ToPyObject,
function::PyComparisonValue,
protocol::{PyMappingMethods, PyNumberMethods},
stdlib::typing::TypeAliasType,
stdlib::typing::{TypeAliasType, call_typing_func_object},
types::{AsMapping, AsNumber, Comparable, GetAttr, Hashable, PyComparisonOp, Representable},
};
use alloc::fmt;
Expand Down Expand Up @@ -193,7 +193,7 @@ impl PyUnion {
}
}

pub fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool {
fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool {
let cls = obj.class();
cls.is(vm.ctx.types.none_type)
|| obj.downcastable::<PyType>()
Expand All @@ -202,6 +202,36 @@ pub fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool {
|| obj.downcast_ref::<TypeAliasType>().is_some()
}

fn type_check(arg: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
// Fast path to avoid calling into typing.py
if is_unionable(arg.clone(), vm) {
return Ok(arg);
}
let message_str: PyObjectRef = vm
.ctx
.new_str("Union[arg, ...]: each arg must be a type.")
.into();
call_typing_func_object(vm, "_type_check", (arg, message_str))
}

fn has_union_operands(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> bool {
let union_type = vm.ctx.types.union_type;
a.class().is(union_type) || b.class().is(union_type)
}

pub fn or_op(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
if !has_union_operands(zelf.clone(), other.clone(), vm)
&& (!is_unionable(zelf.clone(), vm) || !is_unionable(other.clone(), vm))
{
return Ok(vm.ctx.not_implemented());
}

let left = type_check(zelf, vm)?;
let right = type_check(other, vm)?;
let tuple = PyTuple::new_ref(vec![left, right], &vm.ctx);
make_union(&tuple, vm)
}

fn make_parameters(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
let parameters = genericalias::make_parameters(args, vm);
let result = dedup_and_flatten_args(&parameters, vm)?;
Expand Down
25 changes: 8 additions & 17 deletions crates/vm/src/stdlib/typevar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,21 @@ pub use typevar::*;
pub(crate) mod typevar {
use crate::{
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
builtins::{PyTuple, PyTupleRef, PyType, PyTypeRef, make_union, pystr::AsPyStr},
builtins::{PyTuple, PyTupleRef, PyType, PyTypeRef, make_union},
common::lock::PyMutex,
function::{FuncArgs, IntoFuncArgs, PyComparisonValue},
function::{FuncArgs, PyComparisonValue},
protocol::PyNumberMethods,
stdlib::typing::call_typing_func_object,
types::{AsNumber, Comparable, Constructor, Iterable, PyComparisonOp, Representable},
};

pub(crate) fn _call_typing_func_object<'a>(
vm: &VirtualMachine,
func_name: impl AsPyStr<'a>,
args: impl IntoFuncArgs,
) -> PyResult {
let module = vm.import("typing", 0)?;
let func = module.get_attr(func_name.as_pystr(&vm.ctx), vm)?;
func.call(args, vm)
}

fn type_check(arg: PyObjectRef, msg: &str, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
// Calling typing.py here leads to bootstrapping problems
if vm.is_none(&arg) {
return Ok(arg.class().to_owned().into());
}
let message_str: PyObjectRef = vm.ctx.new_str(msg).into();
_call_typing_func_object(vm, "_type_check", (arg, message_str))
call_typing_func_object(vm, "_type_check", (arg, message_str))
}

/// Get the module of the caller frame, similar to CPython's caller() function.
Expand Down Expand Up @@ -169,7 +160,7 @@ pub(crate) mod typevar {
vm: &VirtualMachine,
) -> PyResult {
let self_obj: PyObjectRef = zelf.into();
_call_typing_func_object(vm, "_typevar_subst", (self_obj, arg))
call_typing_func_object(vm, "_typevar_subst", (self_obj, arg))
}

#[pymethod]
Expand Down Expand Up @@ -514,7 +505,7 @@ pub(crate) mod typevar {
vm: &VirtualMachine,
) -> PyResult {
let self_obj: PyObjectRef = zelf.into();
_call_typing_func_object(vm, "_paramspec_subst", (self_obj, arg))
call_typing_func_object(vm, "_paramspec_subst", (self_obj, arg))
}

#[pymethod]
Expand All @@ -525,7 +516,7 @@ pub(crate) mod typevar {
vm: &VirtualMachine,
) -> PyResult {
let self_obj: PyObjectRef = zelf.into();
_call_typing_func_object(vm, "_paramspec_prepare_subst", (self_obj, alias, args))
call_typing_func_object(vm, "_paramspec_prepare_subst", (self_obj, alias, args))
}
}

Expand Down Expand Up @@ -711,7 +702,7 @@ pub(crate) mod typevar {
vm: &VirtualMachine,
) -> PyResult {
let self_obj: PyObjectRef = zelf.into();
_call_typing_func_object(vm, "_typevartuple_prepare_subst", (self_obj, alias, args))
call_typing_func_object(vm, "_typevartuple_prepare_subst", (self_obj, alias, args))
}
}

Expand Down
29 changes: 16 additions & 13 deletions crates/vm/src/stdlib/typing.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
// spell-checker:ignore typevarobject funcobj
use crate::{Context, class::PyClassImpl};
use crate::{
Context, PyResult, VirtualMachine, builtins::pystr::AsPyStr, class::PyClassImpl,
function::IntoFuncArgs,
};

pub use crate::stdlib::typevar::{
Generic, ParamSpec, ParamSpecArgs, ParamSpecKwargs, TypeVar, TypeVarTuple,
Expand All @@ -13,26 +16,26 @@ pub fn init(ctx: &Context) {
NoDefault::extend_class(ctx, ctx.types.typing_no_default_type);
}

pub fn call_typing_func_object<'a>(
vm: &VirtualMachine,
func_name: impl AsPyStr<'a>,
args: impl IntoFuncArgs,
) -> PyResult {
let module = vm.import("typing", 0)?;
let func = module.get_attr(func_name.as_pystr(&vm.ctx), vm)?;
func.call(args, vm)
}

#[pymodule(name = "_typing", with(super::typevar::typevar))]
pub(crate) mod decl {
use crate::{
Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
builtins::{PyStrRef, PyTupleRef, PyType, PyTypeRef, pystr::AsPyStr, type_},
function::{FuncArgs, IntoFuncArgs},
builtins::{PyStrRef, PyTupleRef, PyType, PyTypeRef, type_},
function::FuncArgs,
protocol::PyNumberMethods,
types::{AsNumber, Constructor, Representable},
};

pub(crate) fn _call_typing_func_object<'a>(
vm: &VirtualMachine,
func_name: impl AsPyStr<'a>,
args: impl IntoFuncArgs,
) -> PyResult {
let module = vm.import("typing", 0)?;
let func = module.get_attr(func_name.as_pystr(&vm.ctx), vm)?;
func.call(args, vm)
}

#[pyfunction]
pub(crate) fn _idfunc(args: FuncArgs, _vm: &VirtualMachine) -> PyObjectRef {
args.args[0].clone()
Expand Down