Skip to content

Commit c7447de

Browse files
committed
[derive] Support TryFromBytes on field-less enums
Makes progress on #5
1 parent dd107ad commit c7447de

File tree

10 files changed

+650
-562
lines changed

10 files changed

+650
-562
lines changed

src/lib.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1244,7 +1244,8 @@ pub unsafe trait TryFromBytes {
12441244
// SAFETY: `candidate` has no uninitialized sub-ranges because it
12451245
// derived from `bytes: &[u8]`, and is therefore at least as-initialized
12461246
// as `Self`.
1247-
let candidate = unsafe { candidate.assume_as_initialized() };
1247+
let candidate =
1248+
unsafe { candidate.assume_validity::<crate::pointer::invariant::AsInitialized>() };
12481249

12491250
// This call may panic. If that happens, it doesn't cause any soundness
12501251
// issues, as we have not generated any invalid state which we need to

src/macros.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ macro_rules! unsafe_impl {
146146

147147
// SAFETY: The caller has promised that the referenced memory region
148148
// will contain a valid `$repr`.
149-
let $candidate = unsafe { candidate.assume_valid() };
149+
let $candidate = unsafe { candidate.assume_validity::<crate::pointer::invariant::Valid>() };
150150
$is_bit_valid
151151
}
152152
};
@@ -166,7 +166,7 @@ macro_rules! unsafe_impl {
166166

167167
// SAFETY: The caller has promised that `$repr` is as-initialized as
168168
// `Self`.
169-
let $candidate = unsafe { $candidate.assume_as_initialized() };
169+
let $candidate = unsafe { $candidate.assume_validity::<crate::pointer::invariant::AsInitialized>() };
170170

171171
$is_bit_valid
172172
}

src/pointer/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ where
4040
if T::is_bit_valid(self.forget_aligned()) {
4141
// SAFETY: If `T::is_bit_valid`, code may assume that `self`
4242
// contains a bit-valid instance of `Self`.
43-
Some(unsafe { self.assume_valid() })
43+
Some(unsafe { self.assume_validity::<invariant::Valid>() })
4444
} else {
4545
None
4646
}
@@ -82,7 +82,7 @@ where
8282
{
8383
// SAFETY: The alignment of `T` is 1 and thus is always aligned
8484
// because `T: Unaligned`.
85-
let ptr = unsafe { self.assume_aligned() };
85+
let ptr = unsafe { self.assume_alignment::<invariant::Aligned>() };
8686
ptr.as_ref()
8787
}
8888
}

src/pointer/ptr.rs

+37-19
Original file line numberDiff line numberDiff line change
@@ -458,51 +458,69 @@ mod _transitions {
458458
T: 'a + ?Sized,
459459
I: Invariants,
460460
{
461-
/// Assumes that `Ptr`'s referent is validly-aligned for `T`.
461+
/// Assumes that `Ptr`'s referent is validly-aligned for `T` if required
462+
/// by `A`.
462463
///
463464
/// # Safety
464465
///
465466
/// The caller promises that `Ptr`'s referent conforms to the alignment
466-
/// invariant of `T`.
467+
/// invariant of `T` if required by `A`.
467468
#[inline]
468-
pub(crate) unsafe fn assume_aligned(
469+
pub(crate) unsafe fn assume_alignment<A: invariant::Alignment>(
469470
self,
470-
) -> Ptr<'a, T, (I::Aliasing, invariant::Aligned, I::Validity)> {
471+
) -> Ptr<'a, T, (I::Aliasing, A, I::Validity)> {
471472
// SAFETY: The caller promises that `self`'s referent is
472-
// well-aligned for `T`.
473+
// well-aligned for `T` if required by `A` .
473474
unsafe { Ptr::from_ptr(self) }
474475
}
475476

476-
/// Assumes that `Ptr`'s referent is as-initialized as `T`.
477+
/// Assumes that `Ptr`'s referent conforms to the validity requirement
478+
/// of `V`.
477479
///
478480
/// # Safety
479481
///
480-
/// The caller promises that `Ptr`'s referent conforms to the
481-
/// [`invariant::AsInitialized`] invariant (see documentation there).
482+
/// The caller promises that `Ptr`'s referent conforms to the validity
483+
/// requirement of `V`.
484+
#[doc(hidden)]
485+
#[inline]
486+
pub unsafe fn assume_validity<V: invariant::Validity>(
487+
self,
488+
) -> Ptr<'a, T, (I::Aliasing, I::Alignment, V)> {
489+
// SAFETY: The caller promises that `self`'s referent conforms to
490+
// the validity requirement of `V`.
491+
unsafe { Ptr::from_ptr(self) }
492+
}
493+
494+
/// A shorthand for `self.assume_validity<invariant::AsInitialized>()`.
495+
///
496+
/// # Safety
497+
///
498+
/// The caller promises to uphold the safety preconditions of
499+
/// `self.assume_validity<invariant::AsInitialized>()`.
482500
#[doc(hidden)]
483501
#[inline]
484502
pub unsafe fn assume_as_initialized(
485503
self,
486504
) -> Ptr<'a, T, (I::Aliasing, I::Alignment, invariant::AsInitialized)> {
487-
// SAFETY: The caller promises that `self`'s referent only contains
488-
// uninitialized bytes in a subset of the uninitialized ranges in
489-
// `T`. for `T`.
490-
unsafe { Ptr::from_ptr(self) }
505+
// SAFETY: The caller has promised to uphold the safety
506+
// preconditions.
507+
unsafe { self.assume_validity::<invariant::AsInitialized>() }
491508
}
492509

493-
/// Assumes that `Ptr`'s referent is validly initialized for `T`.
510+
/// A shorthand for `self.assume_validity<invariant::Valid>()`.
494511
///
495512
/// # Safety
496513
///
497-
/// The caller promises that `Ptr`'s referent conforms to the
498-
/// bit validity invariants on `T`.
514+
/// The caller promises to uphold the safety preconditions of
515+
/// `self.assume_validity<invariant::Valid>()`.
516+
#[doc(hidden)]
499517
#[inline]
500-
pub(crate) unsafe fn assume_valid(
518+
pub unsafe fn assume_valid(
501519
self,
502520
) -> Ptr<'a, T, (I::Aliasing, I::Alignment, invariant::Valid)> {
503-
// SAFETY: The caller promises that `self`'s referent is bit-valid
504-
// for `T`.
505-
unsafe { Ptr::from_ptr(self) }
521+
// SAFETY: The caller has promised to uphold the safety
522+
// preconditions.
523+
unsafe { self.assume_validity::<invariant::Valid>() }
506524
}
507525

508526
/// Forgets that `Ptr`'s referent is validly-aligned for `T`.

zerocopy-derive/src/lib.rs

+143-4
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use {
3636
quote::quote,
3737
syn::{
3838
parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, Expr, ExprLit,
39-
GenericParam, Ident, Lit,
39+
GenericParam, Ident, Index, Lit,
4040
},
4141
};
4242

@@ -268,9 +268,7 @@ pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenSt
268268
let ast = syn::parse_macro_input!(ts as DeriveInput);
269269
match &ast.data {
270270
Data::Struct(strct) => derive_try_from_bytes_struct(&ast, strct),
271-
Data::Enum(_) => {
272-
Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error()
273-
}
271+
Data::Enum(enm) => derive_try_from_bytes_enum(&ast, enm),
274272
Data::Union(unn) => derive_try_from_bytes_union(&ast, unn),
275273
}
276274
.into()
@@ -401,6 +399,147 @@ const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[
401399
&[StructRepr::C, StructRepr::Packed],
402400
];
403401

402+
fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream {
403+
if !enm.is_fieldless() {
404+
return Error::new_spanned(ast, "only field-less enums can implement TryFromBytes")
405+
.to_compile_error();
406+
}
407+
408+
let reprs = try_or_print!(ENUM_TRY_FROM_BYTES_CFG.validate_reprs(ast));
409+
let discriminant_type = match reprs.as_slice() {
410+
[EnumRepr::U8] => quote!(u8),
411+
[EnumRepr::U16] => quote!(u16),
412+
[EnumRepr::U32] => quote!(u32),
413+
[EnumRepr::U64] => quote!(u64),
414+
[EnumRepr::Usize] => quote!(usize),
415+
[EnumRepr::I8] => quote!(i8),
416+
[EnumRepr::I16] => quote!(i16),
417+
[EnumRepr::I32] => quote!(i32),
418+
[EnumRepr::I64] => quote!(i64),
419+
[EnumRepr::Isize] => quote!(isize),
420+
// `validate_reprs` has already validated that it's one of the preceding
421+
// patterns.
422+
_ => unreachable!(),
423+
};
424+
425+
let discriminant_exprs = enm.variants.iter().scan(Discriminant::default(), |disc, var| {
426+
Some(disc.update_and_generate_expr(&var.discriminant))
427+
});
428+
let extras = Some(quote!(
429+
// SAFETY: We use `is_bit_valid` to validate that the bit pattern
430+
// corresponds to one of the C-like enum's variant discriminants.
431+
// Thus, this is a sound implementation of `is_bit_valid`.
432+
fn is_bit_valid(
433+
candidate: zerocopy::Ptr<
434+
'_,
435+
Self,
436+
(
437+
zerocopy::pointer::invariant::Shared,
438+
zerocopy::pointer::invariant::AnyAlignment,
439+
zerocopy::pointer::invariant::AsInitialized,
440+
),
441+
>,
442+
) -> bool {
443+
// SAFETY:
444+
// - `cast` is implemented as required.
445+
// - Since we cast to the type specified by `Self`'s repr, `p`'s
446+
// referent and the referent of the returned pointer have the
447+
// same size.
448+
let discriminant = unsafe { candidate.cast_unsized(|p: *mut Self| p as *mut #discriminant_type) };
449+
// SAFETY: Since `candidate` has the invariant `AsInitialized`,
450+
// we know that `candidate`'s referent (and thus
451+
// `discriminant`'s referent) is as-initialized as `Self`. Since
452+
// `Self`'s repr is the same type as `discriminant`, we know
453+
// that `discriminant`'s referent satisfies the as-initialized
454+
// property.
455+
let discriminant = unsafe { discriminant.assume_valid() };
456+
let discriminant = discriminant.read_unaligned();
457+
458+
false #(|| (discriminant == (#discriminant_exprs)))*
459+
}
460+
));
461+
impl_block(ast, enm, Trait::TryFromBytes, RequireBoundedFields::Yes, false, None, extras)
462+
}
463+
464+
// Enum variant discriminants can be manually set not only as literal values,
465+
// but as arbitrary const expressions. In order to handle this, we keep track of
466+
// the most-recently-seen expression and a count of how many variants have been
467+
// encountered since then.
468+
//
469+
// #[repr(u8)]
470+
// enum Foo {
471+
// A, // 0
472+
// B = 5, // 5
473+
// C, // 6
474+
// D = 1 + 1, // 2
475+
// E, // 3
476+
// }
477+
//
478+
// Note: Default::default does the right thing (initializes to { None, 0 }).
479+
#[derive(Default, Copy, Clone)]
480+
struct Discriminant<'a> {
481+
// The most-recently-set explicit discriminant.
482+
previous: Option<&'a Expr>,
483+
// When the next variant is encountered, what offset should be used compared
484+
// to `previous` to determine the variant's discriminant?
485+
next_offset: usize,
486+
}
487+
488+
impl<'a> Discriminant<'a> {
489+
/// Called when encountering a variant with discriminant set to `ast`.
490+
/// Updates `self` in preparation for the next variant and generates an
491+
/// expression which will evaluate to the numeric value this variant's
492+
/// discriminant.
493+
fn update_and_generate_expr(
494+
&mut self,
495+
ast: &'a Option<(syn::token::Eq, Expr)>,
496+
) -> proc_macro2::TokenStream {
497+
match ast.as_ref().map(|(_eq, expr)| expr) {
498+
Some(expr) => {
499+
self.previous = Some(expr);
500+
self.next_offset = 1;
501+
quote!(#expr)
502+
}
503+
None => {
504+
let previous = self.previous.iter();
505+
// Use `Index` instead of `usize` so that the number is
506+
// formatted just as `0` rather than as `0usize`; the latter
507+
// syntax is only valid if the repr is `usize`; otherwise,
508+
// comparison will result in a type mismatch.
509+
let offset = Index::from(self.next_offset);
510+
let tokens = quote!(#(#previous +)* #offset);
511+
512+
self.next_offset += 1;
513+
tokens
514+
}
515+
}
516+
}
517+
}
518+
519+
#[rustfmt::skip]
520+
const ENUM_TRY_FROM_BYTES_CFG: Config<EnumRepr> = {
521+
use EnumRepr::*;
522+
Config {
523+
allowed_combinations_message: r#"TryFromBytes requires repr of "u8", "u16", "u32", "u64", "usize", "i8", or "i16", "i32", "i64", or "isize""#,
524+
derive_unaligned: false,
525+
allowed_combinations: &[
526+
&[U8],
527+
&[U16],
528+
&[U32],
529+
&[U64],
530+
&[Usize],
531+
&[I8],
532+
&[I16],
533+
&[I32],
534+
&[I64],
535+
&[Isize],
536+
],
537+
disallowed_but_legal_combinations: &[
538+
&[C],
539+
],
540+
}
541+
};
542+
404543
// A struct is `FromZeros` if:
405544
// - all fields are `FromZeros`
406545

0 commit comments

Comments
 (0)