Last active
April 5, 2025 16:15
-
-
Save Swoorup/798a36a8cc9b2e594f9632ff7cccfc4e to your computer and use it in GitHub Desktop.
Slightly different take than dtype_dispatch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! # DType Macros | |
//! | |
//! This module provides a set of macros and traits for defining and working with strongly-typed | |
//! enums and their associated tokens. These macros are designed to: | |
//! | |
//! - **Generate Tokens**: Define unique tokens for each variant of an enum using `build_dtype_tokens`. | |
//! - **Define Enums**: Create enums with compile-time checks to ensure all variants are valid tokens | |
//! using `build_dtype_enum`. | |
//! - **Dynamic Downcasting**: Enable safe downcasting of enum variants to their underlying types | |
//! or containers. | |
//! - **Pattern Matching**: Provide utilities for ergonomic pattern matching on enums and their variants. | |
//! | |
//! ## Key Features | |
//! - Compile-time validation of tokens and variants. | |
//! - Support for both simple and container-based enum variants. | |
//! - Flexible downcasting methods (`downcast_ref`, `downcast_mut`, `downcast`) for accessing inner values. | |
//! - Customizable pattern matching macros for concise and type-safe matching. | |
//! Thanks to https://github.com/pcodec/pcodec/tree/main/dtype_dispatch for inspiration | |
#![allow(unreachable_patterns)] | |
#![allow(dead_code)] | |
// Define the VariantToken trait with Target parameter | |
pub trait EnumVariantTarget<VariantToken> { | |
type Target: 'static; | |
} | |
// Define the EnumVariantConstraint trait with Constraint parameter | |
pub trait EnumVariantConstraint<VariantToken> { | |
type Constraint: 'static; | |
} | |
// Marker trait for tokens defined by build_dtype_tokens | |
pub trait DTypeToken { | |
// Optional: Add a const associated type for token name for better error messages | |
const TOKEN_NAME: &'static str; | |
} | |
#[macro_export] | |
macro_rules! build_dtype_tokens { | |
( | |
[$($variant: ident),* $(,)?] ) => { | |
::paste::paste! { | |
$( | |
pub struct [<$variant Variant>]; | |
impl $crate::DTypeToken for [<$variant Variant>] { | |
const TOKEN_NAME: &'static str = stringify!([<$variant Variant>]); | |
} | |
)+ | |
} | |
} | |
} | |
#[macro_export] | |
macro_rules! build_dtype_enum_variant_targets { | |
( | |
$tokens_path: path, | |
$name: ident, | |
$($variant: ident => $t: ty,)+ | |
) => { | |
::paste::paste! { | |
$( | |
impl $crate::EnumVariantTarget<[<$variant Variant>]> for $name { | |
type Target = $t; | |
} | |
)+ | |
} | |
}; | |
} | |
/// Produces two macros: an enum definer and an enum matcher. | |
/// | |
/// See the crate-level documentation for more info. | |
/// | |
/// This macro ensures that all variants used in the enum definition | |
/// have been previously defined using `build_dtype_tokens`. | |
#[macro_export] | |
macro_rules! build_dtype_enum { | |
( | |
$tokens_path: path, | |
$(#[$matcher_attrs: meta])* | |
$matcher: ident, | |
{$($(#[$variant_attrs:meta])* $variant: ident,)+}$(,)? | |
#[$enum_attrs: meta] | |
$vis: vis $name: ident | |
) => { | |
// Compile-time check that all variants are valid tokens | |
::paste::paste! { | |
$( | |
const _: fn() = || { | |
// This will fail to compile if $tokens_path::[<$variant Variant>] is not a DTypeToken | |
fn assert_is_token<T: $crate::DTypeToken>() {} | |
let _: fn() = || { | |
assert_is_token::<$tokens_path::[<$variant Variant>]>(); | |
// This is a more descriptive error message | |
let _token_name = <$tokens_path::[<$variant Variant>] as $crate::DTypeToken>::TOKEN_NAME; | |
}; | |
}; | |
)+ | |
} | |
#[$enum_attrs] | |
$vis enum $name { | |
$($(#[$variant_attrs])* $variant,)+ | |
} | |
::paste::paste!{ | |
#[doc(hidden)] | |
#[macro_export] | |
$(#[$matcher_attrs])* | |
macro_rules! [<_ $matcher>] { | |
($value:expr, $enum_:ident<$token_type:ident> => $body:block) => { | |
match $value { | |
$($enum_::$variant => { | |
#[allow(unused)] | |
type $token_type = $tokens_path::[<$variant Variant>]; | |
#[allow(unused_braces)] | |
$body | |
})+ | |
} | |
}; | |
} | |
#[allow(unused_imports)] | |
pub use [<_ $matcher>] as $matcher; | |
} | |
}; | |
( | |
$tokens_path: path, | |
$constraint: path, | |
{$($(#[$variant_attrs:meta])* $variant: ident => $t: ty,)+}$(,)? | |
#[$enum_attrs: meta] | |
#[repr($desc_t: ty)] | |
$vis: vis $name: ident = $desc_val: ident | |
) => { | |
// Compile-time check that all variants are valid tokens | |
::paste::paste! { | |
$( | |
const _: fn() = || { | |
// This will fail to compile if $tokens_path::[<$variant Variant>] is not a DTypeToken | |
fn assert_is_token<T: $crate::DTypeToken>() {} | |
let _: fn() = || { | |
assert_is_token::<$tokens_path::[<$variant Variant>]>(); | |
// This is a more descriptive error message | |
let _token_name = <$tokens_path::[<$variant Variant>] as $crate::DTypeToken>::TOKEN_NAME; | |
}; | |
}; | |
)+ | |
} | |
#[$enum_attrs] | |
#[repr($desc_t)] | |
$vis enum $name { | |
$($(#[$variant_attrs])* $variant = <$t>::$desc_val,)+ | |
} | |
impl $name { | |
#[inline] | |
pub fn new<T: $constraint>() -> Option<Self> { | |
let type_id = std::any::TypeId::of::<T>(); | |
$( | |
if type_id == std::any::TypeId::of::<$t>() { | |
return Some($name::$variant); | |
} | |
)+ | |
None | |
} | |
pub fn from_descriminant(desc: $desc_t) -> Option<Self> { | |
match desc { | |
$(<$t>::$desc_val => Some(Self::$variant),)+ | |
_ => None | |
} | |
} | |
} | |
}; | |
( | |
$tokens_path: path, | |
$(#[$matcher_attrs: meta])* | |
$matcher: ident, | |
$constraint: path, | |
{$($(#[$variant_attrs:meta])* $variant: ident => $t: ty,)+}$(,)? | |
#[$enum_attrs: meta] | |
$vis: vis $name: ident($container: ident) | |
) => { | |
// Compile-time check that all variants are valid tokens | |
::paste::paste! { | |
$( | |
const _: fn() = || { | |
// This will fail to compile if $tokens_path::[<$variant Variant>] is not a DTypeToken | |
fn assert_is_token<T: $crate::DTypeToken>() {} | |
let _: fn() = || { | |
assert_is_token::<$tokens_path::[<$variant Variant>]>(); | |
// This is a more descriptive error message | |
let _token_name = <$tokens_path::[<$variant Variant>] as $crate::DTypeToken>::TOKEN_NAME; | |
}; | |
}; | |
)+ | |
} | |
#[$enum_attrs] | |
$vis enum $name { | |
$($(#[$variant_attrs])* $variant($container<$t>),)+ | |
} | |
// Use build_dtype_enum_variant_targets to generate EnumVariantTarget implementations | |
$crate::build_dtype_enum_variant_targets!($tokens_path, $name, $($variant => $container<$t>,)+); | |
// Generate VariantToken implementations for each variant | |
::paste::paste! { | |
$( | |
impl $crate::EnumVariantConstraint<$tokens_path::[<$variant Variant>]> for $name { | |
type Constraint = $t; | |
} | |
)+ | |
} | |
$( | |
impl From<$container<$t>> for $name { | |
fn from(inner: $container<$t>) -> Self { | |
Self::$variant(inner) | |
} | |
} | |
)+ | |
#[allow(dead_code)] | |
impl $name { | |
/// Returns a reference to the inner container if the enum variant matches the requested token type. | |
pub fn downcast_ref<T>(&self) -> Option<&<$name as $crate::EnumVariantTarget<T>>::Target> | |
where | |
$name: $crate::EnumVariantTarget<T> | |
{ | |
// Get the TypeId of the target type | |
let target_type_id = std::any::TypeId::of::<<$name as $crate::EnumVariantTarget<T>>::Target>(); | |
match self { | |
$( | |
// For each variant, check if its container type matches the target type | |
Self::$variant(inner) => { | |
if target_type_id == std::any::TypeId::of::<$container<$t>>() { | |
// If it matches, cast the reference | |
unsafe { | |
Some(&*(inner as *const $container<$t> as *const <$name as $crate::EnumVariantTarget<T>>::Target)) | |
} | |
} else { | |
None | |
} | |
}, | |
)+ | |
} | |
} | |
/// Returns a mutable reference to the inner container if the enum variant matches the requested token type. | |
pub fn downcast_mut<T>(&mut self) -> Option<&mut <$name as $crate::EnumVariantTarget<T>>::Target> | |
where | |
$name: $crate::EnumVariantTarget<T> | |
{ | |
// Get the TypeId of the target type | |
let target_type_id = std::any::TypeId::of::<<$name as $crate::EnumVariantTarget<T>>::Target>(); | |
match self { | |
$( | |
// For each variant, check if its container type matches the target type | |
Self::$variant(inner) => { | |
if target_type_id == std::any::TypeId::of::<$container<$t>>() { | |
// If it matches, cast the mutable reference | |
unsafe { | |
Some(&mut *(inner as *mut $container<$t> as *mut <$name as $crate::EnumVariantTarget<T>>::Target)) | |
} | |
} else { | |
None | |
} | |
}, | |
)+ | |
} | |
} | |
/// Consumes the enum and returns the inner container if the enum variant matches the requested token type. | |
pub fn downcast<T>(self) -> Option<<$name as $crate::EnumVariantTarget<T>>::Target> | |
where | |
$name: $crate::EnumVariantTarget<T> | |
{ | |
// Get the TypeId of the target type | |
let target_type_id = std::any::TypeId::of::<<$name as $crate::EnumVariantTarget<T>>::Target>(); | |
match self { | |
$( | |
// For each variant, check if its container type matches the target type | |
Self::$variant(inner) => { | |
if target_type_id == std::any::TypeId::of::<$container<$t>>() { | |
// If it matches, convert the container | |
let ptr = &inner as *const $container<$t> as *const <$name as $crate::EnumVariantTarget<T>>::Target; | |
let result = unsafe { ptr.read() }; | |
std::mem::forget(inner); | |
Some(result) | |
} else { | |
None | |
} | |
}, | |
)+ | |
} | |
} | |
} | |
::paste::paste! { | |
#[doc(hidden)] | |
#[macro_export] | |
$(#[$matcher_attrs])* | |
macro_rules! [<_ $matcher>] { | |
( | |
$value:expr, $enum_:ident<$generic:ident, $token_type:ident> => $body:block) => { | |
match $value { | |
$($enum_::$variant => { | |
#[allow(unused)] | |
type $generic = $t; | |
#[allow(unused)] | |
type $token_type = $tokens_path::[<$variant Variant>]; | |
#[allow(unused_braces)] | |
$body | |
})+ | |
} | |
}; | |
($value:expr, $enum_:ident<$generic:ident, $token_type:ident>($inner:ident) => $body:block) => { | |
match $value { | |
$($enum_::$variant($inner) => { | |
#[allow(unused)] | |
type $generic = $t; | |
#[allow(unused)] | |
type $token_type = $tokens_path::[<$variant Variant>]; | |
#[allow(unused_braces)] | |
$body | |
})+ | |
} | |
}; | |
} | |
pub use [<_ $matcher>] as $matcher; | |
} | |
}; | |
( | |
$tokens_path: path, | |
$(#[$matcher_attrs: meta])* | |
$matcher: ident, | |
{$($(#[$variant_attrs:meta])* $variant: ident => $t: ty,)+}$(,)? | |
#[$enum_attrs: meta] | |
$vis: vis $name: ident | |
) => { | |
// Compile-time check that all variants are valid tokens | |
::paste::paste! { | |
$( | |
const _: fn() = || { | |
// This will fail to compile if $tokens_path::[<$variant Variant>] is not a DTypeToken | |
fn assert_is_token<T: $crate::DTypeToken>() {} | |
let _: fn() = || { | |
assert_is_token::<$tokens_path::[<$variant Variant>]>(); | |
// This is a more descriptive error message | |
let _token_name = <$tokens_path::[<$variant Variant>] as $crate::DTypeToken>::TOKEN_NAME; | |
}; | |
}; | |
)+ | |
} | |
#[$enum_attrs] | |
$vis enum $name { | |
$($(#[$variant_attrs])* $variant($t),)+ | |
} | |
// Use build_dtype_enum_variant_targets to generate EnumVariantTarget implementations | |
$crate::build_dtype_enum_variant_targets!($tokens_path, $name, $($variant => $t,)+); | |
$( | |
impl From<$t> for $name { | |
fn from(inner: $t) -> Self { | |
Self::$variant(inner) | |
} | |
} | |
)+ | |
#[allow(dead_code)] | |
impl $name { | |
/// Returns a reference to the inner value if the enum variant matches the requested token type. | |
pub fn downcast_ref<T>(&self) -> Option<&<$name as $crate::EnumVariantTarget<T>>::Target> | |
where | |
$name: $crate::EnumVariantTarget<T> | |
{ | |
match self { | |
$( | |
Self::$variant(inner) => { | |
if std::any::TypeId::of::<<$name as $crate::EnumVariantTarget<T>>::Target>() == std::any::TypeId::of::<$t>() { | |
Some(unsafe { &*(inner as *const $t as *const <$name as $crate::EnumVariantTarget<T>>::Target) }) | |
} else { | |
None | |
} | |
}, | |
)+ | |
} | |
} | |
/// Returns a mutable reference to the inner value if the enum variant matches the requested token type. | |
pub fn downcast_mut<T>(&mut self) -> Option<&mut <$name as $crate::EnumVariantTarget<T>>::Target> | |
where | |
$name: $crate::EnumVariantTarget<T> | |
{ | |
match self { | |
$( | |
Self::$variant(inner) => { | |
if std::any::TypeId::of::<<$name as $crate::EnumVariantTarget<T>>::Target>() == std::any::TypeId::of::<$t>() { | |
Some(unsafe { &mut *(inner as *mut $t as *mut <$name as $crate::EnumVariantTarget<T>>::Target) }) | |
} else { | |
None | |
} | |
}, | |
)+ | |
} | |
} | |
/// Consumes the enum and returns the inner value if the enum variant matches the requested token type. | |
pub fn downcast<T>(self) -> Option<<$name as $crate::EnumVariantTarget<T>>::Target> | |
where | |
$name: $crate::EnumVariantTarget<T> | |
{ | |
let target_type_id = std::any::TypeId::of::<<$name as $crate::EnumVariantTarget<T>>::Target>(); | |
match self { | |
$( | |
Self::$variant(inner) => { | |
if target_type_id == std::any::TypeId::of::<$t>() { | |
// If it matches, convert the container | |
let ptr = &inner as *const $t as *const <$name as $crate::EnumVariantTarget<T>>::Target; | |
let result = unsafe { ptr.read() }; | |
#[allow(forgetting_copy_types)] | |
std::mem::forget(inner); | |
Some(result) | |
} else { | |
None | |
} | |
}, | |
)+ | |
} | |
} | |
} | |
::paste::paste!{ | |
#[doc(hidden)] | |
#[macro_export] | |
$(#[$matcher_attrs])* | |
macro_rules! [<_ $matcher>] { | |
($value:expr, $enum_:ident<$generic:ident, $token_type:ident>($inner:ident) => $body:block) => { | |
match $value { | |
$($enum_::$variant($inner) => { | |
#[allow(unused)] | |
type $generic = $t; | |
#[allow(unused)] | |
type $token_type = $tokens_path::[<$variant Variant>]; | |
#[allow(unused_braces)] | |
$body | |
})+ | |
} | |
}; | |
} | |
pub use [<_ $matcher>] as $matcher; | |
} | |
}; | |
} | |
#[allow(dead_code)] | |
#[cfg(test)] | |
mod tests { | |
trait Constraint: 'static {} | |
impl Constraint for u16 {} | |
impl Constraint for u32 {} | |
impl Constraint for u64 {} | |
build_dtype_tokens!([U16, U32, U64,]); | |
build_dtype_enum!( | |
self, | |
test_match_enum, | |
Constraint, | |
{ | |
U16 => u16, | |
U32 => u32, | |
U64 => u64, | |
}, | |
#[derive(Clone, Debug)] | |
pub MyEnum(Vec) | |
); | |
#[test] | |
fn test_simple_enum() { | |
build_dtype_enum!( | |
self, | |
test_match_enum_variant, | |
{ | |
U16, | |
U32, | |
#[default] | |
U64, | |
}, | |
#[derive(Clone, Debug, Default)] | |
pub MyEnumVariant | |
); | |
let a = MyEnumVariant::U16; | |
let b = MyEnumVariant::U32; | |
test_match_enum_variant!(a, MyEnumVariant<VariantToken> => { | |
}); | |
} | |
// Uncomment to test compile-time error for undefined token | |
// build_dtype_enum!( | |
// test_match_enum_fail, | |
// $crate, | |
// Constraint, | |
// { | |
// U16 => u16, | |
// U8 => u8, // This will fail with a clear error about U8 not being a DTypeToken | |
// }, | |
// #[derive(Clone, Debug)] | |
// pub MyEnumFail(Vec) | |
// ); | |
// Uncomment to test compile-time error for missing token | |
// build_dtype_tokens!([U16, U32, U64, UNUSED_TOKEN]); | |
// This doesn't currently generate an error, but could be enhanced to check for unused tokens | |
#[test] | |
fn test_end_to_end() { | |
let x = MyEnum::from(vec![1_u16, 1, 2, 3, 5]); | |
let bit_size = test_match_enum!(&x, MyEnum<T, VariantToken>(inner) => { inner.len() * T::BITS as usize }); | |
assert_eq!(bit_size, 80); | |
let x = x.downcast::<U16Variant>().unwrap(); | |
assert_eq!(x[0], 1); | |
} | |
#[test] | |
fn test_token_based_downcast() { | |
let x = MyEnum::from(vec![1_u16, 1, 2, 3, 5]); | |
let first_element = x.downcast_ref::<U16Variant>().unwrap()[0]; | |
assert_eq!(first_element, 1_u16); | |
} | |
build_dtype_tokens!([I32, F32]); | |
build_dtype_enum!( | |
self, | |
match_dyn_enum, | |
{ | |
I32 => i32, | |
F32 => f32, | |
}, | |
#[derive(Clone, Debug)] | |
DynChunk | |
); | |
#[test] | |
fn test_dyn_chunk() { | |
let x = DynChunk::from(42_i32); | |
if let DynChunk::I32(value) = x { | |
assert_eq!(value, 42); | |
} else { | |
panic!("Expected DynChunk::I32"); | |
} | |
let mut y = DynChunk::from(3.14_f32); | |
if let DynChunk::F32(value) = y { | |
assert_eq!(value, 3.14); | |
} else { | |
panic!("Expected DynChunk::F32"); | |
} | |
let downcasted: Option<&i32> = x.downcast_ref::<I32Variant>(); | |
assert_eq!(*downcasted.unwrap(), 42); | |
let downcasted_mut: Option<&mut f32> = y.downcast_mut::<F32Variant>(); | |
*downcasted_mut.unwrap() = 2.71; | |
if let DynChunk::F32(value) = y { | |
assert_eq!(value, 2.71); | |
} | |
} | |
#[test] | |
fn test_match_dyn_enum_usage() { | |
let x = DynChunk::from(42_i32); | |
match_dyn_enum!(x, DynChunk<T, Token>(value) => { | |
let str_repr = value.to_string(); | |
assert_eq!(str_repr, "42"); | |
}); | |
let y = DynChunk::from(3.14_f32); | |
match_dyn_enum!(y, DynChunk<T, Token>(value) => { | |
let str_repr = value.to_string(); | |
assert_eq!(str_repr, "3.14"); | |
}); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment