Skip to content

Instantly share code, notes, and snippets.

@Swoorup
Last active April 5, 2025 16:15
Show Gist options
  • Save Swoorup/798a36a8cc9b2e594f9632ff7cccfc4e to your computer and use it in GitHub Desktop.
Save Swoorup/798a36a8cc9b2e594f9632ff7cccfc4e to your computer and use it in GitHub Desktop.
Slightly different take than dtype_dispatch
//! # DType Macros
//!
//! This module provides a set of macros and traits for defining and working with strongly-typed
//! enums and their associated tokens. These macros are designed to:
//!
//! - **Generate Tokens**: Define unique tokens for each variant of an enum using `build_dtype_tokens`.
//! - **Define Enums**: Create enums with compile-time checks to ensure all variants are valid tokens
//! using `build_dtype_enum`.
//! - **Dynamic Downcasting**: Enable safe downcasting of enum variants to their underlying types
//! or containers.
//! - **Pattern Matching**: Provide utilities for ergonomic pattern matching on enums and their variants.
//!
//! ## Key Features
//! - Compile-time validation of tokens and variants.
//! - Support for both simple and container-based enum variants.
//! - Flexible downcasting methods (`downcast_ref`, `downcast_mut`, `downcast`) for accessing inner values.
//! - Customizable pattern matching macros for concise and type-safe matching.
//! Thanks to https://github.com/pcodec/pcodec/tree/main/dtype_dispatch for inspiration
#![allow(unreachable_patterns)]
#![allow(dead_code)]
// Define the VariantToken trait with Target parameter
pub trait EnumVariantTarget<VariantToken> {
type Target: 'static;
}
// Define the EnumVariantConstraint trait with Constraint parameter
pub trait EnumVariantConstraint<VariantToken> {
type Constraint: 'static;
}
// Marker trait for tokens defined by build_dtype_tokens
pub trait DTypeToken {
// Optional: Add a const associated type for token name for better error messages
const TOKEN_NAME: &'static str;
}
#[macro_export]
macro_rules! build_dtype_tokens {
(
[$($variant: ident),* $(,)?] ) => {
::paste::paste! {
$(
pub struct [<$variant Variant>];
impl $crate::DTypeToken for [<$variant Variant>] {
const TOKEN_NAME: &'static str = stringify!([<$variant Variant>]);
}
)+
}
}
}
#[macro_export]
macro_rules! build_dtype_enum_variant_targets {
(
$tokens_path: path,
$name: ident,
$($variant: ident => $t: ty,)+
) => {
::paste::paste! {
$(
impl $crate::EnumVariantTarget<[<$variant Variant>]> for $name {
type Target = $t;
}
)+
}
};
}
/// Produces two macros: an enum definer and an enum matcher.
///
/// See the crate-level documentation for more info.
///
/// This macro ensures that all variants used in the enum definition
/// have been previously defined using `build_dtype_tokens`.
#[macro_export]
macro_rules! build_dtype_enum {
(
$tokens_path: path,
$(#[$matcher_attrs: meta])*
$matcher: ident,
{$($(#[$variant_attrs:meta])* $variant: ident,)+}$(,)?
#[$enum_attrs: meta]
$vis: vis $name: ident
) => {
// Compile-time check that all variants are valid tokens
::paste::paste! {
$(
const _: fn() = || {
// This will fail to compile if $tokens_path::[<$variant Variant>] is not a DTypeToken
fn assert_is_token<T: $crate::DTypeToken>() {}
let _: fn() = || {
assert_is_token::<$tokens_path::[<$variant Variant>]>();
// This is a more descriptive error message
let _token_name = <$tokens_path::[<$variant Variant>] as $crate::DTypeToken>::TOKEN_NAME;
};
};
)+
}
#[$enum_attrs]
$vis enum $name {
$($(#[$variant_attrs])* $variant,)+
}
::paste::paste!{
#[doc(hidden)]
#[macro_export]
$(#[$matcher_attrs])*
macro_rules! [<_ $matcher>] {
($value:expr, $enum_:ident<$token_type:ident> => $body:block) => {
match $value {
$($enum_::$variant => {
#[allow(unused)]
type $token_type = $tokens_path::[<$variant Variant>];
#[allow(unused_braces)]
$body
})+
}
};
}
#[allow(unused_imports)]
pub use [<_ $matcher>] as $matcher;
}
};
(
$tokens_path: path,
$constraint: path,
{$($(#[$variant_attrs:meta])* $variant: ident => $t: ty,)+}$(,)?
#[$enum_attrs: meta]
#[repr($desc_t: ty)]
$vis: vis $name: ident = $desc_val: ident
) => {
// Compile-time check that all variants are valid tokens
::paste::paste! {
$(
const _: fn() = || {
// This will fail to compile if $tokens_path::[<$variant Variant>] is not a DTypeToken
fn assert_is_token<T: $crate::DTypeToken>() {}
let _: fn() = || {
assert_is_token::<$tokens_path::[<$variant Variant>]>();
// This is a more descriptive error message
let _token_name = <$tokens_path::[<$variant Variant>] as $crate::DTypeToken>::TOKEN_NAME;
};
};
)+
}
#[$enum_attrs]
#[repr($desc_t)]
$vis enum $name {
$($(#[$variant_attrs])* $variant = <$t>::$desc_val,)+
}
impl $name {
#[inline]
pub fn new<T: $constraint>() -> Option<Self> {
let type_id = std::any::TypeId::of::<T>();
$(
if type_id == std::any::TypeId::of::<$t>() {
return Some($name::$variant);
}
)+
None
}
pub fn from_descriminant(desc: $desc_t) -> Option<Self> {
match desc {
$(<$t>::$desc_val => Some(Self::$variant),)+
_ => None
}
}
}
};
(
$tokens_path: path,
$(#[$matcher_attrs: meta])*
$matcher: ident,
$constraint: path,
{$($(#[$variant_attrs:meta])* $variant: ident => $t: ty,)+}$(,)?
#[$enum_attrs: meta]
$vis: vis $name: ident($container: ident)
) => {
// Compile-time check that all variants are valid tokens
::paste::paste! {
$(
const _: fn() = || {
// This will fail to compile if $tokens_path::[<$variant Variant>] is not a DTypeToken
fn assert_is_token<T: $crate::DTypeToken>() {}
let _: fn() = || {
assert_is_token::<$tokens_path::[<$variant Variant>]>();
// This is a more descriptive error message
let _token_name = <$tokens_path::[<$variant Variant>] as $crate::DTypeToken>::TOKEN_NAME;
};
};
)+
}
#[$enum_attrs]
$vis enum $name {
$($(#[$variant_attrs])* $variant($container<$t>),)+
}
// Use build_dtype_enum_variant_targets to generate EnumVariantTarget implementations
$crate::build_dtype_enum_variant_targets!($tokens_path, $name, $($variant => $container<$t>,)+);
// Generate VariantToken implementations for each variant
::paste::paste! {
$(
impl $crate::EnumVariantConstraint<$tokens_path::[<$variant Variant>]> for $name {
type Constraint = $t;
}
)+
}
$(
impl From<$container<$t>> for $name {
fn from(inner: $container<$t>) -> Self {
Self::$variant(inner)
}
}
)+
#[allow(dead_code)]
impl $name {
/// Returns a reference to the inner container if the enum variant matches the requested token type.
pub fn downcast_ref<T>(&self) -> Option<&<$name as $crate::EnumVariantTarget<T>>::Target>
where
$name: $crate::EnumVariantTarget<T>
{
// Get the TypeId of the target type
let target_type_id = std::any::TypeId::of::<<$name as $crate::EnumVariantTarget<T>>::Target>();
match self {
$(
// For each variant, check if its container type matches the target type
Self::$variant(inner) => {
if target_type_id == std::any::TypeId::of::<$container<$t>>() {
// If it matches, cast the reference
unsafe {
Some(&*(inner as *const $container<$t> as *const <$name as $crate::EnumVariantTarget<T>>::Target))
}
} else {
None
}
},
)+
}
}
/// Returns a mutable reference to the inner container if the enum variant matches the requested token type.
pub fn downcast_mut<T>(&mut self) -> Option<&mut <$name as $crate::EnumVariantTarget<T>>::Target>
where
$name: $crate::EnumVariantTarget<T>
{
// Get the TypeId of the target type
let target_type_id = std::any::TypeId::of::<<$name as $crate::EnumVariantTarget<T>>::Target>();
match self {
$(
// For each variant, check if its container type matches the target type
Self::$variant(inner) => {
if target_type_id == std::any::TypeId::of::<$container<$t>>() {
// If it matches, cast the mutable reference
unsafe {
Some(&mut *(inner as *mut $container<$t> as *mut <$name as $crate::EnumVariantTarget<T>>::Target))
}
} else {
None
}
},
)+
}
}
/// Consumes the enum and returns the inner container if the enum variant matches the requested token type.
pub fn downcast<T>(self) -> Option<<$name as $crate::EnumVariantTarget<T>>::Target>
where
$name: $crate::EnumVariantTarget<T>
{
// Get the TypeId of the target type
let target_type_id = std::any::TypeId::of::<<$name as $crate::EnumVariantTarget<T>>::Target>();
match self {
$(
// For each variant, check if its container type matches the target type
Self::$variant(inner) => {
if target_type_id == std::any::TypeId::of::<$container<$t>>() {
// If it matches, convert the container
let ptr = &inner as *const $container<$t> as *const <$name as $crate::EnumVariantTarget<T>>::Target;
let result = unsafe { ptr.read() };
std::mem::forget(inner);
Some(result)
} else {
None
}
},
)+
}
}
}
::paste::paste! {
#[doc(hidden)]
#[macro_export]
$(#[$matcher_attrs])*
macro_rules! [<_ $matcher>] {
(
$value:expr, $enum_:ident<$generic:ident, $token_type:ident> => $body:block) => {
match $value {
$($enum_::$variant => {
#[allow(unused)]
type $generic = $t;
#[allow(unused)]
type $token_type = $tokens_path::[<$variant Variant>];
#[allow(unused_braces)]
$body
})+
}
};
($value:expr, $enum_:ident<$generic:ident, $token_type:ident>($inner:ident) => $body:block) => {
match $value {
$($enum_::$variant($inner) => {
#[allow(unused)]
type $generic = $t;
#[allow(unused)]
type $token_type = $tokens_path::[<$variant Variant>];
#[allow(unused_braces)]
$body
})+
}
};
}
pub use [<_ $matcher>] as $matcher;
}
};
(
$tokens_path: path,
$(#[$matcher_attrs: meta])*
$matcher: ident,
{$($(#[$variant_attrs:meta])* $variant: ident => $t: ty,)+}$(,)?
#[$enum_attrs: meta]
$vis: vis $name: ident
) => {
// Compile-time check that all variants are valid tokens
::paste::paste! {
$(
const _: fn() = || {
// This will fail to compile if $tokens_path::[<$variant Variant>] is not a DTypeToken
fn assert_is_token<T: $crate::DTypeToken>() {}
let _: fn() = || {
assert_is_token::<$tokens_path::[<$variant Variant>]>();
// This is a more descriptive error message
let _token_name = <$tokens_path::[<$variant Variant>] as $crate::DTypeToken>::TOKEN_NAME;
};
};
)+
}
#[$enum_attrs]
$vis enum $name {
$($(#[$variant_attrs])* $variant($t),)+
}
// Use build_dtype_enum_variant_targets to generate EnumVariantTarget implementations
$crate::build_dtype_enum_variant_targets!($tokens_path, $name, $($variant => $t,)+);
$(
impl From<$t> for $name {
fn from(inner: $t) -> Self {
Self::$variant(inner)
}
}
)+
#[allow(dead_code)]
impl $name {
/// Returns a reference to the inner value if the enum variant matches the requested token type.
pub fn downcast_ref<T>(&self) -> Option<&<$name as $crate::EnumVariantTarget<T>>::Target>
where
$name: $crate::EnumVariantTarget<T>
{
match self {
$(
Self::$variant(inner) => {
if std::any::TypeId::of::<<$name as $crate::EnumVariantTarget<T>>::Target>() == std::any::TypeId::of::<$t>() {
Some(unsafe { &*(inner as *const $t as *const <$name as $crate::EnumVariantTarget<T>>::Target) })
} else {
None
}
},
)+
}
}
/// Returns a mutable reference to the inner value if the enum variant matches the requested token type.
pub fn downcast_mut<T>(&mut self) -> Option<&mut <$name as $crate::EnumVariantTarget<T>>::Target>
where
$name: $crate::EnumVariantTarget<T>
{
match self {
$(
Self::$variant(inner) => {
if std::any::TypeId::of::<<$name as $crate::EnumVariantTarget<T>>::Target>() == std::any::TypeId::of::<$t>() {
Some(unsafe { &mut *(inner as *mut $t as *mut <$name as $crate::EnumVariantTarget<T>>::Target) })
} else {
None
}
},
)+
}
}
/// Consumes the enum and returns the inner value if the enum variant matches the requested token type.
pub fn downcast<T>(self) -> Option<<$name as $crate::EnumVariantTarget<T>>::Target>
where
$name: $crate::EnumVariantTarget<T>
{
let target_type_id = std::any::TypeId::of::<<$name as $crate::EnumVariantTarget<T>>::Target>();
match self {
$(
Self::$variant(inner) => {
if target_type_id == std::any::TypeId::of::<$t>() {
// If it matches, convert the container
let ptr = &inner as *const $t as *const <$name as $crate::EnumVariantTarget<T>>::Target;
let result = unsafe { ptr.read() };
#[allow(forgetting_copy_types)]
std::mem::forget(inner);
Some(result)
} else {
None
}
},
)+
}
}
}
::paste::paste!{
#[doc(hidden)]
#[macro_export]
$(#[$matcher_attrs])*
macro_rules! [<_ $matcher>] {
($value:expr, $enum_:ident<$generic:ident, $token_type:ident>($inner:ident) => $body:block) => {
match $value {
$($enum_::$variant($inner) => {
#[allow(unused)]
type $generic = $t;
#[allow(unused)]
type $token_type = $tokens_path::[<$variant Variant>];
#[allow(unused_braces)]
$body
})+
}
};
}
pub use [<_ $matcher>] as $matcher;
}
};
}
#[allow(dead_code)]
#[cfg(test)]
mod tests {
trait Constraint: 'static {}
impl Constraint for u16 {}
impl Constraint for u32 {}
impl Constraint for u64 {}
build_dtype_tokens!([U16, U32, U64,]);
build_dtype_enum!(
self,
test_match_enum,
Constraint,
{
U16 => u16,
U32 => u32,
U64 => u64,
},
#[derive(Clone, Debug)]
pub MyEnum(Vec)
);
#[test]
fn test_simple_enum() {
build_dtype_enum!(
self,
test_match_enum_variant,
{
U16,
U32,
#[default]
U64,
},
#[derive(Clone, Debug, Default)]
pub MyEnumVariant
);
let a = MyEnumVariant::U16;
let b = MyEnumVariant::U32;
test_match_enum_variant!(a, MyEnumVariant<VariantToken> => {
});
}
// Uncomment to test compile-time error for undefined token
// build_dtype_enum!(
// test_match_enum_fail,
// $crate,
// Constraint,
// {
// U16 => u16,
// U8 => u8, // This will fail with a clear error about U8 not being a DTypeToken
// },
// #[derive(Clone, Debug)]
// pub MyEnumFail(Vec)
// );
// Uncomment to test compile-time error for missing token
// build_dtype_tokens!([U16, U32, U64, UNUSED_TOKEN]);
// This doesn't currently generate an error, but could be enhanced to check for unused tokens
#[test]
fn test_end_to_end() {
let x = MyEnum::from(vec![1_u16, 1, 2, 3, 5]);
let bit_size = test_match_enum!(&x, MyEnum<T, VariantToken>(inner) => { inner.len() * T::BITS as usize });
assert_eq!(bit_size, 80);
let x = x.downcast::<U16Variant>().unwrap();
assert_eq!(x[0], 1);
}
#[test]
fn test_token_based_downcast() {
let x = MyEnum::from(vec![1_u16, 1, 2, 3, 5]);
let first_element = x.downcast_ref::<U16Variant>().unwrap()[0];
assert_eq!(first_element, 1_u16);
}
build_dtype_tokens!([I32, F32]);
build_dtype_enum!(
self,
match_dyn_enum,
{
I32 => i32,
F32 => f32,
},
#[derive(Clone, Debug)]
DynChunk
);
#[test]
fn test_dyn_chunk() {
let x = DynChunk::from(42_i32);
if let DynChunk::I32(value) = x {
assert_eq!(value, 42);
} else {
panic!("Expected DynChunk::I32");
}
let mut y = DynChunk::from(3.14_f32);
if let DynChunk::F32(value) = y {
assert_eq!(value, 3.14);
} else {
panic!("Expected DynChunk::F32");
}
let downcasted: Option<&i32> = x.downcast_ref::<I32Variant>();
assert_eq!(*downcasted.unwrap(), 42);
let downcasted_mut: Option<&mut f32> = y.downcast_mut::<F32Variant>();
*downcasted_mut.unwrap() = 2.71;
if let DynChunk::F32(value) = y {
assert_eq!(value, 2.71);
}
}
#[test]
fn test_match_dyn_enum_usage() {
let x = DynChunk::from(42_i32);
match_dyn_enum!(x, DynChunk<T, Token>(value) => {
let str_repr = value.to_string();
assert_eq!(str_repr, "42");
});
let y = DynChunk::from(3.14_f32);
match_dyn_enum!(y, DynChunk<T, Token>(value) => {
let str_repr = value.to_string();
assert_eq!(str_repr, "3.14");
});
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment