diff --git a/Sources/DatadogSDKTesting/NetworkInstrumentation/DDNetworkInstrumentation.swift b/Sources/DatadogSDKTesting/NetworkInstrumentation/DDNetworkInstrumentation.swift index 60ee710f..4bc28e50 100644 --- a/Sources/DatadogSDKTesting/NetworkInstrumentation/DDNetworkInstrumentation.swift +++ b/Sources/DatadogSDKTesting/NetworkInstrumentation/DDNetworkInstrumentation.swift @@ -178,21 +178,29 @@ class DDNetworkInstrumentation { } var originalIMP: IMP? let sessionTaskId = UUID().uuidString - - let block: @convention(block) (URLSession, AnyObject, @escaping (Any?, URLResponse?, Error?) -> Void) -> URLSessionTask = { session, argument, completion in + + let block: @convention(block) (URLSession, AnyObject, ((Any?, URLResponse?, Error?) -> Void)? ) -> URLSessionTask = { session, argument, completion in + if let url = argument as? URL, self.injectHeaders == true { let request = URLRequest(url: url) - if selector == #selector(URLSession.dataTask(with:completionHandler:) as (URLSession) -> (URL, @escaping (Data?, URLResponse?, Error?) -> Void) -> URLSessionDataTask) { - return session.dataTask(with: request, completionHandler: completion) + if let completion = completion { + return session.dataTask(with: request, completionHandler: completion) + } else { + return session.dataTask(with: request) + } } else { - return session.downloadTask(with: request, completionHandler: completion) + if let completion = completion { + return session.downloadTask(with: request, completionHandler: completion) + } else { + return session.downloadTask(with: request) + } } } - let castedIMP = unsafeBitCast(originalIMP, to: (@convention(c) (URLSession, Selector, Any, @escaping (Any?, URLResponse?, Error?) -> Void) -> URLSessionDataTask).self) - var task: URLSessionTask + let castedIMP = unsafeBitCast(originalIMP, to: (@convention(c) (URLSession, Selector, Any, ( (Any?, URLResponse?, Error?) -> Void)?) -> URLSessionDataTask).self) + var task: URLSessionTask! var completionBlock = completion if objc_getAssociatedObject(argument, &idKey) == nil { @@ -205,7 +213,11 @@ class DDNetworkInstrumentation { DDNetworkActivityLogger.log(response: response, dataOrFile: object, sessionTaskId: sessionTaskId) } } - completion(object, response, error) + if let completion = completion { + completion(object, response, error) + } else { + (session.delegate as? URLSessionTaskDelegate)?.urlSession?(session, task: task, didCompleteWithError: error) + } } completionBlock = completionWrapper } @@ -243,10 +255,11 @@ class DDNetworkInstrumentation { var originalIMP: IMP? let sessionTaskId = UUID().uuidString - let block: @convention(block) (URLSession, URLRequest, AnyObject, @escaping (Any?, URLResponse?, Error?) -> Void) -> URLSessionTask = { session, request, argument, completion in - - let castedIMP = unsafeBitCast(originalIMP, to: (@convention(c) (URLSession, Selector, URLRequest, AnyObject, @escaping (Any?, URLResponse?, Error?) -> Void) -> URLSessionDataTask).self) + let block: @convention(block) (URLSession, URLRequest, AnyObject, ( (Any?, URLResponse?, Error?) -> Void)?) -> URLSessionTask = { session, request, argument, completion in + let castedIMP = unsafeBitCast(originalIMP, to: (@convention(c) (URLSession, Selector, URLRequest, AnyObject, ( (Any?, URLResponse?, Error?) -> Void)?) -> URLSessionDataTask).self) + + var task: URLSessionTask! var completionBlock = completion if objc_getAssociatedObject(argument, &idKey) == nil { let completionWrapper: (Any?, URLResponse?, Error?) -> Void = { object, response, error in @@ -258,13 +271,17 @@ class DDNetworkInstrumentation { DDNetworkActivityLogger.log(response: response, dataOrFile: object, sessionTaskId: sessionTaskId) } } - completion(object, response, error) + if let completion = completion { + completion(object, response, error) + } else { + (session.delegate as? URLSessionTaskDelegate)?.urlSession?(session, task: task, didCompleteWithError: error) + } } completionBlock = completionWrapper } let instrumentedRequest = self.instrumentedRequest(for: request) - let task = castedIMP(session, selector, instrumentedRequest, argument, completionBlock) + task = castedIMP(session, selector, instrumentedRequest, argument, completionBlock) DDNetworkActivityLogger.log(request: instrumentedRequest, sessionTaskId: sessionTaskId) self.setIdKey(value: sessionTaskId, for: task)