Created
December 27, 2019 02:14
-
-
Save kreeger/7423a72a6b22f98c1d22954c93f062d0 to your computer and use it in GitHub Desktop.
A convenient way to mock out URL requests in Swift tests.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// HTTPRecording.swift | |
// | |
import Foundation | |
struct HTTPRecording { | |
let statusCode: Int | |
let httpVersion: String | |
let headers: [String: String] | |
let body: Data? | |
init?(data: Data, encoding: String.Encoding) { | |
guard let content = String(data: data, encoding: encoding) else { return nil } | |
var lines = content.components(separatedBy: .newlines).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) } | |
let firstLineElements = lines.removeFirst().components(separatedBy: .whitespaces) | |
guard firstLineElements.count >= 2, let statusCode = Int(firstLineElements[1]) else { return nil } | |
self.httpVersion = firstLineElements[0] | |
self.statusCode = statusCode | |
var headers = [String: String]() | |
var lastLine = 0 | |
for (idx, line) in lines.enumerated() { | |
lastLine = idx | |
guard !line.isEmpty else { break } | |
var components = line.components(separatedBy: ": ") | |
guard components.count >= 2 else { break } | |
headers[components.removeFirst()] = components.joined(separator: ": ") | |
} | |
self.headers = headers | |
if (lastLine + 1) >= lines.count { | |
self.body = nil | |
} else { | |
let remnants = lines[(lastLine + 1)..<lines.count].joined(separator: "\n") | |
self.body = remnants.data(using: encoding) | |
} | |
} | |
func response(for url: URL) -> HTTPURLResponse? { | |
return HTTPURLResponse(url: url, statusCode: statusCode, httpVersion: httpVersion, headerFields: headers) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// MockURLProtocol.swift | |
// | |
import Foundation | |
class MockURLProtocol: URLProtocol { | |
enum Result { | |
case redirect(request: URLRequest, response: URLResponse) | |
case recording(HTTPRecording) | |
case fail(Error) | |
case authChallenge(URLAuthenticationChallenge) | |
} | |
private static var testURLs = [URL: Result]() | |
private static var requestCounts = [URL: Int]() | |
// MARK: - URLProtocol overrides | |
override class func canInit(with request: URLRequest) -> Bool { | |
guard let url = request.url else { return false } | |
return findResponseMock(url: url) != nil | |
} | |
override class func canonicalRequest(for request: URLRequest) -> URLRequest { | |
return request | |
} | |
override func startLoading() { | |
defer { | |
client?.urlProtocolDidFinishLoading(self) | |
} | |
guard let url = request.url else { return } | |
type(of: self).incrementRequestCount(url: url) | |
switch type(of: self).findResponseMock(url: url) { | |
case .redirect(let redirectRequest, let redirectResponse)?: | |
client?.urlProtocol(self, wasRedirectedTo: redirectRequest, redirectResponse: redirectResponse) | |
case .recording(let recording)?: | |
if let data = recording.body { | |
client?.urlProtocol(self, didLoad: data) | |
} | |
if let response = recording.response(for: url) { | |
client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) | |
} | |
case .fail(let error)?: | |
client?.urlProtocol(self, didFailWithError: error) | |
case .authChallenge(let challenge)?: | |
client?.urlProtocol(self, didReceive: challenge) | |
default: | |
assertionFailure("Unexpected network call to url \(url)") | |
break | |
} | |
} | |
override func stopLoading() { } | |
// MARK: - Public class functions | |
class func mock(result: MockURLProtocol.Result, url: String) { | |
let url = URL(string: url)! | |
testURLs[url] = result | |
requestCounts[url] = 0 | |
} | |
class func clear() { | |
testURLs.removeAll() | |
requestCounts.removeAll() | |
} | |
// MARK: - Private functions | |
private class func incrementRequestCount(url: URL) { | |
// If we already have an exact match, use it. | |
if let found = requestCounts[url] { | |
requestCounts[url] = found + 1 | |
return | |
} | |
guard var components = URLComponents(url: url, resolvingAgainstBaseURL: false) else { | |
// Fall back on the original URL and mark it (something went wrong). | |
requestCounts[url, default: 0] += 1 | |
return | |
} | |
components.queryItems = nil | |
guard let newURL = components.url else { return } | |
// If we have a known request count at the URL sans query params, use it. | |
if let found = requestCounts[newURL] { | |
requestCounts[newURL] = found + 1 | |
return | |
} | |
// Otherwise, use the original URL when marking the value. | |
requestCounts[url, default: 0] += 1 | |
} | |
private class func findResponseMock(url: URL) -> Result? { | |
// If we already have an exact match, use it. | |
if let found = testURLs[url] { | |
return found | |
} | |
guard var components = URLComponents(url: url, resolvingAgainstBaseURL: false) else { | |
return nil | |
} | |
components.queryItems = nil | |
guard let newURL = components.url else { return nil } | |
// If we have a known request mock at the URL sans query params, use it. | |
return testURLs[newURL] | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// NetworkTestable.swift | |
// | |
import Foundation | |
enum NetworkTestableError: Error, CustomStringConvertible { | |
case invalidURL | |
case unknownBundle | |
case missingFile | |
case invalidRecording | |
var description: String { | |
switch self { | |
case .invalidURL: return "invalid URL" | |
case .unknownBundle: return "unknown bundle identifier" | |
case .missingFile: return "missing file" | |
case .invalidRecording: return "invalid recording" | |
} | |
} | |
} | |
protocol NetworkTestable: AnyObject { } | |
extension NetworkTestable { | |
func loadRecording(named name: String, for url: String, bundle: String) throws { | |
var fileComponents = name.components(separatedBy: ".") | |
let fileExt = fileComponents.removeLast() | |
let filename = fileComponents.joined(separator: ".") | |
guard let bundle = Bundle(identifier: bundle) else { throw NetworkTestableError.unknownBundle } | |
guard let fileURL = bundle.url(forResource: filename, withExtension: fileExt) else { throw NetworkTestableError.missingFile } | |
let data = try Data(contentsOf: fileURL) | |
guard let recording = HTTPRecording(data: data, encoding: .utf8) else { throw NetworkTestableError.invalidRecording } | |
mockResult(.recording(recording), url: url) | |
} | |
func mockResult(_ result: MockURLProtocol.Result, url: String) { | |
MockURLProtocol.mock(result: result, url: url) | |
} | |
func removeAllURLMocks() { | |
MockURLProtocol.clear() | |
} | |
func vendURLSession() -> (session: URLSession, queue: DispatchQueue) { | |
let configuration = URLSessionConfiguration.ephemeral | |
configuration.protocolClasses = [MockURLProtocol.self] | |
let queue = DispatchQueue(label: "URLSession.Mock.DispatchQueue") | |
let opQueue = OperationQueue() | |
opQueue.name = "URLSession.Mock.OperationQueue" | |
opQueue.underlyingQueue = queue | |
return (URLSession(configuration: configuration, delegate: nil, delegateQueue: opQueue), queue) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment