Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 1 addition & 21 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 1 addition & 10 deletions crates/vm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,6 @@ num_cpus = "1.17.0"
[target.'cfg(windows)'.dependencies]
junction = { workspace = true }

[target.'cfg(windows)'.dependencies.windows]
version = "0.52.0"
features = [
"Win32_Foundation",
"Win32_System_LibraryLoader",
"Win32_System_Threading",
"Win32_System_Time",
"Win32_UI_Shell",
]

[target.'cfg(windows)'.dependencies.windows-sys]
workspace = true
features = [
Expand All @@ -143,6 +133,7 @@ features = [
"Win32_System_SystemInformation",
"Win32_System_SystemServices",
"Win32_System_Threading",
"Win32_System_Time",
"Win32_System_WindowsProgramming",
"Win32_UI_Shell",
"Win32_UI_WindowsAndMessaging",
Expand Down
37 changes: 20 additions & 17 deletions crates/vm/src/stdlib/nt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,23 +556,26 @@ pub(crate) mod module {
String::from_utf16(wstr).map_err(|e| vm.new_unicode_decode_error(e.to_string()))
}

let wbuf = windows::core::PCWSTR::from_raw(backslashed.as_ptr());
let (root, path) = match unsafe { windows::Win32::UI::Shell::PathCchSkipRoot(wbuf) } {
Ok(end) => {
assert!(!end.is_null());
let len: usize = unsafe { end.as_ptr().offset_from(wbuf.as_ptr()) }
.try_into()
.expect("len must be non-negative");
assert!(
len < backslashed.len(), // backslashed is null-terminated
"path: {:?} {} < {}",
std::path::PathBuf::from(std::ffi::OsString::from_wide(&backslashed)),
len,
backslashed.len()
);
(from_utf16(&orig[..len], vm)?, from_utf16(&orig[len..], vm)?)
}
Err(_) => ("".to_owned(), from_utf16(&orig, vm)?),
let mut end: *const u16 = std::ptr::null();
let hr = unsafe {
windows_sys::Win32::UI::Shell::PathCchSkipRoot(backslashed.as_ptr(), &mut end)
};
let (root, path) = if hr == 0 {
// S_OK
assert!(!end.is_null());
let len: usize = unsafe { end.offset_from(backslashed.as_ptr()) }
.try_into()
.expect("len must be non-negative");
assert!(
len < backslashed.len(), // backslashed is null-terminated
"path: {:?} {} < {}",
std::path::PathBuf::from(std::ffi::OsString::from_wide(&backslashed)),
len,
backslashed.len()
);
(from_utf16(&orig[..len], vm)?, from_utf16(&orig[len..], vm)?)
} else {
("".to_owned(), from_utf16(&orig, vm)?)
};
Ok((root, path))
}
Expand Down
9 changes: 4 additions & 5 deletions crates/vm/src/stdlib/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ mod decl {
use std::time::Duration;
#[cfg(target_env = "msvc")]
#[cfg(not(target_arch = "wasm32"))]
use windows::Win32::System::Time;
use windows_sys::Win32::System::Time::{GetTimeZoneInformation, TIME_ZONE_INFORMATION};

#[allow(dead_code)]
pub(super) const SEC_TO_MS: i64 = 1000;
Expand Down Expand Up @@ -186,10 +186,9 @@ mod decl {

#[cfg(target_env = "msvc")]
#[cfg(not(target_arch = "wasm32"))]
fn get_tz_info() -> Time::TIME_ZONE_INFORMATION {
let mut info = Time::TIME_ZONE_INFORMATION::default();
let info_ptr = &mut info as *mut Time::TIME_ZONE_INFORMATION;
let _ = unsafe { Time::GetTimeZoneInformation(info_ptr) };
fn get_tz_info() -> TIME_ZONE_INFORMATION {
let mut info: TIME_ZONE_INFORMATION = unsafe { std::mem::zeroed() };
unsafe { GetTimeZoneInformation(&mut info) };
info
}
Comment on lines +189 to 193
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing error check for GetTimeZoneInformation.

The return value of GetTimeZoneInformation is discarded. This function returns TIME_ZONE_ID_INVALID (0xFFFFFFFF) on failure, which should be checked to avoid returning zeroed/invalid timezone data.

Consider adding error handling:

 fn get_tz_info() -> TIME_ZONE_INFORMATION {
     let mut info: TIME_ZONE_INFORMATION = unsafe { std::mem::zeroed() };
-    unsafe { GetTimeZoneInformation(&mut info) };
+    let result = unsafe { GetTimeZoneInformation(&mut info) };
+    debug_assert!(result != 0xFFFFFFFF, "GetTimeZoneInformation failed");
     info
 }

Alternatively, if the callers can handle errors, consider returning a Result type.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
fn get_tz_info() -> TIME_ZONE_INFORMATION {
let mut info: TIME_ZONE_INFORMATION = unsafe { std::mem::zeroed() };
unsafe { GetTimeZoneInformation(&mut info) };
info
}
fn get_tz_info() -> TIME_ZONE_INFORMATION {
let mut info: TIME_ZONE_INFORMATION = unsafe { std::mem::zeroed() };
let result = unsafe { GetTimeZoneInformation(&mut info) };
debug_assert!(result != 0xFFFFFFFF, "GetTimeZoneInformation failed");
info
}


Expand Down
99 changes: 49 additions & 50 deletions crates/vm/src/stdlib/winapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@ mod _winapi {
convert::{ToPyException, ToPyResult},
function::{ArgMapping, ArgSequence, OptionalArg},
stdlib::os::errno_err,
windows::WindowsSysResult,
windows::{WinHandle, WindowsSysResult},
};
use std::ptr::{null, null_mut};
use windows::{
Win32::Foundation::{HANDLE, HINSTANCE, MAX_PATH},
core::PCWSTR,
};
use windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE;
use windows_sys::Win32::Foundation::{INVALID_HANDLE_VALUE, MAX_PATH};

#[pyattr]
use windows_sys::Win32::{
Expand Down Expand Up @@ -78,15 +74,15 @@ mod _winapi {
const NULL: isize = 0;

#[pyfunction]
fn CloseHandle(handle: HANDLE) -> WindowsSysResult<i32> {
WindowsSysResult(unsafe { windows_sys::Win32::Foundation::CloseHandle(handle.0 as _) })
fn CloseHandle(handle: WinHandle) -> WindowsSysResult<i32> {
WindowsSysResult(unsafe { windows_sys::Win32::Foundation::CloseHandle(handle.0) })
}

#[pyfunction]
fn GetStdHandle(
std_handle: windows_sys::Win32::System::Console::STD_HANDLE,
vm: &VirtualMachine,
) -> PyResult<Option<HANDLE>> {
) -> PyResult<Option<WinHandle>> {
let handle = unsafe { windows_sys::Win32::System::Console::GetStdHandle(std_handle) };
if handle == INVALID_HANDLE_VALUE {
return Err(errno_err(vm));
Expand All @@ -95,7 +91,7 @@ mod _winapi {
// NULL handle - return None
None
} else {
Some(HANDLE(handle as isize))
Some(WinHandle(handle))
})
}

Expand All @@ -104,47 +100,49 @@ mod _winapi {
_pipe_attrs: PyObjectRef,
size: u32,
vm: &VirtualMachine,
) -> PyResult<(HANDLE, HANDLE)> {
) -> PyResult<(WinHandle, WinHandle)> {
use windows_sys::Win32::Foundation::HANDLE;
let (read, write) = unsafe {
let mut read = std::mem::MaybeUninit::<isize>::uninit();
let mut write = std::mem::MaybeUninit::<isize>::uninit();
let mut read = std::mem::MaybeUninit::<HANDLE>::uninit();
let mut write = std::mem::MaybeUninit::<HANDLE>::uninit();
WindowsSysResult(windows_sys::Win32::System::Pipes::CreatePipe(
read.as_mut_ptr() as _,
write.as_mut_ptr() as _,
read.as_mut_ptr(),
write.as_mut_ptr(),
std::ptr::null(),
size,
))
.to_pyresult(vm)?;
(read.assume_init(), write.assume_init())
};
Ok((HANDLE(read), HANDLE(write)))
Ok((WinHandle(read), WinHandle(write)))
}

#[pyfunction]
fn DuplicateHandle(
src_process: HANDLE,
src: HANDLE,
target_process: HANDLE,
src_process: WinHandle,
src: WinHandle,
target_process: WinHandle,
access: u32,
inherit: i32,
options: OptionalArg<u32>,
vm: &VirtualMachine,
) -> PyResult<HANDLE> {
) -> PyResult<WinHandle> {
use windows_sys::Win32::Foundation::HANDLE;
let target = unsafe {
let mut target = std::mem::MaybeUninit::<isize>::uninit();
let mut target = std::mem::MaybeUninit::<HANDLE>::uninit();
WindowsSysResult(windows_sys::Win32::Foundation::DuplicateHandle(
src_process.0 as _,
src.0 as _,
target_process.0 as _,
target.as_mut_ptr() as _,
src_process.0,
src.0,
target_process.0,
target.as_mut_ptr(),
access,
inherit,
options.unwrap_or(0),
))
.to_pyresult(vm)?;
target.assume_init()
};
Ok(HANDLE(target))
Ok(WinHandle(target))
}

#[pyfunction]
Expand All @@ -153,16 +151,16 @@ mod _winapi {
}

#[pyfunction]
fn GetCurrentProcess() -> HANDLE {
unsafe { windows::Win32::System::Threading::GetCurrentProcess() }
fn GetCurrentProcess() -> WinHandle {
WinHandle(unsafe { windows_sys::Win32::System::Threading::GetCurrentProcess() })
}

#[pyfunction]
fn GetFileType(
h: HANDLE,
h: WinHandle,
vm: &VirtualMachine,
) -> PyResult<windows_sys::Win32::Storage::FileSystem::FILE_TYPE> {
let file_type = unsafe { windows_sys::Win32::Storage::FileSystem::GetFileType(h.0 as _) };
let file_type = unsafe { windows_sys::Win32::Storage::FileSystem::GetFileType(h.0) };
if file_type == 0 && unsafe { windows_sys::Win32::Foundation::GetLastError() } != 0 {
Err(errno_err(vm))
} else {
Expand Down Expand Up @@ -206,7 +204,7 @@ mod _winapi {
fn CreateProcess(
args: CreateProcessArgs,
vm: &VirtualMachine,
) -> PyResult<(HANDLE, HANDLE, u32, u32)> {
) -> PyResult<(WinHandle, WinHandle, u32, u32)> {
let mut si: windows_sys::Win32::System::Threading::STARTUPINFOEXW =
unsafe { std::mem::zeroed() };
si.StartupInfo.cb = std::mem::size_of_val(&si) as _;
Expand Down Expand Up @@ -285,8 +283,8 @@ mod _winapi {
};

Ok((
HANDLE(procinfo.hProcess as _),
HANDLE(procinfo.hThread as _),
WinHandle(procinfo.hProcess),
WinHandle(procinfo.hThread),
procinfo.dwProcessId,
procinfo.dwThreadId,
))
Expand Down Expand Up @@ -434,7 +432,7 @@ mod _winapi {
0,
(2 & 0xffff) | 0x20000, // PROC_THREAD_ATTRIBUTE_HANDLE_LIST
handlelist.as_mut_ptr() as _,
(handlelist.len() * std::mem::size_of::<HANDLE>()) as _,
(handlelist.len() * std::mem::size_of::<isize>()) as _,
std::ptr::null_mut(),
std::ptr::null(),
)
Expand All @@ -447,9 +445,8 @@ mod _winapi {
}

#[pyfunction]
fn WaitForSingleObject(h: HANDLE, ms: u32, vm: &VirtualMachine) -> PyResult<u32> {
let ret =
unsafe { windows_sys::Win32::System::Threading::WaitForSingleObject(h.0 as _, ms) };
fn WaitForSingleObject(h: WinHandle, ms: u32, vm: &VirtualMachine) -> PyResult<u32> {
let ret = unsafe { windows_sys::Win32::System::Threading::WaitForSingleObject(h.0, ms) };
if ret == windows_sys::Win32::Foundation::WAIT_FAILED {
Err(errno_err(vm))
} else {
Expand All @@ -458,11 +455,11 @@ mod _winapi {
}

#[pyfunction]
fn GetExitCodeProcess(h: HANDLE, vm: &VirtualMachine) -> PyResult<u32> {
fn GetExitCodeProcess(h: WinHandle, vm: &VirtualMachine) -> PyResult<u32> {
unsafe {
let mut ec = std::mem::MaybeUninit::uninit();
WindowsSysResult(windows_sys::Win32::System::Threading::GetExitCodeProcess(
h.0 as _,
h.0,
ec.as_mut_ptr(),
))
.to_pyresult(vm)?;
Expand All @@ -471,33 +468,35 @@ mod _winapi {
}

#[pyfunction]
fn TerminateProcess(h: HANDLE, exit_code: u32) -> WindowsSysResult<i32> {
fn TerminateProcess(h: WinHandle, exit_code: u32) -> WindowsSysResult<i32> {
WindowsSysResult(unsafe {
windows_sys::Win32::System::Threading::TerminateProcess(h.0 as _, exit_code)
windows_sys::Win32::System::Threading::TerminateProcess(h.0, exit_code)
})
}

// TODO: ctypes.LibraryLoader.LoadLibrary
#[allow(dead_code)]
fn LoadLibrary(path: PyStrRef, vm: &VirtualMachine) -> PyResult<isize> {
let path = path.as_str().to_wide_with_nul();
let handle = unsafe {
windows::Win32::System::LibraryLoader::LoadLibraryW(PCWSTR::from_raw(path.as_ptr()))
.unwrap()
};
if handle.is_invalid() {
let handle =
unsafe { windows_sys::Win32::System::LibraryLoader::LoadLibraryW(path.as_ptr()) };
if handle.is_null() {
return Err(vm.new_runtime_error("LoadLibrary failed"));
}
Ok(handle.0)
Ok(handle as isize)
}

#[pyfunction]
fn GetModuleFileName(handle: isize, vm: &VirtualMachine) -> PyResult<String> {
let mut path: Vec<u16> = vec![0; MAX_PATH as usize];
let handle = HINSTANCE(handle);

let length =
unsafe { windows::Win32::System::LibraryLoader::GetModuleFileNameW(handle, &mut path) };
let length = unsafe {
windows_sys::Win32::System::LibraryLoader::GetModuleFileNameW(
handle as windows_sys::Win32::Foundation::HMODULE,
path.as_mut_ptr(),
path.len() as u32,
)
};
if length == 0 {
return Err(vm.new_runtime_error("GetModuleFileName failed"));
}
Expand Down
Loading
Loading