Skip to content

Instantly share code, notes, and snippets.

@needlesslygrim
Created February 14, 2026 17:41
Show Gist options
  • Select an option

  • Save needlesslygrim/4d726369461d58482c0a30f55d6951de to your computer and use it in GitHub Desktop.

Select an option

Save needlesslygrim/4d726369461d58482c0a30f55d6951de to your computer and use it in GitHub Desktop.
import SwiftCompilerPlugin
import SwiftSyntax
import SwiftSyntaxMacros
import SwiftDiagnostics
struct AnyErasedProtocolError: Error {
let message: String
init(message: String = "") {
self.message = message
}
}
enum Config {
case exposedStruct(suppressedCopyableConformance: Bool)
case concreteClass(suppressedCopyableConformance: Bool, genericParameterName: TokenSyntax)
case erasedClass
}
let subclassResponsibilitySyntax = FunctionCallExprSyntax.init(
calledExpression: DeclReferenceExprSyntax(baseName: "fatalError"),
leftParen: .leftParenToken(),
arguments: [LabeledExprSyntax(expression: "Subclass responsibility".makeLiteralSyntax())],
rightParen: .rightParenToken()
)
let classMemberModifierMap: [TokenKind: TokenKind?] =
[ .keyword(.mutating): nil
, .keyword(.nonmutating): nil
, .keyword(.static): .keyword(.class) ]
struct AnyErasedProtocolMacro: PeerMacro {
static func generateInitialiser(
from initialiser: InitializerDeclSyntax,
forProtocol protocolDecl: ProtocolDeclSyntax,
in context: some MacroExpansionContext,
config: Config
) -> InitializerDeclSyntax? {
switch config {
case .erasedClass:
return nil
case .concreteClass(_, let genericParameterName):
return InitializerDeclSyntax(
attributes: initialiser.attributes,
modifiers: initialiser.modifiers,
optionalMark: initialiser.optionalMark,
genericParameterClause: initialiser.genericParameterClause,
signature: initialiser.signature,
genericWhereClause: initialiser.genericWhereClause
) {
if let _ = initialiser.optionalMark {
"guard let"
}
"inner = \(genericParameterName)("
for param in initialiser.signature.parameterClause.parameters {
"\(param.firstName): \(param.secondName ?? param.firstName),"
}
")"
if let _ = initialiser.optionalMark {
"else { return nil } "
"self.inner = inner"
}
}
case .exposedStruct(let suppressedCopyableConformance):
let parameters = initialiser.signature.parameterClause.parameters
let genericParameterName = context.makeUniqueName("T")
let parameterClause = FunctionParameterClauseSyntax {
FunctionParameterSyntax(
firstName: "of",
secondName: "_",
type: MetatypeTypeSyntax(
baseType: IdentifierTypeSyntax(name: genericParameterName),
metatypeSpecifier: .keyword(.Type)),
)
for p in parameters {
p.trimmed
}
}
let genericParameterCompositionElements = CompositionTypeElementListSyntax {
if suppressedCopyableConformance {
CompositionTypeElementSyntax(
type:
SuppressedTypeSyntax(
withoutTilde: .prefixOperator("~"),
type: IdentifierTypeSyntax(name: "Copyable")
),
ampersand: .binaryOperator("&")
)
}
CompositionTypeElementSyntax(
type: IdentifierTypeSyntax(name: protocolDecl.name),
)
}
return InitializerDeclSyntax(
attributes: initialiser.attributes,
modifiers: initialiser.modifiers,
optionalMark: initialiser.optionalMark,
genericParameterClause: GenericParameterClauseSyntax {
GenericParameterSyntax(
name: genericParameterName,
colon: .colonToken(),
inheritedType: CompositionTypeSyntax(
elements: genericParameterCompositionElements),
)
for p in initialiser.genericParameterClause?.parameters ?? [] {
p.trimmed
}
},
signature: FunctionSignatureSyntax(parameterClause: parameterClause),
genericWhereClause: initialiser.genericWhereClause
) {
if let _ = initialiser.optionalMark {
"guard let"
}
"inner = Concrete<\(genericParameterName)>("
for p in parameters {
"\(p.firstName): \(p.secondName ?? p.firstName),"
}
")"
if let _ = initialiser.optionalMark {
"else { return nil } "
"self.inner = inner"
}
}
}
}
static func generateFunction(
from member: FunctionDeclSyntax,
forProtocol protocolDecl: ProtocolDeclSyntax,
in context: some MacroExpansionContext,
config: Config
) -> FunctionDeclSyntax? {
switch config {
case .erasedClass:
return FunctionDeclSyntax(
attributes: member.attributes,
modifiers: member.modifiers.mapClassMemberModifiers(),
name: member.name,
genericParameterClause: member.genericParameterClause,
signature: member.signature,
genericWhereClause: member.genericWhereClause,
) {
subclassResponsibilitySyntax
}
case .concreteClass(_, let genericParameterName):
let modifiers = DeclModifierListSyntax {
DeclModifierSyntax(name: .keyword(.override))
for modifier in member.modifiers.mapClassMemberModifiers() {
modifier
}
}
return FunctionDeclSyntax(
attributes: member.attributes,
modifiers: modifiers,
name: member.name,
genericParameterClause: member.genericParameterClause,
signature: member.signature,
genericWhereClause: member.genericWhereClause,
) {
if member.modifiers.contains(where: { $0.name.tokenKind == .keyword(.static) }) {
"\(genericParameterName)"
} else {
"inner"
}
".\(member.name)("
for p in member.signature.parameterClause.parameters {
"\(p.firstName): \(p.secondName ?? p.firstName),"
}
")"
}
case .exposedStruct:
if member.modifiers.contains(where: { $0.name.tokenKind == .keyword(.static) }) {
context.diagnose(
Diagnostic(
node: member,
message: MacroExpansionWarningMessage("Static protocol members cannot be exposed on struct types. Use `AnyErasedProtocolClass` instead if this is worth more than value semantics.")
)
)
return nil
}
return FunctionDeclSyntax(
attributes: member.attributes,
modifiers: member.modifiers,
name: member.name,
genericParameterClause: member.genericParameterClause,
signature: member.signature,
genericWhereClause: member.genericWhereClause,
) {
"inner.\(member.name)("
for p in member.signature.parameterClause.parameters {
"\(p.firstName): \(p.secondName ?? p.firstName),"
}
")"
}
}
}
static func generateVariable(
from member: VariableDeclSyntax,
forProtocol protocolDecl: ProtocolDeclSyntax,
in context: some MacroExpansionContext,
config: Config
) throws -> VariableDeclSyntax {
var bindings = PatternBindingListSyntax()
for binding in member.bindings {
let id = binding.pattern.as(IdentifierPatternSyntax.self)!.identifier
var accessorDeclList = AccessorDeclListSyntax()
for accessor in binding.accessorBlock?.accessors.as(AccessorDeclListSyntax.self) ?? [] {
let classTypeModifier = if let accessorModifier = accessor.modifier {
if let tokenKind = classMemberModifierMap[accessorModifier.name.tokenKind, default: accessorModifier.name.tokenKind] {
DeclModifierSyntax(name: TokenSyntax(tokenKind, presence: .present))
} else { DeclModifierSyntax?.none }
} else {
DeclModifierSyntax?.none
}
switch accessor.accessorSpecifier.tokenKind {
case .keyword(.get):
let decl =
switch config {
case .erasedClass:
AccessorDeclSyntax(
modifier: classTypeModifier,
accessorSpecifier: accessor.accessorSpecifier
) { subclassResponsibilitySyntax }
case .concreteClass:
AccessorDeclSyntax(
modifier: classTypeModifier,
accessorSpecifier: accessor.accessorSpecifier
) { "inner.\(id)" }
case .exposedStruct:
AccessorDeclSyntax(
modifier: accessor.modifier,
accessorSpecifier: accessor.accessorSpecifier
) { "inner.\(id)" }
}
accessorDeclList.append(decl)
case .keyword(.set):
let decl =
switch config {
case .erasedClass:
AccessorDeclSyntax(
modifier: classTypeModifier,
accessorSpecifier: accessor.accessorSpecifier
) { subclassResponsibilitySyntax }
case .concreteClass:
AccessorDeclSyntax(
modifier: classTypeModifier,
accessorSpecifier: accessor.accessorSpecifier
) { "inner.\(id) = newValue" }
case .exposedStruct:
AccessorDeclSyntax(
modifier: accessor.modifier,
accessorSpecifier: accessor.accessorSpecifier
) { "inner.\(id) = newValue" }
}
accessorDeclList.append(decl)
default:
throw AnyErasedProtocolError()
}
}
bindings.append(
PatternBindingSyntax(
pattern: binding.pattern,
typeAnnotation: binding.typeAnnotation,
accessorBlock: AccessorBlockSyntax(accessors: .accessors(accessorDeclList)),
))
}
let modifiers = switch config {
case .exposedStruct:
DeclModifierListSyntax()
case .concreteClass:
DeclModifierListSyntax {
DeclModifierSyntax(name: .keyword(.override))
for modifier in member.modifiers.mapClassMemberModifiers() {
modifier
}
}
case .erasedClass:
member.modifiers.mapClassMemberModifiers()
}
return VariableDeclSyntax(
modifiers: modifiers,
bindingSpecifier: .keyword(.var),
bindings: bindings
)
}
static func generateMember(
from member: MemberBlockItemSyntax,
forProtocol protocolDecl: ProtocolDeclSyntax,
in context: some MacroExpansionContext,
config: Config
) throws -> MemberBlockItemSyntax? {
if let initialiser = member.decl.as(InitializerDeclSyntax.self) {
return self.generateInitialiser(
from: initialiser,
forProtocol: protocolDecl,
in: context,
config: config
).map { MemberBlockItemSyntax(decl: $0) }
} else if let functionDecl = member.decl.as(FunctionDeclSyntax.self) {
return self.generateFunction(
from: functionDecl,
forProtocol: protocolDecl,
in: context,
config: config
).map { MemberBlockItemSyntax(decl: $0) }
} else if let variabledecl = member.decl.as(VariableDeclSyntax.self) {
return MemberBlockItemSyntax(
decl: try self.generateVariable(
from: variabledecl,
forProtocol: protocolDecl,
in: context,
config: config
)
)
} else {
throw AnyErasedProtocolError()
}
}
static func expansion(
of node: SwiftSyntax.AttributeSyntax,
providingPeersOf declaration: some SwiftSyntax.DeclSyntaxProtocol,
in context: some SwiftSyntaxMacros.MacroExpansionContext
) throws -> [SwiftSyntax.DeclSyntax] {
guard let protocolDecl = declaration.as(ProtocolDeclSyntax.self) else {
throw AnyErasedProtocolError(message: "This macro can only be used on protocol declarations.")
}
guard let attributeIdentifier = node.attributeName.as(IdentifierTypeSyntax.self)?.name else {
throw AnyErasedProtocolError(message: "How did you get here?")
}
var inheritanceClause: InheritanceClauseSyntax? = nil
for type in protocolDecl.inheritanceClause?.inheritedTypes ?? [] {
if let supressedType = type.type.as(SuppressedTypeSyntax.self),
let innerType = supressedType.type.as(IdentifierTypeSyntax.self),
innerType.name.identifier == Identifier(canonicalName: "Copyable")
{
inheritanceClause = InheritanceClauseSyntax(inheritedTypes: [
InheritedTypeSyntax(
type: SuppressedTypeSyntax(
withoutTilde: .prefixOperator("~"),
type: IdentifierTypeSyntax(name: .identifier("Copyable"))))
])
} else {
throw AnyErasedProtocolError(message: "Protocol inheritance from \(type) is not supported.")
}
}
let fromInitialiserType = AttributedTypeSyntax(
specifiers: [.simpleTypeSpecifier(SimpleTypeSpecifierSyntax(specifier: "consuming "))],
baseType: SomeOrAnyTypeSyntax(
someOrAnySpecifier: .keyword(.some),
constraint: CompositionTypeSyntax(
elements: .init {
if inheritanceClause != nil {
.init(
type: SuppressedTypeSyntax(
withoutTilde: .prefixOperator("~"),
type: IdentifierTypeSyntax(name: "Copyable")),
ampersand: .binaryOperator("&")
)
}
.init(
type: IdentifierTypeSyntax(name: protocolDecl.name))
}
),
),
)
if attributeIdentifier.identifier == Identifier(canonicalName: "AnyErasedProtocolStruct") {
let decl = try StructDeclSyntax(
name: "Any\(protocolDecl.name)", inheritanceClause: inheritanceClause
) {
"private let inner: Erased"
"""
init(from inner: \(fromInitialiserType)) {
self.inner = Concrete(from: inner)
}
"""
try ClassDeclSyntax(
modifiers: [DeclModifierSyntax(name: .keyword(.private))], name: "Erased"
) {
for member in protocolDecl.memberBlock.members {
if let member = try self.generateMember(
from: member,
forProtocol: protocolDecl,
in: context,
config: .erasedClass
) {
member
}
}
}
let concreteGenericParamInheritedType = CompositionTypeSyntax(
elements: .init {
.init(
type: IdentifierTypeSyntax(name: protocolDecl.name),
ampersand: inheritanceClause != nil ? .binaryOperator("&") : nil)
if inheritanceClause != nil {
.init(
type: SuppressedTypeSyntax(
withoutTilde: .prefixOperator("~"),
type: IdentifierTypeSyntax(name: "Copyable")))
}
}
)
let concreteTypeParamName = context.makeUniqueName("T")
try ClassDeclSyntax(
modifiers: [
DeclModifierSyntax(name: .keyword(.private)),
DeclModifierSyntax(name: .keyword(.final)),
],
name: "Concrete",
genericParameterClause: GenericParameterClauseSyntax {
GenericParameterSyntax(
name: concreteTypeParamName,
colon: .colonToken(),
inheritedType: concreteGenericParamInheritedType)
},
inheritanceClause: InheritanceClauseSyntax {
InheritedTypeSyntax(type: IdentifierTypeSyntax(name: "Erased"))
}
) {
"private var inner: \(concreteTypeParamName)"
"""
init(from inner: consuming \(concreteTypeParamName)) {
self.inner = inner
}
"""
for member in protocolDecl.memberBlock.members {
if let member = try self.generateMember(
from: member,
forProtocol: protocolDecl,
in: context,
config: .concreteClass(
suppressedCopyableConformance: inheritanceClause != nil,
genericParameterName: concreteTypeParamName)
) {
member
}
}
}
for member in protocolDecl.memberBlock.members {
if let member = try self.generateMember(
from: member,
forProtocol: protocolDecl,
in: context,
config: .exposedStruct(suppressedCopyableConformance: inheritanceClause != nil)
) {
member
}
}
}
return [DeclSyntax(decl)]
} else if attributeIdentifier.identifier == Identifier(canonicalName: "AnyErasedProtocolClass") {
let decl = try ClassDeclSyntax(name: "Any\(protocolDecl.name)") {
"""
static func create(from inner: \(fromInitialiserType)) -> Any\(protocolDecl.name) {
return Concrete(from: inner)
}
"""
for member in protocolDecl.memberBlock.members {
if let member = try self.generateMember(
from: member,
forProtocol: protocolDecl,
in: context,
config: .erasedClass
) {
member
}
}
let concreteGenericParamInheritedType = CompositionTypeSyntax(
elements: .init {
.init(
type: IdentifierTypeSyntax(name: protocolDecl.name),
ampersand: inheritanceClause != nil ? .binaryOperator("&") : nil)
if inheritanceClause != nil {
.init(
type: SuppressedTypeSyntax(
withoutTilde: .prefixOperator("~"),
type: IdentifierTypeSyntax(name: "Copyable")))
}
}
)
let concreteTypeParamName = context.makeUniqueName("T")
try ClassDeclSyntax(
modifiers: [
DeclModifierSyntax(name: .keyword(.private)),
DeclModifierSyntax(name: .keyword(.final)),
],
name: "Concrete",
genericParameterClause: GenericParameterClauseSyntax {
GenericParameterSyntax(
name: concreteTypeParamName,
colon: .colonToken(),
inheritedType: concreteGenericParamInheritedType)
},
inheritanceClause: InheritanceClauseSyntax {
InheritedTypeSyntax(type: IdentifierTypeSyntax(name: "Any\(protocolDecl.name)"))
}
) {
"private var inner: \(concreteTypeParamName)"
"""
init(from inner: consuming \(concreteTypeParamName)) {
self.inner = inner
}
"""
for member in protocolDecl.memberBlock.members {
if let member = try self.generateMember(
from: member,
forProtocol: protocolDecl,
in: context,
config: .concreteClass(
suppressedCopyableConformance: inheritanceClause != nil,
genericParameterName: concreteTypeParamName)
) {
member
}
}
}
}
return [DeclSyntax(decl)]
} else {
return []
}
}
}
extension DeclModifierListSyntax {
func mapClassMemberModifiers() -> DeclModifierListSyntax {
DeclModifierListSyntax {
for modifier in self {
if let mapped = classMemberModifierMap[modifier.name.tokenKind] {
if let mapped {
DeclModifierSyntax(name: TokenSyntax(mapped, presence: .present))
}
} else {
modifier.trimmed
}
}
}
}
}
@main
struct Macros: CompilerPlugin {
var providingMacros: [Macro.Type] = [AnyErasedProtocolMacro.self]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment