From dadbf63dac27494d4725833a727ba715c73bfc9f Mon Sep 17 00:00:00 2001 From: Chris Denton Date: Mon, 25 May 2026 05:47:39 +0000 Subject: [PATCH] Special case assume_init for Bool on Windows Fixes #564 --- src/socket.rs | 6 ++--- src/sys/windows.rs | 59 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/src/socket.rs b/src/socket.rs index 5f37529d..defbb1e0 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -918,7 +918,7 @@ impl Socket { pub fn keepalive(&self) -> io::Result { unsafe { getsockopt::(self.as_raw(), sys::SOL_SOCKET, sys::SO_KEEPALIVE) - .map(|keepalive| keepalive != false as Bool) + .map(|keepalive| keepalive != 0) } } @@ -2060,7 +2060,7 @@ impl Socket { /// [`set_only_v6`]: Socket::set_only_v6 pub fn only_v6(&self) -> io::Result { unsafe { - getsockopt::(self.as_raw(), sys::IPPROTO_IPV6, sys::IPV6_V6ONLY) + getsockopt::(self.as_raw(), sys::IPPROTO_IPV6, sys::IPV6_V6ONLY) .map(|only_v6| only_v6 != 0) } } @@ -2356,7 +2356,7 @@ impl Socket { pub fn tcp_nodelay(&self) -> io::Result { unsafe { getsockopt::(self.as_raw(), sys::IPPROTO_TCP, sys::TCP_NODELAY) - .map(|nodelay| nodelay != false as Bool) + .map(|nodelay| nodelay != 0) } } diff --git a/src/sys/windows.rs b/src/sys/windows.rs index c935c8b0..67de3670 100644 --- a/src/sys/windows.rs +++ b/src/sys/windows.rs @@ -94,13 +94,53 @@ pub(crate) use windows_sys::Win32::Networking::WinSock::{ pub(crate) const IPPROTO_IP: c_int = windows_sys::Win32::Networking::WinSock::IPPROTO_IP as c_int; pub(crate) const SOL_SOCKET: c_int = windows_sys::Win32::Networking::WinSock::SOL_SOCKET as c_int; -/// Type used in set/getsockopt to retrieve the `TCP_NODELAY` option. +// This is so we can special case MaybeUninit::assume_init for Bool. +// See Bool for why. +pub(crate) trait GetsockoptOutput: Sized { + unsafe fn assume_init(uninit: MaybeUninit, size: c_int) -> Self { + debug_assert_eq!(size as usize, size_of::()); + uninit.assume_init() + } +} + +impl GetsockoptOutput for i32 {} +impl GetsockoptOutput for u32 {} +impl GetsockoptOutput for IN_ADDR {} +impl GetsockoptOutput for linger {} +impl GetsockoptOutput for WSAPROTOCOL_INFOW {} + +/// Type used in getsockopt to retrieve options such as `TCP_NODELAY` or `IPV6_V6ONLY`. /// /// NOTE: -/// documents that options such as `TCP_NODELAY` and `SO_KEEPALIVE` expect a -/// `BOOL` (alias for `c_int`, 4 bytes), however in practice this turns out to -/// be false (or misleading) as a `bool` (1 byte) is returned by `getsockopt`. -pub(crate) type Bool = bool; +/// documents that options such as `TCP_NODELAY` and `SO_KEEPALIVE` expect a "DWORD (boolean)". +/// A DWORD is 4 bytes but in practice only a 1 byte bool is often written. +/// While this behaviour is mostly consistent, it's been oberved that `getsockopt` with +/// IPV6_V6ONLY can sometimes write 4 bytes and sometimes write 1, so we handle both cases. +#[derive(Clone, Copy, PartialEq, Eq)] +#[repr(transparent)] +pub(crate) struct Bool { + value: c_int, +} +impl PartialEq for Bool { + #[inline(always)] + fn eq(&self, other: &c_int) -> bool { + self.value == *other + } +} + +impl GetsockoptOutput for Bool { + unsafe fn assume_init(uninit: MaybeUninit, size: c_int) -> Self { + if size == 1 { + // SAFETY: 1 byte has been initialized + let value = unsafe { *uninit.as_ptr().cast::() } as c_int; + Self { value } + } else { + debug_assert_eq!(size as usize, size_of::()); + // SAFETY: outside of debug, we assume the caller has correctly initialised the value + unsafe { uninit.assume_init() } + } + } +} /// Maximum size of a buffer passed to system call like `recv` and `send`. const MAX_BUF_LEN: usize = c_int::MAX as usize; @@ -831,7 +871,11 @@ pub(crate) fn set_tcp_ack_frequency(socket: RawSocket, frequency: u8) -> io::Res /// Caller must ensure `T` is the correct type for `level` and `optname`. // NOTE: `optname` is actually `i32`, but all constants are `u32`. -pub(crate) unsafe fn getsockopt(socket: RawSocket, level: c_int, optname: i32) -> io::Result { +pub(crate) unsafe fn getsockopt( + socket: RawSocket, + level: c_int, + optname: i32, +) -> io::Result { let mut optval: MaybeUninit = MaybeUninit::uninit(); let mut optlen = mem::size_of::() as c_int; syscall!( @@ -846,9 +890,8 @@ pub(crate) unsafe fn getsockopt(socket: RawSocket, level: c_int, optname: i32 SOCKET_ERROR ) .map(|_| { - debug_assert_eq!(optlen as usize, mem::size_of::()); // Safety: `getsockopt` initialised `optval` for us. - optval.assume_init() + T::assume_init(optval, optlen) }) }