Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix a bunch of random tests
  • Loading branch information
coolreader18 committed Feb 20, 2025
commit bd786a959722db69c65ecc28895228369045833c
25 changes: 0 additions & 25 deletions Lib/test/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ def randomlist(self, n):
"""Helper function to make a list of random numbers"""
return [self.gen.random() for i in range(n)]

# TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate'
@unittest.expectedFailure
def test_autoseed(self):
self.gen.seed()
state1 = self.gen.getstate()
Expand All @@ -32,8 +30,6 @@ def test_autoseed(self):
state2 = self.gen.getstate()
self.assertNotEqual(state1, state2)

# TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate'
@unittest.expectedFailure
def test_saverestore(self):
N = 1000
self.gen.seed()
Expand All @@ -60,7 +56,6 @@ def __hash__(self):
self.assertRaises(TypeError, self.gen.seed, 1, 2, 3, 4)
self.assertRaises(TypeError, type(self.gen), [])

@unittest.skip("TODO: RUSTPYTHON, TypeError: Expected type 'bytes', not 'bytearray'")
def test_seed_no_mutate_bug_44018(self):
a = bytearray(b'1234')
self.gen.seed(a)
Expand Down Expand Up @@ -386,8 +381,6 @@ def test_getrandbits(self):
self.assertRaises(ValueError, self.gen.getrandbits, -1)
self.assertRaises(TypeError, self.gen.getrandbits, 10.1)

# TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate'
@unittest.expectedFailure
def test_pickling(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
state = pickle.dumps(self.gen, proto)
Expand All @@ -396,8 +389,6 @@ def test_pickling(self):
restoredseq = [newgen.random() for i in range(10)]
self.assertEqual(origseq, restoredseq)

# TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate'
@unittest.expectedFailure
def test_bug_1727780(self):
# verify that version-2-pickles can be loaded
# fine, whether they are created on 32-bit or 64-bit
Expand Down Expand Up @@ -600,11 +591,6 @@ def test_bug_42008(self):
class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
gen = random.Random()

# TODO: RUSTPYTHON, TypeError: Expected type 'bytes', not 'bytearray'
@unittest.expectedFailure
def test_seed_no_mutate_bug_44018(self): # TODO: RUSTPYTHON, remove when this passes
super().test_seed_no_mutate_bug_44018() # TODO: RUSTPYTHON, remove when this passes

def test_guaranteed_stable(self):
# These sequences are guaranteed to stay the same across versions of python
self.gen.seed(3456147, version=1)
Expand Down Expand Up @@ -675,8 +661,6 @@ def test_bug_31482(self):
def test_setstate_first_arg(self):
self.assertRaises(ValueError, self.gen.setstate, (1, None, None))

# TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate'
@unittest.expectedFailure
def test_setstate_middle_arg(self):
start_state = self.gen.getstate()
# Wrong type, s/b tuple
Expand Down Expand Up @@ -1282,15 +1266,6 @@ def test_betavariate_return_zero(self, gammavariate_mock):


class TestRandomSubclassing(unittest.TestCase):
# TODO: RUSTPYTHON Unexpected keyword argument newarg
@unittest.expectedFailure
def test_random_subclass_with_kwargs(self):
# SF bug #1486663 -- this used to erroneously raise a TypeError
class Subclass(random.Random):
def __init__(self, newarg=None):
random.Random.__init__(self)
Subclass(newarg=1)

def test_subclasses_overriding_methods(self):
# Subclasses with an overridden random, but only the original
# getrandbits method should not rely on getrandbits in for randrange,
Expand Down
130 changes: 61 additions & 69 deletions stdlib/src/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,73 +6,37 @@ pub(crate) use _random::make_module;
mod _random {
use crate::common::lock::PyMutex;
use crate::vm::{
builtins::{PyInt, PyTypeRef},
builtins::{PyInt, PyTupleRef},
convert::ToPyException,
function::OptionalOption,
types::Constructor,
PyObjectRef, PyPayload, PyResult, VirtualMachine,
types::{Constructor, Initializer},
PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
};
use itertools::Itertools;
use malachite_bigint::{BigInt, BigUint, Sign};
use mt19937::MT19937;
use num_traits::{Signed, Zero};
use rand::{rngs::StdRng, RngCore, SeedableRng};

#[derive(Debug)]
enum PyRng {
Std(Box<StdRng>),
MT(Box<mt19937::MT19937>),
}

impl Default for PyRng {
fn default() -> Self {
PyRng::Std(Box::new(StdRng::from_os_rng()))
}
}

impl RngCore for PyRng {
fn next_u32(&mut self) -> u32 {
match self {
Self::Std(s) => s.next_u32(),
Self::MT(m) => m.next_u32(),
}
}
fn next_u64(&mut self) -> u64 {
match self {
Self::Std(s) => s.next_u64(),
Self::MT(m) => m.next_u64(),
}
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
match self {
Self::Std(s) => s.fill_bytes(dest),
Self::MT(m) => m.fill_bytes(dest),
}
}
}
use rand::{RngCore, SeedableRng};
use rustpython_vm::types::DefaultConstructor;

#[pyattr]
#[pyclass(name = "Random")]
#[derive(Debug, PyPayload)]
#[derive(Debug, PyPayload, Default)]
struct PyRandom {
rng: PyMutex<PyRng>,
rng: PyMutex<MT19937>,
}

impl Constructor for PyRandom {
type Args = OptionalOption<PyObjectRef>;
impl DefaultConstructor for PyRandom {}

fn py_new(
cls: PyTypeRef,
// TODO: use x as the seed.
_x: Self::Args,
vm: &VirtualMachine,
) -> PyResult {
PyRandom {
rng: PyMutex::default(),
}
.into_ref_with_type(vm, cls)
.map(Into::into)
impl Initializer for PyRandom {
type Args = OptionalOption;

fn init(zelf: PyRef<Self>, x: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
zelf.seed(x, vm)
}
}

#[pyclass(flags(BASETYPE), with(Constructor))]
#[pyclass(flags(BASETYPE), with(Constructor, Initializer))]
impl PyRandom {
#[pymethod]
fn random(&self) -> f64 {
Expand All @@ -82,9 +46,8 @@ mod _random {

#[pymethod]
fn seed(&self, n: OptionalOption<PyObjectRef>, vm: &VirtualMachine) -> PyResult<()> {
let new_rng = n
.flatten()
.map(|n| {
*self.rng.lock() = match n.flatten() {
Some(n) => {
// Fallback to using hash if object isn't Int-like.
let (_, mut key) = match n.downcast::<PyInt>() {
Ok(n) => n.as_bigint().abs(),
Expand All @@ -95,27 +58,21 @@ mod _random {
key.reverse();
}
let key = if key.is_empty() { &[0] } else { key.as_slice() };
Ok(PyRng::MT(Box::new(mt19937::MT19937::new_with_slice_seed(
key,
))))
})
.transpose()?
.unwrap_or_default();

*self.rng.lock() = new_rng;
MT19937::new_with_slice_seed(key)
}
None => MT19937::try_from_os_rng()
.map_err(|e| std::io::Error::from(e).to_pyexception(vm))?,
};
Ok(())
}

#[pymethod]
fn getrandbits(&self, k: isize, vm: &VirtualMachine) -> PyResult<BigInt> {
match k {
k if k < 0 => {
Err(vm.new_value_error("number of bits must be non-negative".to_owned()))
}
..0 => Err(vm.new_value_error("number of bits must be non-negative".to_owned())),
0 => Ok(BigInt::zero()),
_ => {
mut k => {
let mut rng = self.rng.lock();
let mut k = k;
let mut gen_u32 = |k| {
let r = rng.next_u32();
if k < 32 {
Expand Down Expand Up @@ -145,5 +102,40 @@ mod _random {
}
}
}

#[pymethod]
fn getstate(&self, vm: &VirtualMachine) -> PyTupleRef {
let rng = self.rng.lock();
vm.new_tuple(
rng.get_state()
.iter()
.copied()
.chain([rng.get_index() as u32])
.map(|i| vm.ctx.new_int(i).into())
.collect::<Vec<PyObjectRef>>(),
)
}

#[pymethod]
fn setstate(&self, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
let state: &[_; mt19937::N + 1] = state
.as_slice()
.try_into()
.map_err(|_| vm.new_value_error("state vector is the wrong size".to_owned()))?;
let (index, state) = state.split_last().unwrap();
let index: usize = index.try_to_value(vm)?;
if index > mt19937::N {
return Err(vm.new_value_error("invalid state".to_owned()));
}
let state: [u32; mt19937::N] = state
.iter()
.map(|i| i.try_to_value(vm))
.process_results(|it| it.collect_array())?
.unwrap();
let mut rng = self.rng.lock();
rng.set_state(&state);
rng.set_index(index);
Ok(())
}
}
}
5 changes: 1 addition & 4 deletions vm/src/stdlib/os.rs
Original file line number Diff line number Diff line change
Expand Up @@ -978,10 +978,7 @@ pub(super) mod _os {
return Err(vm.new_value_error("negative argument not allowed".to_owned()));
}
let mut buf = vec![0u8; size as usize];
getrandom::fill(&mut buf).map_err(|e| match e.raw_os_error() {
Some(errno) => io::Error::from_raw_os_error(errno).into_pyexception(vm),
None => vm.new_os_error("Getting random failed".to_owned()),
})?;
getrandom::fill(&mut buf).map_err(|e| io::Error::from(e).into_pyexception(vm))?;
Ok(buf)
}

Expand Down
Loading