|
36 | 36 | quote::quote,
|
37 | 37 | syn::{
|
38 | 38 | parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, Expr, ExprLit,
|
39 |
| - GenericParam, Ident, Lit, |
| 39 | + GenericParam, Ident, Index, Lit, |
40 | 40 | },
|
41 | 41 | };
|
42 | 42 |
|
@@ -268,9 +268,7 @@ pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenSt
|
268 | 268 | let ast = syn::parse_macro_input!(ts as DeriveInput);
|
269 | 269 | match &ast.data {
|
270 | 270 | Data::Struct(strct) => derive_try_from_bytes_struct(&ast, strct),
|
271 |
| - Data::Enum(_) => { |
272 |
| - Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error() |
273 |
| - } |
| 271 | + Data::Enum(enm) => derive_try_from_bytes_enum(&ast, enm), |
274 | 272 | Data::Union(unn) => derive_try_from_bytes_union(&ast, unn),
|
275 | 273 | }
|
276 | 274 | .into()
|
@@ -401,6 +399,147 @@ const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[
|
401 | 399 | &[StructRepr::C, StructRepr::Packed],
|
402 | 400 | ];
|
403 | 401 |
|
| 402 | +fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream { |
| 403 | + if !enm.is_fieldless() { |
| 404 | + return Error::new_spanned(ast, "only field-less enums can implement TryFromBytes") |
| 405 | + .to_compile_error(); |
| 406 | + } |
| 407 | + |
| 408 | + let reprs = try_or_print!(ENUM_TRY_FROM_BYTES_CFG.validate_reprs(ast)); |
| 409 | + let discriminant_type = match reprs.as_slice() { |
| 410 | + [EnumRepr::U8] => quote!(u8), |
| 411 | + [EnumRepr::U16] => quote!(u16), |
| 412 | + [EnumRepr::U32] => quote!(u32), |
| 413 | + [EnumRepr::U64] => quote!(u64), |
| 414 | + [EnumRepr::Usize] => quote!(usize), |
| 415 | + [EnumRepr::I8] => quote!(i8), |
| 416 | + [EnumRepr::I16] => quote!(i16), |
| 417 | + [EnumRepr::I32] => quote!(i32), |
| 418 | + [EnumRepr::I64] => quote!(i64), |
| 419 | + [EnumRepr::Isize] => quote!(isize), |
| 420 | + // `validate_reprs` has already validated that it's one of the preceding |
| 421 | + // patterns. |
| 422 | + _ => unreachable!(), |
| 423 | + }; |
| 424 | + |
| 425 | + let discriminant_exprs = enm.variants.iter().scan(Discriminant::default(), |disc, var| { |
| 426 | + Some(disc.update_and_generate_expr(&var.discriminant)) |
| 427 | + }); |
| 428 | + let extras = Some(quote!( |
| 429 | + // SAFETY: We use `is_bit_valid` to validate that the bit pattern |
| 430 | + // corresponds to one of the C-like enum's variant discriminants. |
| 431 | + // Thus, this is a sound implementation of `is_bit_valid`. |
| 432 | + fn is_bit_valid( |
| 433 | + candidate: zerocopy::Ptr< |
| 434 | + '_, |
| 435 | + Self, |
| 436 | + ( |
| 437 | + zerocopy::pointer::invariant::Shared, |
| 438 | + zerocopy::pointer::invariant::AnyAlignment, |
| 439 | + zerocopy::pointer::invariant::AsInitialized, |
| 440 | + ), |
| 441 | + >, |
| 442 | + ) -> bool { |
| 443 | + // SAFETY: |
| 444 | + // - `cast` is implemented as required. |
| 445 | + // - Since we cast to the type specified by `Self`'s repr, `p`'s |
| 446 | + // referent and the referent of the returned pointer have the |
| 447 | + // same size. |
| 448 | + let discriminant = unsafe { candidate.cast_unsized(|p: *mut Self| p as *mut #discriminant_type) }; |
| 449 | + // SAFETY: Since `candidate` has the invariant `AsInitialized`, |
| 450 | + // we know that `candidate`'s referent (and thus |
| 451 | + // `discriminant`'s referent) is as-initialized as `Self`. Since |
| 452 | + // `Self`'s repr is the same type as `discriminant`, we know |
| 453 | + // that `discriminant`'s referent satisfies the as-initialized |
| 454 | + // property. |
| 455 | + let discriminant = unsafe { discriminant.assume_valid() }; |
| 456 | + let discriminant = discriminant.read_unaligned(); |
| 457 | + |
| 458 | + false #(|| (discriminant == (#discriminant_exprs)))* |
| 459 | + } |
| 460 | + )); |
| 461 | + impl_block(ast, enm, Trait::TryFromBytes, RequireBoundedFields::Yes, false, None, extras) |
| 462 | +} |
| 463 | + |
| 464 | +// Enum variant discriminants can be manually set not only as literal values, |
| 465 | +// but as arbitrary const expressions. In order to handle this, we keep track of |
| 466 | +// the most-recently-seen expression and a count of how many variants have been |
| 467 | +// encountered since then. |
| 468 | +// |
| 469 | +// #[repr(u8)] |
| 470 | +// enum Foo { |
| 471 | +// A, // 0 |
| 472 | +// B = 5, // 5 |
| 473 | +// C, // 6 |
| 474 | +// D = 1 + 1, // 2 |
| 475 | +// E, // 3 |
| 476 | +// } |
| 477 | +// |
| 478 | +// Note: Default::default does the right thing (initializes to { None, 0 }). |
| 479 | +#[derive(Default, Copy, Clone)] |
| 480 | +struct Discriminant<'a> { |
| 481 | + // The most-recently-set explicit discriminant. |
| 482 | + previous: Option<&'a Expr>, |
| 483 | + // When the next variant is encountered, what offset should be used compared |
| 484 | + // to `previous` to determine the variant's discriminant? |
| 485 | + next_offset: usize, |
| 486 | +} |
| 487 | + |
| 488 | +impl<'a> Discriminant<'a> { |
| 489 | + /// Called when encountering a variant with discriminant set to `ast`. |
| 490 | + /// Updates `self` in preparation for the next variant and generates an |
| 491 | + /// expression which will evaluate to the numeric value this variant's |
| 492 | + /// discriminant. |
| 493 | + fn update_and_generate_expr( |
| 494 | + &mut self, |
| 495 | + ast: &'a Option<(syn::token::Eq, Expr)>, |
| 496 | + ) -> proc_macro2::TokenStream { |
| 497 | + match ast.as_ref().map(|(_eq, expr)| expr) { |
| 498 | + Some(expr) => { |
| 499 | + self.previous = Some(expr); |
| 500 | + self.next_offset = 1; |
| 501 | + quote!(#expr) |
| 502 | + } |
| 503 | + None => { |
| 504 | + let previous = self.previous.iter(); |
| 505 | + // Use `Index` instead of `usize` so that the number is |
| 506 | + // formatted just as `0` rather than as `0usize`; the latter |
| 507 | + // syntax is only valid if the repr is `usize`; otherwise, |
| 508 | + // comparison will result in a type mismatch. |
| 509 | + let offset = Index::from(self.next_offset); |
| 510 | + let tokens = quote!(#(#previous +)* #offset); |
| 511 | + |
| 512 | + self.next_offset += 1; |
| 513 | + tokens |
| 514 | + } |
| 515 | + } |
| 516 | + } |
| 517 | +} |
| 518 | + |
| 519 | +#[rustfmt::skip] |
| 520 | +const ENUM_TRY_FROM_BYTES_CFG: Config<EnumRepr> = { |
| 521 | + use EnumRepr::*; |
| 522 | + Config { |
| 523 | + allowed_combinations_message: r#"TryFromBytes requires repr of "u8", "u16", "u32", "u64", "usize", "i8", or "i16", "i32", "i64", or "isize""#, |
| 524 | + derive_unaligned: false, |
| 525 | + allowed_combinations: &[ |
| 526 | + &[U8], |
| 527 | + &[U16], |
| 528 | + &[U32], |
| 529 | + &[U64], |
| 530 | + &[Usize], |
| 531 | + &[I8], |
| 532 | + &[I16], |
| 533 | + &[I32], |
| 534 | + &[I64], |
| 535 | + &[Isize], |
| 536 | + ], |
| 537 | + disallowed_but_legal_combinations: &[ |
| 538 | + &[C], |
| 539 | + ], |
| 540 | + } |
| 541 | +}; |
| 542 | + |
404 | 543 | // A struct is `FromZeros` if:
|
405 | 544 | // - all fields are `FromZeros`
|
406 | 545 |
|
|
0 commit comments