diff --git a/Sources/AuthProvider/InverseRedirectMiddleware.swift b/Sources/AuthProvider/InverseRedirectMiddleware.swift new file mode 100644 index 0000000..1dcf223 --- /dev/null +++ b/Sources/AuthProvider/InverseRedirectMiddleware.swift @@ -0,0 +1,35 @@ +import HTTP +import Authentication + +/// Redirects authenticated requests to a supplied path. +public final class InverseRedirectMiddleware: Middleware { + /// The path to redirect to + public let path: String + + /// Which type of redirect to perform + public let redirectType: RedirectType + + /// Create a new inverse redirect middleware. + public init( + _ userType: U.Type = U.self, + path: String, + redirectType: RedirectType = .normal + ) { + self.path = path + self.redirectType = redirectType + } + + public func respond(to req: Request, chainingTo next: Responder) throws -> Response { + guard !req.auth.isAuthenticated(U.self) else { + return Response(redirect: path, redirectType) + } + + return try next.respond(to: req) + } + + /// Use this middleware to redirect authenticated + /// away from login pages back to a secure home page. + public static func home(_ userType: U.Type = U.self, path: String = "/") -> InverseRedirectMiddleware { + return InverseRedirectMiddleware(U.self, path: path) + } +} diff --git a/Sources/AuthProvider/RedirectMiddleware.swift b/Sources/AuthProvider/RedirectMiddleware.swift new file mode 100644 index 0000000..8d31279 --- /dev/null +++ b/Sources/AuthProvider/RedirectMiddleware.swift @@ -0,0 +1,34 @@ +import HTTP +import Authentication + +/// Redirects unauthenticated requests to a supplied path. +public final class RedirectMiddleware: Middleware { + /// The path to redirect to + public let path: String + + /// Which type of redirect to perform + public let redirectType: RedirectType + + /// Create a new redirect middleware. + public init( + path: String, + redirectType: RedirectType = .normal + ) { + self.path = path + self.redirectType = redirectType + } + + public func respond(to req: Request, chainingTo next: Responder) throws -> Response { + do { + return try next.respond(to: req) + } catch is AuthenticationError { + return Response(redirect: path, redirectType) + } + } + + /// Use this middleware to redirect users away from + /// protected content to a login page + public static func login(path: String = "/login") -> RedirectMiddleware { + return RedirectMiddleware(path: path) + } +} diff --git a/Tests/AuthProviderTests/MiddlewareTests.swift b/Tests/AuthProviderTests/MiddlewareTests.swift new file mode 100644 index 0000000..b4ba714 --- /dev/null +++ b/Tests/AuthProviderTests/MiddlewareTests.swift @@ -0,0 +1,56 @@ +import XCTest +import Vapor +import HTTP +import AuthProvider +import Authentication +import Testing + +class MiddlewareTests: XCTestCase { + override func setUp() { + Testing.onFail = XCTFail + } + + /// Test that an unauthenticated request to a secure + /// page gets redirected to the login page. + func testRedirectMiddleware() throws { + let drop = try Droplet() + + let redirect = RedirectMiddleware.login() + let auth = TokenAuthenticationMiddleware(TestUser.self) + + let protected = drop.grouped([redirect, auth]) + protected.get { req in + let user = try req.auth.assertAuthenticated(TestUser.self) + return "Welcome to the dashboard, \(user.name)" + } + + try drop.testResponse(to: .get, at: "/") + .assertStatus(is: .seeOther) + .assertHeader("Location", contains: "/login") + } + + /// Test that an authenticated request to login + /// gets redirected to the home page. + func testInverseRedirectMiddleware() throws { + let drop = try Droplet() + + let redirect = InverseRedirectMiddleware.home(TestUser.self) + let group = drop.grouped([redirect]) + group.get("login") { req in + return "Please login" + } + + let req = Request.makeTest(method: .get, path: "/login") + let user = TestUser(name: "Foo") + req.auth.authenticate(user) + + try drop.testResponse(to: req) + .assertStatus(is: .seeOther) + .assertHeader("Location", contains: "/") + } + + static var allTests = [ + ("testRedirectMiddleware", testRedirectMiddleware), + ("testInverseRedirectMiddleware", testInverseRedirectMiddleware) + ] +}