Skip to content

Instantly share code, notes, and snippets.

@c-u-l8er
Created July 20, 2025 19:48
Show Gist options
  • Save c-u-l8er/54710fe8268eb48e497f4eb8de415bcb to your computer and use it in GitHub Desktop.
Save c-u-l8er/54710fe8268eb48e497f4eb8de415bcb to your computer and use it in GitHub Desktop.
Enhanced ADT system
defmodule EnhancedADT do
@moduledoc """
Enhanced ADT system combining the best of both worlds:
- Complete type system from ADT
- Advanced recursion and state management from BenBen
- Powerful fold/bend operations
"""
defmacro __using__(_opts) do
quote do
import EnhancedADT.DSL
import EnhancedADT.Fold
import EnhancedADT.Bend
import EnhancedADT.Recursion
end
end
end
defmodule EnhancedADT.DSL do
@moduledoc """
Enhanced DSL for defining algebraic data types with recursion markers
"""
@doc """
Enhanced sum type definition with recursion support
## Example
```elixir
defsum Tree(a) do
Leaf(a)
Node(rec(Tree(a)), rec(Tree(a))) # rec() marks recursive fields
end
```
"""
defmacro defsum(name_and_params, do: variants) do
{name, params} = case name_and_params do
{name, _, params} when is_list(params) -> {name, params}
name when is_atom(name) -> {name, []}
end
variant_list = extract_variants(variants)
quote do
defmodule unquote(name) do
@type_params unquote(params)
@variants unquote(variant_list)
# Generate constructor functions
unquote_splicing(generate_constructors(variant_list))
# Generate type definition
unquote(generate_type_def(name, params, variant_list))
# Generate pattern matching helpers
unquote(generate_matchers(variant_list))
# Generate enhanced fold function with recursion support
unquote(generate_enhanced_fold_function(name, variant_list))
# Generate metadata for recursion analysis
def __recursive_fields__ do
unquote(analyze_recursive_fields(variant_list))
end
end
end
end
@doc """
Enhanced product type with computed fields and lenses
"""
defmacro defproduct(name, do: fields) do
field_list = extract_fields(fields)
quote do
defmodule unquote(name) do
@enforce_keys unquote(Enum.map(field_list, fn {name, _type, _opts} -> name end))
defstruct unquote(Enum.map(field_list, fn {name, _type, opts} ->
{name, Keyword.get(opts, :default)}
end))
@type t :: %__MODULE__{
unquote_splicing(
Enum.map(field_list, fn {name, type, _opts} ->
{name, type}
end)
)
}
# Generate constructor with validation
def new(attrs \\ []) do
struct = struct(__MODULE__, attrs)
validate!(struct)
struct
end
# Generate lenses for functional updates
unquote_splicing(generate_lenses(field_list))
# Generate validators
unquote(generate_validators(field_list))
# Generate computed fields
unquote_splicing(generate_computed_fields(field_list))
end
end
end
# Mark recursive fields
defmacro rec(type) do
quote do: {:rec, unquote(type)}
end
# Helper functions
defp extract_variants({:__block__, _, variants}), do: variants
defp extract_variants(single_variant), do: [single_variant]
defp extract_fields({:__block__, _, fields}), do: Enum.map(fields, &parse_field/1)
defp extract_fields(single_field), do: [parse_field(single_field)]
defp parse_field({:::, _, [{name, _, _}, type]}) do
{name, type, []}
end
defp parse_field({name, _, _}) when is_atom(name) do
{name, quote(do: any()), []}
end
# Support for field options: name :: type, opt1: val1, opt2: val2
defp parse_field({:when, _, [{:::, _, [{name, _, _}, type]}, opts]}) do
{name, type, extract_field_options(opts)}
end
defp extract_field_options({:__block__, _, opts}), do: opts
defp extract_field_options(single_opt), do: [single_opt]
defp generate_constructors(variants) do
Enum.map(variants, fn
{name, _, nil} ->
quote do
def unquote(name)(), do: {unquote(name)}
end
{name, _, []} ->
quote do
def unquote(name)(), do: {unquote(name)}
end
{name, _, args} when is_list(args) ->
clean_args = Enum.map(args, &clean_recursive_marker/1)
quote do
def unquote(name)(unquote_splicing(clean_args)), do: {unquote(name), unquote_splicing(clean_args)}
end
name when is_atom(name) ->
quote do
def unquote(name)(), do: {unquote(name)}
end
end)
end
defp clean_recursive_marker({:rec, _, [arg]}), do: arg
defp clean_recursive_marker(arg), do: arg
defp analyze_recursive_fields(variants) do
Enum.map(variants, fn
{name, _, args} when is_list(args) ->
recursive_positions =
args
|> Enum.with_index()
|> Enum.filter(fn
{{:rec, _, _}, _} -> true
_ -> false
end)
|> Enum.map(&elem(&1, 1))
{name, recursive_positions}
{name, _, _} -> {name, []}
name when is_atom(name) -> {name, []}
end)
|> Enum.into(%{})
end
defp generate_enhanced_fold_function(name, variants) do
quote do
def fold(value, handlers, opts \\ []) do
state = Keyword.get(opts, :state)
mode = Keyword.get(opts, :mode, :bottom_up)
case mode do
:bottom_up -> fold_bottom_up(value, handlers, state)
:top_down -> fold_top_down(value, handlers, state)
:stateful -> fold_stateful(value, handlers, state)
end
end
defp fold_bottom_up(value, handlers, state) do
# Process recursive fields first, then apply handler
case value do
unquote_splicing(generate_bottom_up_cases(variants))
_ -> raise "Invalid #{unquote(name)} value: #{inspect(value)}"
end
end
defp fold_top_down(value, handlers, state) do
# Apply handler first, then process recursive fields
case value do
unquote_splicing(generate_top_down_cases(variants))
_ -> raise "Invalid #{unquote(name)} value: #{inspect(value)}"
end
end
defp fold_stateful(value, handlers, initial_state) do
# Thread state through the computation
case value do
unquote_splicing(generate_stateful_cases(variants))
_ -> raise "Invalid #{unquote(name)} value: #{inspect(value)}"
end
end
end
end
defp generate_bottom_up_cases(variants) do
Enum.map(variants, fn
{name, _, nil} ->
quote do
{unquote(name)} ->
handler = Map.get(handlers, unquote(name), fn -> nil end)
if state, do: {handler.(), state}, else: handler.()
end
{name, _, []} ->
quote do
{unquote(name)} ->
handler = Map.get(handlers, unquote(name), fn -> nil end)
if state, do: {handler.(), state}, else: handler.()
end
{name, _, args} when is_list(args) ->
# Generate pattern with proper variable binding
vars = Enum.with_index(args) |> Enum.map(fn {_, i} -> Macro.var(:"arg#{i}", nil) end)
recursive_positions = get_recursive_positions(args)
quote do
{unquote(name), unquote_splicing(vars)} ->
# Process recursive fields first
processed_args = unquote(generate_recursive_processing(vars, recursive_positions))
handler = Map.get(handlers, unquote(name), fn _ -> nil end)
if state do
{result, new_state} = handler.(processed_args, state)
{result, new_state}
else
handler.(processed_args)
end
end
name when is_atom(name) ->
quote do
{unquote(name)} ->
handler = Map.get(handlers, unquote(name), fn -> nil end)
if state, do: {handler.(), state}, else: handler.()
end
end)
end
defp get_recursive_positions(args) do
args
|> Enum.with_index()
|> Enum.filter(fn
{{:rec, _, _}, _} -> true
_ -> false
end)
|> Enum.map(&elem(&1, 1))
end
defp generate_recursive_processing(vars, recursive_positions) do
Enum.map(Enum.with_index(vars), fn {var, i} ->
if i in recursive_positions do
quote do
if state do
{processed, new_state} = fold_bottom_up(unquote(var), handlers, state)
state = new_state
processed
else
fold_bottom_up(unquote(var), handlers, nil)
end
end
else
var
end
end)
end
defp generate_top_down_cases(variants) do
# Similar to bottom_up but handler is applied first
# Implementation similar to generate_bottom_up_cases but with different order
[]
end
defp generate_stateful_cases(variants) do
# Similar to bottom_up but with explicit state threading
# Implementation similar to generate_bottom_up_cases but with state management
[]
end
defp generate_lenses(fields) do
Enum.map(fields, fn {name, _type, _opts} ->
quote do
def unquote(:"lens_#{name}")() do
{
fn struct -> Map.get(struct, unquote(name)) end,
fn struct, value -> Map.put(struct, unquote(name), value) end
}
end
def unquote(:"update_#{name}")(struct, func) do
{getter, setter} = unquote(:"lens_#{name}")()
current = getter.(struct)
setter.(struct, func.(current))
end
end
end)
end
defp generate_validators(fields) do
quote do
defp validate!(struct) do
# Add validation logic here
struct
end
end
end
defp generate_computed_fields(fields) do
# Generate computed field functions
[]
end
# ... other helper functions from original ADT ...
defp generate_type_def(name, params, variants) do
type_variants = Enum.map(variants, fn
{variant_name, _, nil} -> {variant_name}
{variant_name, _, []} -> {variant_name}
{variant_name, _, args} when is_list(args) ->
clean_args = Enum.map(args, &clean_recursive_marker/1)
quote do: {unquote(variant_name), unquote_splicing(clean_args)}
variant_name when is_atom(variant_name) -> {variant_name}
end)
if params == [] do
quote do
@type t :: unquote({:|, [], type_variants})
end
else
quote do
@type t(unquote_splicing(params)) :: unquote({:|, [], type_variants})
end
end
end
defp generate_matchers(variants) do
Enum.map(variants, fn
{name, _, nil} ->
quote do
def unquote(:"is_#{name}")({unquote(name)}), do: true
def unquote(:"is_#{name}")(_), do: false
end
{name, _, []} ->
quote do
def unquote(:"is_#{name}")({unquote(name)}), do: true
def unquote(:"is_#{name}")(_), do: false
end
{name, _, _args} ->
quote do
def unquote(:"is_#{name}")({unquote(name), _}), do: true
def unquote(:"is_#{name}")(_), do: false
end
name when is_atom(name) ->
quote do
def unquote(:"is_#{name}")({unquote(name)}), do: true
def unquote(:"is_#{name}")(_), do: false
end
end)
end
end
defmodule EnhancedADT.Fold do
@moduledoc """
Enhanced fold operations with BenBen-like power
"""
@doc """
Powerful fold with automatic recursion handling
## Example
```elixir
fold tree, state: 0, mode: :bottom_up do
Leaf(x) -> x
Node(left, right) -> left + right
end
```
"""
defmacro fold(value, opts \\ [], do: clauses) do
state = Keyword.get(opts, :state)
mode = Keyword.get(opts, :mode, :bottom_up)
cases = extract_fold_cases(clauses)
handlers = generate_handler_map(cases)
if state do
quote do
result = unquote(value).__struct__.fold(
unquote(value),
unquote(handlers),
state: unquote(state),
mode: unquote(mode)
)
result
end
else
quote do
unquote(value).__struct__.fold(
unquote(value),
unquote(handlers),
mode: unquote(mode)
)
end
end
end
@doc """
Stateful fold with explicit state threading
"""
defmacro fold_with_state(value, initial_state, do: clauses) do
cases = extract_fold_cases(clauses)
handlers = generate_stateful_handler_map(cases)
quote do
unquote(value).__struct__.fold(
unquote(value),
unquote(handlers),
state: unquote(initial_state),
mode: :stateful
)
end
end
@doc """
Parallel fold for independent computations
"""
defmacro fold_parallel(value, do: clauses) do
# Implementation for parallel processing
quote do
# Parallel fold implementation
fold(unquote(value), mode: :parallel, do: unquote(clauses))
end
end
defp extract_fold_cases({:__block__, _, clauses}), do: clauses
defp extract_fold_cases(single_clause), do: [single_clause]
defp generate_handler_map(cases) do
handlers = Enum.map(cases, fn
{:->, _, [[{name, _, args}], body]} when is_list(args) ->
{name, quote(do: fn unquote(args) -> unquote(body) end)}
{:->, _, [[{name, _, _}], body]} ->
{name, quote(do: fn -> unquote(body) end)}
{:->, _, [[name], body]} when is_atom(name) ->
{name, quote(do: fn -> unquote(body) end)}
end)
quote do: %{unquote_splicing(handlers)}
end
defp generate_stateful_handler_map(cases) do
handlers = Enum.map(cases, fn
{:->, _, [[{name, _, args}], body]} when is_list(args) ->
{name, quote(do: fn unquote(args), state -> {unquote(body), state} end)}
{:->, _, [[{name, _, _}], body]} ->
{name, quote(do: fn state -> {unquote(body), state} end)}
{:->, _, [[name], body]} when is_atom(name) ->
{name, quote(do: fn state -> {unquote(body), state} end)}
end)
quote do: %{unquote_splicing(handlers)}
end
end
defmodule EnhancedADT.Bend do
@moduledoc """
Enhanced bend operations with fork support
"""
@doc """
Powerful unfold with fork semantics
## Example
```elixir
bend from: 10 do
n when n > 0 -> Tree.Node(fork(n-1), fork(n-1))
0 -> Tree.Leaf(0)
end
```
"""
defmacro bend(opts, do: clauses) do
seed = Keyword.get(opts, :from)
cases = extract_bend_cases(clauses)
quote do
EnhancedADT.Bend.unfold_with_fork(unquote(seed), fn value ->
case value do
unquote_splicing(cases)
end
end)
end
end
@doc """
Fork operation for parallel unfold branches
"""
defmacro fork(expr) do
quote do: {:fork, unquote(expr)}
end
@doc """
Unfold with fork support
"""
def unfold_with_fork(seed, generator) do
case generator.(seed) do
{:fork, next_seed} ->
unfold_with_fork(next_seed, generator)
{tag, {:fork, arg1}, {:fork, arg2}} ->
{tag, unfold_with_fork(arg1, generator), unfold_with_fork(arg2, generator)}
{tag, {:fork, arg}} ->
{tag, unfold_with_fork(arg, generator)}
{tag, args} when is_list(args) ->
processed_args = Enum.map(args, fn
{:fork, arg} -> unfold_with_fork(arg, generator)
arg -> arg
end)
List.to_tuple([tag | processed_args])
{tag} -> {tag}
result -> result
end
end
defp extract_bend_cases({:__block__, _, clauses}), do: clauses
defp extract_bend_cases(single_clause), do: [single_clause]
end
defmodule EnhancedADT.Recursion do
@moduledoc """
Advanced recursion utilities
"""
@doc """
Automatic memoization for recursive functions
"""
defmacro memo_fold(value, do: clauses) do
quote do
memo_table = :ets.new(:memo_table, [:set, :private])
try do
result = memo_fold_impl(unquote(value), memo_table, unquote(clauses))
result
after
:ets.delete(memo_table)
end
end
end
def memo_fold_impl(value, memo_table, clauses) do
# Implementation of memoized fold
case :ets.lookup(memo_table, value) do
[{^value, result}] -> result
[] ->
result = fold(value, do: clauses)
:ets.insert(memo_table, {value, result})
result
end
end
@doc """
Zipper for navigating recursive structures
"""
def zipper(value) do
%{
focus: value,
path: [],
modified: false
}
end
def move_down(zipper, index) do
# Implementation for moving down in the structure
zipper
end
def move_up(zipper) do
# Implementation for moving up in the structure
zipper
end
def modify(zipper, func) do
# Implementation for modifying the current focus
%{zipper | focus: func.(zipper.focus), modified: true}
end
end
# Enhanced examples
defmodule EnhancedADT.Examples do
use EnhancedADT
# Tree with recursion markers
defsum Tree(a) do
Leaf(a)
Node(rec(Tree(a)), rec(Tree(a)))
end
# Maybe with enhanced operations
defsum Maybe(a) do
Some(a)
None
end
# Enhanced product type with computed fields
defproduct Person do
first_name :: String.t()
last_name :: String.t()
age :: integer()
# computed field would be generated
end
# Usage examples
def tree_sum(tree) do
fold tree, mode: :bottom_up do
Leaf(x) -> x
Node(left, right) -> left + right
end
end
def tree_sum_with_depth(tree) do
fold_with_state tree, 0 do
Leaf(x) -> {x, state + 1}
Node(left, right) -> {left + right, state + 1}
end
end
def generate_perfect_tree(depth) do
bend from: depth do
n when n > 0 -> Tree.Node(fork(n-1), fork(n-1))
0 -> Tree.Leaf(0)
end
end
def memoized_fibonacci_tree(n) do
memo_fold generate_perfect_tree(n) do
Leaf(x) -> x
Node(left, right) -> left + right
end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment