diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index fcd150f7..8765a336 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -162,7 +162,7 @@ "uses": "actions-rs/tarpaulin@v0.1", "with": { "version": "0.18.2", - "args": "--exclude-files derive/" + "args": "--all" } }, { diff --git a/.gitignore b/.gitignore index 19231ce7..29c90d8f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ .cargo .vscode rls*.log +tarpaulin-report.html diff --git a/Cargo.toml b/Cargo.toml index f2ef9eb5..c188eb1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ members = [ [package] name = "bincode" version = "2.0.0-dev" # remember to update html_root_url and bincode_derive -authors = ["Ty Overby ", "Francesco Mazzoli ", "Zoey Riordan "] +authors = ["Ty Overby ", "Francesco Mazzoli ", "Zoey Riordan ", "Victor Koenders "] exclude = ["logo.png", "examples/*", ".gitignore", ".github/"] publish = true @@ -30,4 +30,9 @@ derive = ["bincode_derive"] [dependencies] bincode_derive = { path = "derive", version = "2.0.0-dev", optional = true } -# serde = { version = "1.0.130", optional = true } +serde = { version = "1.0.130", optional = true } + +# Used for derive tests +[dev-dependencies] +serde_derive = "1.0.130" +serde_json = "1.0.68" diff --git a/derive/Cargo.toml b/derive/Cargo.toml index b4379a89..609a1dfb 100644 --- a/derive/Cargo.toml +++ b/derive/Cargo.toml @@ -1,16 +1,17 @@ -[package] -name = "bincode_derive" -version = "2.0.0-dev" # remember to update bincode -edition = "2018" - -[lib] -proc-macro = true - -[dependencies] -quote = "1.0.9" -proc-macro2 = "1.0" - -[dependencies.syn] -version = "1.0.74" -default-features = false -features = ["parsing", "derive", "proc-macro", "printing", "clone-impls"] +[package] +name = "bincode_derive" +version = "2.0.0-dev" # remember to update bincode +authors = ["Zoey Riordan ", "Victor Koenders "] +edition = "2018" + +repository = "https://github.com/bincode-org/bincode" +documentation = "https://docs.rs/bincode_derive" +readme = "./readme.md" +categories = ["encoding", "network-programming"] +keywords = ["binary", "encode", "decode", "serialize", "deserialize"] + +[lib] +proc-macro = true + +[dev-dependencies] +proc-macro2 = "1.0" diff --git a/derive/readme.md b/derive/readme.md new file mode 100644 index 00000000..5dd54987 --- /dev/null +++ b/derive/readme.md @@ -0,0 +1,28 @@ +# Bincode-derive + +The derive crate for bincode. Implements `bincode::Encodable` and `bincode::Decodable`. + +This crate is roughly split into 2 parts: + +# Parsing + +Most of parsing is done in the `src/parse/` folder. This will generate the following types: +- `Attributes`, not being used currently +- `Visibility`, not being used currently +- `DataType` either `Struct` or `Enum`, with the name of the data type being parsed +- `Generics` the generics part of the type, e.g. `struct Foo<'a>` +- `GenericConstraints` the "where" part of the type + +# Generate + +Generating the code implementation is done in either `src/derive_enum.rs` and `src/derive_struct.rs`. + +This is supported by the structs in `src/generate`. The most notable points of this module are: +- `StreamBuilder` is a thin but friendly wrapper around `TokenStream` +- `Generator` is the base type of the code generator. This has helper methods to generate implementations: + - `ImplFor` is a helper struct for a single `impl A for B` construction. In this functions can be defined: + - `GenerateFnBody` is a helper struct for a single function in the above `impl`. This is created with a callback to `FnBuilder` which helps set some properties. `GenerateFnBody` has a `stream()` function which returns ` StreamBuilder` for the function. + +For additional derive testing, see the test cases in `../tests` + +For testing purposes, all generated code is outputted to the current `target` folder, under file name `_Encodeable.rs` and `_Decodeable.rs`. This can help with debugging. diff --git a/derive/src/derive_enum.rs b/derive/src/derive_enum.rs index e08c63ce..6e063198 100644 --- a/derive/src/derive_enum.rs +++ b/derive/src/derive_enum.rs @@ -1,242 +1,187 @@ +use crate::generate::{FnSelfArg, Generator}; +use crate::parse::{EnumVariant, Fields}; +use crate::prelude::*; use crate::Result; -use proc_macro::TokenStream; -use proc_macro2::{Span, TokenStream as TokenStream2}; -use quote::{quote, ToTokens}; -use syn::{ - spanned::Spanned, Fields, GenericParam, Generics, Ident, Index, Lifetime, LifetimeDef, Variant, -}; + +const TUPLE_FIELD_PREFIX: &str = "field_"; + pub struct DeriveEnum { - name: Ident, - generics: Generics, - variants: Vec, + pub variants: Vec, } impl DeriveEnum { - pub fn parse(name: Ident, generics: Generics, en: syn::DataEnum) -> Result { - let variants = en.variants.into_iter().collect(); - - Ok(DeriveEnum { - name, - generics, - variants, - }) - } - - pub fn generate_encodable(self) -> Result { - let DeriveEnum { - name, - generics, - variants, - } = self; - - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - let match_arms = variants.iter().enumerate().map(|(index, variant)| { - let fields_section = fields_to_match_arm(&variant.fields); - let encode_statements = field_names_to_encodable(&fields_to_names(&variant.fields)); - let variant_name = variant.ident.clone(); - quote! { - #name :: #variant_name #fields_section => { - encoder.encode_u32(#index as u32)?; - #(#encode_statements)* - } - } - }); - let result = quote! { - impl #impl_generics bincode::enc::Encodeable for #name #ty_generics #where_clause { - fn encode(&self, mut encoder: E) -> Result<(), bincode::error::EncodeError> { - match self { - #(#match_arms)* + pub fn generate_encodable(self, generator: &mut Generator) -> Result<()> { + let DeriveEnum { variants } = self; + + generator + .impl_for("bincode::enc::Encodeable") + .generate_fn("encode") + .with_generic("E", ["bincode::enc::Encode"]) + .with_self_arg(FnSelfArg::RefSelf) + .with_arg("mut encoder", "E") + .with_return_type("core::result::Result<(), bincode::error::EncodeError>") + .body(|fn_body| { + fn_body.ident_str("match"); + fn_body.ident_str("self"); + fn_body.group(Delimiter::Brace, |match_body| { + for (variant_index, variant) in variants.into_iter().enumerate() { + // Self::Variant + match_body.ident_str("Self"); + match_body.puncts("::"); + match_body.ident(variant.name.clone()); + + // if we have any fields, declare them here + // Self::Variant { a, b, c } + if let Some(delimiter) = variant.fields.delimiter() { + match_body.group(delimiter, |field_body| { + for (idx, field_name) in + variant.fields.names().into_iter().enumerate() + { + if idx != 0 { + field_body.punct(','); + } + field_body.push( + field_name.to_token_tree_with_prefix(TUPLE_FIELD_PREFIX), + ); + } + }); + } + + // Arrow + // Self::Variant { a, b, c } => + match_body.puncts("=>"); + + // Body of this variant + // Note that the fields are available as locals because of the match destructuring above + // { + // encoder.encode_u32(n)?; + // bincode::enc::Encodeable::encode(a, &mut encoder)?; + // bincode::enc::Encodeable::encode(b, &mut encoder)?; + // bincode::enc::Encodeable::encode(c, &mut encoder)?; + // } + match_body.group(Delimiter::Brace, |body| { + // variant index + body.push_parsed(format!("encoder.encode_u32({})?;", variant_index)); + // If we have any fields, encode them all one by one + for field_name in variant.fields.names() { + body.push_parsed(format!( + "bincode::enc::Encodeable::encode({}, &mut encoder)?;", + field_name.to_string_with_prefix(TUPLE_FIELD_PREFIX), + )); + } + }); + match_body.punct(','); } - Ok(()) - } - - } - }; - - Ok(result.into()) + }); + fn_body.push_parsed("Ok(())"); + }); + Ok(()) } - pub fn generate_decodable(self) -> Result { - let DeriveEnum { - name, - generics, - variants, - } = self; - - let (mut impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - // check if the type has lifetimes - let mut should_insert_lifetime = false; - - for param in &generics.params { - if let GenericParam::Lifetime(_) = param { - should_insert_lifetime = true; - break; - } - } - - // if we don't have a '__de lifetime, insert it - let mut generics_with_decode_lifetime; - if should_insert_lifetime { - generics_with_decode_lifetime = generics.clone(); - - let mut new_lifetime = LifetimeDef::new(Lifetime::new("'__de", Span::call_site())); - - for param in &generics.params { - if let GenericParam::Lifetime(lt) = param { - new_lifetime.bounds.push(lt.lifetime.clone()) - } - } - - generics_with_decode_lifetime - .params - .push(GenericParam::Lifetime(new_lifetime)); - - impl_generics = generics_with_decode_lifetime.split_for_impl().0; - } - - let max_variant = (variants.len() - 1) as u32; - let match_arms = variants.iter().enumerate().map(|(index, variant)| { - let index = index as u32; - let decode_statements = field_names_to_decodable( - &fields_to_constructable_names(&variant.fields), - should_insert_lifetime, - ); - let variant_name = variant.ident.clone(); - quote! { - #index => { - #name :: #variant_name { - #(#decode_statements)* - } - } - } - }); - let result = if should_insert_lifetime { - quote! { - impl #impl_generics bincode::de::BorrowDecodable<'__de> for #name #ty_generics #where_clause { - fn borrow_decode>(mut decoder: D) -> Result { - let i = decoder.decode_u32()?; - Ok(match i { - #(#match_arms)* - variant => return Err(bincode::error::DecodeError::UnexpectedVariant{ - min: 0, - max: #max_variant, - found: variant, - }) - }) + pub fn generate_decodable(self, generator: &mut Generator) -> Result<()> { + let DeriveEnum { variants } = self; + + if generator.has_lifetimes() { + // enum has a lifetime, implement BorrowDecodable + + generator.impl_for_with_de_lifetime("bincode::de::BorrowDecodable<'__de>") + .generate_fn("borrow_decode") + .with_generic("D", ["bincode::de::BorrowDecode<'__de>"]) + .with_arg("mut decoder", "D") + .with_return_type("Result") + .body(|fn_builder| { + fn_builder + .push_parsed("let variant_index = bincode::de::Decode::decode_u32(&mut decoder)?;"); + fn_builder.push_parsed("match variant_index"); + fn_builder.group(Delimiter::Brace, |variant_case| { + for (idx, variant) in variants.iter().enumerate() { + // idx => Ok(..) + variant_case.lit_u32(idx as u32); + variant_case.puncts("=>"); + variant_case.ident_str("Ok"); + variant_case.group(Delimiter::Parenthesis, |variant_case_body| { + // Self::Variant { } + // Self::Variant { 0: ..., 1: ... 2: ... }, + // Self::Variant { a: ..., b: ... c: ... }, + variant_case_body.ident_str("Self"); + variant_case_body.puncts("::"); + variant_case_body.ident(variant.name.clone()); + + variant_case_body.group(Delimiter::Brace, |variant_body| { + let is_tuple = matches!(variant.fields, Fields::Tuple(_)); + for (idx, field) in variant.fields.names().into_iter().enumerate() { + if is_tuple { + variant_body.lit_usize(idx); + } else { + variant_body.ident(field.unwrap_ident().clone()); + } + variant_body.punct(':'); + variant_body.push_parsed("bincode::de::BorrowDecodable::borrow_decode(&mut decoder)?,"); + } + }); + }); + variant_case.punct(','); } - } - } + // invalid idx + variant_case.push_parsed(format!( + "variant => return Err(bincode::error::DecodeError::UnexpectedVariant {{ min: 0, max: {}, found: variant }})", + variants.len() - 1 + )); + }); + }); } else { - quote! { - impl #impl_generics bincode::de::Decodable for #name #ty_generics #where_clause { - fn decode(mut decoder: D) -> Result { - let i = decoder.decode_u32()?; - Ok(match i { - #(#match_arms)* - variant => return Err(bincode::error::DecodeError::UnexpectedVariant{ - min: 0, - max: #max_variant, - found: variant, - }) - }) - } - + // enum has no lifetimes, implement Decodable + + generator.impl_for("bincode::de::Decodable") + .generate_fn("decode") + .with_generic("D", ["bincode::de::Decode"]) + .with_arg("mut decoder", "D") + .with_return_type("Result") + .body(|fn_builder| { + + fn_builder + .push_parsed("let variant_index = bincode::de::Decode::decode_u32(&mut decoder)?;"); + fn_builder.push_parsed("match variant_index"); + fn_builder.group(Delimiter::Brace, |variant_case| { + for (idx, variant) in variants.iter().enumerate() { + // idx => Ok(..) + variant_case.lit_u32(idx as u32); + variant_case.puncts("=>"); + variant_case.ident_str("Ok"); + variant_case.group(Delimiter::Parenthesis, |variant_case_body| { + // Self::Variant { } + // Self::Variant { 0: ..., 1: ... 2: ... }, + // Self::Variant { a: ..., b: ... c: ... }, + variant_case_body.ident_str("Self"); + variant_case_body.puncts("::"); + variant_case_body.ident(variant.name.clone()); + + variant_case_body.group(Delimiter::Brace, |variant_body| { + let is_tuple = matches!(variant.fields, Fields::Tuple(_)); + for (idx, field) in variant.fields.names().into_iter().enumerate() { + if is_tuple { + variant_body.lit_usize(idx); + } else { + variant_body.ident(field.unwrap_ident().clone()); + } + variant_body.punct(':'); + variant_body.push_parsed("bincode::de::Decodable::decode(&mut decoder)?,"); + } + }); + }); + variant_case.punct(','); } - } - }; - Ok(result.into()) - } -} - -fn fields_to_match_arm(fields: &Fields) -> TokenStream2 { - match fields { - syn::Fields::Named(fields) => { - let fields: Vec<_> = fields - .named - .iter() - .map(|f| f.ident.clone().unwrap().to_token_stream()) - .collect(); - quote! { - {#(#fields),*} - } - } - syn::Fields::Unnamed(fields) => { - let fields: Vec<_> = fields - .unnamed - .iter() - .enumerate() - .map(|(i, f)| Ident::new(&format!("_{}", i), f.span())) - .collect(); - quote! { - (#(#fields),*) - } + // invalid idx + variant_case.push_parsed(format!( + "variant => return Err(bincode::error::DecodeError::UnexpectedVariant {{ min: 0, max: {}, found: variant }})", + variants.len() - 1 + )); + }); + }); } - syn::Fields::Unit => quote! {}, - } -} - -fn fields_to_names(fields: &Fields) -> Vec { - match fields { - syn::Fields::Named(fields) => fields - .named - .iter() - .map(|f| f.ident.clone().unwrap().to_token_stream()) - .collect(), - syn::Fields::Unnamed(fields) => fields - .unnamed - .iter() - .enumerate() - .map(|(i, f)| Ident::new(&format!("_{}", i), f.span()).to_token_stream()) - .collect(), - syn::Fields::Unit => Vec::new(), - } -} - -fn field_names_to_encodable(names: &[TokenStream2]) -> Vec { - names - .iter() - .map(|field| { - quote! { - bincode::enc::Encodeable::encode(#field, &mut encoder)?; - } - }) - .collect::>() -} -fn fields_to_constructable_names(fields: &Fields) -> Vec { - match fields { - syn::Fields::Named(fields) => fields - .named - .iter() - .map(|f| f.ident.clone().unwrap().to_token_stream()) - .collect(), - syn::Fields::Unnamed(fields) => fields - .unnamed - .iter() - .enumerate() - .map(|(i, _)| Index::from(i).to_token_stream()) - .collect(), - syn::Fields::Unit => Vec::new(), + Ok(()) } } - -fn field_names_to_decodable(names: &[TokenStream2], borrowed: bool) -> Vec { - names - .iter() - .map(|field| { - if borrowed { - quote! { - #field: bincode::de::BorrowDecodable::borrow_decode(&mut decoder)?, - } - } else { - quote! { - #field: bincode::de::Decodable::decode(&mut decoder)?, - } - } - }) - .collect::>() -} diff --git a/derive/src/derive_struct.rs b/derive/src/derive_struct.rs index 881545f7..549f7d43 100644 --- a/derive/src/derive_struct.rs +++ b/derive/src/derive_struct.rs @@ -1,145 +1,97 @@ +use crate::generate::Generator; +use crate::parse::Fields; +use crate::prelude::Delimiter; use crate::Result; -use proc_macro::TokenStream; -use proc_macro2::{Span, TokenStream as TokenStream2}; -use quote::{quote, ToTokens}; -use syn::{GenericParam, Generics, Ident, Index, Lifetime, LifetimeDef}; pub struct DeriveStruct { - name: Ident, - generics: Generics, - fields: Vec, + pub fields: Fields, } impl DeriveStruct { - pub fn parse(name: Ident, generics: Generics, str: syn::DataStruct) -> Result { - let fields = match str.fields { - syn::Fields::Named(fields) => fields - .named - .iter() - .map(|f| f.ident.clone().unwrap().to_token_stream()) - .collect(), - syn::Fields::Unnamed(fields) => fields - .unnamed - .iter() - .enumerate() - .map(|(i, _)| Index::from(i).to_token_stream()) - .collect(), - syn::Fields::Unit => Vec::new(), - }; - - Ok(Self { - name, - generics, - fields, - }) - } - - pub fn generate_encodable(self) -> Result { - let DeriveStruct { - name, - generics, - fields, - } = self; - - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - let fields = fields - .into_iter() - .map(|field| { - quote! { - bincode::enc::Encodeable::encode(&self. #field, &mut encoder)?; - } - }) - .collect::>(); - - let result = quote! { - impl #impl_generics bincode::enc::Encodeable for #name #ty_generics #where_clause { - fn encode(&self, mut encoder: E) -> Result<(), bincode::error::EncodeError> { - #(#fields)* - Ok(()) + pub fn generate_encodable(self, generator: &mut Generator) -> Result<()> { + let DeriveStruct { fields } = self; + + let mut impl_for = generator.impl_for("bincode::enc::Encodeable"); + impl_for + .generate_fn("encode") + .with_generic("E", ["bincode::enc::Encode"]) + .with_self_arg(crate::generate::FnSelfArg::RefSelf) + .with_arg("mut encoder", "E") + .with_return_type("Result<(), bincode::error::EncodeError>") + .body(|fn_body| { + for field in fields.names() { + fn_body.push_parsed(format!( + "bincode::enc::Encodeable::encode(&self.{}, &mut encoder)?;", + field.to_string() + )); } + fn_body.push_parsed("Ok(())"); + }); - } - }; - Ok(result.into()) + Ok(()) } - pub fn generate_decodable(self) -> Result { - let DeriveStruct { - name, - generics, - fields, - } = self; - - let (mut impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - // check if the type has lifetimes - let mut should_insert_lifetime = false; - - for param in &generics.params { - if let GenericParam::Lifetime(_) = param { - should_insert_lifetime = true; - break; - } - } - - // if the type has lifetimes, insert '__de and bound it to the lifetimes - let mut generics_with_decode_lifetime; - if should_insert_lifetime { - generics_with_decode_lifetime = generics.clone(); - let mut new_lifetime = LifetimeDef::new(Lifetime::new("'__de", Span::call_site())); - - for param in &generics.params { - if let GenericParam::Lifetime(lt) = param { - new_lifetime.bounds.push(lt.lifetime.clone()) - } - } - generics_with_decode_lifetime - .params - .push(GenericParam::Lifetime(new_lifetime)); - - impl_generics = generics_with_decode_lifetime.split_for_impl().0; - } - - let fields = fields - .into_iter() - .map(|field| { - if should_insert_lifetime { - quote! { - #field: bincode::de::BorrowDecodable::borrow_decode(&mut decoder)?, - } - } else { - quote! { - #field: bincode::de::Decodable::decode(&mut decoder)?, - } - } - }) - .collect::>(); - - let result = if should_insert_lifetime { - quote! { - impl #impl_generics bincode::de::BorrowDecodable<'__de> for #name #ty_generics #where_clause { - fn borrow_decode>(mut decoder: D) -> Result { - Ok(#name { - #(#fields)* - }) - } - - } - } + pub fn generate_decodable(self, generator: &mut Generator) -> Result<()> { + let DeriveStruct { fields } = self; + + if generator.has_lifetimes() { + // struct has a lifetime, implement BorrowDecodable + + generator + .impl_for_with_de_lifetime("bincode::de::BorrowDecodable<'__de>") + .generate_fn("borrow_decode") + .with_generic("D", ["bincode::de::BorrowDecode<'__de>"]) + .with_arg("mut decoder", "D") + .with_return_type("Result") + .body(|fn_body| { + // Ok(Self { + fn_body.ident_str("Ok"); + fn_body.group(Delimiter::Parenthesis, |ok_group| { + ok_group.ident_str("Self"); + ok_group.group(Delimiter::Brace, |struct_body| { + for field in fields.names() { + struct_body.push_parsed(format!( + "{}: bincode::de::BorrowDecodable::borrow_decode(&mut decoder)?,", + field.to_string() + )); + } + }); + }); + }); + + Ok(()) } else { - quote! { - impl #impl_generics bincode::de::Decodable for #name #ty_generics #where_clause { - fn decode(mut decoder: D) -> Result { - Ok(#name { - #(#fields)* - }) - } - - } - } - }; - - Ok(result.into()) + // struct has no lifetimes, implement Decodable + + let mut impl_for = generator.impl_for("bincode::de::Decodable"); + impl_for + .generate_fn("decode") + .with_generic("D", ["bincode::de::Decode"]) + .with_arg("mut decoder", "D") + .with_return_type("Result") + .body(|fn_body| { + // Ok(Self { + fn_body.ident_str("Ok"); + fn_body.group(Delimiter::Parenthesis, |ok_group| { + ok_group.ident_str("Self"); + ok_group.group(Delimiter::Brace, |struct_body| { + // Fields + // { + // a: bincode::de::Decodable::decode(&mut decoder)?, + // b: bincode::de::Decodable::decode(&mut decoder)?, + // ... + // } + for field in fields.names() { + struct_body.push_parsed(format!( + "{}: bincode::de::Decodable::decode(&mut decoder)?,", + field.to_string() + )); + } + }); + }); + }); + + Ok(()) + } } } diff --git a/derive/src/error.rs b/derive/src/error.rs index ff1abee3..724c3729 100644 --- a/derive/src/error.rs +++ b/derive/src/error.rs @@ -1,24 +1,70 @@ -use proc_macro::TokenStream; -use quote::__private::Span; +use crate::prelude::*; use std::fmt; +#[derive(Debug)] pub enum Error { - UnionNotSupported, + UnknownDataType(Span), + InvalidRustSyntax(Span), + ExpectedIdent(Span), +} + +// helper functions for the unit tests +#[cfg(test)] +impl Error { + pub fn is_unknown_data_type(&self) -> bool { + matches!(self, Error::UnknownDataType(_)) + } + + pub fn is_invalid_rust_syntax(&self) -> bool { + matches!(self, Error::InvalidRustSyntax(_)) + } } impl fmt::Display for Error { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { match self { - Self::UnionNotSupported => write!(fmt, "Unions are not supported"), + Self::UnknownDataType(_) => { + write!(fmt, "Unknown data type, only enum and struct are supported") + } + Self::InvalidRustSyntax(_) => write!(fmt, "Invalid rust syntax"), + Self::ExpectedIdent(_) => write!(fmt, "Expected ident"), } } } impl Error { pub fn into_token_stream(self) -> TokenStream { - self.into_token_stream_with_span(Span::call_site()) + let maybe_span = match &self { + Error::UnknownDataType(span) + | Error::ExpectedIdent(span) + | Error::InvalidRustSyntax(span) => Some(*span), + }; + self.throw_with_span(maybe_span.unwrap_or_else(Span::call_site)) } - pub fn into_token_stream_with_span(self, span: Span) -> TokenStream { - syn::Error::new(span, self).into_compile_error().into() + + pub fn throw_with_span(self, span: Span) -> TokenStream { + // compile_error!($message) + vec![ + TokenTree::Ident(Ident::new("compile_error", span)), + TokenTree::Punct({ + let mut punct = Punct::new('!', Spacing::Alone); + punct.set_span(span); + punct + }), + TokenTree::Group({ + let mut group = Group::new(Delimiter::Brace, { + TokenTree::Literal({ + let mut string = Literal::string(&self.to_string()); + string.set_span(span); + string + }) + .into() + }); + group.set_span(span); + group + }), + ] + .into_iter() + .collect() } } diff --git a/derive/src/generate/generate_fn.rs b/derive/src/generate/generate_fn.rs new file mode 100644 index 00000000..53428ca4 --- /dev/null +++ b/derive/src/generate/generate_fn.rs @@ -0,0 +1,134 @@ +use super::{ImplFor, StreamBuilder}; +use crate::prelude::Delimiter; +pub struct FnBuilder<'a, 'b> { + generate: &'b mut ImplFor<'a>, + name: String, + + lifetime_and_generics: Vec<(String, Vec)>, + self_arg: FnSelfArg, + args: Vec<(String, String)>, + return_type: Option, +} + +impl<'a, 'b> FnBuilder<'a, 'b> { + pub(super) fn new(generate: &'b mut ImplFor<'a>, name: impl Into) -> Self { + Self { + generate, + name: name.into(), + lifetime_and_generics: Vec::new(), + self_arg: FnSelfArg::None, + args: Vec::new(), + return_type: None, + } + } + + pub fn with_generic(mut self, name: T, dependencies: U) -> Self + where + T: Into, + U: IntoIterator, + V: Into, + { + self.lifetime_and_generics.push(( + name.into(), + dependencies.into_iter().map(|d| d.into()).collect(), + )); + self + } + + pub fn with_self_arg(mut self, self_arg: FnSelfArg) -> Self { + self.self_arg = self_arg; + self + } + + pub fn with_arg(mut self, name: impl Into, ty: impl Into) -> Self { + self.args.push((name.into(), ty.into())); + self + } + + pub fn with_return_type(mut self, ret_type: impl Into) -> Self { + self.return_type = Some(ret_type.into()); + self + } + + pub fn body(self, body_builder: impl FnOnce(&mut StreamBuilder)) { + let FnBuilder { + generate, + name, + lifetime_and_generics, + self_arg, + args, + return_type, + } = self; + + let mut builder = StreamBuilder::new(); + + // function name; `fn name` + builder.ident_str("fn"); + builder.ident_str(name); + + // lifetimes; `<'a: 'b, D: Display>` + if !lifetime_and_generics.is_empty() { + builder.punct('<'); + for (idx, (lifetime_and_generic, dependencies)) in + lifetime_and_generics.into_iter().enumerate() + { + if idx != 0 { + builder.punct(','); + } + builder.ident_str(&lifetime_and_generic); + if !dependencies.is_empty() { + for (idx, dependency) in dependencies.into_iter().enumerate() { + builder.punct(if idx == 0 { ':' } else { '+' }); + builder.push_parsed(&dependency); + } + } + } + builder.punct('>'); + } + + // Arguments; `(&self, foo: &Bar)` + builder.group(Delimiter::Parenthesis, |arg_stream| { + if let Some(self_arg) = self_arg.into_token_tree() { + arg_stream.append(self_arg); + arg_stream.punct(','); + } + for (idx, (arg_name, arg_ty)) in args.into_iter().enumerate() { + if idx != 0 { + arg_stream.punct(','); + } + arg_stream.push_parsed(&arg_name); + arg_stream.punct(':'); + arg_stream.push_parsed(&arg_ty); + } + }); + + // Return type: `-> ResultType` + if let Some(return_type) = return_type { + builder.puncts("->"); + builder.push_parsed(&return_type); + } + + generate.group.append(builder); + + generate.group.group(Delimiter::Brace, body_builder); + } +} + +pub enum FnSelfArg { + None, + RefSelf, +} + +impl FnSelfArg { + fn into_token_tree(self) -> Option { + let mut builder = StreamBuilder::new(); + match self { + Self::None => return None, + Self::RefSelf => { + builder.punct('&'); + builder.ident_str("self"); + } + } + Some(builder) + } +} diff --git a/derive/src/generate/generator.rs b/derive/src/generate/generator.rs new file mode 100644 index 00000000..94edf7df --- /dev/null +++ b/derive/src/generate/generator.rs @@ -0,0 +1,53 @@ +use super::{ImplFor, StreamBuilder}; +use crate::parse::{GenericConstraints, Generics}; +use crate::prelude::{Ident, TokenStream}; + +#[must_use] +pub struct Generator { + pub(super) name: Ident, + pub(super) generics: Option, + pub(super) generic_constraints: Option, + pub(super) stream: StreamBuilder, +} + +impl Generator { + pub(crate) fn new( + name: Ident, + generics: Option, + generic_constraints: Option, + ) -> Self { + Self { + name, + generics, + generic_constraints, + stream: StreamBuilder::new(), + } + } + + pub fn impl_for<'a>(&'a mut self, trait_name: &str) -> ImplFor<'a> { + ImplFor::new(self, trait_name) + } + + pub fn impl_for_with_de_lifetime<'a>(&'a mut self, trait_name: &str) -> ImplFor<'a> { + ImplFor::new_with_de_lifetime(self, trait_name) + } + + pub fn has_lifetimes(&self) -> bool { + self.generics + .as_ref() + .map(|g| g.has_lifetime()) + .unwrap_or(false) + } + + pub fn take_stream(mut self) -> TokenStream { + std::mem::take(&mut self.stream.stream) + } +} + +impl Drop for Generator { + fn drop(&mut self) { + if !self.stream.stream.is_empty() && !std::thread::panicking() { + panic!("Generator dropped but the stream is not empty. Please call `.take_stream()` on the generator"); + } + } +} diff --git a/derive/src/generate/impl_for.rs b/derive/src/generate/impl_for.rs new file mode 100644 index 00000000..753f2065 --- /dev/null +++ b/derive/src/generate/impl_for.rs @@ -0,0 +1,73 @@ +use super::{FnBuilder, Generator, StreamBuilder}; +use crate::prelude::Delimiter; + +#[must_use] +pub struct ImplFor<'a> { + pub(super) generator: &'a mut Generator, + pub(super) group: StreamBuilder, +} + +impl<'a> ImplFor<'a> { + pub(super) fn new(generator: &'a mut Generator, trait_name: &str) -> Self { + let mut builder = StreamBuilder::new(); + builder.ident_str("impl"); + + if let Some(generics) = &generator.generics { + builder.append(generics.impl_generics()); + } + builder.push_parsed(trait_name); + builder.ident_str("for"); + builder.ident(generator.name.clone()); + + if let Some(generics) = &generator.generics { + builder.append(generics.type_generics()); + } + if let Some(generic_constraints) = &generator.generic_constraints { + builder.append(generic_constraints.where_clause()); + } + generator.stream.append(builder); + + let group = StreamBuilder::new(); + Self { generator, group } + } + + pub(super) fn new_with_de_lifetime(generator: &'a mut Generator, trait_name: &str) -> Self { + let mut builder = StreamBuilder::new(); + builder.ident_str("impl"); + + if let Some(generics) = &generator.generics { + builder.append(generics.impl_generics_with_additional_lifetime("__de")); + } else { + builder.punct('<'); + builder.lifetime_str("__de"); + builder.punct('>'); + } + + builder.push_parsed(trait_name); + builder.ident_str("for"); + builder.ident(generator.name.clone()); + if let Some(generics) = &generator.generics { + builder.append(generics.type_generics()); + } + if let Some(generic_constraints) = &generator.generic_constraints { + builder.append(generic_constraints.where_clause()); + } + generator.stream.append(builder); + + let group = StreamBuilder::new(); + Self { generator, group } + } + + pub fn generate_fn<'b>(&'b mut self, name: &str) -> FnBuilder<'a, 'b> { + FnBuilder::new(self, name) + } +} + +impl Drop for ImplFor<'_> { + fn drop(&mut self) { + let stream = std::mem::take(&mut self.group); + self.generator + .stream + .group(Delimiter::Brace, |builder| builder.append(stream)) + } +} diff --git a/derive/src/generate/mod.rs b/derive/src/generate/mod.rs new file mode 100644 index 00000000..696464da --- /dev/null +++ b/derive/src/generate/mod.rs @@ -0,0 +1,9 @@ +mod generate_fn; +mod generator; +mod impl_for; +mod stream_builder; + +pub use self::generate_fn::{FnBuilder, FnSelfArg}; +pub use self::generator::Generator; +pub use self::impl_for::ImplFor; +pub use self::stream_builder::StreamBuilder; diff --git a/derive/src/generate/stream_builder.rs b/derive/src/generate/stream_builder.rs new file mode 100644 index 00000000..a1a08542 --- /dev/null +++ b/derive/src/generate/stream_builder.rs @@ -0,0 +1,95 @@ +use crate::prelude::{ + Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree, +}; +use std::str::FromStr; + +#[must_use] +#[derive(Default)] +pub struct StreamBuilder { + pub(super) stream: TokenStream, +} + +impl StreamBuilder { + pub fn new() -> Self { + Self { + stream: TokenStream::new(), + } + } + + pub fn extend(&mut self, item: impl IntoIterator) { + self.stream.extend(item); + } + + pub fn append(&mut self, builder: StreamBuilder) { + self.stream.extend(builder.stream); + } + + pub fn push(&mut self, item: impl Into) { + self.stream.extend([item.into()]); + } + + pub fn push_parsed(&mut self, item: impl AsRef) { + self.stream + .extend(TokenStream::from_str(item.as_ref()).unwrap_or_else(|e| { + panic!( + "Could not parse string as rust: {:?}\n{:?}", + item.as_ref(), + e + ) + })); + } + + pub fn ident(&mut self, ident: Ident) { + self.stream.extend([TokenTree::Ident(ident)]); + } + + pub fn ident_str(&mut self, ident: impl AsRef) { + self.stream.extend([TokenTree::Ident(Ident::new( + ident.as_ref(), + Span::call_site(), + ))]); + } + + pub fn group(&mut self, delim: Delimiter, inner: impl FnOnce(&mut StreamBuilder)) { + let mut stream = StreamBuilder::new(); + inner(&mut stream); + self.stream + .extend([TokenTree::Group(Group::new(delim, stream.stream))]); + } + + pub fn punct(&mut self, p: char) { + self.stream + .extend([TokenTree::Punct(Punct::new(p, Spacing::Alone))]); + } + + pub fn puncts(&mut self, puncts: &str) { + self.stream.extend( + puncts + .chars() + .map(|char| TokenTree::Punct(Punct::new(char, Spacing::Joint))), + ); + } + + pub fn lifetime(&mut self, lt: Ident) { + self.stream.extend([ + TokenTree::Punct(Punct::new('\'', Spacing::Joint)), + TokenTree::Ident(lt), + ]); + } + pub fn lifetime_str(&mut self, lt: &str) { + self.stream.extend([ + TokenTree::Punct(Punct::new('\'', Spacing::Joint)), + TokenTree::Ident(Ident::new(lt, Span::call_site())), + ]); + } + + pub fn lit_u32(&mut self, val: u32) { + self.stream + .extend([TokenTree::Literal(Literal::u32_unsuffixed(val))]); + } + + pub fn lit_usize(&mut self, val: usize) { + self.stream + .extend([TokenTree::Literal(Literal::usize_unsuffixed(val))]); + } +} diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 2e91110e..493f1435 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -3,51 +3,127 @@ extern crate proc_macro; mod derive_enum; mod derive_struct; mod error; +mod generate; +mod parse; + +#[cfg(test)] +pub(crate) mod prelude { + pub use proc_macro2::*; +} +#[cfg(not(test))] +pub(crate) mod prelude { + pub use proc_macro::*; +} -use derive_enum::DeriveEnum; -use derive_struct::DeriveStruct; use error::Error; -use proc_macro::TokenStream; -use syn::{parse_macro_input, DeriveInput}; +use prelude::TokenStream; type Result = std::result::Result; #[proc_macro_derive(Encodable)] -pub fn derive_encodable(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - derive_encodable_inner(input).unwrap_or_else(|e| e.into_token_stream()) +pub fn derive_encodable(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + #[allow(clippy::useless_conversion)] + derive_encodable_inner(input.into()) + .unwrap_or_else(|e| e.into_token_stream()) + .into() } -fn derive_encodable_inner(input: DeriveInput) -> Result { - match input.data { - syn::Data::Struct(struct_definition) => { - DeriveStruct::parse(input.ident, input.generics, struct_definition) - .and_then(|str| str.generate_encodable()) +fn derive_encodable_inner(input: TokenStream) -> Result { + let source = &mut input.into_iter().peekable(); + + let _attributes = parse::Attributes::try_take(source)?; + let _visibility = parse::Visibility::try_take(source)?; + let (datatype, name) = parse::DataType::take(source)?; + let generics = parse::Generics::try_take(source)?; + let generic_constraints = parse::GenericConstraints::try_take(source)?; + + let mut generator = generate::Generator::new(name.clone(), generics, generic_constraints); + + match datatype { + parse::DataType::Struct => { + let body = parse::StructBody::take(source)?; + derive_struct::DeriveStruct { + fields: body.fields, + } + .generate_encodable(&mut generator)?; } - syn::Data::Enum(enum_definition) => { - DeriveEnum::parse(input.ident, input.generics, enum_definition) - .and_then(|str| str.generate_encodable()) + parse::DataType::Enum => { + let body = parse::EnumBody::take(source)?; + derive_enum::DeriveEnum { + variants: body.variants, + } + .generate_encodable(&mut generator)?; } - syn::Data::Union(_) => Err(Error::UnionNotSupported), } + + let stream = generator.take_stream(); + dump_output(name, "Encodeable", &stream); + Ok(stream) } #[proc_macro_derive(Decodable)] -pub fn derive_decodable(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - derive_decodable_inner(input).unwrap_or_else(|e| e.into_token_stream()) +pub fn derive_decodable(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + #[allow(clippy::useless_conversion)] + derive_decodable_inner(input.into()) + .unwrap_or_else(|e| e.into_token_stream()) + .into() } -fn derive_decodable_inner(input: DeriveInput) -> Result { - match input.data { - syn::Data::Struct(struct_definition) => { - DeriveStruct::parse(input.ident, input.generics, struct_definition) - .and_then(|str| str.generate_decodable()) +fn derive_decodable_inner(input: TokenStream) -> Result { + let source = &mut input.into_iter().peekable(); + + let _attributes = parse::Attributes::try_take(source)?; + let _visibility = parse::Visibility::try_take(source)?; + let (datatype, name) = parse::DataType::take(source)?; + let generics = parse::Generics::try_take(source)?; + let generic_constraints = parse::GenericConstraints::try_take(source)?; + + let mut generator = generate::Generator::new(name.clone(), generics, generic_constraints); + + match datatype { + parse::DataType::Struct => { + let body = parse::StructBody::take(source)?; + derive_struct::DeriveStruct { + fields: body.fields, + } + .generate_decodable(&mut generator)?; } - syn::Data::Enum(enum_definition) => { - DeriveEnum::parse(input.ident, input.generics, enum_definition) - .and_then(|str| str.generate_decodable()) + parse::DataType::Enum => { + let body = parse::EnumBody::take(source)?; + derive_enum::DeriveEnum { + variants: body.variants, + } + .generate_decodable(&mut generator)?; } - syn::Data::Union(_) => Err(Error::UnionNotSupported), } + + let stream = generator.take_stream(); + dump_output(name, "Decodeable", &stream); + Ok(stream) +} + +fn dump_output(name: crate::prelude::Ident, derive: &str, stream: &crate::prelude::TokenStream) { + use std::io::Write; + + if let Ok(var) = std::env::var("CARGO_MANIFEST_DIR") { + let mut path = std::path::PathBuf::from(var); + path.push("target"); + if path.exists() { + path.push(format!("{}_{}.rs", name, derive)); + if let Ok(mut file) = std::fs::File::create(path) { + let _ = file.write_all(stream.to_string().as_bytes()); + } + } + } +} + +#[cfg(test)] +pub(crate) fn token_stream( + s: &str, +) -> std::iter::Peekable> { + use std::str::FromStr; + + let stream = proc_macro2::TokenStream::from_str(s) + .unwrap_or_else(|e| panic!("Could not parse code: {:?}\n{:?}", s, e)); + stream.into_iter().peekable() } diff --git a/derive/src/parse/attributes.rs b/derive/src/parse/attributes.rs new file mode 100644 index 00000000..74d194f5 --- /dev/null +++ b/derive/src/parse/attributes.rs @@ -0,0 +1,56 @@ +use super::assume_group; +use crate::parse::consume_punct_if; +use crate::prelude::{Delimiter, Group, Punct, TokenTree}; +use crate::{Error, Result}; +use std::iter::Peekable; + +#[derive(Debug)] +pub struct Attributes { + // we don't use these fields yet + #[allow(dead_code)] + punct: Punct, + #[allow(dead_code)] + tokens: Group, +} + +impl Attributes { + pub fn try_take(input: &mut Peekable>) -> Result> { + if let Some(punct) = consume_punct_if(input, '#') { + // found attributes, next token should be a [] group + if let Some(TokenTree::Group(g)) = input.peek() { + if g.delimiter() != Delimiter::Bracket { + return Err(Error::InvalidRustSyntax(g.span())); + } + return Ok(Some(Attributes { + punct, + tokens: assume_group(input.next()), + })); + } + // expected [] group, found something else + return Err(Error::InvalidRustSyntax(match input.peek() { + Some(next_token) => next_token.span(), + None => punct.span(), + })); + } + Ok(None) + } +} + +#[test] +fn test_attributes_try_take() { + use crate::token_stream; + + let stream = &mut token_stream("struct Foo;"); + assert!(Attributes::try_take(stream).unwrap().is_none()); + match stream.next().unwrap() { + TokenTree::Ident(i) => assert_eq!(i, "struct"), + x => panic!("Expected ident, found {:?}", x), + } + + let stream = &mut token_stream("#[cfg(test)] struct Foo;"); + assert!(Attributes::try_take(stream).unwrap().is_some()); + match stream.next().unwrap() { + TokenTree::Ident(i) => assert_eq!(i, "struct"), + x => panic!("Expected ident, found {:?}", x), + } +} diff --git a/derive/src/parse/body.rs b/derive/src/parse/body.rs new file mode 100644 index 00000000..ff9689d6 --- /dev/null +++ b/derive/src/parse/body.rs @@ -0,0 +1,421 @@ +use super::{assume_group, assume_ident, read_tokens_until_punct, Attributes, Visibility}; +use crate::parse::consume_punct_if; +use crate::prelude::{Delimiter, Ident, Span, TokenTree}; +use crate::{Error, Result}; +use std::iter::Peekable; + +#[derive(Debug)] +pub struct StructBody { + pub fields: Fields, +} + +impl StructBody { + pub fn take(input: &mut Peekable>) -> Result { + match input.peek() { + Some(TokenTree::Group(_)) => {} + Some(TokenTree::Punct(p)) if p.as_char() == ';' => { + return Ok(StructBody { + fields: Fields::Unit, + }) + } + Some(t) => { + return Err(Error::InvalidRustSyntax(t.span())); + } + _ => { + return Err(Error::InvalidRustSyntax(Span::call_site())); + } + } + let group = assume_group(input.next()); + let mut stream = group.stream().into_iter().peekable(); + let fields = match group.delimiter() { + Delimiter::Brace => Fields::Struct(UnnamedField::parse_with_name(&mut stream)?), + Delimiter::Parenthesis => Fields::Tuple(UnnamedField::parse(&mut stream)?), + _ => return Err(Error::InvalidRustSyntax(group.span())), + }; + Ok(StructBody { fields }) + } +} + +#[test] +fn test_struct_body_take() { + use crate::token_stream; + + let stream = &mut token_stream( + "struct Foo { pub bar: u8, pub(crate) baz: u32, bla: Vec>> }", + ); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Struct); + assert_eq!(ident, "Foo"); + let body = StructBody::take(stream).unwrap(); + + assert_eq!(body.fields.len(), 3); + let (ident, field) = body.fields.get(0).unwrap(); + assert_eq!(ident.unwrap(), "bar"); + assert_eq!(field.vis, Visibility::Pub); + assert_eq!(field.type_string(), "u8"); + + let (ident, field) = body.fields.get(1).unwrap(); + assert_eq!(ident.unwrap(), "baz"); + assert_eq!(field.vis, Visibility::Pub); + assert_eq!(field.type_string(), "u32"); + + let (ident, field) = body.fields.get(2).unwrap(); + assert_eq!(ident.unwrap(), "bla"); + assert_eq!(field.vis, Visibility::Default); + assert_eq!(field.type_string(), "Vec>>"); + + let stream = &mut token_stream( + "struct Foo ( pub u8, pub(crate) u32, Vec>> )", + ); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Struct); + assert_eq!(ident, "Foo"); + let body = StructBody::take(stream).unwrap(); + + assert_eq!(body.fields.len(), 3); + + let (ident, field) = body.fields.get(0).unwrap(); + assert!(ident.is_none()); + assert_eq!(field.vis, Visibility::Pub); + assert_eq!(field.type_string(), "u8"); + + let (ident, field) = body.fields.get(1).unwrap(); + assert!(ident.is_none()); + assert_eq!(field.vis, Visibility::Pub); + assert_eq!(field.type_string(), "u32"); + + let (ident, field) = body.fields.get(2).unwrap(); + assert!(ident.is_none()); + assert_eq!(field.vis, Visibility::Default); + assert_eq!(field.type_string(), "Vec>>"); + + let stream = &mut token_stream("struct Foo;"); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Struct); + assert_eq!(ident, "Foo"); + let body = StructBody::take(stream).unwrap(); + assert_eq!(body.fields.len(), 0); + + let stream = &mut token_stream("struct Foo {}"); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Struct); + assert_eq!(ident, "Foo"); + let body = StructBody::take(stream).unwrap(); + assert_eq!(body.fields.len(), 0); + + let stream = &mut token_stream("struct Foo ()"); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Struct); + assert_eq!(ident, "Foo"); + assert_eq!(body.fields.len(), 0); +} + +#[derive(Debug)] +pub struct EnumBody { + pub variants: Vec, +} + +impl EnumBody { + pub fn take(input: &mut Peekable>) -> Result { + match input.peek() { + Some(TokenTree::Group(_)) => {} + Some(TokenTree::Punct(p)) if p.as_char() == ';' => { + return Ok(EnumBody { + variants: Vec::new(), + }) + } + Some(t) => { + return Err(Error::InvalidRustSyntax(t.span())); + } + _ => { + return Err(Error::InvalidRustSyntax(Span::call_site())); + } + } + let group = assume_group(input.next()); + let mut variants = Vec::new(); + let stream = &mut group.stream().into_iter().peekable(); + while stream.peek().is_some() { + let attributes = Attributes::try_take(stream)?; + let ident = match stream.peek() { + Some(TokenTree::Ident(_)) => assume_ident(stream.next()), + Some(x) => return Err(Error::InvalidRustSyntax(x.span())), + None => return Err(Error::InvalidRustSyntax(Span::call_site())), + }; + + let mut fields = Fields::Unit; + + if let Some(TokenTree::Group(_)) = stream.peek() { + let group = assume_group(stream.next()); + let stream = &mut group.stream().into_iter().peekable(); + match group.delimiter() { + Delimiter::Brace => { + fields = Fields::Struct(UnnamedField::parse_with_name(stream)?) + } + Delimiter::Parenthesis => fields = Fields::Tuple(UnnamedField::parse(stream)?), + _ => return Err(Error::InvalidRustSyntax(group.span())), + } + } + consume_punct_if(stream, ','); + + variants.push(EnumVariant { + name: ident, + fields, + attributes, + }); + } + + Ok(EnumBody { variants }) + } +} + +#[test] +fn test_enum_body_take() { + use crate::token_stream; + + let stream = &mut token_stream("enum Foo { }"); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Enum); + assert_eq!(ident, "Foo"); + let body = EnumBody::take(stream).unwrap(); + assert_eq!(0, body.variants.len()); + + let stream = &mut token_stream("enum Foo { Bar, Baz(u8), Blah { a: u32, b: u128 } }"); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Enum); + assert_eq!(ident, "Foo"); + let body = EnumBody::take(stream).unwrap(); + assert_eq!(3, body.variants.len()); + + assert_eq!(body.variants[0].name, "Bar"); + assert!(body.variants[0].fields.is_unit()); + + assert_eq!(body.variants[1].name, "Baz"); + assert_eq!(1, body.variants[1].fields.len()); + let (ident, field) = body.variants[1].fields.get(0).unwrap(); + assert!(ident.is_none()); + assert_eq!(field.type_string(), "u8"); + + assert_eq!(body.variants[2].name, "Blah"); + assert_eq!(2, body.variants[2].fields.len()); + let (ident, field) = body.variants[2].fields.get(0).unwrap(); + assert_eq!(ident.unwrap(), "a"); + assert_eq!(field.type_string(), "u32"); + let (ident, field) = body.variants[2].fields.get(1).unwrap(); + assert_eq!(ident.unwrap(), "b"); + assert_eq!(field.type_string(), "u128"); +} + +#[derive(Debug)] +pub struct EnumVariant { + pub name: Ident, + pub fields: Fields, + pub attributes: Option, +} + +#[derive(Debug)] +pub enum Fields { + /// Empty variant. + /// ```rs + /// enum Foo { + /// Baz, + /// } + /// struct Bar { } + /// ``` + Unit, + + /// Tuple-like variant + /// ```rs + /// enum Foo { + /// Baz(u32) + /// } + /// struct Bar(u32); + /// ``` + Tuple(Vec), + + /// Struct-like variant + /// ```rs + /// enum Foo { + /// Baz { + /// baz: u32 + /// } + /// } + /// struct Bar { + /// baz: u32 + /// } + /// ``` + Struct(Vec<(Ident, UnnamedField)>), +} + +impl Fields { + pub fn names(&self) -> Vec { + match self { + Self::Tuple(fields) => fields + .iter() + .enumerate() + .map(|(idx, field)| IdentOrIndex::Index(idx, field.span())) + .collect(), + Self::Struct(fields) => fields + .iter() + .map(|(ident, _)| IdentOrIndex::Ident(ident)) + .collect(), + Self::Unit => Vec::new(), + } + } + + pub fn delimiter(&self) -> Option { + match self { + Self::Tuple(_) => Some(Delimiter::Parenthesis), + Self::Struct(_) => Some(Delimiter::Brace), + Self::Unit => None, + } + } +} + +#[cfg(test)] +impl Fields { + pub fn is_unit(&self) -> bool { + matches!(self, Self::Unit) + } + + pub fn len(&self) -> usize { + match self { + Self::Tuple(fields) => fields.len(), + Self::Struct(fields) => fields.len(), + Self::Unit => 0, + } + } + + pub fn get(&self, index: usize) -> Option<(Option<&Ident>, &UnnamedField)> { + match self { + Self::Tuple(fields) => fields.get(index).map(|f| (None, f)), + Self::Struct(fields) => fields.get(index).map(|(ident, field)| (Some(ident), field)), + Self::Unit => None, + } + } +} + +#[derive(Debug)] +pub struct UnnamedField { + pub vis: Visibility, + pub r#type: Vec, + pub attributes: Option, +} + +impl UnnamedField { + pub fn parse_with_name( + input: &mut Peekable>, + ) -> Result> { + let mut result = Vec::new(); + loop { + let attributes = Attributes::try_take(input)?; + let vis = Visibility::try_take(input)?; + + let ident = match input.peek() { + Some(TokenTree::Ident(_)) => assume_ident(input.next()), + Some(x) => return Err(Error::InvalidRustSyntax(x.span())), + None => break, + }; + match input.peek() { + Some(TokenTree::Punct(p)) if p.as_char() == ':' => { + input.next(); + } + Some(x) => return Err(Error::InvalidRustSyntax(x.span())), + None => return Err(Error::InvalidRustSyntax(Span::call_site())), + } + let r#type = read_tokens_until_punct(input, &[','])?; + consume_punct_if(input, ','); + result.push(( + ident, + Self { + vis, + r#type, + attributes, + }, + )); + } + Ok(result) + } + + pub fn parse(input: &mut Peekable>) -> Result> { + let mut result = Vec::new(); + while input.peek().is_some() { + let attributes = Attributes::try_take(input)?; + let vis = Visibility::try_take(input)?; + + let r#type = read_tokens_until_punct(input, &[','])?; + consume_punct_if(input, ','); + result.push(Self { + vis, + r#type, + attributes, + }); + } + Ok(result) + } + + #[cfg(test)] + pub fn type_string(&self) -> String { + self.r#type.iter().map(|t| t.to_string()).collect() + } + + pub fn span(&self) -> Span { + // BlockedTODO: https://github.com/rust-lang/rust/issues/54725 + // Span::join is unstable + // if let Some(first) = self.r#type.first() { + // let mut span = first.span(); + // for token in self.r#type.iter().skip(1) { + // span = span.join(span).unwrap(); + // } + // span + // } else { + // Span::call_site() + // } + + match self.r#type.first() { + Some(first) => first.span(), + None => Span::call_site(), + } + } +} + +#[derive(Debug)] +pub enum IdentOrIndex<'a> { + Ident(&'a Ident), + Index(usize, Span), +} + +impl<'a> IdentOrIndex<'a> { + pub fn unwrap_ident(&self) -> &'a Ident { + match self { + Self::Ident(i) => i, + x => panic!("Expected ident, found {:?}", x), + } + } + + pub fn to_token_tree_with_prefix(&self, prefix: &str) -> TokenTree { + TokenTree::Ident(match self { + IdentOrIndex::Ident(i) => (*i).clone(), + IdentOrIndex::Index(idx, span) => { + let name = format!("{}{}", prefix, idx); + Ident::new(&name, *span) + } + }) + } + pub fn to_string_with_prefix(&self, prefix: &str) -> String { + match self { + IdentOrIndex::Ident(i) => i.to_string(), + IdentOrIndex::Index(idx, _) => { + format!("{}{}", prefix, idx) + } + } + } +} + +impl std::fmt::Display for IdentOrIndex<'_> { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + IdentOrIndex::Ident(i) => write!(fmt, "{}", i), + IdentOrIndex::Index(idx, _) => write!(fmt, "{}", idx), + } + } +} diff --git a/derive/src/parse/data_type.rs b/derive/src/parse/data_type.rs new file mode 100644 index 00000000..64d07603 --- /dev/null +++ b/derive/src/parse/data_type.rs @@ -0,0 +1,77 @@ +use crate::prelude::{Ident, Span, TokenTree}; +use crate::{Error, Result}; +use std::iter::Peekable; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum DataType { + Enum, + Struct, +} + +impl DataType { + pub fn take(input: &mut Peekable>) -> Result<(Self, Ident)> { + if let Some(TokenTree::Ident(ident)) = input.peek() { + let result = match ident.to_string().as_str() { + "struct" => DataType::Struct, + "enum" => DataType::Enum, + _ => return Err(Error::UnknownDataType(ident.span())), + }; + let ident = super::assume_ident(input.next()); + return match input.next() { + Some(TokenTree::Ident(ident)) => Ok((result, ident)), + Some(t) => Err(Error::InvalidRustSyntax(t.span())), + None => Err(Error::InvalidRustSyntax(ident.span())), + }; + } + let span = input + .peek() + .map(|t| t.span()) + .unwrap_or_else(Span::call_site); + Err(Error::InvalidRustSyntax(span)) + } +} + +#[test] +fn test_datatype_take() { + use crate::token_stream; + + fn validate_output_eq(input: &str, expected_dt: DataType, expected_ident: &str) { + let (dt, ident) = DataType::take(&mut token_stream(input)).unwrap_or_else(|e| { + panic!("Could not parse tokenstream {:?}: {:?}", input, e); + }); + if dt != expected_dt || ident != expected_ident { + println!("While parsing {:?}", input); + panic!( + "Expected {:?} {:?}, received {:?} {:?}", + dt, ident, expected_dt, expected_ident + ); + } + } + + assert!(DataType::take(&mut token_stream("enum")) + .unwrap_err() + .is_invalid_rust_syntax()); + validate_output_eq("enum Foo", DataType::Enum, "Foo"); + validate_output_eq("enum Foo { }", DataType::Enum, "Foo"); + validate_output_eq("enum Foo { bar, baz }", DataType::Enum, "Foo"); + validate_output_eq("enum Foo<'a, T> { bar, baz }", DataType::Enum, "Foo"); + + assert!(DataType::take(&mut token_stream("struct")) + .unwrap_err() + .is_invalid_rust_syntax()); + validate_output_eq("struct Foo { }", DataType::Struct, "Foo"); + validate_output_eq("struct Foo { bar: u32, baz: u32 }", DataType::Struct, "Foo"); + validate_output_eq("struct Foo<'a, T> { bar: &'a T }", DataType::Struct, "Foo"); + + assert!(DataType::take(&mut token_stream("fn foo() {}")) + .unwrap_err() + .is_unknown_data_type()); + + assert!(DataType::take(&mut token_stream("() {}")) + .unwrap_err() + .is_invalid_rust_syntax()); + + assert!(DataType::take(&mut token_stream("")) + .unwrap_err() + .is_invalid_rust_syntax()); +} diff --git a/derive/src/parse/generics.rs b/derive/src/parse/generics.rs new file mode 100644 index 00000000..3af1581c --- /dev/null +++ b/derive/src/parse/generics.rs @@ -0,0 +1,417 @@ +use super::assume_punct; +use crate::generate::StreamBuilder; +use crate::parse::{ident_eq, read_tokens_until_punct}; +use crate::prelude::{Ident, TokenTree}; +use crate::{Error, Result}; +use std::iter::Peekable; + +#[derive(Debug)] +pub struct Generics { + lifetimes_and_generics: Vec, +} + +impl Generics { + pub fn try_take(input: &mut Peekable>) -> Result> { + let maybe_punct = input.peek(); + if let Some(TokenTree::Punct(punct)) = maybe_punct { + if punct.as_char() == '<' { + let punct = super::assume_punct(input.next(), '<'); + let mut result = Generics { + lifetimes_and_generics: Vec::new(), + }; + loop { + match input.peek() { + Some(TokenTree::Punct(punct)) if punct.as_char() == '\'' => { + result + .lifetimes_and_generics + .push(Lifetime::take(input)?.into()); + super::consume_punct_if(input, ','); + } + Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => { + assume_punct(input.next(), '>'); + break; + } + Some(TokenTree::Ident(_)) => { + result + .lifetimes_and_generics + .push(Generic::take(input)?.into()); + super::consume_punct_if(input, ','); + } + x => { + return Err(Error::InvalidRustSyntax( + x.map(|x| x.span()).unwrap_or_else(|| punct.span()), + )); + } + } + } + return Ok(Some(result)); + } + } + Ok(None) + } + + pub fn has_lifetime(&self) -> bool { + self.lifetimes_and_generics + .iter() + .any(|lt| lt.is_lifetime()) + } + + pub fn impl_generics(&self) -> StreamBuilder { + let mut result = StreamBuilder::new(); + result.punct('<'); + + for (idx, generic) in self.lifetimes_and_generics.iter().enumerate() { + if idx > 0 { + result.punct(','); + } + + if generic.is_lifetime() { + result.lifetime(generic.ident()); + } else { + result.ident(generic.ident()); + } + + if generic.has_constraints() { + result.punct(':'); + result.extend(generic.constraints()); + } + } + + result.punct('>'); + + result + } + + pub fn impl_generics_with_additional_lifetime(&self, lifetime: &str) -> StreamBuilder { + assert!(self.has_lifetime()); + + let mut result = StreamBuilder::new(); + result.punct('<'); + result.lifetime_str(lifetime); + + if self.has_lifetime() { + for (idx, lt) in self + .lifetimes_and_generics + .iter() + .filter_map(|lt| lt.as_lifetime()) + .enumerate() + { + result.punct(if idx == 0 { ':' } else { '+' }); + result.lifetime(lt.ident.clone()); + } + } + + for generic in &self.lifetimes_and_generics { + result.punct(','); + + if generic.is_lifetime() { + result.lifetime(generic.ident()); + } else { + result.ident(generic.ident()); + } + + if generic.has_constraints() { + result.punct(':'); + result.extend(generic.constraints()); + } + } + + result.punct('>'); + + result + } + + pub fn type_generics(&self) -> StreamBuilder { + let mut result = StreamBuilder::new(); + result.punct('<'); + + for (idx, generic) in self.lifetimes_and_generics.iter().enumerate() { + if idx > 0 { + result.punct(','); + } + if generic.is_lifetime() { + result.lifetime(generic.ident()); + } else { + result.ident(generic.ident()); + } + } + + result.punct('>'); + result + } +} + +#[derive(Debug)] +enum LifetimeOrGeneric { + Lifetime(Lifetime), + Generic(Generic), +} + +impl LifetimeOrGeneric { + fn is_lifetime(&self) -> bool { + matches!(self, LifetimeOrGeneric::Lifetime(_)) + } + + fn ident(&self) -> Ident { + match self { + Self::Lifetime(lt) => lt.ident.clone(), + Self::Generic(gen) => gen.ident.clone(), + } + } + + fn as_lifetime(&self) -> Option<&Lifetime> { + match self { + Self::Lifetime(lt) => Some(lt), + Self::Generic(_) => None, + } + } + + fn has_constraints(&self) -> bool { + match self { + Self::Lifetime(lt) => !lt.constraint.is_empty(), + Self::Generic(gen) => !gen.constraints.is_empty(), + } + } + + fn constraints(&self) -> Vec { + match self { + Self::Lifetime(lt) => lt.constraint.clone(), + Self::Generic(gen) => gen.constraints.clone(), + } + } +} + +impl From for LifetimeOrGeneric { + fn from(lt: Lifetime) -> Self { + Self::Lifetime(lt) + } +} + +impl From for LifetimeOrGeneric { + fn from(gen: Generic) -> Self { + Self::Generic(gen) + } +} + +#[test] +fn test_generics_try_take() { + use crate::token_stream; + + assert!(Generics::try_take(&mut token_stream("")).unwrap().is_none()); + assert!(Generics::try_take(&mut token_stream("foo")) + .unwrap() + .is_none()); + assert!(Generics::try_take(&mut token_stream("()")) + .unwrap() + .is_none()); + + let stream = &mut token_stream("struct Foo<'a, T>()"); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Struct); + assert_eq!(ident, "Foo"); + let generics = Generics::try_take(stream).unwrap().unwrap(); + assert_eq!(generics.lifetimes_and_generics.len(), 2); + assert_eq!(generics.lifetimes_and_generics[0].ident(), "a"); + assert_eq!(generics.lifetimes_and_generics[1].ident(), "T"); + + let stream = &mut token_stream("struct Foo()"); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Struct); + assert_eq!(ident, "Foo"); + let generics = Generics::try_take(stream).unwrap().unwrap(); + assert_eq!(generics.lifetimes_and_generics.len(), 2); + assert_eq!(generics.lifetimes_and_generics[0].ident(), "A"); + assert_eq!(generics.lifetimes_and_generics[1].ident(), "B"); + + let stream = &mut token_stream("struct Foo<'a, T: Display>()"); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Struct); + assert_eq!(ident, "Foo"); + let generics = Generics::try_take(stream).unwrap().unwrap(); + dbg!(&generics); + assert_eq!(generics.lifetimes_and_generics.len(), 2); + assert_eq!(generics.lifetimes_and_generics[0].ident(), "a"); + assert_eq!(generics.lifetimes_and_generics[1].ident(), "T"); + + let stream = &mut token_stream("struct Foo<'a, T: for<'a> Bar<'a> + 'static>()"); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Struct); + assert_eq!(ident, "Foo"); + dbg!(&generics); + assert_eq!(generics.lifetimes_and_generics.len(), 2); + assert_eq!(generics.lifetimes_and_generics[0].ident(), "a"); + assert_eq!(generics.lifetimes_and_generics[1].ident(), "T"); + + let stream = &mut token_stream( + "struct Baz Bar<'a, for<'b> Bar<'b, for<'c> Bar<'c, u32>>>> {}", + ); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Struct); + assert_eq!(ident, "Baz"); + let generics = Generics::try_take(stream).unwrap().unwrap(); + dbg!(&generics); + assert_eq!(generics.lifetimes_and_generics.len(), 1); + assert_eq!(generics.lifetimes_and_generics[0].ident(), "T"); + + let stream = &mut token_stream("struct Baz<()> {}"); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Struct); + assert_eq!(ident, "Baz"); + assert!(Generics::try_take(stream) + .unwrap_err() + .is_invalid_rust_syntax()); + + let stream = &mut token_stream("struct Bar SomeStruct, B>"); + let (data_type, ident) = super::DataType::take(stream).unwrap(); + assert_eq!(data_type, super::DataType::Struct); + assert_eq!(ident, "Bar"); + let generics = Generics::try_take(stream).unwrap().unwrap(); + dbg!(&generics); + assert_eq!(generics.lifetimes_and_generics.len(), 2); + assert_eq!(generics.lifetimes_and_generics[0].ident(), "A"); + assert_eq!(generics.lifetimes_and_generics[1].ident(), "B"); +} + +#[derive(Debug)] +pub struct Lifetime { + ident: Ident, + constraint: Vec, +} + +impl Lifetime { + pub fn take(input: &mut Peekable>) -> Result { + let start = super::assume_punct(input.next(), '\''); + let ident = match input.peek() { + Some(TokenTree::Ident(_)) => super::assume_ident(input.next()), + Some(t) => return Err(Error::ExpectedIdent(t.span())), + None => return Err(Error::ExpectedIdent(start.span())), + }; + + let mut constraint = Vec::new(); + if let Some(TokenTree::Punct(p)) = input.peek() { + if p.as_char() == ':' { + assume_punct(input.next(), ':'); + constraint = super::read_tokens_until_punct(input, &[',', '>'])?; + } + } + + Ok(Self { ident, constraint }) + } + + #[cfg(test)] + fn is_ident(&self, s: &str) -> bool { + self.ident.to_string() == s + } +} + +#[test] +fn test_lifetime_take() { + use crate::token_stream; + use std::panic::catch_unwind; + assert!(Lifetime::take(&mut token_stream("'a")) + .unwrap() + .is_ident("a")); + assert!(catch_unwind(|| Lifetime::take(&mut token_stream("'0"))).is_err()); + assert!(catch_unwind(|| Lifetime::take(&mut token_stream("'("))).is_err()); + assert!(catch_unwind(|| Lifetime::take(&mut token_stream("')"))).is_err()); + assert!(catch_unwind(|| Lifetime::take(&mut token_stream("'0'"))).is_err()); + + let stream = &mut token_stream("'a: 'b>"); + let lifetime = Lifetime::take(stream).unwrap(); + assert_eq!(lifetime.ident, "a"); + assert_eq!(lifetime.constraint.len(), 2); + assume_punct(stream.next(), '>'); + assert!(stream.next().is_none()); +} + +#[derive(Debug)] +pub struct Generic { + ident: Ident, + constraints: Vec, +} + +impl Generic { + pub fn take(input: &mut Peekable>) -> Result { + let ident = super::assume_ident(input.next()); + let mut constraints = Vec::new(); + if let Some(TokenTree::Punct(punct)) = input.peek() { + if punct.as_char() == ':' { + super::assume_punct(input.next(), ':'); + constraints = super::read_tokens_until_punct(input, &['>', ','])?; + } + } + Ok(Generic { ident, constraints }) + } +} + +#[derive(Debug)] +pub struct GenericConstraints { + constraints: Vec, +} + +impl GenericConstraints { + pub fn try_take(input: &mut Peekable>) -> Result> { + match input.peek() { + Some(TokenTree::Ident(ident)) => { + if !ident_eq(ident, "where") { + return Ok(None); + } + } + _ => { + return Ok(None); + } + } + input.next(); + let constraints = read_tokens_until_punct(input, &['{', '('])?; + Ok(Some(Self { constraints })) + } + + pub fn where_clause(&self) -> StreamBuilder { + let mut result = StreamBuilder::new(); + result.ident_str("where"); + result.extend(self.constraints.clone()); + result + } +} + +#[test] +fn test_generic_constraints_try_take() { + use super::{DataType, StructBody, Visibility}; + use crate::token_stream; + + let stream = &mut token_stream("struct Foo where Foo: Bar { }"); + super::DataType::take(stream).unwrap(); + assert!(GenericConstraints::try_take(stream).unwrap().is_some()); + + let stream = &mut token_stream("struct Foo { }"); + super::DataType::take(stream).unwrap(); + assert!(GenericConstraints::try_take(stream).unwrap().is_none()); + + let stream = &mut token_stream("struct Foo where Foo: Bar(Foo)"); + super::DataType::take(stream).unwrap(); + assert!(GenericConstraints::try_take(stream).unwrap().is_some()); + + let stream = &mut token_stream("struct Foo()"); + super::DataType::take(stream).unwrap(); + assert!(GenericConstraints::try_take(stream).unwrap().is_none()); + + let stream = &mut token_stream("struct Foo()"); + assert!(GenericConstraints::try_take(stream).unwrap().is_none()); + + let stream = &mut token_stream("{}"); + assert!(GenericConstraints::try_take(stream).unwrap().is_none()); + + let stream = &mut token_stream(""); + assert!(GenericConstraints::try_take(stream).unwrap().is_none()); + + let stream = &mut token_stream("pub(crate) struct Test {}"); + assert_eq!(Visibility::Pub, Visibility::try_take(stream).unwrap()); + let (data_type, ident) = DataType::take(stream).unwrap(); + assert_eq!(data_type, DataType::Struct); + assert_eq!(ident, "Test"); + let constraints = Generics::try_take(stream).unwrap().unwrap(); + assert_eq!(constraints.lifetimes_and_generics.len(), 1); + assert_eq!(constraints.lifetimes_and_generics[0].ident(), "T"); + let body = StructBody::take(stream).unwrap(); + assert_eq!(body.fields.len(), 0); +} diff --git a/derive/src/parse/mod.rs b/derive/src/parse/mod.rs new file mode 100644 index 00000000..e888abb9 --- /dev/null +++ b/derive/src/parse/mod.rs @@ -0,0 +1,142 @@ +use crate::error::Error; +use crate::prelude::{Delimiter, Group, Ident, Punct, TokenTree}; +use std::iter::Peekable; + +mod attributes; +mod body; +mod data_type; +mod generics; +mod visibility; + +pub use self::attributes::Attributes; +pub use self::body::{EnumBody, EnumVariant, Fields, StructBody, UnnamedField}; +pub use self::data_type::DataType; +pub use self::generics::{Generic, GenericConstraints, Generics, Lifetime}; +pub use self::visibility::Visibility; + +pub(self) fn assume_group(t: Option) -> Group { + match t { + Some(TokenTree::Group(group)) => group, + _ => unreachable!(), + } +} +pub(self) fn assume_ident(t: Option) -> Ident { + match t { + Some(TokenTree::Ident(ident)) => ident, + _ => unreachable!(), + } +} +pub(self) fn assume_punct(t: Option, punct: char) -> Punct { + match t { + Some(TokenTree::Punct(p)) => { + debug_assert_eq!(punct, p.as_char()); + p + } + _ => unreachable!(), + } +} + +pub(self) fn consume_punct_if( + input: &mut Peekable>, + punct: char, +) -> Option { + if let Some(TokenTree::Punct(p)) = input.peek() { + if p.as_char() == punct { + match input.next() { + Some(TokenTree::Punct(p)) => return Some(p), + _ => unreachable!(), + } + } + } + None +} + +#[cfg(test)] +pub(self) fn ident_eq(ident: &Ident, text: &str) -> bool { + ident == text +} + +#[cfg(not(test))] +pub(self) fn ident_eq(ident: &Ident, text: &str) -> bool { + ident.to_string() == text +} + +fn check_if_arrow(tokens: &[TokenTree], punct: &Punct) -> bool { + if punct.as_char() == '>' { + if let Some(TokenTree::Punct(previous_punct)) = tokens.last() { + if previous_punct.as_char() == '-' { + return true; + } + } + } + false +} + +const OPEN_BRACKETS: &[char] = &['<', '(', '[', '{']; +const CLOSING_BRACKETS: &[char] = &['>', ')', ']', '}']; +const BRACKET_DELIMITER: &[Option] = &[ + None, + Some(Delimiter::Parenthesis), + Some(Delimiter::Bracket), + Some(Delimiter::Brace), +]; + +pub(self) fn read_tokens_until_punct( + input: &mut Peekable>, + expected_puncts: &[char], +) -> Result, Error> { + let mut result = Vec::new(); + let mut open_brackets = Vec::::new(); + 'outer: loop { + match input.peek() { + Some(TokenTree::Punct(punct)) => { + if check_if_arrow(&result, punct) { + // do nothing + } else if OPEN_BRACKETS.contains(&punct.as_char()) { + open_brackets.push(punct.as_char()); + } else if let Some(index) = + CLOSING_BRACKETS.iter().position(|c| c == &punct.as_char()) + { + let last_bracket = match open_brackets.pop() { + Some(bracket) => bracket, + None => { + if expected_puncts.contains(&punct.as_char()) { + break; + } + return Err(Error::InvalidRustSyntax(punct.span())); + } + }; + let expected = OPEN_BRACKETS[index]; + assert_eq!( + expected, + last_bracket, + "Unexpected closing bracket: found {}, expected {}", + punct.as_char(), + expected + ); + } else if expected_puncts.contains(&punct.as_char()) && open_brackets.is_empty() { + break; + } + result.push(input.next().unwrap()); + } + Some(TokenTree::Group(g)) if open_brackets.is_empty() => { + for punct in expected_puncts { + if let Some(idx) = OPEN_BRACKETS.iter().position(|c| c == punct) { + if let Some(delim) = BRACKET_DELIMITER[idx] { + if delim == g.delimiter() { + // we need to split on this delimiter + break 'outer; + } + } + } + } + result.push(input.next().unwrap()); + } + Some(_) => result.push(input.next().unwrap()), + None => { + break; + } + } + } + Ok(result) +} diff --git a/derive/src/parse/visibility.rs b/derive/src/parse/visibility.rs new file mode 100644 index 00000000..49a7ebf1 --- /dev/null +++ b/derive/src/parse/visibility.rs @@ -0,0 +1,68 @@ +use crate::prelude::TokenTree; +use crate::Result; +use std::iter::Peekable; + +#[derive(Debug, PartialEq, Clone)] +pub enum Visibility { + Default, + Pub, +} + +impl Visibility { + pub fn try_take(input: &mut Peekable>) -> Result { + if let Some(TokenTree::Ident(ident)) = input.peek() { + if super::ident_eq(ident, "pub") { + // Consume this token + super::assume_ident(input.next()); + + // check if the next token is `pub(...)` + if let Some(TokenTree::Group(_)) = input.peek() { + // we just consume the visibility, we're not actually using it for generation + super::assume_group(input.next()); + } + + return Ok(Visibility::Pub); + } + } + Ok(Visibility::Default) + } +} + +#[test] +fn test_visibility_try_take() { + use crate::token_stream; + + assert_eq!( + Visibility::Default, + Visibility::try_take(&mut token_stream("")).unwrap() + ); + assert_eq!( + Visibility::Pub, + Visibility::try_take(&mut token_stream("pub")).unwrap() + ); + assert_eq!( + Visibility::Pub, + Visibility::try_take(&mut token_stream(" pub ")).unwrap(), + ); + assert_eq!( + Visibility::Pub, + Visibility::try_take(&mut token_stream("\tpub\t")).unwrap() + ); + assert_eq!( + Visibility::Pub, + Visibility::try_take(&mut token_stream("pub(crate)")).unwrap() + ); + assert_eq!( + Visibility::Pub, + Visibility::try_take(&mut token_stream(" pub ( crate ) ")).unwrap() + ); + assert_eq!( + Visibility::Pub, + Visibility::try_take(&mut token_stream("\tpub\t(\tcrate\t)\t")).unwrap() + ); + + assert_eq!( + Visibility::Default, + Visibility::try_take(&mut token_stream("pb")).unwrap() + ); +} diff --git a/src/de/mod.rs b/src/de/mod.rs index 199889eb..71ef2812 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -1,45 +1,45 @@ -use crate::error::DecodeError; - -mod decoder; -mod impls; - -pub mod read; -pub use self::decoder::Decoder; - -pub trait Decodable: for<'de> BorrowDecodable<'de> { - fn decode(decoder: D) -> Result; -} - -pub trait BorrowDecodable<'de>: Sized { - fn borrow_decode>(decoder: D) -> Result; -} - -impl<'de, T: Decodable> BorrowDecodable<'de> for T { - fn borrow_decode(decoder: D) -> Result { - Decodable::decode(decoder) - } -} - -pub trait Decode { - fn decode_u8(&mut self) -> Result; - fn decode_u16(&mut self) -> Result; - fn decode_u32(&mut self) -> Result; - fn decode_u64(&mut self) -> Result; - fn decode_u128(&mut self) -> Result; - fn decode_usize(&mut self) -> Result; - - fn decode_i8(&mut self) -> Result; - fn decode_i16(&mut self) -> Result; - fn decode_i32(&mut self) -> Result; - fn decode_i64(&mut self) -> Result; - fn decode_i128(&mut self) -> Result; - fn decode_isize(&mut self) -> Result; - - fn decode_f32(&mut self) -> Result; - fn decode_f64(&mut self) -> Result; - fn decode_array(&mut self) -> Result<[u8; N], DecodeError>; -} - -pub trait BorrowDecode<'de>: Decode { - fn decode_slice(&mut self, len: usize) -> Result<&'de [u8], DecodeError>; -} +use crate::error::DecodeError; + +mod decoder; +mod impls; + +pub mod read; +pub use self::decoder::Decoder; + +pub trait Decodable: for<'de> BorrowDecodable<'de> { + fn decode(decoder: D) -> Result; +} + +pub trait BorrowDecodable<'de>: Sized { + fn borrow_decode>(decoder: D) -> Result; +} + +impl<'de, T: Decodable> BorrowDecodable<'de> for T { + fn borrow_decode(decoder: D) -> Result { + Decodable::decode(decoder) + } +} + +pub trait Decode { + fn decode_u8(&mut self) -> Result; + fn decode_u16(&mut self) -> Result; + fn decode_u32(&mut self) -> Result; + fn decode_u64(&mut self) -> Result; + fn decode_u128(&mut self) -> Result; + fn decode_usize(&mut self) -> Result; + + fn decode_i8(&mut self) -> Result; + fn decode_i16(&mut self) -> Result; + fn decode_i32(&mut self) -> Result; + fn decode_i64(&mut self) -> Result; + fn decode_i128(&mut self) -> Result; + fn decode_isize(&mut self) -> Result; + + fn decode_f32(&mut self) -> Result; + fn decode_f64(&mut self) -> Result; + fn decode_array(&mut self) -> Result<[u8; N], DecodeError>; +} + +pub trait BorrowDecode<'de>: Decode { + fn decode_slice(&mut self, len: usize) -> Result<&'de [u8], DecodeError>; +} diff --git a/src/features/impl_alloc.rs b/src/features/impl_alloc.rs index 8b137891..b770a3c9 100644 --- a/src/features/impl_alloc.rs +++ b/src/features/impl_alloc.rs @@ -1 +1,28 @@ +use crate::{config, enc, error, Config}; +use alloc::vec::Vec; +#[derive(Default)] +struct VecWriter { + inner: Vec, +} + +impl enc::write::Writer for VecWriter { + fn write(&mut self, bytes: &[u8]) -> Result<(), error::EncodeError> { + self.inner.extend_from_slice(bytes); + Ok(()) + } +} + +pub fn encode_to_vec(val: E) -> Result, error::EncodeError> { + encode_to_vec_with_config(val, config::Default) +} + +pub fn encode_to_vec_with_config( + val: E, + _config: C, +) -> Result, error::EncodeError> { + let writer = VecWriter::default(); + let mut encoder = enc::Encoder::<_, C>::new(writer); + val.encode(&mut encoder)?; + Ok(encoder.into_writer().inner) +} diff --git a/src/lib.rs b/src/lib.rs index 35f8441d..71e4cb99 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,61 +1,61 @@ -#![no_std] - -//! Bincode is a crate for encoding and decoding using a tiny binary -//! serialization strategy. Using it, you can easily go from having -//! an object in memory, quickly serialize it to bytes, and then -//! deserialize it back just as fast! - -#![doc(html_root_url = "https://docs.rs/bincode/2.0.0-dev")] -#![crate_name = "bincode"] -#![crate_type = "rlib"] -#![crate_type = "dylib"] - -#[cfg(feature = "alloc")] -extern crate alloc; -#[cfg(any(feature = "std", test))] -extern crate std; - -mod features; -pub(crate) mod varint; - -pub use features::*; - -pub mod config; -pub mod de; -pub mod enc; -pub mod error; - -use config::Config; - -pub fn encode_into_slice( - val: E, - dst: &mut [u8], -) -> Result { - encode_into_slice_with_config(val, dst, config::Default) -} - -pub fn encode_into_slice_with_config( - val: E, - dst: &mut [u8], - _config: C, -) -> Result { - let writer = enc::write::SliceWriter::new(dst); - let mut encoder = enc::Encoder::<_, C>::new(writer); - val.encode(&mut encoder)?; - Ok(encoder.into_writer().bytes_written()) -} - -pub fn decode<'__de, D: de::BorrowDecodable<'__de>>( - src: &'__de mut [u8], -) -> Result { - decode_with_config(src, config::Default) -} - -pub fn decode_with_config<'__de, D: de::BorrowDecodable<'__de>, C: Config>( - src: &'__de mut [u8], - _config: C, -) -> Result { - let reader = de::read::SliceReader::new(src); - let mut decoder = de::Decoder::<_, C>::new(reader, _config); - D::borrow_decode(&mut decoder) -} +#![no_std] + +//! Bincode is a crate for encoding and decoding using a tiny binary +//! serialization strategy. Using it, you can easily go from having +//! an object in memory, quickly serialize it to bytes, and then +//! deserialize it back just as fast! + +#![doc(html_root_url = "https://docs.rs/bincode/2.0.0-dev")] +#![crate_name = "bincode"] +#![crate_type = "rlib"] +#![crate_type = "dylib"] + +#[cfg(feature = "alloc")] +extern crate alloc; +#[cfg(any(feature = "std", test))] +extern crate std; + +mod features; +pub(crate) mod varint; + +pub use features::*; + +pub mod config; +pub mod de; +pub mod enc; +pub mod error; + +use config::Config; + +pub fn encode_into_slice( + val: E, + dst: &mut [u8], +) -> Result { + encode_into_slice_with_config(val, dst, config::Default) +} + +pub fn encode_into_slice_with_config( + val: E, + dst: &mut [u8], + _config: C, +) -> Result { + let writer = enc::write::SliceWriter::new(dst); + let mut encoder = enc::Encoder::<_, C>::new(writer); + val.encode(&mut encoder)?; + Ok(encoder.into_writer().bytes_written()) +} + +pub fn decode<'__de, D: de::BorrowDecodable<'__de>>( + src: &'__de [u8], +) -> Result { + decode_with_config(src, config::Default) +} + +pub fn decode_with_config<'__de, D: de::BorrowDecodable<'__de>, C: Config>( + src: &'__de [u8], + _config: C, +) -> Result { + let reader = de::read::SliceReader::new(src); + let mut decoder = de::Decoder::<_, C>::new(reader, _config); + D::borrow_decode(&mut decoder) +} diff --git a/tests/derive.rs b/tests/derive.rs index b6aa2fe1..3738cf06 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -3,7 +3,7 @@ use bincode::{de::Decodable, enc::Encodeable}; #[derive(bincode::Encodable, PartialEq, Debug)] -pub struct Test { +pub(crate) struct Test { a: T, b: u32, c: u8, diff --git a/tests/serde.rs b/tests/serde.rs new file mode 100644 index 00000000..7cb083bf --- /dev/null +++ b/tests/serde.rs @@ -0,0 +1,28 @@ +#![cfg(all(feature = "serde", feature = "alloc", feature = "derive"))] + +use serde_derive::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, bincode::Encodable, bincode::Decodable)] +pub struct SerdeRoundtrip { + pub a: u32, + #[serde(skip)] + pub b: u32, +} + +#[test] +fn test_serde_round_trip() { + // validate serde attribute working + let json = serde_json::to_string(&SerdeRoundtrip { a: 5, b: 5 }).unwrap(); + assert_eq!("{\"a\":5}", json); + + let result: SerdeRoundtrip = serde_json::from_str(&json).unwrap(); + assert_eq!(result.a, 5); + assert_eq!(result.b, 0); + + // validate bincode working + let bytes = bincode::encode_to_vec(SerdeRoundtrip { a: 15, b: 15 }).unwrap(); + assert_eq!(bytes, &[15, 15]); + let result: SerdeRoundtrip = bincode::decode(&bytes).unwrap(); + assert_eq!(result.a, 15); + assert_eq!(result.b, 15); +}