diff --git a/gotatun/src/packet/ipv4/mod.rs b/gotatun/src/packet/ipv4/mod.rs index d69711d5..0125529b 100644 --- a/gotatun/src/packet/ipv4/mod.rs +++ b/gotatun/src/packet/ipv4/mod.rs @@ -10,6 +10,7 @@ // SPDX-License-Identifier: MPL-2.0 use bitfield_struct::bitfield; +use eyre::{Context, eyre}; use std::{fmt::Debug, net::Ipv4Addr}; use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned, big_endian}; @@ -208,6 +209,43 @@ impl Ipv4 { pub const MAX_LEN: usize = 65535; } +impl Ipv4

{ + pub fn update_ip_checksum(&mut self) { + // TODO: handle IP options + debug_assert!(self.assert_no_ip_options().is_ok()); + + let checksum = pnet_packet::util::checksum(self.header.as_bytes(), 5); + self.header.header_checksum.set(checksum); + } + + /// Assert that [`Ipv4Header::ihl`] is 5, which means that the IPv4 header does not contain + /// any optional values. + pub(crate) fn assert_no_ip_options(&self) -> eyre::Result<()> { + match self.header.ihl() { + 5 => Ok(()), + 6.. => Err(eyre!("IP header: {:?}", self.header)) + .wrap_err(eyre!("IPv4 packets with options are not supported")), + ihl @ ..5 => { + Err(eyre!("IP header: {:?}", self.header)).wrap_err(eyre!("Bad IHL value: {ihl}")) + } + } + } +} + +impl Ipv4

+where + Self: IntoBytes + Immutable, +{ + pub fn try_update_ip_len(&mut self) -> eyre::Result<()> { + self.header.total_len = self + .as_bytes() + .len() + .try_into() + .map_err(|_| eyre!("IPv4 packet was larger than {}", u16::MAX))?; + Ok(()) + } +} + impl Debug for Ipv4Header { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Ipv4Header") diff --git a/gotatun/src/packet/ipv6/mod.rs b/gotatun/src/packet/ipv6/mod.rs index c53259e9..997e3adb 100644 --- a/gotatun/src/packet/ipv6/mod.rs +++ b/gotatun/src/packet/ipv6/mod.rs @@ -10,6 +10,7 @@ // SPDX-License-Identifier: MPL-2.0 use bitfield_struct::bitfield; +use eyre::eyre; use std::{fmt::Debug, net::Ipv6Addr}; use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned, big_endian}; @@ -122,6 +123,21 @@ impl Ipv6Header { } } +impl Ipv6

+where + P: IntoBytes + Immutable, +{ + pub fn try_update_ip_len(&mut self) -> eyre::Result<()> { + self.header.payload_length = self + .payload + .as_bytes() + .len() + .try_into() + .map_err(|_| eyre!("IPv6 payload was larger than {}", u16::MAX))?; + Ok(()) + } +} + impl Debug for Ipv6Header { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Ipv6Header") diff --git a/gotatun/src/packet/mod.rs b/gotatun/src/packet/mod.rs index e1fc33a0..d3a3b834 100644 --- a/gotatun/src/packet/mod.rs +++ b/gotatun/src/packet/mod.rs @@ -69,6 +69,7 @@ mod ip; mod ipv4; mod ipv6; mod pool; +mod tcp; mod udp; mod util; mod wg; @@ -77,6 +78,7 @@ pub use ip::*; pub use ipv4::*; pub use ipv6::*; pub use pool::*; +pub use tcp::*; pub use udp::*; pub use wg::*; @@ -137,6 +139,7 @@ impl CheckedPayload for Ip {} impl CheckedPayload for Ipv6

{} impl CheckedPayload for Ipv4

{} impl CheckedPayload for Udp

{} +impl CheckedPayload for Tcp {} impl CheckedPayload for WgHandshakeInit {} impl CheckedPayload for WgHandshakeResp {} impl CheckedPayload for WgCookieReply {} @@ -191,14 +194,20 @@ impl Packet { FromType ToType; [Ipv4] [Ipv4]; [Ipv6] [Ipv6]; + [Ipv4] [Ipv4]; + [Ipv6] [Ipv6]; [Ipv4] [Ip]; [Ipv6] [Ip]; + [Ipv4] [Ip]; + [Ipv6] [Ip]; [Ipv4] [Ip]; [Ipv6] [Ip]; [Ipv4] [[u8]]; [Ipv6] [[u8]]; + [Ipv4] [[u8]]; + [Ipv6] [[u8]]; [Ipv4] [[u8]]; [Ipv6] [[u8]]; [Ip] [[u8]]; @@ -213,6 +222,26 @@ impl From> for Packet { } } +#[duplicate_item( + FromType ToType either_fn; + [[u8]] [Ipv4] [ left]; + [[u8]] [Ipv6] [right]; + [ Ip ] [Ipv4] [ left]; + [ Ip ] [Ipv6] [right]; +)] +impl TryFrom> for Packet { + type Error = eyre::Report; + + fn try_from(packet: Packet) -> Result { + packet.try_into_ipvx()?.either_fn().ok_or_else(|| { + eyre!( + "Expected {} but found another IP version", + stringify!(ToType) + ) + }) + } +} + impl Default for Packet<[u8]> { fn default() -> Self { Self { @@ -383,17 +412,7 @@ impl Packet { // because there we can still parse the part of the Ipv4 header that is always present // and ignore the options. To parse the UDP packet, we must know that the IHL is 5, // otherwise it will not start at the right offset. - match ip.header.ihl() { - 5 => {} - 6.. => { - return Err(eyre!("IP header: {:?}", ip.header)) - .wrap_err(eyre!("IPv4 packets with options are not supported")); - } - ihl @ ..5 => { - return Err(eyre!("IP header: {:?}", ip.header)) - .wrap_err(eyre!("Bad IHL value: {ihl}")); - } - } + self.assert_no_ip_options()?; if ip.header.fragment_offset() != 0 || ip.header.more_fragments() { eyre::bail!("IPv4 packet is a fragment: {:?}", ip.header); @@ -406,6 +425,32 @@ impl Packet { // update `_kind` to reflect this. Ok(self.cast::>()) } + + /// Try to cast this [`Ipv4`] packet into an [`Tcp`] packet. + /// + /// Returns `Packet>` if the packet is a valid, + /// non-fragmented IPv4 TCP packet with no options (IHL == `5`). + /// + /// # Errors + /// Returns an error if + /// - the packet is a fragment + /// - the IHL is not `5` + /// - TCP validation fails + pub fn try_into_tcp(self) -> eyre::Result>> { + // We validate the IHL here, instead of in the `try_into_ipvx` method, + // because there we can still parse the part of the Ipv4 header that is always present + // and ignore the options. To parse the TCP packet, we must know that the IHL is 5, + // otherwise it will not start at the right offset. + self.assert_no_ip_options()?; + + let ip = self.deref(); + validate_tcp(ip.header.next_protocol(), &ip.payload) + .wrap_err_with(|| eyre!("IP header: {:?}", ip.header))?; + + // we have asserted that the packet is a valid IPv4 TCP packet. + // update `_kind` to reflect this. + Ok(self.cast::>()) + } } impl Packet { @@ -425,6 +470,22 @@ impl Packet { // update `_kind` to reflect this. Ok(self.cast::>()) } + + /// Try to cast this [`Ipv6`] packet into an [`Tcp`] packet. + /// + /// Returns `Packet>` if the packet is a valid IPv6 TCP packet. + /// + /// # Errors + /// Returns an error if TCP validation fails + pub fn try_into_tcp(self) -> eyre::Result>> { + let ip = self.deref(); + validate_tcp(ip.header.next_protocol(), &ip.payload) + .wrap_err_with(|| eyre!("IP header: {:?}", ip.header))?; + + // we have asserted that the packet is a valid IPv6 TCP packet. + // update `_kind` to reflect this. + Ok(self.cast::>()) + } } impl Packet> { @@ -482,6 +543,23 @@ fn validate_udp(next_protocol: IpNextProtocol, payload: &[u8]) -> eyre::Result<( Ok(()) } +fn validate_tcp(next_protocol: IpNextProtocol, payload: &[u8]) -> eyre::Result<()> { + let IpNextProtocol::Tcp = next_protocol else { + bail!("Expected TCP, but packet was {next_protocol:?}"); + }; + + let tcp = Tcp::ref_from_bytes(payload).map_err(|_| eyre!("Too small to be a TCP packet"))?; + + // Check that `data_offset` is correct by trying to look at the payload. + tcp.payload() + .ok_or(eyre!("{:?}", tcp.header)) + .wrap_err_with(|| eyre!("Bad TCP packet"))?; + + // TODO: validate checksum? + + Ok(()) +} + impl Deref for Packet where Kind: CheckedPayload + ?Sized, diff --git a/gotatun/src/packet/tcp.rs b/gotatun/src/packet/tcp.rs new file mode 100644 index 00000000..30a38826 --- /dev/null +++ b/gotatun/src/packet/tcp.rs @@ -0,0 +1,262 @@ +use std::fmt; + +use bitfield_struct::bitfield; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned, big_endian}; + +use crate::packet::util::size_must_be; + +use super::{Ipv4, Ipv6}; + +#[repr(C)] +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)] +pub struct Tcp { + pub header: TcpHeader, + pub options_and_payload: OptionsAndPayload, +} + +impl fmt::Debug for Tcp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Tcp") + .field("header", &self.header) + .field("options", &self.options()) + .field("payload", &self.payload()) + .finish() + } +} + +#[repr(C, packed)] +#[derive(Clone, Copy, FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)] +pub struct TcpHeader { + pub source_port: big_endian::U16, + pub destination_port: big_endian::U16, + pub seq_num: big_endian::U32, + pub ack_num: big_endian::U32, + pub data_offset: TcpDataOffset, + pub flags: TcpFlags, + pub window: big_endian::U16, + pub checksum: big_endian::U16, + pub urgent_pointer: big_endian::U16, +} + +#[bitfield(u8, order = Msb)] +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)] +pub struct TcpFlags { + pub cwr: bool, + pub ece: bool, + pub urg: bool, + pub ack: bool, + pub psh: bool, + pub rst: bool, + pub syn: bool, + pub fin: bool, +} + +#[bitfield(u8, order = Msb)] +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)] +pub struct TcpDataOffset { + /// Offset in `u32`s from the start of [TcpHeader] to the start of the payload. + /// + /// Must be at least 5. + #[bits(4)] + pub data_offset: u8, + + #[bits(4)] + _reserved: u8, +} + +impl TcpDataOffset { + /// Set data_offset to `5`, which means that the TCP header contains _no_ options. + pub const fn no_options() -> Self { + TcpDataOffset::new().with_data_offset(5) + } +} + +impl TcpHeader { + /// Length of a [TcpHeader]. Not including TCP options. + pub const LEN: usize = size_must_be::(20); + + pub const fn fin(&self) -> bool { + self.flags.fin() + } + pub const fn syn(&self) -> bool { + self.flags.syn() + } + pub const fn rst(&self) -> bool { + self.flags.rst() + } + pub const fn psh(&self) -> bool { + self.flags.psh() + } + pub const fn ack(&self) -> bool { + self.flags.ack() + } + pub const fn urg(&self) -> bool { + self.flags.urg() + } + pub const fn ece(&self) -> bool { + self.flags.ece() + } + pub const fn cwr(&self) -> bool { + self.flags.cwr() + } + + pub const fn set_fin(&mut self, value: bool) { + self.flags.set_fin(value); + } + pub const fn set_syn(&mut self, value: bool) { + self.flags.set_syn(value); + } + pub const fn set_rst(&mut self, value: bool) { + self.flags.set_rst(value); + } + pub const fn set_psh(&mut self, value: bool) { + self.flags.set_psh(value); + } + pub const fn set_ack(&mut self, value: bool) { + self.flags.set_ack(value); + } + pub const fn set_urg(&mut self, value: bool) { + self.flags.set_urg(value); + } + pub const fn set_ece(&mut self, value: bool) { + self.flags.set_ece(value); + } + pub const fn set_cwr(&mut self, value: bool) { + self.flags.set_cwr(value); + } + + pub const fn data_offset(&self) -> u8 { + self.data_offset.data_offset() + } +} + +impl Tcp { + /// Get the length of the TCP header options, in bytes. + pub fn options_len(&self) -> Option { + let data_offset = usize::from(self.header.data_offset()); + let options_words = data_offset.checked_sub(5)?; + Some(options_words * size_of::()) + } + + /// Get the TCP payload portion of this packet. + /// + /// Returns `None` if [TcpHeader::data_offset] is either: + /// - Malformed (i.e. `data_offset < 5`) + /// - Too big and would overflow [Tcp::options_and_payload]. + pub fn payload(&self) -> Option<&[u8]> { + let i = self.options_len()?; + self.options_and_payload.get(i..) + } + + /// Get the TCP options portion of the header. + /// + /// Returns `None` if [TcpHeader::data_offset] is either: + /// - Malformed (i.e. `data_offset < 5`) + /// - Too big and would overflow [Tcp::options_and_payload]. + pub fn options(&self) -> Option<&[u8]> { + let i = self.options_len()?; + self.options_and_payload.get(..i) + } +} + +impl Ipv4 { + /// Calculate and return the TCP checksum for this packet. + #[must_use] + pub fn calculate_tcp_checksum(&self) -> u16 { + let tcp = &self.payload; + pnet_packet::util::ipv4_checksum( + tcp.as_bytes(), + 8, + &[], + &self.header.source(), + &self.header.destination(), + pnet_packet::ip::IpNextHeaderProtocols::Tcp, + ) + } + + /// Calculate and set the TCP checksum for this packet. + pub fn update_tcp_checksum(&mut self) { + self.payload.header.checksum = self.calculate_tcp_checksum().into(); + } +} + +impl Ipv6 { + /// Calculate and return the TCP checksum for this packet. + #[must_use] + pub fn calculate_tcp_checksum(&self) -> u16 { + let tcp = &self.payload; + pnet_packet::util::ipv6_checksum( + tcp.as_bytes(), + 8, + &[], + &self.header.source(), + &self.header.destination(), + pnet_packet::ip::IpNextHeaderProtocols::Tcp, + ) + } + + /// Calculate and set the TCP checksum for this packet. + pub fn update_tcp_checksum(&mut self) { + self.payload.header.checksum = self.calculate_tcp_checksum().into(); + } +} + +impl fmt::Debug for TcpHeader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TcpHeader") + .field("source_port", &self.source_port.get()) + .field("destination_port", &self.destination_port.get()) + .field("seq_num", &self.seq_num.get()) + .field("ack_num", &self.ack_num.get()) + .field("fin", &self.fin()) + .field("syn", &self.syn()) + .field("rst", &self.rst()) + .field("psh", &self.psh()) + .field("ack", &self.ack()) + .field("urg", &self.urg()) + .field("ece", &self.ece()) + .field("cwr", &self.cwr()) + .field("data_offset", &self.data_offset()) + .field("window", &self.window.get()) + .field("checksum", &self.checksum.get()) + .field("urgent_pointer", &self.urgent_pointer.get()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use crate::packet::{Ipv4, Tcp}; + use zerocopy::TryFromBytes; + + const EXAMPLE_IPV4_TCP: &[u8] = &[ + 0x45, 0x0, 0x1, 0x88, 0x47, 0x7a, 0x40, 0x0, 0x40, 0x6, 0xa5, 0x8f, 0xc0, 0xa8, 0x65, 0x7e, + 0xc0, 0xa8, 0x65, 0x97, 0xc5, 0xd8, 0x17, 0x66, 0x8f, 0x1, 0xa5, 0x50, 0xc2, 0x1d, 0x36, + 0x16, 0x80, 0x18, 0x60, 0x76, 0x7b, 0x9a, 0x0, 0x0, 0x1, 0x1, 0x8, 0xa, 0xcf, 0xc9, 0x84, + 0xe5, 0xd7, 0xd0, 0xdf, 0x50, /* payload snipped */ + ]; + + #[test] + fn tcp_header_layout() { + let packet = Ipv4::::try_ref_from_bytes(EXAMPLE_IPV4_TCP).unwrap(); + let packet = &packet.payload; + let header = &packet.header; + + assert!(header.psh()); + assert!(header.ack()); + + assert!(!header.fin()); + assert!(!header.syn()); + assert!(!header.rst()); + assert!(!header.urg()); + assert!(!header.ece()); + assert!(!header.cwr()); + + assert_eq!(header.data_offset(), 8); + assert_eq!(packet.payload(), Some(&[][..])); + + assert_eq!(header.ack_num, 3256694294); + assert_eq!(header.seq_num, 2399249744); + assert_eq!(header.urgent_pointer, 0); + } +} diff --git a/gotatun/src/tun/tun_async_device.rs b/gotatun/src/tun/tun_async_device.rs index e3fb83aa..1f40ae71 100644 --- a/gotatun/src/tun/tun_async_device.rs +++ b/gotatun/src/tun/tun_async_device.rs @@ -11,8 +11,15 @@ //! Implementations of [`IpSend`] and [`IpRecv`] for the [`tun`] crate. +mod linux; +mod tso; +mod virtio; + +use bytes::BytesMut; use tokio::{sync::watch, time::sleep}; +use tso::try_enable_tso; use tun::AbstractDevice; +use zerocopy::IntoBytes; use crate::{ packet::{Ip, Packet, PacketBufPool}, @@ -20,7 +27,7 @@ use crate::{ tun::{IpRecv, IpSend, MtuWatcher}, }; -use std::{convert::Infallible, io, iter, sync::Arc, time::Duration}; +use std::{convert::Infallible, io, iter, ops::Deref, sync::Arc, time::Duration}; /// Error from [`TunDevice`]. #[derive(Debug, thiserror::Error)] @@ -79,10 +86,18 @@ impl TunDevice { tun_config.platform_config(|p| { p.enable_routing(false); }); + + #[cfg(target_os = "linux")] + tun_config.platform_config(|p| { + p.vnet_hdr(true); + }); + // TODO: for wintun, must set path or enable signature check // we should upstream to `tun` let tun = tun::create_as_async(&tun_config).map_err(Error::OpenTun)?; + try_enable_tso(tun.deref()).unwrap(); let tun = TunDevice::from_tun_device(tun)?; + Ok(tun) } @@ -133,13 +148,31 @@ impl TunDevice { } } +// TODO +const VNET_HDR: bool = true; impl IpSend for TunDevice { async fn send(&mut self, packet: Packet) -> io::Result<()> { + let mut packet = packet.into_bytes(); + if VNET_HDR { + let header = virtio::VirtioNetHeader { + flags: virtio::Flags::new(), + gso_type: virtio::GsoType::VIRTIO_NET_HDR_GSO_NONE, + hdr_len: 0, + gso_size: 0, + csum_start: 0, + csum_offset: 0, + }; + let mut buf = BytesMut::new(); + buf.extend_from_slice(header.as_bytes()); + buf.extend_from_slice(packet.as_bytes()); + *packet.buf_mut() = buf; + } self.tun.send(&packet.into_bytes()).await?; Ok(()) } } +#[cfg(not(any(target_os = "linux", target_os = "android")))] impl IpRecv for TunDevice { async fn recv<'a>( &'a mut self, @@ -158,3 +191,51 @@ impl IpRecv for TunDevice { self.state.mtu.clone() } } + +#[cfg(any(target_os = "linux", target_os = "android"))] +impl IpRecv for TunDevice { + async fn recv<'a>( + &'a mut self, + pool: &mut PacketBufPool, + ) -> io::Result> + 'a> { + use bytes::BytesMut; + use either::Either; + use zerocopy::FromBytes; + + use crate::tun::tun_async_device::virtio::VirtioNetHeader; + + // FIXME: pool buffers have a cap of 4096, but we need more + //let mut packet = pool.get(); + let _ = pool; + + let mut buf = BytesMut::zeroed(usize::from(u16::MAX)); + let n = self.tun.recv(&mut buf).await?; + buf.truncate(n); + + let vnet_hdr = buf.split_to(size_of::()); + let vnet_hdr = *VirtioNetHeader::ref_from_bytes(&vnet_hdr).unwrap(); + + let packet = Packet::from_bytes(buf) + .try_into_ipvx() + .map_err(|e| io::Error::other(e.to_string()))?; + + // TODO + let mtu = 1200; + + // TODO: if segmentation and checksum offload is disabled, + // we could take a more efficient branch where we do not need to check + // packet length, and whether it's an IP/TCP packet. + match packet { + Either::Left(ipv4_packet) => { + tso::new_tso_iter_ipv4(ipv4_packet, usize::from(vnet_hdr.gso_size)) + } + Either::Right(ipv6_packet) => { + tso::new_tso_iter_ipv6(ipv6_packet, usize::from(vnet_hdr.gso_size)) + } + } + } + + fn mtu(&self) -> MtuWatcher { + self.state.mtu.clone() + } +} diff --git a/gotatun/src/tun/tun_async_device/linux.rs b/gotatun/src/tun/tun_async_device/linux.rs new file mode 100644 index 00000000..e69de29b diff --git a/gotatun/src/tun/tun_async_device/tso.rs b/gotatun/src/tun/tun_async_device/tso.rs new file mode 100644 index 00000000..06dc8f0c --- /dev/null +++ b/gotatun/src/tun/tun_async_device/tso.rs @@ -0,0 +1,439 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// This file incorporates work covered by the following copyright and +// permission notice: +// +// Copyright (c) Mullvad VPN AB. All rights reserved. +// +// SPDX-License-Identifier: MPL-2.0 + +use crate::packet::{ + Ip, IpNextProtocol, Ipv4, Ipv4Header, Ipv4VersionIhl, Ipv6, Ipv6Header, Packet, Tcp, TcpHeader, +}; +use bytes::BytesMut; +use duplicate::duplicate_item; +use libc::{TUN_F_CSUM, TUN_F_TSO4, TUN_F_TSO6, TUNSETOFFLOAD}; +use std::io; +use std::os::fd::AsRawFd; +use zerocopy::{FromBytes, IntoBytes}; + +/// Enable TCP offloading on the given tun device +/// +/// Returns `EINVAL` if TSO is not supported (pre Linux 2.6) +/// +/// +pub fn try_enable_tso(tun: &impl AsRawFd) -> io::Result<()> { + // TODO: ask the OS what linux version we're running. + let linux_version = (6, 16); + + let offload_flags = match linux_version { + v if v >= (6, 2) => { + TUN_F_CSUM // checksum offload, this is required for TSO + | TUN_F_TSO4 // TCP segmentation offload (IPv4) + | TUN_F_TSO6 // TCP segmentation offload (IPv6) + // TODO: TUN_F_USO4 + // TODO: TUN_F_USO6 + } + + v if v >= (2, 6) => { + TUN_F_CSUM // checksum offload, this is required for TSO + | TUN_F_TSO4 // TCP segmentation offload (IPv4) + | TUN_F_TSO6 // TCP segmentation offload (IPv6) + } + + _ => return Err(io::ErrorKind::InvalidInput.into()), + }; + + let tun_fd = tun.as_raw_fd(); + + // SAFETY: TODO: perfectly safe + let status = unsafe { libc::ioctl(tun_fd, TUNSETOFFLOAD, offload_flags) }; + if status != 0 { + return Err(io::Error::last_os_error()); + } + + Ok(()) +} + +#[duplicate_item( + new_tso_iter_ipvx IpvX IpvXHeader CoalescedIpvX; + [new_tso_iter_ipv4] [Ipv4] [Ipv4Header] [CoalescedIpv4]; + [new_tso_iter_ipv6] [Ipv6] [Ipv6Header] [CoalescedIpv6]; + )] +pub fn new_tso_iter_ipvx(ipvx_packet: Packet, gso_size: usize) -> io::Result { + let packet_len = ipvx_packet.as_bytes().len(); + + match ipvx_packet.header.next_protocol() { + IpNextProtocol::Tcp => { + let mut tcp_packet = ipvx_packet + .try_into_tcp() + .map_err(|e| io::Error::other(e.to_string()))?; + + // TODO: also check gso_type + if 0 < gso_size && gso_size < packet_len { + let tcp_options_len = tcp_packet + .payload + .options() + .expect("We've validated the TCP packet") + .len(); + let header_len = IpvXHeader::LEN + TcpHeader::LEN + tcp_options_len; + + let mut packet = tcp_packet.into_bytes(); + + // Split the giant packet into IP/TCP header and giant payload. The payload + // will be segmented, and the header will be prepended to each segment. + let headers = packet.buf_mut().split_to(header_len); + let payload = packet; + let mut headers = Packet::from_bytes(headers); + + // Update IP header length field + IpvX::::mut_from_bytes(headers.as_mut_bytes()) + .expect("`headers` contains Ip/Tcp headers") + .try_update_ip_len() + .expect("IP packet is not too large"); + + let headers = Packet::::try_from(headers) + .and_then(|headers| headers.try_into_tcp()) + .expect("We're copying valid IP/TCP headers"); + + // Length of the giant payload. + let payload_len = packet_len - header_len; + + // Target size of the segment payloads + // TODO: does gso_size already exclude headers? + let segment_payload_len = gso_size + .checked_sub(header_len) + .unwrap_or_else(|| panic!("gso_size ({gso_size}) must be greater than the length of the IP/TCP headers ({header_len})")); + + // We'll need this many segments + let segment_count = payload_len.div_ceil(segment_payload_len); + + // TODO: segmentation should not block the next tun.read + return Ok(TsoIter::CoalescedIpvX { + // TODO: consider using a pool + buf: BytesMut::with_capacity((header_len + gso_size) * segment_count), + segment_payload_len: segment_payload_len, + headers, + payload, + i: 0, + }); + } + + tcp_packet.update_tcp_checksum(); + + Ok(TsoIter::SinglePacket { + packet: Some(tcp_packet.into()), + }) + } + _ => Ok(TsoIter::SinglePacket { + packet: Some(ipvx_packet.into()), + }), + } +} + +/// An iterator that segments a large TCP packet into smaller TCP packets. +pub enum TsoIter { + SinglePacket { + packet: Option>, + }, + CoalescedIpv4 { + buf: BytesMut, + + i: usize, + segment_payload_len: usize, + + headers: Packet>, + payload: Packet<[u8]>, + }, + CoalescedIpv6 { + buf: BytesMut, + + i: usize, + segment_payload_len: usize, + + headers: Packet>, + payload: Packet<[u8]>, + }, +} + +impl Iterator for TsoIter { + type Item = Packet; + + fn next(&mut self) -> Option { + match self { + TsoIter::SinglePacket { packet } => packet.take(), + + TsoIter::CoalescedIpv4 { + buf, + i, + segment_payload_len, + headers, + payload, + } => { + if payload.is_empty() { + return None; + } + + // TODO: remove me + if cfg!(debug_assertions) { + log::info!("##########"); + log::info!( + "TSO (v4): i={i} buf.len={}, payload.len={}, segment_payload_len={segment_payload_len}", + buf.len(), + payload.len() + ); + log::info!("##########"); + } + + let len = payload.len().min(*segment_payload_len); + let segment_payload = payload.buf_mut().split_to(len).freeze(); + + let is_last_segment = payload.is_empty(); + + // Headers from the original TSO packet + let ipv4_header = &headers.header; + let tcp_header = &headers.payload.header; + let tcp_options = headers.payload.options(); + let tcp_options = tcp_options.expect("We've validated the TCP header"); + + let seq_num = (*segment_payload_len).wrapping_mul(*i) as u32; + let seq_num = seq_num.wrapping_add(tcp_header.seq_num.get()); + + // TODO: explain how identification works and why we need to inc it + let identification = ipv4_header.identification.get(); + let identification = identification.wrapping_add(*i as u16); + + let total_len = (const { Ipv4Header::LEN + TcpHeader::LEN } + + tcp_options.len() + + segment_payload.len()) as u16; + + // Use them to construct the headers for this segment + let mut segment_headers = Ipv4 { + header: Ipv4Header { + version_and_ihl: Ipv4VersionIhl::new().with_version(4).with_ihl(5), + dscp_and_ecn: ipv4_header.dscp_and_ecn, + total_len: total_len.into(), + + identification: identification.into(), + + // TODO: handle this field + // we should never receive fragmented ip packets, i *think*. + flags_and_fragment_offset: ipv4_header.flags_and_fragment_offset, + + time_to_live: ipv4_header.time_to_live, + + protocol: IpNextProtocol::Tcp, + header_checksum: 0.into(), + + source_address: ipv4_header.source_address, + destination_address: ipv4_header.destination_address, + }, + payload: TcpHeader { + source_port: tcp_header.source_port, + destination_port: tcp_header.destination_port, + + seq_num: seq_num.into(), + ack_num: tcp_header.ack_num, + + data_offset: tcp_header.data_offset, + flags: tcp_header.flags, + window: tcp_header.window, + checksum: 0.into(), + urgent_pointer: tcp_header.urgent_pointer, + }, + }; + + if !is_last_segment { + segment_headers.payload.set_fin(false); + segment_headers.payload.set_psh(false); + } + + // Copy the data into the large `buf` allocation and split it off into Packet. + buf.extend_from_slice(segment_headers.as_bytes()); + buf.extend_from_slice(tcp_options); + buf.extend_from_slice(&segment_payload); + let packet = Packet::from_bytes(buf.split()); + + let mut packet = packet + .try_into_ipvx() + .map(|either| either.expect_left("The packet is IPv4")) + .and_then(|packet| packet.try_into_tcp()) + .expect("we've correctly initialized the packet"); + + packet.update_tcp_checksum(); + packet.update_ip_checksum(); + + *i += 1; + + Some(packet.into()) + } + + TsoIter::CoalescedIpv6 { + buf, + i, + segment_payload_len, + headers, + payload, + } => { + if payload.is_empty() { + return None; + } + + // TODO: remove me + if cfg!(debug_assertions) { + log::info!("##########"); + log::info!( + "TSO (v6): i={i} buf.len={}, payload.len={}, segment_payload_len={segment_payload_len}", + buf.len(), + payload.len() + ); + log::info!("##########"); + } + + let len = payload.len().min(*segment_payload_len); + let segment_payload = payload.buf_mut().split_to(len).freeze(); + + let is_last_segment = payload.is_empty(); + + // Headers from the original TSO packet + let ipv6_header = &headers.header; + let tcp_header = &headers.payload.header; + let tcp_options = headers.payload.options().unwrap_or_default(); + + let seq_num = (*segment_payload_len).wrapping_mul(*i) as u32; + let seq_num = seq_num.wrapping_add(tcp_header.seq_num.get()); + + // Use them to construct the headers for this segment + let mut segment_headers = Ipv6 { + header: Ipv6Header { + version_traffic_flow: headers.header.version_traffic_flow, + payload_length: (TcpHeader::LEN + + tcp_options.len() + + segment_payload.len()) + .try_into() + .unwrap(), + + next_header: IpNextProtocol::Tcp, + hop_limit: headers.header.hop_limit, + + source_address: ipv6_header.source_address, + destination_address: ipv6_header.destination_address, + }, + + // TODO: deduplicate with CoalescedIpv4 + payload: TcpHeader { + source_port: tcp_header.source_port, + destination_port: tcp_header.destination_port, + + seq_num: seq_num.into(), + ack_num: tcp_header.ack_num, + + data_offset: tcp_header.data_offset, + flags: tcp_header.flags, + window: tcp_header.window, + checksum: 0.into(), + urgent_pointer: tcp_header.urgent_pointer, + }, + }; + + if !is_last_segment { + segment_headers.payload.set_fin(false); + segment_headers.payload.set_psh(false); + } + + // Copy the data into the large `buf` allocation and split it off into Packet. + buf.extend_from_slice(segment_headers.as_bytes()); + buf.extend_from_slice(tcp_options); + buf.extend_from_slice(&segment_payload); + let packet = Packet::from_bytes(buf.split()); + + let mut packet = packet + .try_into_ipvx() + .map(|either| either.expect_right("The packet is IPv6")) + .and_then(|packet| packet.try_into_tcp()) + .expect("we've correctly initialized the packet"); + + packet.update_tcp_checksum(); + + *i += 1; + + Some(packet.into()) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::Ipv4Addr; + + #[test] + fn test_tso_split() { + // TODO: test with TCP options + + let tcp = Tcp { + header: TcpHeader { + source_port: 111.into(), + destination_port: 222.into(), + seq_num: 1000.into(), + ack_num: 444.into(), + data_offset: crate::packet::TcpDataOffset::no_options(), + // TODO: more flags? + flags: crate::packet::TcpFlags::new() + .with_syn(true) + .with_ack(true) + .with_fin(true), + window: 555.into(), + checksum: 0.into(), // TODO + urgent_pointer: 666.into(), + }, + options_and_payload: *b"1st segment!\02nd segment!\03rd segment?\0", + }; + + let mut ip = Ipv4 { + header: Ipv4Header { + identification: 1212.into(), + ..Ipv4Header::new_for_length( + Ipv4Addr::new(1, 2, 3, 4), + Ipv4Addr::new(4, 3, 2, 1), + IpNextProtocol::Tcp, + tcp.as_bytes().len().try_into().unwrap(), + ) + }, + payload: tcp, + }; + + let payload_segment_size = 13; + let mtu = payload_segment_size + size_of::>(); + let expected_payloads: Vec = ip + .payload + .options_and_payload + .chunks(payload_segment_size) + .map(|bytes| std::str::from_utf8(bytes).unwrap().to_string()) + .collect(); + + ip.update_ip_checksum(); + + let packet = Packet::copy_from(ip.as_bytes()); + let packet = packet.try_into_ipvx().unwrap().unwrap_left(); + let packet = packet.try_into_tcp().unwrap(); + + let segmented_packets: Vec<_> = new_tso_iter_ipv4(packet.into(), mtu) + .unwrap() + .map(|packet| packet.try_into_ipvx().unwrap().unwrap_left()) + .map(|packet| packet.try_into_tcp().unwrap()) + .collect(); + + println!("tso count: {}", segmented_packets.len()); + for (packet, expected_payload) in segmented_packets.into_iter().zip(expected_payloads) { + let payload = packet.payload.payload().unwrap(); + assert_eq!(payload, expected_payload.as_bytes()); + println!("{:#?}", &*packet); + } + + panic!() // TODO: remove me + } +} diff --git a/gotatun/src/tun/tun_async_device/virtio.rs b/gotatun/src/tun/tun_async_device/virtio.rs new file mode 100644 index 00000000..8138014e --- /dev/null +++ b/gotatun/src/tun/tun_async_device/virtio.rs @@ -0,0 +1,240 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// This file incorporates work covered by the following copyright and +// permission notice: +// +// Copyright (c) Mullvad VPN AB. All rights reserved. +// +// SPDX-License-Identifier: MPL-2.0 + +//! Implementation of the Virtio net header. +//! +//! The header can be enabled on TUN devices using the [libc::IFF_VNET_HDR]-flag, +//! or using [tun::PlatformConfig::vnet_hdr], and enables use of GSO. + +use bitfield_struct::bitfield; +use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned}; + +/// See [module](self) docs. +/// +/// Definition in linux include/uapi/linux/virtio_net.h +#[repr(C)] +#[derive(Debug, FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)] +pub struct VirtioNetPacket { + pub header: VirtioNetHeader, + pub payload: Payload, +} + +/// See [module](self) docs. +/// +/// Definition in linux include/uapi/linux/virtio_net.h +#[repr(C, packed)] +#[derive( + Clone, Copy, Debug, FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq, +)] +pub struct VirtioNetHeader { + pub flags: Flags, + + pub gso_type: GsoType, + + /// Ethernet + IP + tcp/udp headers + pub hdr_len: u16, + + /// Bytes to append to `hdr_len` per frame + pub gso_size: u16, + + /// Position to start checksumming from + pub csum_start: u16, + + /// Offset after that to place checksum + pub csum_offset: u16, +} + +/// A field of [VirtioNetHeader]. +#[repr(transparent)] +#[derive(Clone, Copy, FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)] +pub struct GsoType(u8); + +impl GsoType { + /// Not a GSO frame + pub const VIRTIO_NET_HDR_GSO_NONE: GsoType = GsoType(0); + + /// GSO frame, IPv4 TCP (TSO) + pub const VIRTIO_NET_HDR_GSO_TCPV4: GsoType = GsoType(1); + + /// GSO frame, IPv4 UDP (UFO) + pub const VIRTIO_NET_HDR_GSO_UDP: GsoType = GsoType(3); + + /// GSO frame, IPv6 TCP + pub const VIRTIO_NET_HDR_GSO_TCPV6: GsoType = GsoType(4); + + /// GSO frame, IPv4& IPv6 UDP (USO) + pub const VIRTIO_NET_HDR_GSO_UDP_L4: GsoType = GsoType(5); + + /// TCP has ECN set + pub const VIRTIO_NET_HDR_GSO_ECN: GsoType = GsoType(0x80); +} + +impl std::fmt::Debug for GsoType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let name = match *self { + GsoType::VIRTIO_NET_HDR_GSO_NONE => "VIRTIO_NET_HDR_GSO_NONE ", + GsoType::VIRTIO_NET_HDR_GSO_TCPV4 => "VIRTIO_NET_HDR_GSO_TCPV4 ", + GsoType::VIRTIO_NET_HDR_GSO_UDP => "VIRTIO_NET_HDR_GSO_UDP", + GsoType::VIRTIO_NET_HDR_GSO_TCPV6 => "VIRTIO_NET_HDR_GSO_TCPV6", + GsoType::VIRTIO_NET_HDR_GSO_UDP_L4 => "VIRTIO_NET_HDR_GSO_UDP_L4", + GsoType::VIRTIO_NET_HDR_GSO_ECN => "VIRTIO_NET_HDR_GSO_ECN", + GsoType(..) => "UNKNOWN_GSO_TYPE", + }; + + f.debug_tuple(name).field(&self.0).finish() + } +} + +/// A field of [VirtioNetHeader]. +#[bitfield(u8)] +#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)] +pub struct Flags { + /// Use csum_start, csum_offset + pub needs_csum: bool, + /// Csum is valid + pub data_valid: bool, + /// rsc info in csum_ fields + pub rsc_info: bool, + + #[bits(5)] + _reserved: u8, +} + +// TODO: handle VIRTIO_NET_HDR_GSO_NONE and VIRTIO_NET_HDR_F_NEEDS_CSUM + +/* + * TODO: + * +/// Split big packet with multiple segments into independent IP packets +/// +/// The unsegmented packet is a VirtioNetHeader followed by an IP header, and a TCP header. +/// Followed by GSO segment-sized (specified in VirtioNetHeader) TCP segments. +fn gso_split(packet: &mut VirtioNetPacket<[u8]>) { + let hdr = &packet.header; + let payload = &mut packet.payload; + + let ip = Ip::try_ref_from_bytes(payload).unwrap(); + + // Clear IPv4 checksum + if ip.header.version() == 4 { + let ipv4 = Ipv4::<[u8]>::mut_from_bytes(payload).unwrap(); + ipv4.header.header_checksum = 0.into(); + let mut ipv4_id ip.header.identification; + } else { + panic!("no IPv6"); + } + + // Check GSO type (UDP or TCP) + // And clear TCP/UDP checksum + match hdr.gso_type { + GsoType::VIRTIO_NET_HDR_GSO_TCPV4 | GsoType::VIRTIO_NET_HDR_GSO_TCPV6 => { + // FIXME: IPv6 + let tcp = Ipv4::::mut_from_bytes(payload).unwrap(); + tcp.header.header_checksum = 0.into(); + } + GsoType::VIRTIO_NET_HDR_GSO_UDP => { + let udp = Ipv4::::mut_from_bytes(payload).unwrap(); + udp.header.header_checksum = 0.into(); + } + // FIXME: handle VIRTIO_NET_HDR_GSO_UDP_L4 + // see https://github.com/WireGuard/wireguard-go/blob/f333402bd9cbe0f3eeb02507bd14e23d7d639280/tun/tun_linux.go#L421 + // TODO: Is it actually unreachable? + _ => unreachable!(), + }; + let hdr = &packet.header; + let payload = &mut packet.payload; + + let ip = Ip::try_ref_from_bytes(payload).unwrap(); + + // Clear IPv4 checksum + if ip.header.version() == 4 { + let ipv4 = Ipv4::<[u8]>::mut_from_bytes(payload).unwrap(); + ipv4.header.header_checksum = 0.into(); + let mut ipv4_id ip.header.identification; + } else { + panic!("no IPv6"); + } + + // Check GSO type (UDP or TCP) + // And clear TCP/UDP checksum + match hdr.gso_type { + GsoType::VIRTIO_NET_HDR_GSO_TCPV4 | GsoType::VIRTIO_NET_HDR_GSO_TCPV6 => { + // FIXME: IPv6 + let tcp = Ipv4::::mut_from_bytes(payload).unwrap(); + tcp.header.header_checksum = 0.into(); + } + GsoType::VIRTIO_NET_HDR_GSO_UDP => { + let udp = Ipv4::::mut_from_bytes(payload).unwrap(); + udp.header.header_checksum = 0.into(); + } + // FIXME: handle VIRTIO_NET_HDR_GSO_UDP_L4 + // see https://github.com/WireGuard/wireguard-go/blob/f333402bd9cbe0f3eeb02507bd14e23d7d639280/tun/tun_linux.go#L421 + // TODO: Is it actually unreachable? + _ => unreachable!(), + }; + + let ip_header_len = usize::from(hdr.csum_start); + let ip_header = &payload[..ip_header_len]; + let transport_header_end = usize::from(hdr.hdr_len); + let transport_header_len = usize::from(hdr.hdr_len) - ip_header_len; + let transport_header = &payload[ip_header_len..transport_header_end]; + + let first_segment_index = ip_header_len + usize::from(hdr.csum_offset); + + let (header, segments) = payload.split_at_mut(first_segment_index); + + for (i, segment) in segments.chunks(usize::from(hdr.gso_size)).enumerate() { + let mut out = BytesMut::new(); + + // copy the IP header, plus the TCP or UDP header that follows. + out.extend_from_slice(&header[..usize::from(hdr.hdr_len)]); + out.extend_from_slice(segment); + + // FIXME: + // IPv6: Set payload length field + + match (ip_version, proto) { + (4, tcp) => { + // IPv4: Increment ID field, set total length, compute checksum + + // TODO: Do we need to care about ipv4 options or ipv6 extra headers? + let segment = Ipv4::::mut_from_bytes(&mut out) + .expect("header + segment should be large enough"); + + // TODO: improve + ipv4_id += 1; + segment.header.identification = ipv4_id; + + + // TCP: Set sequence and flags + // TODO: update TCP sequence number + let tcp_header = &mut segment..payload.header; + tcp_header.seq_num = first_tcp_seq_num + gso_size * i; + + if !last_tcp_segment { + tcp_header.set_fin(false); + tcp_header.set_psh(false); + } + // TODO: update TCP flags + // - only the last segmented packet may have FIN and PSH set. + } + (_, udp) => { + // UDP: Set header length field + } + } + + out.extend_from_slice(&payload[..usize::from(hdr.hdr_len)]); + out.extend_from_slice(segment); + // TODO: Compute UDP/TCP checksum + + } +} +*/