Created
February 14, 2026 17:41
-
-
Save needlesslygrim/4d726369461d58482c0a30f55d6951de to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import SwiftCompilerPlugin | |
| import SwiftSyntax | |
| import SwiftSyntaxMacros | |
| import SwiftDiagnostics | |
| struct AnyErasedProtocolError: Error { | |
| let message: String | |
| init(message: String = "") { | |
| self.message = message | |
| } | |
| } | |
| enum Config { | |
| case exposedStruct(suppressedCopyableConformance: Bool) | |
| case concreteClass(suppressedCopyableConformance: Bool, genericParameterName: TokenSyntax) | |
| case erasedClass | |
| } | |
| let subclassResponsibilitySyntax = FunctionCallExprSyntax.init( | |
| calledExpression: DeclReferenceExprSyntax(baseName: "fatalError"), | |
| leftParen: .leftParenToken(), | |
| arguments: [LabeledExprSyntax(expression: "Subclass responsibility".makeLiteralSyntax())], | |
| rightParen: .rightParenToken() | |
| ) | |
| let classMemberModifierMap: [TokenKind: TokenKind?] = | |
| [ .keyword(.mutating): nil | |
| , .keyword(.nonmutating): nil | |
| , .keyword(.static): .keyword(.class) ] | |
| struct AnyErasedProtocolMacro: PeerMacro { | |
| static func generateInitialiser( | |
| from initialiser: InitializerDeclSyntax, | |
| forProtocol protocolDecl: ProtocolDeclSyntax, | |
| in context: some MacroExpansionContext, | |
| config: Config | |
| ) -> InitializerDeclSyntax? { | |
| switch config { | |
| case .erasedClass: | |
| return nil | |
| case .concreteClass(_, let genericParameterName): | |
| return InitializerDeclSyntax( | |
| attributes: initialiser.attributes, | |
| modifiers: initialiser.modifiers, | |
| optionalMark: initialiser.optionalMark, | |
| genericParameterClause: initialiser.genericParameterClause, | |
| signature: initialiser.signature, | |
| genericWhereClause: initialiser.genericWhereClause | |
| ) { | |
| if let _ = initialiser.optionalMark { | |
| "guard let" | |
| } | |
| "inner = \(genericParameterName)(" | |
| for param in initialiser.signature.parameterClause.parameters { | |
| "\(param.firstName): \(param.secondName ?? param.firstName)," | |
| } | |
| ")" | |
| if let _ = initialiser.optionalMark { | |
| "else { return nil } " | |
| "self.inner = inner" | |
| } | |
| } | |
| case .exposedStruct(let suppressedCopyableConformance): | |
| let parameters = initialiser.signature.parameterClause.parameters | |
| let genericParameterName = context.makeUniqueName("T") | |
| let parameterClause = FunctionParameterClauseSyntax { | |
| FunctionParameterSyntax( | |
| firstName: "of", | |
| secondName: "_", | |
| type: MetatypeTypeSyntax( | |
| baseType: IdentifierTypeSyntax(name: genericParameterName), | |
| metatypeSpecifier: .keyword(.Type)), | |
| ) | |
| for p in parameters { | |
| p.trimmed | |
| } | |
| } | |
| let genericParameterCompositionElements = CompositionTypeElementListSyntax { | |
| if suppressedCopyableConformance { | |
| CompositionTypeElementSyntax( | |
| type: | |
| SuppressedTypeSyntax( | |
| withoutTilde: .prefixOperator("~"), | |
| type: IdentifierTypeSyntax(name: "Copyable") | |
| ), | |
| ampersand: .binaryOperator("&") | |
| ) | |
| } | |
| CompositionTypeElementSyntax( | |
| type: IdentifierTypeSyntax(name: protocolDecl.name), | |
| ) | |
| } | |
| return InitializerDeclSyntax( | |
| attributes: initialiser.attributes, | |
| modifiers: initialiser.modifiers, | |
| optionalMark: initialiser.optionalMark, | |
| genericParameterClause: GenericParameterClauseSyntax { | |
| GenericParameterSyntax( | |
| name: genericParameterName, | |
| colon: .colonToken(), | |
| inheritedType: CompositionTypeSyntax( | |
| elements: genericParameterCompositionElements), | |
| ) | |
| for p in initialiser.genericParameterClause?.parameters ?? [] { | |
| p.trimmed | |
| } | |
| }, | |
| signature: FunctionSignatureSyntax(parameterClause: parameterClause), | |
| genericWhereClause: initialiser.genericWhereClause | |
| ) { | |
| if let _ = initialiser.optionalMark { | |
| "guard let" | |
| } | |
| "inner = Concrete<\(genericParameterName)>(" | |
| for p in parameters { | |
| "\(p.firstName): \(p.secondName ?? p.firstName)," | |
| } | |
| ")" | |
| if let _ = initialiser.optionalMark { | |
| "else { return nil } " | |
| "self.inner = inner" | |
| } | |
| } | |
| } | |
| } | |
| static func generateFunction( | |
| from member: FunctionDeclSyntax, | |
| forProtocol protocolDecl: ProtocolDeclSyntax, | |
| in context: some MacroExpansionContext, | |
| config: Config | |
| ) -> FunctionDeclSyntax? { | |
| switch config { | |
| case .erasedClass: | |
| return FunctionDeclSyntax( | |
| attributes: member.attributes, | |
| modifiers: member.modifiers.mapClassMemberModifiers(), | |
| name: member.name, | |
| genericParameterClause: member.genericParameterClause, | |
| signature: member.signature, | |
| genericWhereClause: member.genericWhereClause, | |
| ) { | |
| subclassResponsibilitySyntax | |
| } | |
| case .concreteClass(_, let genericParameterName): | |
| let modifiers = DeclModifierListSyntax { | |
| DeclModifierSyntax(name: .keyword(.override)) | |
| for modifier in member.modifiers.mapClassMemberModifiers() { | |
| modifier | |
| } | |
| } | |
| return FunctionDeclSyntax( | |
| attributes: member.attributes, | |
| modifiers: modifiers, | |
| name: member.name, | |
| genericParameterClause: member.genericParameterClause, | |
| signature: member.signature, | |
| genericWhereClause: member.genericWhereClause, | |
| ) { | |
| if member.modifiers.contains(where: { $0.name.tokenKind == .keyword(.static) }) { | |
| "\(genericParameterName)" | |
| } else { | |
| "inner" | |
| } | |
| ".\(member.name)(" | |
| for p in member.signature.parameterClause.parameters { | |
| "\(p.firstName): \(p.secondName ?? p.firstName)," | |
| } | |
| ")" | |
| } | |
| case .exposedStruct: | |
| if member.modifiers.contains(where: { $0.name.tokenKind == .keyword(.static) }) { | |
| context.diagnose( | |
| Diagnostic( | |
| node: member, | |
| message: MacroExpansionWarningMessage("Static protocol members cannot be exposed on struct types. Use `AnyErasedProtocolClass` instead if this is worth more than value semantics.") | |
| ) | |
| ) | |
| return nil | |
| } | |
| return FunctionDeclSyntax( | |
| attributes: member.attributes, | |
| modifiers: member.modifiers, | |
| name: member.name, | |
| genericParameterClause: member.genericParameterClause, | |
| signature: member.signature, | |
| genericWhereClause: member.genericWhereClause, | |
| ) { | |
| "inner.\(member.name)(" | |
| for p in member.signature.parameterClause.parameters { | |
| "\(p.firstName): \(p.secondName ?? p.firstName)," | |
| } | |
| ")" | |
| } | |
| } | |
| } | |
| static func generateVariable( | |
| from member: VariableDeclSyntax, | |
| forProtocol protocolDecl: ProtocolDeclSyntax, | |
| in context: some MacroExpansionContext, | |
| config: Config | |
| ) throws -> VariableDeclSyntax { | |
| var bindings = PatternBindingListSyntax() | |
| for binding in member.bindings { | |
| let id = binding.pattern.as(IdentifierPatternSyntax.self)!.identifier | |
| var accessorDeclList = AccessorDeclListSyntax() | |
| for accessor in binding.accessorBlock?.accessors.as(AccessorDeclListSyntax.self) ?? [] { | |
| let classTypeModifier = if let accessorModifier = accessor.modifier { | |
| if let tokenKind = classMemberModifierMap[accessorModifier.name.tokenKind, default: accessorModifier.name.tokenKind] { | |
| DeclModifierSyntax(name: TokenSyntax(tokenKind, presence: .present)) | |
| } else { DeclModifierSyntax?.none } | |
| } else { | |
| DeclModifierSyntax?.none | |
| } | |
| switch accessor.accessorSpecifier.tokenKind { | |
| case .keyword(.get): | |
| let decl = | |
| switch config { | |
| case .erasedClass: | |
| AccessorDeclSyntax( | |
| modifier: classTypeModifier, | |
| accessorSpecifier: accessor.accessorSpecifier | |
| ) { subclassResponsibilitySyntax } | |
| case .concreteClass: | |
| AccessorDeclSyntax( | |
| modifier: classTypeModifier, | |
| accessorSpecifier: accessor.accessorSpecifier | |
| ) { "inner.\(id)" } | |
| case .exposedStruct: | |
| AccessorDeclSyntax( | |
| modifier: accessor.modifier, | |
| accessorSpecifier: accessor.accessorSpecifier | |
| ) { "inner.\(id)" } | |
| } | |
| accessorDeclList.append(decl) | |
| case .keyword(.set): | |
| let decl = | |
| switch config { | |
| case .erasedClass: | |
| AccessorDeclSyntax( | |
| modifier: classTypeModifier, | |
| accessorSpecifier: accessor.accessorSpecifier | |
| ) { subclassResponsibilitySyntax } | |
| case .concreteClass: | |
| AccessorDeclSyntax( | |
| modifier: classTypeModifier, | |
| accessorSpecifier: accessor.accessorSpecifier | |
| ) { "inner.\(id) = newValue" } | |
| case .exposedStruct: | |
| AccessorDeclSyntax( | |
| modifier: accessor.modifier, | |
| accessorSpecifier: accessor.accessorSpecifier | |
| ) { "inner.\(id) = newValue" } | |
| } | |
| accessorDeclList.append(decl) | |
| default: | |
| throw AnyErasedProtocolError() | |
| } | |
| } | |
| bindings.append( | |
| PatternBindingSyntax( | |
| pattern: binding.pattern, | |
| typeAnnotation: binding.typeAnnotation, | |
| accessorBlock: AccessorBlockSyntax(accessors: .accessors(accessorDeclList)), | |
| )) | |
| } | |
| let modifiers = switch config { | |
| case .exposedStruct: | |
| DeclModifierListSyntax() | |
| case .concreteClass: | |
| DeclModifierListSyntax { | |
| DeclModifierSyntax(name: .keyword(.override)) | |
| for modifier in member.modifiers.mapClassMemberModifiers() { | |
| modifier | |
| } | |
| } | |
| case .erasedClass: | |
| member.modifiers.mapClassMemberModifiers() | |
| } | |
| return VariableDeclSyntax( | |
| modifiers: modifiers, | |
| bindingSpecifier: .keyword(.var), | |
| bindings: bindings | |
| ) | |
| } | |
| static func generateMember( | |
| from member: MemberBlockItemSyntax, | |
| forProtocol protocolDecl: ProtocolDeclSyntax, | |
| in context: some MacroExpansionContext, | |
| config: Config | |
| ) throws -> MemberBlockItemSyntax? { | |
| if let initialiser = member.decl.as(InitializerDeclSyntax.self) { | |
| return self.generateInitialiser( | |
| from: initialiser, | |
| forProtocol: protocolDecl, | |
| in: context, | |
| config: config | |
| ).map { MemberBlockItemSyntax(decl: $0) } | |
| } else if let functionDecl = member.decl.as(FunctionDeclSyntax.self) { | |
| return self.generateFunction( | |
| from: functionDecl, | |
| forProtocol: protocolDecl, | |
| in: context, | |
| config: config | |
| ).map { MemberBlockItemSyntax(decl: $0) } | |
| } else if let variabledecl = member.decl.as(VariableDeclSyntax.self) { | |
| return MemberBlockItemSyntax( | |
| decl: try self.generateVariable( | |
| from: variabledecl, | |
| forProtocol: protocolDecl, | |
| in: context, | |
| config: config | |
| ) | |
| ) | |
| } else { | |
| throw AnyErasedProtocolError() | |
| } | |
| } | |
| static func expansion( | |
| of node: SwiftSyntax.AttributeSyntax, | |
| providingPeersOf declaration: some SwiftSyntax.DeclSyntaxProtocol, | |
| in context: some SwiftSyntaxMacros.MacroExpansionContext | |
| ) throws -> [SwiftSyntax.DeclSyntax] { | |
| guard let protocolDecl = declaration.as(ProtocolDeclSyntax.self) else { | |
| throw AnyErasedProtocolError(message: "This macro can only be used on protocol declarations.") | |
| } | |
| guard let attributeIdentifier = node.attributeName.as(IdentifierTypeSyntax.self)?.name else { | |
| throw AnyErasedProtocolError(message: "How did you get here?") | |
| } | |
| var inheritanceClause: InheritanceClauseSyntax? = nil | |
| for type in protocolDecl.inheritanceClause?.inheritedTypes ?? [] { | |
| if let supressedType = type.type.as(SuppressedTypeSyntax.self), | |
| let innerType = supressedType.type.as(IdentifierTypeSyntax.self), | |
| innerType.name.identifier == Identifier(canonicalName: "Copyable") | |
| { | |
| inheritanceClause = InheritanceClauseSyntax(inheritedTypes: [ | |
| InheritedTypeSyntax( | |
| type: SuppressedTypeSyntax( | |
| withoutTilde: .prefixOperator("~"), | |
| type: IdentifierTypeSyntax(name: .identifier("Copyable")))) | |
| ]) | |
| } else { | |
| throw AnyErasedProtocolError(message: "Protocol inheritance from \(type) is not supported.") | |
| } | |
| } | |
| let fromInitialiserType = AttributedTypeSyntax( | |
| specifiers: [.simpleTypeSpecifier(SimpleTypeSpecifierSyntax(specifier: "consuming "))], | |
| baseType: SomeOrAnyTypeSyntax( | |
| someOrAnySpecifier: .keyword(.some), | |
| constraint: CompositionTypeSyntax( | |
| elements: .init { | |
| if inheritanceClause != nil { | |
| .init( | |
| type: SuppressedTypeSyntax( | |
| withoutTilde: .prefixOperator("~"), | |
| type: IdentifierTypeSyntax(name: "Copyable")), | |
| ampersand: .binaryOperator("&") | |
| ) | |
| } | |
| .init( | |
| type: IdentifierTypeSyntax(name: protocolDecl.name)) | |
| } | |
| ), | |
| ), | |
| ) | |
| if attributeIdentifier.identifier == Identifier(canonicalName: "AnyErasedProtocolStruct") { | |
| let decl = try StructDeclSyntax( | |
| name: "Any\(protocolDecl.name)", inheritanceClause: inheritanceClause | |
| ) { | |
| "private let inner: Erased" | |
| """ | |
| init(from inner: \(fromInitialiserType)) { | |
| self.inner = Concrete(from: inner) | |
| } | |
| """ | |
| try ClassDeclSyntax( | |
| modifiers: [DeclModifierSyntax(name: .keyword(.private))], name: "Erased" | |
| ) { | |
| for member in protocolDecl.memberBlock.members { | |
| if let member = try self.generateMember( | |
| from: member, | |
| forProtocol: protocolDecl, | |
| in: context, | |
| config: .erasedClass | |
| ) { | |
| member | |
| } | |
| } | |
| } | |
| let concreteGenericParamInheritedType = CompositionTypeSyntax( | |
| elements: .init { | |
| .init( | |
| type: IdentifierTypeSyntax(name: protocolDecl.name), | |
| ampersand: inheritanceClause != nil ? .binaryOperator("&") : nil) | |
| if inheritanceClause != nil { | |
| .init( | |
| type: SuppressedTypeSyntax( | |
| withoutTilde: .prefixOperator("~"), | |
| type: IdentifierTypeSyntax(name: "Copyable"))) | |
| } | |
| } | |
| ) | |
| let concreteTypeParamName = context.makeUniqueName("T") | |
| try ClassDeclSyntax( | |
| modifiers: [ | |
| DeclModifierSyntax(name: .keyword(.private)), | |
| DeclModifierSyntax(name: .keyword(.final)), | |
| ], | |
| name: "Concrete", | |
| genericParameterClause: GenericParameterClauseSyntax { | |
| GenericParameterSyntax( | |
| name: concreteTypeParamName, | |
| colon: .colonToken(), | |
| inheritedType: concreteGenericParamInheritedType) | |
| }, | |
| inheritanceClause: InheritanceClauseSyntax { | |
| InheritedTypeSyntax(type: IdentifierTypeSyntax(name: "Erased")) | |
| } | |
| ) { | |
| "private var inner: \(concreteTypeParamName)" | |
| """ | |
| init(from inner: consuming \(concreteTypeParamName)) { | |
| self.inner = inner | |
| } | |
| """ | |
| for member in protocolDecl.memberBlock.members { | |
| if let member = try self.generateMember( | |
| from: member, | |
| forProtocol: protocolDecl, | |
| in: context, | |
| config: .concreteClass( | |
| suppressedCopyableConformance: inheritanceClause != nil, | |
| genericParameterName: concreteTypeParamName) | |
| ) { | |
| member | |
| } | |
| } | |
| } | |
| for member in protocolDecl.memberBlock.members { | |
| if let member = try self.generateMember( | |
| from: member, | |
| forProtocol: protocolDecl, | |
| in: context, | |
| config: .exposedStruct(suppressedCopyableConformance: inheritanceClause != nil) | |
| ) { | |
| member | |
| } | |
| } | |
| } | |
| return [DeclSyntax(decl)] | |
| } else if attributeIdentifier.identifier == Identifier(canonicalName: "AnyErasedProtocolClass") { | |
| let decl = try ClassDeclSyntax(name: "Any\(protocolDecl.name)") { | |
| """ | |
| static func create(from inner: \(fromInitialiserType)) -> Any\(protocolDecl.name) { | |
| return Concrete(from: inner) | |
| } | |
| """ | |
| for member in protocolDecl.memberBlock.members { | |
| if let member = try self.generateMember( | |
| from: member, | |
| forProtocol: protocolDecl, | |
| in: context, | |
| config: .erasedClass | |
| ) { | |
| member | |
| } | |
| } | |
| let concreteGenericParamInheritedType = CompositionTypeSyntax( | |
| elements: .init { | |
| .init( | |
| type: IdentifierTypeSyntax(name: protocolDecl.name), | |
| ampersand: inheritanceClause != nil ? .binaryOperator("&") : nil) | |
| if inheritanceClause != nil { | |
| .init( | |
| type: SuppressedTypeSyntax( | |
| withoutTilde: .prefixOperator("~"), | |
| type: IdentifierTypeSyntax(name: "Copyable"))) | |
| } | |
| } | |
| ) | |
| let concreteTypeParamName = context.makeUniqueName("T") | |
| try ClassDeclSyntax( | |
| modifiers: [ | |
| DeclModifierSyntax(name: .keyword(.private)), | |
| DeclModifierSyntax(name: .keyword(.final)), | |
| ], | |
| name: "Concrete", | |
| genericParameterClause: GenericParameterClauseSyntax { | |
| GenericParameterSyntax( | |
| name: concreteTypeParamName, | |
| colon: .colonToken(), | |
| inheritedType: concreteGenericParamInheritedType) | |
| }, | |
| inheritanceClause: InheritanceClauseSyntax { | |
| InheritedTypeSyntax(type: IdentifierTypeSyntax(name: "Any\(protocolDecl.name)")) | |
| } | |
| ) { | |
| "private var inner: \(concreteTypeParamName)" | |
| """ | |
| init(from inner: consuming \(concreteTypeParamName)) { | |
| self.inner = inner | |
| } | |
| """ | |
| for member in protocolDecl.memberBlock.members { | |
| if let member = try self.generateMember( | |
| from: member, | |
| forProtocol: protocolDecl, | |
| in: context, | |
| config: .concreteClass( | |
| suppressedCopyableConformance: inheritanceClause != nil, | |
| genericParameterName: concreteTypeParamName) | |
| ) { | |
| member | |
| } | |
| } | |
| } | |
| } | |
| return [DeclSyntax(decl)] | |
| } else { | |
| return [] | |
| } | |
| } | |
| } | |
| extension DeclModifierListSyntax { | |
| func mapClassMemberModifiers() -> DeclModifierListSyntax { | |
| DeclModifierListSyntax { | |
| for modifier in self { | |
| if let mapped = classMemberModifierMap[modifier.name.tokenKind] { | |
| if let mapped { | |
| DeclModifierSyntax(name: TokenSyntax(mapped, presence: .present)) | |
| } | |
| } else { | |
| modifier.trimmed | |
| } | |
| } | |
| } | |
| } | |
| } | |
| @main | |
| struct Macros: CompilerPlugin { | |
| var providingMacros: [Macro.Type] = [AnyErasedProtocolMacro.self] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment