Skip to content

Commit

Permalink
src: connection: Fix panic when DNS lookup fails
Browse files Browse the repository at this point in the history
  • Loading branch information
joaoantoniocardoso authored and patrickelectric committed Jan 17, 2024
1 parent a03e2b8 commit b469cc6
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 30 deletions.
16 changes: 16 additions & 0 deletions src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,19 @@ pub fn connect<M: Message>(address: &str) -> io::Result<Box<dyn MavConnection<M>
protocol_err
}
}

/// Returns the socket address for the given address.
pub(crate) fn get_socket_addr<T: std::net::ToSocketAddrs>(
address: T,
) -> Result<std::net::SocketAddr, io::Error> {
let addr = match address.to_socket_addrs()?.next() {
Some(addr) => addr,
None => {
return Err(io::Error::new(
io::ErrorKind::Other,
"Host address lookup failed",
));
}
};
Ok(addr)
}
19 changes: 7 additions & 12 deletions src/connection/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::net::{TcpListener, TcpStream};
use std::sync::Mutex;
use std::time::Duration;

use super::get_socket_addr;

/// TCP MAVLink connection
pub fn select_protocol<M: Message>(
Expand All @@ -24,12 +26,9 @@ pub fn select_protocol<M: Message>(
}

pub fn tcpout<T: ToSocketAddrs>(address: T) -> io::Result<TcpConnection> {
let addr = address
.to_socket_addrs()
.unwrap()
.next()
.expect("Host address lookup failed.");
let socket = TcpStream::connect(&addr)?;
let addr = get_socket_addr(address)?;

let socket = TcpStream::connect(addr)?;
socket.set_read_timeout(Some(Duration::from_millis(100)))?;

Ok(TcpConnection {
Expand All @@ -43,12 +42,8 @@ pub fn tcpout<T: ToSocketAddrs>(address: T) -> io::Result<TcpConnection> {
}

pub fn tcpin<T: ToSocketAddrs>(address: T) -> io::Result<TcpConnection> {
let addr = address
.to_socket_addrs()
.unwrap()
.next()
.expect("Invalid address");
let listener = TcpListener::bind(&addr)?;
let addr = get_socket_addr(address)?;
let listener = TcpListener::bind(addr)?;

//For now we only accept one incoming stream: this blocks until we get one
for incoming in listener.incoming() {
Expand Down
26 changes: 8 additions & 18 deletions src/connection/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use std::net::{SocketAddr, UdpSocket};
use std::str::FromStr;

Check warning on line 7 in src/connection/udp.rs

View workflow job for this annotation

GitHub Actions / build (macos-latest, x86_64-apple-darwin, --features all-dialects)

unused import: `std::str::FromStr`

Check warning on line 7 in src/connection/udp.rs

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, arm-unknown-linux-musleabihf, --features ardupilotmega)

unused import: `std::str::FromStr`

Check warning on line 7 in src/connection/udp.rs

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, armv7-unknown-linux-musleabihf, --features ardupilotmega)

unused import: `std::str::FromStr`

Check warning on line 7 in src/connection/udp.rs

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, x86_64-unknown-linux-musl, --features all-dialects)

unused import: `std::str::FromStr`

Check warning on line 7 in src/connection/udp.rs

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest, x86_64-unknown-linux-musl, --features all-dialects,emit-extensions)

unused import: `std::str::FromStr`

Check warning on line 7 in src/connection/udp.rs

View workflow job for this annotation

GitHub Actions / build (windows-latest, x86_64-pc-windows-msvc, --features all-dialects)

unused import: `std::str::FromStr`
use std::sync::Mutex;

use super::get_socket_addr;

/// UDP MAVLink connection
pub fn select_protocol<M: Message>(
Expand All @@ -27,35 +29,23 @@ pub fn select_protocol<M: Message>(
}

pub fn udpbcast<T: ToSocketAddrs>(address: T) -> io::Result<UdpConnection> {
let addr = address
.to_socket_addrs()
.unwrap()
.next()
.expect("Invalid address");
let socket = UdpSocket::bind(&SocketAddr::from_str("0.0.0.0:0").unwrap()).unwrap();
let addr = get_socket_addr(address)?;
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket
.set_broadcast(true)
.expect("Couldn't bind to broadcast address.");
UdpConnection::new(socket, false, Some(addr))
}

pub fn udpout<T: ToSocketAddrs>(address: T) -> io::Result<UdpConnection> {
let addr = address
.to_socket_addrs()
.unwrap()
.next()
.expect("Invalid address");
let socket = UdpSocket::bind(&SocketAddr::from_str("0.0.0.0:0").unwrap())?;
let addr = get_socket_addr(address)?;
let socket = UdpSocket::bind("0.0.0.0:0")?;
UdpConnection::new(socket, false, Some(addr))
}

pub fn udpin<T: ToSocketAddrs>(address: T) -> io::Result<UdpConnection> {
let addr = address
.to_socket_addrs()
.unwrap()
.next()
.expect("Invalid address");
let socket = UdpSocket::bind(&addr)?;
let addr = get_socket_addr(address)?;
let socket = UdpSocket::bind(addr)?;
UdpConnection::new(socket, true, None)
}

Expand Down

0 comments on commit b469cc6

Please sign in to comment.