Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions crates/stdlib/src/contextvars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ thread_local! {
mod _contextvars {
use crate::vm::{
AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func,
builtins::{PyGenericAlias, PyStrRef, PyType, PyTypeRef},
builtins::{PyGenericAlias, PyList, PyStrRef, PyType, PyTypeRef},
class::StaticType,
common::hash::PyHash,
function::{ArgCallable, FuncArgs, OptionalArg},
protocol::{PyMappingMethods, PySequenceMethods},
types::{AsMapping, AsSequence, Constructor, Hashable, Representable},
types::{AsMapping, AsSequence, Constructor, Hashable, Iterable, Representable},
};
use core::{
cell::{Cell, RefCell, UnsafeCell},
Expand Down Expand Up @@ -163,7 +163,7 @@ mod _contextvars {
}
}

#[pyclass(with(Constructor, AsMapping, AsSequence))]
#[pyclass(with(Constructor, AsMapping, AsSequence, Iterable))]
impl PyContext {
#[pymethod]
fn run(
Expand Down Expand Up @@ -205,11 +205,6 @@ mod _contextvars {
self.borrow_vars().len()
}

#[pymethod]
fn __iter__(&self) -> PyResult {
unimplemented!("Context.__iter__ is currently under construction")
}

#[pymethod]
fn get(
&self,
Expand Down Expand Up @@ -238,6 +233,15 @@ mod _contextvars {
let vars = zelf.borrow_vars();
vars.values().map(|value| value.to_owned()).collect()
}

// TODO: wrong return type
#[pymethod]
fn items(zelf: PyRef<Self>, vm: &VirtualMachine) -> Vec<PyObjectRef> {
let vars = zelf.borrow_vars();
vars.iter()
.map(|(k, v)| vm.ctx.new_tuple(vec![k.clone().into(), v.clone()]).into())
.collect()
}
}

impl Constructor for PyContext {
Expand Down Expand Up @@ -281,6 +285,15 @@ mod _contextvars {
}
}

impl Iterable for PyContext {
fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
let vars = zelf.borrow_vars();
let keys: Vec<PyObjectRef> = vars.keys().map(|k| k.clone().into()).collect();
let list = vm.ctx.new_list(keys);
<PyList as Iterable>::iter(list, vm)
}
}

#[pyattr]
#[pyclass(name, traverse)]
#[derive(PyPayload)]
Expand Down Expand Up @@ -574,6 +587,22 @@ mod _contextvars {
) -> PyGenericAlias {
PyGenericAlias::from_args(cls, args, vm)
}

#[pymethod]
fn __enter__(zelf: PyRef<Self>) -> PyRef<Self> {
zelf
}

#[pymethod]
fn __exit__(
zelf: &Py<Self>,
_ty: PyObjectRef,
_val: PyObjectRef,
_tb: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<()> {
ContextVar::reset(&zelf.var, zelf.to_owned(), vm)
}
}

impl Constructor for ContextToken {
Expand Down
48 changes: 47 additions & 1 deletion crates/vm/src/stdlib/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub(crate) use _thread::{
pub(crate) mod _thread {
use crate::{
AsObject, Py, PyPayload, PyRef, PyResult, VirtualMachine,
builtins::{PyDictRef, PyStr, PyTupleRef, PyType, PyTypeRef},
builtins::{PyDictRef, PyStr, PyStrRef, PyTupleRef, PyType, PyTypeRef},
frame::FrameRef,
function::{ArgCallable, Either, FuncArgs, KwArgs, OptionalArg, PySetterValue},
types::{Constructor, GetAttr, Representable, SetAttr},
Expand Down Expand Up @@ -260,6 +260,11 @@ pub(crate) mod _thread {
Ok(())
}

#[pymethod]
fn locked(&self) -> bool {
self.mu.is_locked()
}

#[pymethod]
fn _is_owned(&self) -> bool {
self.mu.is_owned_by_current_thread()
Expand Down Expand Up @@ -293,6 +298,47 @@ pub(crate) mod _thread {
current_thread_id()
}

/// Set the name of the current thread
#[pyfunction]
fn set_name(name: PyStrRef) {
#[cfg(target_os = "linux")]
{
use std::ffi::CString;
if let Ok(c_name) = CString::new(name.as_str()) {
// pthread_setname_np on Linux has a 16-byte limit including null terminator
// TODO: Potential UTF-8 boundary issue when truncating thread name on Linux.
// https://github.com/RustPython/RustPython/pull/6726/changes#r2689379171
let truncated = if c_name.as_bytes().len() > 15 {
CString::new(&c_name.as_bytes()[..15]).unwrap_or(c_name)
} else {
c_name
};
unsafe {
libc::pthread_setname_np(libc::pthread_self(), truncated.as_ptr());
}
}
Comment on lines 304 to 319
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential UTF-8 boundary issue when truncating thread name on Linux.

The truncation at byte 15 may split a multi-byte UTF-8 character, producing an invalid UTF-8 sequence. While CString::new won't fail (it only checks for interior nulls), passing invalid UTF-8 to the OS could cause unexpected behavior.

Consider truncating at a valid UTF-8 character boundary:

Proposed fix
         #[cfg(target_os = "linux")]
         {
             use std::ffi::CString;
             if let Ok(c_name) = CString::new(name.as_str()) {
                 // pthread_setname_np on Linux has a 16-byte limit including null terminator
-                let truncated = if c_name.as_bytes().len() > 15 {
-                    CString::new(&c_name.as_bytes()[..15]).unwrap_or(c_name)
+                let truncated = if name.as_str().len() > 15 {
+                    // Find a valid UTF-8 boundary at or before byte 15
+                    let s = name.as_str();
+                    let mut end = 15;
+                    while end > 0 && !s.is_char_boundary(end) {
+                        end -= 1;
+                    }
+                    CString::new(&s[..end]).unwrap_or(c_name)
                 } else {
                     c_name
                 };
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#[cfg(target_os = "linux")]
{
use std::ffi::CString;
if let Ok(c_name) = CString::new(name.as_str()) {
// pthread_setname_np on Linux has a 16-byte limit including null terminator
let truncated = if c_name.as_bytes().len() > 15 {
CString::new(&c_name.as_bytes()[..15]).unwrap_or(c_name)
} else {
c_name
};
unsafe {
libc::pthread_setname_np(libc::pthread_self(), truncated.as_ptr());
}
}
#[cfg(target_os = "linux")]
{
use std::ffi::CString;
if let Ok(c_name) = CString::new(name.as_str()) {
// pthread_setname_np on Linux has a 16-byte limit including null terminator
let truncated = if name.as_str().len() > 15 {
// Find a valid UTF-8 boundary at or before byte 15
let s = name.as_str();
let mut end = 15;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
CString::new(&s[..end]).unwrap_or(c_name)
} else {
c_name
};
unsafe {
libc::pthread_setname_np(libc::pthread_self(), truncated.as_ptr());
}
}
}

}
#[cfg(target_os = "macos")]
{
use std::ffi::CString;
if let Ok(c_name) = CString::new(name.as_str()) {
unsafe {
libc::pthread_setname_np(c_name.as_ptr());
}
}
}
#[cfg(windows)]
{
// Windows doesn't have a simple pthread_setname_np equivalent
// SetThreadDescription requires Windows 10+
let _ = name;
}
#[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
{
let _ = name;
}
}

/// Get OS-level thread ID (pthread_self on Unix)
/// This is important for fork compatibility - the ID must remain stable after fork
#[cfg(unix)]
Expand Down
Loading