diff --git a/crates/shadowsocks/src/net/sys/unix/bsd/freebsd.rs b/crates/shadowsocks/src/net/sys/unix/bsd/freebsd.rs index 8c8177068209..cd4c6a523002 100644 --- a/crates/shadowsocks/src/net/sys/unix/bsd/freebsd.rs +++ b/crates/shadowsocks/src/net/sys/unix/bsd/freebsd.rs @@ -105,46 +105,62 @@ impl AsyncRead for TcpStream { } impl AsyncWrite for TcpStream { - fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll> { - let this = self.project(); - - if let TcpStreamState::FastOpenConnect(addr) = this.state { - loop { - // TCP_FASTOPEN was supported since FreeBSD 12.0 - // - // Example program: - // - - // Wait until socket is writable - ready!(this.inner.poll_write_ready(cx))?; - - unsafe { - let saddr = SockAddr::from(*addr); - - let ret = libc::sendto( - this.inner.as_raw_fd(), - buf.as_ptr() as *const libc::c_void, - buf.len(), - 0, // Yes, BSD doesn't need MSG_FASTOPEN - saddr.as_ptr(), - saddr.len(), - ); - - if ret >= 0 { - // Connect successfully. - *(this.state) = TcpStreamState::Connected; - return Ok(ret as usize).into(); - } else { - // Error occurs - let err = io::Error::last_os_error(); - if err.kind() != ErrorKind::WouldBlock { - return Err(err).into(); + fn poll_write(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll> { + loop { + let this = self.as_mut().project(); + + match this.state { + TcpStreamState::FastOpenConnect(addr) => { + // TCP_FASTOPEN was supported since FreeBSD 12.0 + // + // Example program: + // + + // Wait until socket is writable + ready!(this.inner.poll_write_ready(cx))?; + + unsafe { + let saddr = SockAddr::from(*addr); + + let ret = libc::sendto( + this.inner.as_raw_fd(), + buf.as_ptr() as *const libc::c_void, + buf.len(), + 0, // Yes, BSD doesn't need MSG_FASTOPEN + saddr.as_ptr(), + saddr.len(), + ); + + if ret >= 0 { + // Connect successfully. + *(this.state) = TcpStreamState::Connected; + return Ok(ret as usize).into(); + } else { + // Error occurs + let err = io::Error::last_os_error(); + + // EAGAIN, EWOULDBLOCK + if err.kind() != ErrorKind::WouldBlock { + // EINPROGRESS + if let Some(libc::EINPROGRESS) = err.raw_os_error() { + // For non-blocking socket, it returns the number of bytes queued (and transmitted in the SYN-data packet) if cookie is available. + // If cookie is not available, it transmits a data-less SYN packet with Fast Open cookie request option and returns -EINPROGRESS like connect(). + // + // So in this state. We have to loop again to call `poll_write` for sending the first packet. + *(this.state) = TcpStreamState::Connected; + } else { + // Other errors + return Err(err).into(); + } + } else { + // Pending on poll_write_ready + } } } } + + TcpStreamState::Connected => return this.inner.poll_write(cx, buf), } - } else { - this.inner.poll_write(cx, buf) } } diff --git a/crates/shadowsocks/src/net/sys/unix/bsd/macos.rs b/crates/shadowsocks/src/net/sys/unix/bsd/macos.rs index cc6719121cb4..dce91dfc731f 100644 --- a/crates/shadowsocks/src/net/sys/unix/bsd/macos.rs +++ b/crates/shadowsocks/src/net/sys/unix/bsd/macos.rs @@ -9,6 +9,7 @@ use std::{ task::{self, Poll}, }; +use futures::ready; use log::error; use pin_project::pin_project; use socket2::SockAddr; @@ -23,9 +24,18 @@ use crate::net::{ ConnectOpts, }; +enum TcpStreamState { + Connected, + FastOpenWrite, +} + /// A `TcpStream` that supports TFO (TCP Fast Open) #[pin_project] -pub struct TcpStream(#[pin] TokioTcpStream); +pub struct TcpStream { + #[pin] + inner: TokioTcpStream, + state: TcpStreamState, +} impl TcpStream { pub async fn connect(addr: SocketAddr, opts: &ConnectOpts) -> io::Result { @@ -45,7 +55,10 @@ impl TcpStream { // If TFO is not enabled, it just works like a normal TcpStream let stream = socket.connect(addr).await?; set_common_sockopt_after_connect(&stream, opts)?; - return Ok(TcpStream(stream)); + return Ok(TcpStream { + inner: stream, + state: TcpStreamState::Connected, + }); } // TFO in macos uses connectx @@ -76,7 +89,10 @@ impl TcpStream { let stream = TokioTcpStream::from_std(unsafe { StdTcpStream::from_raw_fd(socket.into_raw_fd()) })?; set_common_sockopt_after_connect(&stream, opts)?; - Ok(TcpStream(stream)) + Ok(TcpStream { + inner: stream, + state: TcpStreamState::FastOpenWrite, + }) } } @@ -84,33 +100,69 @@ impl Deref for TcpStream { type Target = TokioTcpStream; fn deref(&self) -> &Self::Target { - &self.0 + &self.inner } } impl DerefMut for TcpStream { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 + &mut self.inner } } impl AsyncRead for TcpStream { fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { - self.project().0.poll_read(cx, buf) + self.project().inner.poll_read(cx, buf) } } impl AsyncWrite for TcpStream { - fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll> { - self.project().0.poll_write(cx, buf) + fn poll_write(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll> { + loop { + let this = self.as_mut().project(); + + match this.state { + TcpStreamState::FastOpenWrite => { + // `CONNECT_RESUME_ON_READ_WRITE` is set when calling `connectx`, + // so the first call of `send` will perform the actual SYN with TFO cookie. + // + // (NOT SURE) If remote server doesn't support TFO or this is the first connection, + // it may return EINPROGRESS just like other platforms (Linux, FreeBSD). + + match ready!(this.inner.poll_write(cx, buf)) { + Ok(n) => { + *(this.state) = TcpStreamState::Connected; + return Ok(n).into(); + } + Err(err) => { + // EAGAIN and EWOULDBLOCK should have been handled by tokio + // + // EINPROGRESS + if let Some(libc::EINPROGRESS) = err.raw_os_error() { + // For non-blocking socket, it returns the number of bytes queued (and transmitted in the SYN-data packet) if cookie is available. + // If cookie is not available, it transmits a data-less SYN packet with Fast Open cookie request option and returns -EINPROGRESS like connect(). + // + // So in this state. We have to loop again to call `poll_write` for sending the first packet. + *(this.state) = TcpStreamState::Connected; + } else { + // Other errors + return Err(err).into(); + } + } + } + } + + TcpStreamState::Connected => return this.inner.poll_write(cx, buf), + } + } } fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - self.project().0.poll_flush(cx) + self.project().inner.poll_flush(cx) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - self.project().0.poll_shutdown(cx) + self.project().inner.poll_shutdown(cx) } } diff --git a/crates/shadowsocks/src/net/sys/unix/linux/mod.rs b/crates/shadowsocks/src/net/sys/unix/linux/mod.rs index a009c9805237..b4739927475d 100644 --- a/crates/shadowsocks/src/net/sys/unix/linux/mod.rs +++ b/crates/shadowsocks/src/net/sys/unix/linux/mod.rs @@ -28,6 +28,7 @@ use crate::net::{ enum TcpStreamState { Connected, FastOpenConnect(SocketAddr), + FastOpenWrite, } /// A `TcpStream` that supports TFO (TCP Fast Open) @@ -147,7 +148,7 @@ impl TcpStream { Ok(TcpStream { inner: stream, state: if connected { - TcpStreamState::Connected + TcpStreamState::FastOpenWrite } else { TcpStreamState::FastOpenConnect(addr) }, @@ -176,45 +177,84 @@ impl AsyncRead for TcpStream { } impl AsyncWrite for TcpStream { - fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll> { - let this = self.project(); - - if let TcpStreamState::FastOpenConnect(addr) = this.state { - loop { - // Fallback mode. Must be kernal < 4.11 - // - // Uses sendto as BSD-like systems - - // Wait until socket is writable - ready!(this.inner.poll_write_ready(cx))?; - - unsafe { - let saddr = SockAddr::from(*addr); - - let ret = libc::sendto( - this.inner.as_raw_fd(), - buf.as_ptr() as *const libc::c_void, - buf.len(), - libc::MSG_FASTOPEN, - saddr.as_ptr(), - saddr.len(), - ); - - if ret >= 0 { - // Connect successfully. - *(this.state) = TcpStreamState::Connected; - return Ok(ret as usize).into(); - } else { - // Error occurs - let err = io::Error::last_os_error(); - if err.kind() != ErrorKind::WouldBlock { - return Err(err).into(); + fn poll_write(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll> { + loop { + let this = self.as_mut().project(); + + match this.state { + TcpStreamState::FastOpenConnect(addr) => { + // Fallback mode. Must be kernal < 4.11 + // + // Uses sendto as BSD-like systems + + // Wait until socket is writable + ready!(this.inner.poll_write_ready(cx))?; + + unsafe { + let saddr = SockAddr::from(*addr); + + let ret = libc::sendto( + this.inner.as_raw_fd(), + buf.as_ptr() as *const libc::c_void, + buf.len(), + libc::MSG_FASTOPEN, + saddr.as_ptr(), + saddr.len(), + ); + + if ret >= 0 { + // Connect successfully. + *(this.state) = TcpStreamState::Connected; + return Ok(ret as usize).into(); + } else { + // Error occurs + let err = io::Error::last_os_error(); + + // EAGAIN, EWOULDBLOCK + if err.kind() != ErrorKind::WouldBlock { + // EINPROGRESS + if let Some(libc::EINPROGRESS) = err.raw_os_error() { + // For non-blocking socket, it returns the number of bytes queued (and transmitted in the SYN-data packet) if cookie is available. + // If cookie is not available, it transmits a data-less SYN packet with Fast Open cookie request option and returns -EINPROGRESS like connect(). + // + // So in this state. We have to loop again to call `poll_write` for sending the first packet. + *(this.state) = TcpStreamState::Connected; + } else { + // Other errors + return Err(err).into(); + } + } else { + // Pending on poll_write_ready + } } } } + + TcpStreamState::FastOpenWrite => { + // First `write` after `TCP_FASTOPEN_CONNECT` + // Kernel >= 4.11 + + match ready!(this.inner.poll_write(cx, buf)) { + Ok(n) => { + *(this.state) = TcpStreamState::Connected; + return Ok(n).into(); + } + Err(err) => { + // EAGAIN and EWOULDBLOCK should have been handled by tokio + // + // EINPROGRESS + if let Some(libc::EINPROGRESS) = err.raw_os_error() { + // loop again to call `poll_write` for sending the first packet + *(this.state) = TcpStreamState::Connected; + } else { + return Err(err).into(); + } + } + } + } + + TcpStreamState::Connected => return this.inner.poll_write(cx, buf), } - } else { - this.inner.poll_write(cx, buf) } }