From 2a9e3478c7d879f1057fffc70bc9c81cd0141615 Mon Sep 17 00:00:00 2001 From: "Demi M. Obenour" Date: Wed, 18 Dec 2019 14:40:48 -0500 Subject: [PATCH] Add `UdpSocket::{poll_recv_from, poll_send_to}` This is needed for QUIC implementations based on async-std. It could be done with `UdpSocket::{recv_from, send_to}`, but this requires boxing, reference counting, and `unsafe` code. --- src/net/udp/mod.rs | 71 ++++++++++++++++++++++++++++------------------ 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/src/net/udp/mod.rs b/src/net/udp/mod.rs index 418b4b60a..c85c5d93b 100644 --- a/src/net/udp/mod.rs +++ b/src/net/udp/mod.rs @@ -1,6 +1,6 @@ use std::io; -use std::net::SocketAddr; -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::task::{Context, Poll}; use crate::future; use crate::net::driver::Watcher; @@ -69,9 +69,7 @@ impl UdpSocket { /// ``` pub async fn bind(addrs: A) -> io::Result { let mut last_err = None; - let addrs = addrs - .to_socket_addrs() - .await?; + let addrs = addrs.to_socket_addrs().await?; for addr in addrs { match mio::net::UdpSocket::bind(&addr) { @@ -116,6 +114,19 @@ impl UdpSocket { .context(|| String::from("could not get local address")) } + /// Sends data on the socket to the given address. + /// + /// If this function returns `Poll::Ready(Ok(_))`, returns the number of bytes written. + pub fn poll_send_to( + &self, + cx: &mut Context<'_>, + buf: &[u8], + addr: &SocketAddr, + ) -> Poll> { + self.watcher + .poll_write_with(cx, |inner| inner.send_to(buf, &addr)) + } + /// Sends data on the socket to the given address. /// /// On success, returns the number of bytes written. @@ -153,12 +164,21 @@ impl UdpSocket { } }; - future::poll_fn(|cx| { - self.watcher - .poll_write_with(cx, |inner| inner.send_to(buf, &addr)) - }) - .await - .context(|| format!("could not send packet to {}", addr)) + future::poll_fn(|cx| self.poll_send_to(cx, buf, &addr)) + .await + .context(|| format!("could not send packet to {}", addr)) + } + + /// Receives data from the socket. + /// + /// On success, returns the number of bytes read and the origin. + pub fn poll_recv_from( + &self, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.watcher + .poll_read_with(cx, |inner| inner.recv_from(buf)) } /// Receives data from the socket. @@ -181,22 +201,19 @@ impl UdpSocket { /// # Ok(()) }) } /// ``` pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - future::poll_fn(|cx| { - self.watcher - .poll_read_with(cx, |inner| inner.recv_from(buf)) - }) - .await - .context(|| { - use std::fmt::Write; - - let mut error = String::from("could not receive data on "); - if let Ok(addr) = self.local_addr() { - let _ = write!(&mut error, "{}", addr); - } else { - error.push_str("socket"); - } - error - }) + future::poll_fn(|cx| self.poll_recv_from(cx, buf)) + .await + .context(|| { + use std::fmt::Write; + + let mut error = String::from("could not receive data on "); + if let Ok(addr) = self.local_addr() { + let _ = write!(&mut error, "{}", addr); + } else { + error.push_str("socket"); + } + error + }) } /// Connects the UDP socket to a remote address.