Skip to content

Commit

Permalink
Improve internal DX around byte classification [1] (#16864)
Browse files Browse the repository at this point in the history
This PR improves the internal DX when working with `u8` classification
into a smaller enum. This is done by implementing a `ClassifyBytes` proc
derive macro. The benefit of this is that the DX is much better and
everything you will see here is done at compile time.

Before:
```rs
#[derive(Debug, Clone, Copy, PartialEq)]
enum Class {
    ValidStart,
    ValidInside,
    OpenBracket,
    OpenParen,
    Slash,
    Other,
}

const CLASS_TABLE: [Class; 256] = {
    let mut table = [Class::Other; 256];

    macro_rules! set {
        ($class:expr, $($byte:expr),+ $(,)?) => {
            $(table[$byte as usize] = $class;)+
        };
    }

    macro_rules! set_range {
        ($class:expr, $start:literal ..= $end:literal) => {
            let mut i = $start;
            while i <= $end {
                table[i as usize] = $class;
                i += 1;
            }
        };
    }

    set_range!(Class::ValidStart, b'a'..=b'z');
    set_range!(Class::ValidStart, b'A'..=b'Z');
    set_range!(Class::ValidStart, b'0'..=b'9');

    set!(Class::OpenBracket, b'[');
    set!(Class::OpenParen, b'(');

    set!(Class::Slash, b'/');

    set!(Class::ValidInside, b'-', b'_', b'.');

    table
};
```

After:
```rs
#[derive(Debug, Clone, Copy, PartialEq, ClassifyBytes)]
enum Class {
    #[bytes_range(b'a'..=b'z', b'A'..=b'Z', b'0'..=b'9')]
    ValidStart,

    #[bytes(b'-', b'_', b'.')]
    ValidInside,

    #[bytes(b'[')]
    OpenBracket,

    #[bytes(b'(')]
    OpenParen,

    #[bytes(b'/')]
    Slash,

    #[fallback]
    Other,
}
```

Before we were generating a `CLASS_TABLE` that we could access directly,
but now it will be part of the `Class`. This means that the usage has to
change:

```diff
- CLASS_TABLE[cursor.curr as usize]
+ Class::TABLE[cursor.curr as usize]
```

This is slightly worse UX, and this is where another change comes in. We
implemented the `From<u8> for #enum_name` trait inside of the
`ClassifyBytes` derive macro. This allows us to use `.into()` on any
`u8` as long as we are comparing it to a `Class` instance. In our
scenario:

```diff
- Class::TABLE[cursor.curr as usize]
+ cursor.curr.into()
```

Usage wise, this looks something like this:
```diff
        while cursor.pos < len {
-           match Class::TABLE[cursor.curr as usize] {
+           match cursor.curr.into() {
-               Class::Escape => match Class::Table[cursor.next as usize] {
+               Class::Escape => match cursor.next.into() {
                    // An escaped whitespace character is not allowed
                    Class::Whitespace => return MachineState::Idle,

                    // An escaped character, skip ahead to the next character
                    _ => cursor.advance(),
                },

                // End of the string
                Class::Quote if cursor.curr == end_char => return self.done(start_pos, cursor),

                // Any kind of whitespace is not allowed
                Class::Whitespace => return MachineState::Idle,

                // Everything else is valid
                _ => {}
            };

            cursor.advance()
        }

        MachineState::Idle
    }
}
```


If you manually look at the `Class::TABLE` in your editor for example,
you can see that it is properly generated at compile time.

Given this input:
```rs
#[derive(Clone, Copy, ClassifyBytes)]
enum Class {
    #[bytes_range(b'a'..=b'z')]
    AlphaLower,

    #[bytes_range(b'A'..=b'Z')]
    AlphaUpper,

    #[bytes(b'@')]
    At,

    #[bytes(b':')]
    Colon,

    #[bytes(b'-')]
    Dash,

    #[bytes(b'.')]
    Dot,

    #[bytes(b'\0')]
    End,

    #[bytes(b'!')]
    Exclamation,

    #[bytes_range(b'0'..=b'9')]
    Number,

    #[bytes(b'[')]
    OpenBracket,

    #[bytes(b']')]
    CloseBracket,

    #[bytes(b'(')]
    OpenParen,

    #[bytes(b'%')]
    Percent,

    #[bytes(b'"', b'\'', b'`')]
    Quote,

    #[bytes(b'/')]
    Slash,

    #[bytes(b'_')]
    Underscore,

    #[bytes(b' ', b'\t', b'\n', b'\r', b'\x0C')]
    Whitespace,

    #[fallback]
    Other,
}
```

This is the result:
<img width="1244" alt="image"
src="https://github.com/user-attachments/assets/6ffd6ad3-0b2f-4381-a24c-593e4c72080e"
/>
  • Loading branch information
RobinMalfait authored Mar 5, 2025
1 parent 0b36dd5 commit 4c11001
Show file tree
Hide file tree
Showing 14 changed files with 474 additions and 514 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions crates/classification-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "classification-macros"
version = "0.1.0"
edition = "2021"

[lib]
proc-macro = true

[dependencies]
syn = "2"
quote = "1"
proc-macro2 = "1"
247 changes: 247 additions & 0 deletions crates/classification-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::{
parse_macro_input, punctuated::Punctuated, token::Comma, Attribute, Data, DataEnum,
DeriveInput, Expr, ExprLit, ExprRange, Ident, Lit, RangeLimits, Result, Variant,
};

/// A custom derive that supports:
///
/// - `#[bytes(…)]` for single byte literals
/// - `#[bytes_range(…)]` for inclusive byte ranges (b'a'..=b'z')
/// - `#[fallback]` for a variant that covers everything else
///
/// Example usage:
///
/// ```rust
/// use classification_macros::ClassifyBytes;
///
/// #[derive(Clone, Copy, ClassifyBytes)]
/// enum Class {
/// #[bytes(b'a', b'b', b'c')]
/// Letters,
///
/// #[bytes_range(b'0'..=b'9')]
/// Digits,
///
/// #[fallback]
/// Other,
/// }
/// ```
/// Then call `b'a'.into()` to get `Example::SomeLetters`.
#[proc_macro_derive(ClassifyBytes, attributes(bytes, bytes_range, fallback))]
pub fn classify_bytes_derive(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);

// This derive only works on an enum
let Data::Enum(DataEnum { variants, .. }) = &ast.data else {
return syn::Error::new_spanned(
&ast.ident,
"ClassifyBytes can only be derived on an enum.",
)
.to_compile_error()
.into();
};

let enum_name = &ast.ident;

let mut byte_map: [Option<Ident>; 256] = [const { None }; 256];
let mut fallback_variant: Option<Ident> = None;

// Start parsing the variants
for variant in variants {
let variant_ident = &variant.ident;

// If this variant has #[fallback], record it
if has_fallback_attr(variant) {
if fallback_variant.is_some() {
let err = syn::Error::new_spanned(
variant_ident,
"Multiple variants have #[fallback]. Only one allowed.",
);
return err.to_compile_error().into();
}
fallback_variant = Some(variant_ident.clone());
}

// Get #[bytes(…)]
let single_bytes = get_bytes_attrs(&variant.attrs);

// Get #[bytes_range(…)]
let range_bytes = get_bytes_range_attrs(&variant.attrs);

// Combine them
let all_bytes = single_bytes
.into_iter()
.chain(range_bytes)
.collect::<Vec<_>>();

// Mark them in the table
for b in all_bytes {
byte_map[b as usize] = Some(variant_ident.clone());
}
}

// If no fallback variant is found, default to "Other"
let fallback_ident = fallback_variant.expect("A variant marked with #[fallback] is missing");

// For each of the 256 byte values, fill the table
let fill = byte_map
.clone()
.into_iter()
.map(|variant_opt| match variant_opt {
Some(ident) => quote!(#enum_name::#ident),
None => quote!(#enum_name::#fallback_ident),
});

// Generate the final expanded code
let expanded = quote! {
impl #enum_name {
pub const TABLE: [#enum_name; 256] = [
#(#fill),*
];
}

impl From<u8> for #enum_name {
fn from(byte: u8) -> Self {
#enum_name::TABLE[byte as usize]
}
}
};

TokenStream::from(expanded)
}

/// Checks if a variant has `#[fallback]`
fn has_fallback_attr(variant: &Variant) -> bool {
variant
.attrs
.iter()
.any(|attr| attr.path().is_ident("fallback"))
}

/// Get all single byte literals from `#[bytes(…)]`
fn get_bytes_attrs(attrs: &[Attribute]) -> Vec<u8> {
let mut assigned = Vec::new();
for attr in attrs {
if attr.path().is_ident("bytes") {
match parse_bytes_attr(attr) {
Ok(list) => assigned.extend(list),
Err(e) => panic!("Error parsing #[bytes(...)]: {}", e),
}
}
}
assigned
}

/// Parse `#[bytes(...)]` as a comma-separated list of **byte literals**, e.g. `b'a'`, `b'\n'`.
fn parse_bytes_attr(attr: &Attribute) -> Result<Vec<u8>> {
// We'll parse it as a list of syn::Lit separated by commas: e.g. (b'a', b'b')
let items: Punctuated<Lit, Comma> = attr.parse_args_with(Punctuated::parse_terminated)?;
let mut out = Vec::new();
for lit in items {
match lit {
Lit::Byte(lb) => out.push(lb.value()),
_ => {
return Err(syn::Error::new_spanned(
lit,
"Expected a byte literal like b'a'",
))
}
}
}
Ok(out)
}

/// Get all byte ranges from `#[bytes_range(...)]`
fn get_bytes_range_attrs(attrs: &[Attribute]) -> Vec<u8> {
let mut assigned = Vec::new();
for attr in attrs {
if attr.path().is_ident("bytes_range") {
match parse_bytes_range_attr(attr) {
Ok(list) => assigned.extend(list),
Err(e) => panic!("Error parsing #[bytes_range(...)]: {}", e),
}
}
}
assigned
}

/// Parse `#[bytes_range(...)]` as a comma-separated list of range expressions, e.g.:
/// `b'a'..=b'z', b'0'..=b'9'`
fn parse_bytes_range_attr(attr: &Attribute) -> Result<Vec<u8>> {
// We'll parse each element as a syn::Expr, then see if it's an Expr::Range
let exprs: Punctuated<Expr, Comma> = attr.parse_args_with(Punctuated::parse_terminated)?;
let mut out = Vec::new();

for expr in exprs {
if let Expr::Range(ExprRange {
start: Some(start),
end: Some(end),
limits,
..
}) = expr
{
let from = extract_byte_literal(&start)?;
let to = extract_byte_literal(&end)?;

match limits {
RangeLimits::Closed(_) => {
// b'a'..=b'z'
if from <= to {
out.extend(from..=to);
}
}
RangeLimits::HalfOpen(_) => {
// b'a'..b'z' => from..(to-1)
if from < to {
out.extend(from..to);
}
}
}
} else {
return Err(syn::Error::new_spanned(
expr,
"Expected a byte range like b'a'..=b'z'",
));
}
}

Ok(out)
}

/// Extract a u8 from an expression that can be:
///
/// - `Expr::Lit(Lit::Byte(...))`, e.g. b'a'
/// - `Expr::Lit(Lit::Int(...))`, e.g. 0x80 or 255
fn extract_byte_literal(expr: &Expr) -> Result<u8> {
if let Expr::Lit(ExprLit { lit, .. }) = expr {
match lit {
// Existing case: b'a'
Lit::Byte(lb) => Ok(lb.value()),

// New case: 0x80, 255, etc.
Lit::Int(li) => {
let value = li.base10_parse::<u64>()?;
if value <= 255 {
Ok(value as u8)
} else {
Err(syn::Error::new_spanned(
li,
format!("Integer literal {} out of range for a byte (0..255)", value),
))
}
}

_ => Err(syn::Error::new_spanned(
lit,
"Expected b'...' or an integer literal in range 0..=255",
)),
}
} else {
Err(syn::Error::new_spanned(
expr,
"Expected a literal expression like b'a' or 0x80",
))
}
}
1 change: 1 addition & 0 deletions crates/oxide/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ ignore = "0.4.23"
dunce = "1.0.5"
bexpand = "1.2.0"
fast-glob = "0.4.3"
classification-macros = { path = "../classification-macros" }

[dev-dependencies]
tempfile = "3.13.0"
Expand Down
Loading

0 comments on commit 4c11001

Please sign in to comment.