Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 70 additions & 16 deletions crates/vm/src/types/structseq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func,
builtins::{PyBaseExceptionRef, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef},
class::{PyClassImpl, StaticType},
function::{Either, PyComparisonValue},
function::{Either, FuncArgs, PyComparisonValue, PyMethodDef, PyMethodFlags},
iter::PyExactSizeIterator,
protocol::{PyMappingMethods, PySequenceMethods},
sliceable::{SequenceIndex, SliceableSequenceOp},
Expand All @@ -11,6 +11,15 @@ use crate::{
};
use std::sync::LazyLock;

const DEFAULT_STRUCTSEQ_REDUCE: PyMethodDef = PyMethodDef::new_const(
"__reduce__",
|zelf: PyRef<PyTuple>, vm: &VirtualMachine| -> PyTupleRef {
vm.new_tuple((zelf.class().to_owned(), (vm.ctx.new_tuple(zelf.to_vec()),)))
},
PyMethodFlags::METHOD,
None,
);

/// Create a new struct sequence instance from a sequence.
///
/// The class must have `n_sequence_fields` and `n_fields` attributes set
Expand Down Expand Up @@ -206,19 +215,13 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
};
let (body, suffix) =
if let Some(_guard) = rustpython_vm::recursion::ReprGuard::enter(vm, zelf.as_ref()) {
if field_names.len() == 1 {
let value = zelf.first().unwrap();
let formatted = format_field((value, field_names[0]))?;
(formatted, ",")
} else {
let fields: PyResult<Vec<_>> = zelf
.iter()
.map(|value| value.as_ref())
.zip(field_names.iter().copied())
.map(format_field)
.collect();
(fields?.join(", "), "")
}
let fields: PyResult<Vec<_>> = zelf
.iter()
.map(|value| value.as_ref())
.zip(field_names.iter().copied())
.map(format_field)
.collect();
(fields?.join(", "), "")
} else {
(String::new(), "...")
};
Expand All @@ -232,8 +235,45 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
}

#[pymethod]
fn __reduce__(zelf: PyRef<PyTuple>, vm: &VirtualMachine) -> PyTupleRef {
vm.new_tuple((zelf.class().to_owned(), (vm.ctx.new_tuple(zelf.to_vec()),)))
fn __replace__(zelf: PyRef<PyTuple>, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
if !args.args.is_empty() {
return Err(vm.new_type_error("__replace__() takes no positional arguments".to_owned()));
}

if Self::Data::UNNAMED_FIELDS_LEN > 0 {
return Err(vm.new_type_error(format!(
"__replace__() is not supported for {} because it has unnamed field(s)",
zelf.class().slot_name()
)));
}

let n_fields =
Self::Data::REQUIRED_FIELD_NAMES.len() + Self::Data::OPTIONAL_FIELD_NAMES.len();
let mut items: Vec<PyObjectRef> = zelf.as_slice()[..n_fields].to_vec();

let mut kwargs = args.kwargs.clone();

// Replace fields from kwargs
let all_field_names: Vec<&str> = Self::Data::REQUIRED_FIELD_NAMES
.iter()
.chain(Self::Data::OPTIONAL_FIELD_NAMES.iter())
.copied()
.collect();
for (i, &name) in all_field_names.iter().enumerate() {
if let Some(val) = kwargs.shift_remove(name) {
items[i] = val;
}
}

// Check for unexpected keyword arguments
if !kwargs.is_empty() {
let names: Vec<&str> = kwargs.keys().map(|k| k.as_str()).collect();
return Err(vm.new_type_error(format!("Got unexpected field name(s): {:?}", names)));
}

PyTuple::new_unchecked(items.into_boxed_slice())
.into_ref_with_type(vm, zelf.class().to_owned())
.map(Into::into)
}

#[pymethod]
Expand Down Expand Up @@ -327,6 +367,20 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
.slots
.richcompare
.store(Some(struct_sequence_richcompare));

// Default __reduce__: only set if not already overridden by the impl's extend_class.
// This allows struct sequences like sched_param to provide a custom __reduce__
// (equivalent to METH_COEXIST in structseq.c).
if !class
.attributes
.read()
.contains_key(ctx.intern_str("__reduce__"))
{
class.set_attr(
ctx.intern_str("__reduce__"),
DEFAULT_STRUCTSEQ_REDUCE.to_proper_method(class, ctx),
);
}
}
}

Expand Down
Loading