Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[derive] Support TryFromBytes on field-less enums #803

Merged
merged 1 commit into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,8 @@ pub unsafe trait TryFromBytes {
// SAFETY: `candidate` has no uninitialized sub-ranges because it
// derived from `bytes: &[u8]`, and is therefore at least as-initialized
// as `Self`.
let candidate = unsafe { candidate.assume_as_initialized() };
let candidate =
unsafe { candidate.assume_validity::<crate::pointer::invariant::AsInitialized>() };

// This call may panic. If that happens, it doesn't cause any soundness
// issues, as we have not generated any invalid state which we need to
Expand Down
4 changes: 2 additions & 2 deletions src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ macro_rules! unsafe_impl {

// SAFETY: The caller has promised that the referenced memory region
// will contain a valid `$repr`.
let $candidate = unsafe { candidate.assume_valid() };
let $candidate = unsafe { candidate.assume_validity::<crate::pointer::invariant::Valid>() };
$is_bit_valid
}
};
Expand All @@ -166,7 +166,7 @@ macro_rules! unsafe_impl {

// SAFETY: The caller has promised that `$repr` is as-initialized as
// `Self`.
let $candidate = unsafe { $candidate.assume_as_initialized() };
let $candidate = unsafe { $candidate.assume_validity::<crate::pointer::invariant::AsInitialized>() };

$is_bit_valid
}
Expand Down
4 changes: 2 additions & 2 deletions src/pointer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ where
if T::is_bit_valid(self.forget_aligned()) {
// SAFETY: If `T::is_bit_valid`, code may assume that `self`
// contains a bit-valid instance of `Self`.
Some(unsafe { self.assume_valid() })
Some(unsafe { self.assume_validity::<invariant::Valid>() })
} else {
None
}
Expand Down Expand Up @@ -82,7 +82,7 @@ where
{
// SAFETY: The alignment of `T` is 1 and thus is always aligned
// because `T: Unaligned`.
let ptr = unsafe { self.assume_aligned() };
let ptr = unsafe { self.assume_alignment::<invariant::Aligned>() };
ptr.as_ref()
}
}
56 changes: 37 additions & 19 deletions src/pointer/ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,51 +458,69 @@ mod _transitions {
T: 'a + ?Sized,
I: Invariants,
{
/// Assumes that `Ptr`'s referent is validly-aligned for `T`.
/// Assumes that `Ptr`'s referent is validly-aligned for `T` if required
/// by `A`.
///
/// # Safety
///
/// The caller promises that `Ptr`'s referent conforms to the alignment
/// invariant of `T`.
/// invariant of `T` if required by `A`.
#[inline]
pub(crate) unsafe fn assume_aligned(
pub(crate) unsafe fn assume_alignment<A: invariant::Alignment>(
self,
) -> Ptr<'a, T, (I::Aliasing, invariant::Aligned, I::Validity)> {
) -> Ptr<'a, T, (I::Aliasing, A, I::Validity)> {
// SAFETY: The caller promises that `self`'s referent is
// well-aligned for `T`.
// well-aligned for `T` if required by `A` .
unsafe { Ptr::from_ptr(self) }
}

/// Assumes that `Ptr`'s referent is as-initialized as `T`.
/// Assumes that `Ptr`'s referent conforms to the validity requirement
/// of `V`.
///
/// # Safety
///
/// The caller promises that `Ptr`'s referent conforms to the
/// [`invariant::AsInitialized`] invariant (see documentation there).
/// The caller promises that `Ptr`'s referent conforms to the validity
/// requirement of `V`.
#[doc(hidden)]
#[inline]
pub unsafe fn assume_validity<V: invariant::Validity>(
self,
) -> Ptr<'a, T, (I::Aliasing, I::Alignment, V)> {
// SAFETY: The caller promises that `self`'s referent conforms to
// the validity requirement of `V`.
unsafe { Ptr::from_ptr(self) }
}

/// A shorthand for `self.assume_validity<invariant::AsInitialized>()`.
///
/// # Safety
///
/// The caller promises to uphold the safety preconditions of
/// `self.assume_validity<invariant::AsInitialized>()`.
#[doc(hidden)]
#[inline]
pub unsafe fn assume_as_initialized(
self,
) -> Ptr<'a, T, (I::Aliasing, I::Alignment, invariant::AsInitialized)> {
// SAFETY: The caller promises that `self`'s referent only contains
// uninitialized bytes in a subset of the uninitialized ranges in
// `T`. for `T`.
unsafe { Ptr::from_ptr(self) }
// SAFETY: The caller has promised to uphold the safety
// preconditions.
unsafe { self.assume_validity::<invariant::AsInitialized>() }
}

/// Assumes that `Ptr`'s referent is validly initialized for `T`.
/// A shorthand for `self.assume_validity<invariant::Valid>()`.
///
/// # Safety
///
/// The caller promises that `Ptr`'s referent conforms to the
/// bit validity invariants on `T`.
/// The caller promises to uphold the safety preconditions of
/// `self.assume_validity<invariant::Valid>()`.
#[doc(hidden)]
#[inline]
pub(crate) unsafe fn assume_valid(
pub unsafe fn assume_valid(
self,
) -> Ptr<'a, T, (I::Aliasing, I::Alignment, invariant::Valid)> {
// SAFETY: The caller promises that `self`'s referent is bit-valid
// for `T`.
unsafe { Ptr::from_ptr(self) }
// SAFETY: The caller has promised to uphold the safety
// preconditions.
unsafe { self.assume_validity::<invariant::Valid>() }
}

/// Forgets that `Ptr`'s referent is validly-aligned for `T`.
Expand Down
147 changes: 143 additions & 4 deletions zerocopy-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use {
quote::quote,
syn::{
parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, Expr, ExprLit,
GenericParam, Ident, Lit,
GenericParam, Ident, Index, Lit,
},
};

Expand Down Expand Up @@ -268,9 +268,7 @@ pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenSt
let ast = syn::parse_macro_input!(ts as DeriveInput);
match &ast.data {
Data::Struct(strct) => derive_try_from_bytes_struct(&ast, strct),
Data::Enum(_) => {
Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error()
}
Data::Enum(enm) => derive_try_from_bytes_enum(&ast, enm),
Data::Union(unn) => derive_try_from_bytes_union(&ast, unn),
}
.into()
Expand Down Expand Up @@ -401,6 +399,147 @@ const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[
&[StructRepr::C, StructRepr::Packed],
];

fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream {
if !enm.is_fieldless() {
return Error::new_spanned(ast, "only field-less enums can implement TryFromBytes")
.to_compile_error();
}

let reprs = try_or_print!(ENUM_TRY_FROM_BYTES_CFG.validate_reprs(ast));
let discriminant_type = match reprs.as_slice() {
[EnumRepr::U8] => quote!(u8),
[EnumRepr::U16] => quote!(u16),
[EnumRepr::U32] => quote!(u32),
[EnumRepr::U64] => quote!(u64),
[EnumRepr::Usize] => quote!(usize),
[EnumRepr::I8] => quote!(i8),
[EnumRepr::I16] => quote!(i16),
[EnumRepr::I32] => quote!(i32),
[EnumRepr::I64] => quote!(i64),
[EnumRepr::Isize] => quote!(isize),
// `validate_reprs` has already validated that it's one of the preceding
// patterns.
_ => unreachable!(),
};

let discriminant_exprs = enm.variants.iter().scan(Discriminant::default(), |disc, var| {
Some(disc.update_and_generate_expr(&var.discriminant))
});
let extras = Some(quote!(
// SAFETY: We use `is_bit_valid` to validate that the bit pattern
// corresponds to one of the C-like enum's variant discriminants.
// Thus, this is a sound implementation of `is_bit_valid`.
fn is_bit_valid(
candidate: ::zerocopy::Ptr<
'_,
Self,
(
::zerocopy::pointer::invariant::Shared,
::zerocopy::pointer::invariant::AnyAlignment,
::zerocopy::pointer::invariant::AsInitialized,
),
>,
) -> bool {
// SAFETY:
// - `cast` is implemented as required.
// - Since we cast to the type specified by `Self`'s repr, `p`'s
// referent and the referent of the returned pointer have the
// same size.
let discriminant = unsafe { candidate.cast_unsized(|p: *mut Self| p as *mut ::zerocopy::macro_util::core_reexport::primitive::#discriminant_type) };
// SAFETY: Since `candidate` has the invariant `AsInitialized`,
// we know that `candidate`'s referent (and thus
// `discriminant`'s referent) is as-initialized as `Self`. Since
// `Self`'s repr is the same type as `discriminant`, we know
// that `discriminant`'s referent satisfies the as-initialized
// property.
let discriminant = unsafe { discriminant.assume_valid() };
let discriminant = discriminant.read_unaligned();

false #(|| (discriminant == (#discriminant_exprs)))*
}
));
impl_block(ast, enm, Trait::TryFromBytes, RequireBoundedFields::Yes, false, None, extras)
}

// Enum variant discriminants can be manually set not only as literal values,
// but as arbitrary const expressions. In order to handle this, we keep track of
// the most-recently-seen expression and a count of how many variants have been
// encountered since then.
//
// #[repr(u8)]
// enum Foo {
// A, // 0
// B = 5, // 5
// C, // 6
// D = 1 + 1, // 2
// E, // 3
// }
//
// Note: Default::default does the right thing (initializes to { None, 0 }).
#[derive(Default, Copy, Clone)]
struct Discriminant<'a> {
// The most-recently-set explicit discriminant.
previous: Option<&'a Expr>,
// When the next variant is encountered, what offset should be used compared
// to `previous` to determine the variant's discriminant?
next_offset: usize,
}

impl<'a> Discriminant<'a> {
/// Called when encountering a variant with discriminant set to `ast`.
/// Updates `self` in preparation for the next variant and generates an
/// expression which will evaluate to the numeric value this variant's
/// discriminant.
fn update_and_generate_expr(
&mut self,
ast: &'a Option<(syn::token::Eq, Expr)>,
) -> proc_macro2::TokenStream {
match ast.as_ref().map(|(_eq, expr)| expr) {
Some(expr) => {
self.previous = Some(expr);
self.next_offset = 1;
quote!(#expr)
}
None => {
let previous = self.previous.iter();
// Use `Index` instead of `usize` so that the number is
// formatted just as `0` rather than as `0usize`; the latter
// syntax is only valid if the repr is `usize`; otherwise,
// comparison will result in a type mismatch.
let offset = Index::from(self.next_offset);
let tokens = quote!(#(#previous +)* #offset);

self.next_offset += 1;
tokens
}
}
}
}

#[rustfmt::skip]
const ENUM_TRY_FROM_BYTES_CFG: Config<EnumRepr> = {
use EnumRepr::*;
Config {
allowed_combinations_message: r#"TryFromBytes requires repr of "u8", "u16", "u32", "u64", "usize", "i8", or "i16", "i32", "i64", or "isize""#,
derive_unaligned: false,
allowed_combinations: &[
&[U8],
&[U16],
&[U32],
&[U64],
&[Usize],
&[I8],
&[I16],
&[I32],
&[I64],
&[Isize],
],
disallowed_but_legal_combinations: &[
&[C],
],
}
};

// A struct is `FromZeros` if:
// - all fields are `FromZeros`

Expand Down
Loading