From 17a0a8c154efcbd69db4fa1f422db4e4a0cbcbd4 Mon Sep 17 00:00:00 2001 From: Mykyta Konopelko Date: Tue, 17 Sep 2024 02:19:43 +0300 Subject: [PATCH] support closures --- .../AccessorKeyword+Macro.swift | 9 +- .../SpryableDiagnostic.swift | 4 + .../SpryableMacro/SpryableAccessorMacro.swift | 2 +- .../SpryableMacro/SpryableBodyMacro.swift | 18 +++- .../SwiftSyntax+SpryKit.swift | 48 ++++++++-- SharedTypes/AccessorKeyword.swift | 10 -- SharedTypes/FuncKeyword.swift | 11 +++ SharedTypes/VarKeyword.swift | 15 +++ Source/Argument.swift | 28 +++++- Source/Helpers/InternalHelpers.swift | 10 ++ Source/SpryableMacros.swift | 4 +- Tests/ArgumentTests.swift | 43 +++++++++ Tests/SpryableMacrosTests.swift | 94 +++++++++++++++++-- 13 files changed, 263 insertions(+), 33 deletions(-) delete mode 100644 SharedTypes/AccessorKeyword.swift create mode 100644 SharedTypes/FuncKeyword.swift create mode 100644 SharedTypes/VarKeyword.swift diff --git a/MacroAndCompilerPlugin/AccessorKeyword+Macro.swift b/MacroAndCompilerPlugin/AccessorKeyword+Macro.swift index 3ecaf99..6794d33 100644 --- a/MacroAndCompilerPlugin/AccessorKeyword+Macro.swift +++ b/MacroAndCompilerPlugin/AccessorKeyword+Macro.swift @@ -2,9 +2,16 @@ import Foundation import SharedTypes -internal extension Array where Element == AccessorKeyword { +internal extension Array where Element == VarKeyword { static func ~=(lhs: [Element], rhs: Element) -> Bool { return lhs.contains(rhs) } } + +internal extension Array where Element == FuncKeyword { + static func ~=(lhs: [Element], rhs: Element) -> Bool { + return lhs.contains(rhs) + } +} + #endif diff --git a/MacroAndCompilerPlugin/SpryableDiagnostic.swift b/MacroAndCompilerPlugin/SpryableDiagnostic.swift index 562971a..ee863b4 100644 --- a/MacroAndCompilerPlugin/SpryableDiagnostic.swift +++ b/MacroAndCompilerPlugin/SpryableDiagnostic.swift @@ -6,6 +6,7 @@ enum SpryableDiagnostic: String, DiagnosticMessage, Error { case notAVariable case onlyApplicableToVar case notAFunction + case nonEscapingClosureNotSupported case subscriptsNotSupported case operatorsNotSupported case invalidVariableRequirement @@ -27,6 +28,8 @@ enum SpryableDiagnostic: String, DiagnosticMessage, Error { return "Operator requirements are not supported by @Spryable." case .invalidVariableRequirement: return "Invalid variable requirement. Missing type annotation." + case .nonEscapingClosureNotSupported: + return "'Non-escaping' closures are not supported by `@Spryable`. You should write the body of the function of your 'Fake' manually." } } @@ -34,6 +37,7 @@ enum SpryableDiagnostic: String, DiagnosticMessage, Error { var severity: DiagnosticSeverity { switch self { case .invalidVariableRequirement, + .nonEscapingClosureNotSupported, .notAFunction, .notAVariable, .onlyApplicableToClass, diff --git a/MacroAndCompilerPlugin/SpryableMacro/SpryableAccessorMacro.swift b/MacroAndCompilerPlugin/SpryableMacro/SpryableAccessorMacro.swift index 90c9f13..4aaa82c 100644 --- a/MacroAndCompilerPlugin/SpryableMacro/SpryableAccessorMacro.swift +++ b/MacroAndCompilerPlugin/SpryableMacro/SpryableAccessorMacro.swift @@ -19,7 +19,7 @@ public enum SpryableAccessorMacro: AccessorMacro { throw SpryableDiagnostic.invalidVariableRequirement } - let options = node.options + let options = node.varOptions var effectSpecifiers: AccessorEffectSpecifiersSyntax? if options ~= .async || options ~= .throws { effectSpecifiers = .init(asyncSpecifier: options ~= .async ? .keyword(.async) : nil, diff --git a/MacroAndCompilerPlugin/SpryableMacro/SpryableBodyMacro.swift b/MacroAndCompilerPlugin/SpryableMacro/SpryableBodyMacro.swift index 5d990a5..67199de 100644 --- a/MacroAndCompilerPlugin/SpryableMacro/SpryableBodyMacro.swift +++ b/MacroAndCompilerPlugin/SpryableMacro/SpryableBodyMacro.swift @@ -11,18 +11,30 @@ public enum SpryableBodyMacro: BodyMacro { throw SpryableDiagnostic.notAFunction } - let parameters = syntax.signature.parameterClause.parameters.enumerated().map { idx, param in + let parameters = try syntax.signature.parameterClause.parameters.enumerated().map { _, param in + if param.isNonEscapingClosure { + throw SpryableDiagnostic.nonEscapingClosureNotSupported + } + let name = param.secondName ?? param.firstName if name.text != TokenSyntax.wildcardToken().text { return param } else { - return param.with(\.secondName, .identifier("arg\(idx)")) + return param.with(\.secondName, .identifier("Argument.skipped")) } } + let options = node.funcOptions let arguments = LabeledExprListSyntax { for (idx, parameter) in parameters.enumerated() { - let name = parameter.secondName ?? parameter.firstName + let name: TokenSyntax = { + if parameter.isEscapingClosure, !(options ~= .asRealClosure) { + return idx == parameters.count - 1 ? "Argument.closure" : "Argument.closure," + } else { + return parameter.secondName ?? parameter.firstName + } + }() + if idx == 0 { LabeledExprSyntax(label: "arguments", expression: DeclReferenceExprSyntax(baseName: name)) } else { diff --git a/MacroAndCompilerPlugin/SwiftSyntax+SpryKit.swift b/MacroAndCompilerPlugin/SwiftSyntax+SpryKit.swift index cb13ba6..25ed0cb 100644 --- a/MacroAndCompilerPlugin/SwiftSyntax+SpryKit.swift +++ b/MacroAndCompilerPlugin/SwiftSyntax+SpryKit.swift @@ -36,7 +36,15 @@ internal extension VariableDeclSyntax { } internal extension MemberAccessExprSyntax { - var keyword: AccessorKeyword? { + var varKeyword: VarKeyword? { + guard let name = declName.baseName.identifier?.name else { + return nil + } + + return .init(rawValue: name) + } + + var funcKeyword: FuncKeyword? { guard let name = declName.baseName.identifier?.name else { return nil } @@ -46,11 +54,11 @@ internal extension MemberAccessExprSyntax { } internal extension VariableDeclSyntax { - var options: [AccessorKeyword] { - var options: [AccessorKeyword] = attributes.flatMap { attr in + var options: [VarKeyword] { + var options: [VarKeyword] = attributes.flatMap { attr in attr.as(AttributeSyntax.self)?.arguments?.as(LabeledExprListSyntax.self).map { args in args.compactMap { arg in - arg.expression.as(MemberAccessExprSyntax.self)?.keyword + arg.expression.as(MemberAccessExprSyntax.self)?.varKeyword } } ?? [] } @@ -64,9 +72,9 @@ internal extension VariableDeclSyntax { } internal extension AttributeSyntax { - var options: [AccessorKeyword] { + var varOptions: [VarKeyword] { var options = arguments?.as(LabeledExprListSyntax.self)?.compactMap { expr in - expr.expression.as(MemberAccessExprSyntax.self)?.keyword + expr.expression.as(MemberAccessExprSyntax.self)?.varKeyword } ?? [] if !(options ~= .get) { @@ -75,6 +83,34 @@ internal extension AttributeSyntax { return options } + + var funcOptions: [FuncKeyword] { + var options = arguments?.as(LabeledExprListSyntax.self)?.compactMap { expr in + expr.expression.as(MemberAccessExprSyntax.self)?.funcKeyword + } ?? [] + + if options.isEmpty { + options.append(.asRealClosure) + } + + return options + } +} + +internal extension FunctionParameterSyntax { + var isClosure: Bool { + return isNonEscapingClosure || isEscapingClosure + } + + var isEscapingClosure: Bool { + return type.as(AttributedTypeSyntax.self)?.attributes.contains(where: { elem in + return elem.as(AttributeSyntax.self)?.attributeName.as(IdentifierTypeSyntax.self)?.name.tokenKind == .identifier("escaping") + }) == true + } + + var isNonEscapingClosure: Bool { + return type.as(FunctionTypeSyntax.self) != nil + } } internal extension Macro { diff --git a/SharedTypes/AccessorKeyword.swift b/SharedTypes/AccessorKeyword.swift deleted file mode 100644 index 509d61e..0000000 --- a/SharedTypes/AccessorKeyword.swift +++ /dev/null @@ -1,10 +0,0 @@ -#if swift(>=6.0) -import Foundation - -public enum AccessorKeyword: String, Hashable, CaseIterable { - case get - case set - case async - case `throws` -} -#endif diff --git a/SharedTypes/FuncKeyword.swift b/SharedTypes/FuncKeyword.swift new file mode 100644 index 0000000..91301a5 --- /dev/null +++ b/SharedTypes/FuncKeyword.swift @@ -0,0 +1,11 @@ +#if swift(>=6.0) +import Foundation + +/// Parameters for @SpryableFunc +public enum FuncKeyword: String, Hashable, CaseIterable { + /// spryify parameter as Argument.closure + case asArgument + /// spryify parameter as real closure which you can handle from stub. Default behavior + case asRealClosure +} +#endif diff --git a/SharedTypes/VarKeyword.swift b/SharedTypes/VarKeyword.swift new file mode 100644 index 0000000..33d7733 --- /dev/null +++ b/SharedTypes/VarKeyword.swift @@ -0,0 +1,15 @@ +#if swift(>=6.0) +import Foundation + +/// Parameters for @SpryableVar +public enum VarKeyword: String, Hashable, CaseIterable { + /// generate 'get'. Always generating it + case get + /// generate 'set' + case set + /// add 'async' parameter to 'get' + case async + /// add 'throws' parameter to 'get' + case `throws` +} +#endif diff --git a/Source/Argument.swift b/Source/Argument.swift index 3fd9cbc..0a11997 100644 --- a/Source/Argument.swift +++ b/Source/Argument.swift @@ -10,6 +10,12 @@ public enum Argument { /// Every value matches this qualification. case anything + /// Every value matches this qualification, but not 'Argument.anything'. + case skipped + + /// Any closure + case closure + /// Custom validator case validator((Any?) -> Bool) @@ -42,15 +48,19 @@ public enum Argument { extension Argument: Equatable { public static func ==(lhs: Self, rhs: Self) -> Bool { switch (lhs, rhs) { - case (.nil, .nil), + case (.closure, .closure), + (.nil, .nil), (.nonNil, .nonNil), + (.skipped, .skipped), (.validator, .validator): return true case (.anything, _), + (.closure, _), (.nil, _), (.nonNil, _), - (.validator(_), _): + (.skipped, _), + (.validator, _): return false } } @@ -69,6 +79,10 @@ extension Argument: CustomStringConvertible { return "Argument.nil" case .validator: return "Argument.validator" + case .closure: + return "Argument.closure" + case .skipped: + return "Argument.skipped" } } @@ -121,6 +135,14 @@ private func isEqualArgs(specifiedArg: Any?, actualArg: Any?) -> Bool { if let specifiedArgAsArgumentEnum = specifiedArg as? Argument { switch specifiedArgAsArgumentEnum { case .anything: + if let actualArg = actualArg as? Argument { + return actualArg != Argument.skipped + } + return true + case .skipped: + if let actualArg = actualArg as? Argument { + return actualArg != Argument.anything + } return true case .nonNil: return !isNil(actualArg) @@ -128,6 +150,8 @@ private func isEqualArgs(specifiedArg: Any?, actualArg: Any?) -> Bool { return isNil(actualArg) case .validator(let validator): return validator(actualArg) + case .closure: + return isClosure(actualArg) } } diff --git a/Source/Helpers/InternalHelpers.swift b/Source/Helpers/InternalHelpers.swift index 0542646..93590a6 100644 --- a/Source/Helpers/InternalHelpers.swift +++ b/Source/Helpers/InternalHelpers.swift @@ -15,6 +15,16 @@ internal func isNil(_ value: Any?) -> Bool { } } +/// This is a helper function to find out if a value is closure. +internal func isClosure(_ value: Any?) -> Bool { + if let unwrappedValue = value { + let mirror = Mirror(reflecting: unwrappedValue) + return String(describing: mirror.subjectType).contains(" -> ") + } else { + return true + } +} + // MARK: - String Extensions extension String { diff --git a/Source/SpryableMacros.swift b/Source/SpryableMacros.swift index 2648275..9bd19ef 100644 --- a/Source/SpryableMacros.swift +++ b/Source/SpryableMacros.swift @@ -7,10 +7,10 @@ public macro Spryable() = #externalMacro(module: "MacroAndCompilerPlugin", type: "SpryableExtensionMacro") @attached(accessor) -public macro SpryableVar(_ accessors: SharedTypes.AccessorKeyword... = [.get]) = +public macro SpryableVar(_ accessors: SharedTypes.VarKeyword... = [.get]) = #externalMacro(module: "MacroAndCompilerPlugin", type: "SpryableAccessorMacro") @attached(body) -public macro SpryableFunc() = +public macro SpryableFunc(_ accessors: SharedTypes.FuncKeyword... = [.asRealClosure]) = #externalMacro(module: "MacroAndCompilerPlugin", type: "SpryableBodyMacro") #endif diff --git a/Tests/ArgumentTests.swift b/Tests/ArgumentTests.swift index 58cf6e0..c0d3cef 100644 --- a/Tests/ArgumentTests.swift +++ b/Tests/ArgumentTests.swift @@ -8,6 +8,7 @@ final class ArgumentTests: XCTestCase { XCTAssertEqual(Argument.nonNil.description, "Argument.nonNil") XCTAssertEqual(Argument.nil.description, "Argument.nil") XCTAssertEqual(Argument.validator { _ in true }.description, "Argument.validator") + XCTAssertEqual(Argument.closure.description, "Argument.closure") } func test_is_equal_args_list() { @@ -40,6 +41,31 @@ final class ArgumentTests: XCTestCase { ] XCTAssertTrue(subjectAction()) + // .skipped + specifiedArgs = [ + Argument.anything, + Argument.anything, + Argument.skipped + ] + actualArgs = [ + "asdf", + 3 as Int?, + Argument.anything + ] + XCTAssertTrue(subjectAction()) + + specifiedArgs = [ + Argument.anything, + Argument.anything, + Argument.skipped + ] + actualArgs = [ + "asdf", + 3 as Int?, + Argument.skipped + ] + XCTAssertTrue(subjectAction()) + // .nonNil specifiedArgs = [Argument.nonNil] actualArgs = [nil as String?] @@ -124,5 +150,22 @@ final class ArgumentTests: XCTestCase { specifiedArgs = [SpryEquatableTestHelper(isEqual: true)] actualArgs = [SpryEquatableTestHelper(isEqual: true)] XCTAssertTrue(subjectAction()) + + specifiedArgs = [Argument.closure] + actualArgs = [{}] + XCTAssertTrue(subjectAction()) + + // .skipped != .anything + specifiedArgs = [ + Argument.anything, + Argument.anything, + Argument.anything + ] + actualArgs = [ + "asdf", + 3 as Int?, + Argument.skipped + ] + XCTAssertFalse(subjectAction()) } } diff --git a/Tests/SpryableMacrosTests.swift b/Tests/SpryableMacrosTests.swift index 00fd691..8851020 100644 --- a/Tests/SpryableMacrosTests.swift +++ b/Tests/SpryableMacrosTests.swift @@ -1,11 +1,12 @@ #if os(macOS) && canImport(SwiftSyntax600) -import MacroAndCompilerPlugin import SpryKit import SwiftSyntax import SwiftSyntaxMacros import SwiftSyntaxMacrosTestSupport import XCTest +@testable import MacroAndCompilerPlugin + final class SpryableMacrosTests: XCTestCase { private let sut: [String: Macro.Type] = [ "SpryableAccessorMacro": SpryableAccessorMacro.self, @@ -44,6 +45,40 @@ final class SpryableMacrosTests: XCTestCase { macros: sut) } + func testNonamedArgs() { + let declaration = + """ + @SpryablePeerMacro + final class FakeFoo { + @SpryableBodyMacro + func bazArg3(some: Int, _: Int, _ some2: Int) + } + """ + + let expected = + """ + + final class FakeFoo { + func bazArg3(some: Int, _: Int, _ some2: Int) { + return spryify(arguments: some, Argument.skipped, some2) + } + } + + extension FakeFoo: Spryable { + enum ClassFunction: String, StringRepresentable { + case _unknown_ = "'enum' must have at least one 'case'" + } + enum Function: String, StringRepresentable { + case bazArg3WithSome_Arg1_Some2 = "bazArg3(some:_:_:)" + } + } + """ + + assertMacroExpansion(declaration, + expandedSource: expected, + macros: sut) + } + func testStaticMacro() { let declaration = """ @@ -109,7 +144,7 @@ final class SpryableMacrosTests: XCTestCase { return spryify(arguments: some, some2) } static func bazArg6(_: Int, _: String) async throws -> Int { - return spryify(arguments: arg0, arg1) + return spryify(arguments: Argument.skipped, Argument.skipped) } } @@ -220,13 +255,13 @@ final class SpryableMacrosTests: XCTestCase { return spryify(arguments: some, some2) } public func bazArg4(_: Int) { - return spryify(arguments: arg0) + return spryify(arguments: Argument.skipped) } func bazArg5(_: Int, _: String) async -> Int { - return spryify(arguments: arg0, arg1) + return spryify(arguments: Argument.skipped, Argument.skipped) } static func bazArg6(_: Int, _: String) async throws -> Int { - return spryify(arguments: arg0, arg1) + return spryify(arguments: Argument.skipped, Argument.skipped) } } @@ -340,13 +375,13 @@ final class SpryableMacrosTests: XCTestCase { return spryify(arguments: some, some2) } public func bazArg4(_: Int) { - return spryify(arguments: arg0) + return spryify(arguments: Argument.skipped) } func bazArg5(_: Int, _: String) async -> Int { - return spryify(arguments: arg0, arg1) + return spryify(arguments: Argument.skipped, Argument.skipped) } static func bazArg6(_: Int, _: String) async throws -> Int { - return spryify(arguments: arg0, arg1) + return spryify(arguments: Argument.skipped, Argument.skipped) } } @@ -376,5 +411,48 @@ final class SpryableMacrosTests: XCTestCase { expandedSource: expected, macros: sut) } + + func testClosures() { + let declaration = + """ + @SpryablePeerMacro + final class FakeClosures { + @SpryableBodyMacro + func sync(execute work: () throws -> R) rethrows -> R + + @SpryableBodyMacro + func escaping(execute work: @escaping () throws -> R) rethrows -> R + } + """ + + let expected = + """ + final class FakeClosures { + func sync(execute work: () throws -> R) rethrows -> R + func escaping(execute work: @escaping () throws -> R) rethrows -> R { + return spryify(arguments: work) + } + } + + extension FakeClosures: Spryable { + enum ClassFunction: String, StringRepresentable { + case _unknown_ = "'enum' must have at least one 'case'" + } + enum Function: String, StringRepresentable { + case syncWithExecute = "sync(execute:)" + case escapingWithExecute = "escaping(execute:)" + } + } + """ + + assertMacroExpansion(declaration, + expandedSource: expected, + diagnostics: [ + .init(message: SpryableDiagnostic.nonEscapingClosureNotSupported.message, + line: 3, + column: 5) + ], + macros: sut) + } } #endif