Skip to content

Commit

Permalink
Use HTTPClient.shared now that it is available (#609)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler committed Apr 10, 2024
1 parent cefaab7 commit b665919
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 131 deletions.
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ let package = Package(
.package(url: "https://github.com/apple/swift-metrics.git", "1.0.0"..<"3.0.0"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.7.2"),
.package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.13.1"),
.package(url: "https://github.com/swift-server/async-http-client.git", from: "1.19.0"),
.package(url: "https://github.com/swift-server/async-http-client.git", from: "1.21.0"),
.package(url: "https://github.com/adam-fowler/jmespath.swift.git", from: "1.0.2"),
],
targets: [
Expand Down
71 changes: 5 additions & 66 deletions Sources/SotoCore/AWSClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ public final class AWSClient: Sendable {
public let middleware: AWSMiddlewareProtocol
/// HTTP client used by AWSClient
public let httpClient: AWSHTTPClient
/// Keeps a record of how we obtained the HTTP client
let httpClientProvider: HTTPClientProvider
/// Logger used for non-request based output
let clientLogger: Logger
/// client options
Expand All @@ -71,20 +69,10 @@ public final class AWSClient: Sendable {
retryPolicy retryPolicyFactory: RetryPolicyFactory = .default,
middleware: Middleware,
options: Options = Options(),
httpClientProvider: HTTPClientProvider,
httpClient: AWSHTTPClient = HTTPClient.shared,
logger clientLogger: Logger = AWSClient.loggingDisabled
) {
// setup httpClient
self.httpClientProvider = httpClientProvider
switch httpClientProvider.value {
case .shared(let providedHTTPClient):
self.httpClient = providedHTTPClient
case .createNewWithEventLoopGroup(let elg):
self.httpClient = AsyncHTTPClient.HTTPClient(eventLoopGroupProvider: .shared(elg), configuration: .init(timeout: .init(connect: .seconds(10))))
case .createNew:
self.httpClient = AsyncHTTPClient.HTTPClient(eventLoopGroupProvider: .singleton, configuration: .init(timeout: .init(connect: .seconds(10))))
}

self.httpClient = httpClient
let credentialProvider = credentialProviderFactory.createProvider(context: .init(
httpClient: self.httpClient,
logger: clientLogger,
Expand Down Expand Up @@ -113,20 +101,10 @@ public final class AWSClient: Sendable {
credentialProvider credentialProviderFactory: CredentialProviderFactory = .default,
retryPolicy retryPolicyFactory: RetryPolicyFactory = .default,
options: Options = Options(),
httpClientProvider: HTTPClientProvider,
httpClient: AWSHTTPClient = HTTPClient.shared,
logger clientLogger: Logger = AWSClient.loggingDisabled
) {
// setup httpClient
self.httpClientProvider = httpClientProvider
switch httpClientProvider.value {
case .shared(let providedHTTPClient):
self.httpClient = providedHTTPClient
case .createNewWithEventLoopGroup(let elg):
self.httpClient = AsyncHTTPClient.HTTPClient(eventLoopGroupProvider: .shared(elg), configuration: .init(timeout: .init(connect: .seconds(10))))
case .createNew:
self.httpClient = AsyncHTTPClient.HTTPClient(eventLoopGroupProvider: .singleton, configuration: .init(timeout: .init(connect: .seconds(10))))
}

self.httpClient = httpClient
let credentialProvider = credentialProviderFactory.createProvider(context: .init(
httpClient: self.httpClient,
logger: clientLogger,
Expand Down Expand Up @@ -206,29 +184,6 @@ public final class AWSClient: Sendable {
public static var failedToAccessPayload: ClientError { .init(error: .failedToAccessPayload) }
}

/// Specifies how `HTTPClient` will be created and establishes lifecycle ownership.
public struct HTTPClientProvider: Sendable {
fileprivate enum Internal: Sendable {
case shared(AWSHTTPClient)
case createNewWithEventLoopGroup(EventLoopGroup)
case createNew
}

fileprivate let value: Internal

fileprivate init(_ value: Internal) {
self.value = value
}

/// Use HTTPClient provided by the user. User is responsible for the lifecycle of the HTTPClient.
public static func shared(_ httpClient: AWSHTTPClient) -> Self { .init(.shared(httpClient)) }
/// HTTPClient will be created by AWSClient using provided EventLoopGroup. When `shutdown` is called, created `HTTPClient`
/// will be shut down as well.
public static func createNewWithEventLoopGroup(_ eventLoopGroup: EventLoopGroup) -> Self { .init(.createNewWithEventLoopGroup(eventLoopGroup)) }
/// `HTTPClient` will be created by `AWSClient`. When `shutdown` is called, created `HTTPClient` will be shut down as well.
public static var createNew: Self { .init(.createNew) }
}

/// Additional options
public struct Options: Sendable {
/// log level used for request logging
Expand All @@ -254,30 +209,14 @@ extension AWSClient {
/// Shutdown AWSClient asynchronously.
///
/// Before an `AWSClient` is deleted you need to call this function or the synchronous
/// version `syncShutdown` to do a clean shutdown of the client. It cleans up `CredentialProvider` tasks and shuts down
/// the HTTP client if it was created by the `AWSClient`.
/// version `syncShutdown` to do a clean shutdown of the client to clean up `CredentialProvider` tasks.
public func shutdown() async throws {
guard self.isShutdown.compareExchange(expected: false, desired: true, ordering: .relaxed).exchanged else {
throw ClientError.alreadyShutdown
}
// shutdown credential provider ignoring any errors as credential provider that doesn't initialize
// can cause the shutdown process to fail
try? await self.credentialProvider.shutdown()
// if httpClient was created by AWSClient then it is required to shutdown the httpClient.
switch self.httpClientProvider.value {
case .createNew, .createNewWithEventLoopGroup:
do {
try await self.httpClient.shutdown()
} catch {
self.clientLogger.log(level: self.options.errorLogLevel, "Error shutting down HTTP client", metadata: [
"aws-error": "\(error)",
])
throw error
}

case .shared:
return
}
}

/// Execute a request with an input object and an empty response
Expand Down
4 changes: 2 additions & 2 deletions Sources/SotoCore/Credential/STSAssumeRole.swift
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ struct STSAssumeRoleCredentialProvider: CredentialProviderWithClient {
httpClient: AWSHTTPClient,
endpoint: String? = nil
) {
self.client = AWSClient(credentialProvider: credentialProvider, httpClientProvider: .shared(httpClient))
self.client = AWSClient(credentialProvider: credentialProvider, httpClient: httpClient)
self.request = .assumeRole(arn: roleArn, sessionName: roleSessionName)
self.config = AWSServiceConfig(
region: region,
Expand All @@ -180,7 +180,7 @@ struct STSAssumeRoleCredentialProvider: CredentialProviderWithClient {
endpoint: String? = nil,
threadPool: NIOThreadPool = .singleton
) {
self.client = AWSClient(credentialProvider: .empty, httpClientProvider: .shared(httpClient))
self.client = AWSClient(credentialProvider: .empty, httpClient: httpClient)
self.request = .assumeRoleWithWebIdentity(
arn: roleArn,
sessionName: roleSessionName,
Expand Down
5 changes: 3 additions & 2 deletions Sources/SotoTestUtils/TestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//

import AsyncHTTPClient
import Foundation
import Logging
import SotoCore
Expand All @@ -37,15 +38,15 @@ public func createAWSClient(
retryPolicy: RetryPolicyFactory = .noRetry,
middlewares: AWSMiddlewareProtocol = TestEnvironment.middlewares,
options: AWSClient.Options = .init(),
httpClientProvider: AWSClient.HTTPClientProvider = .createNew,
httpClient: AWSHTTPClient = HTTPClient.shared,
logger: Logger = TestEnvironment.logger
) -> AWSClient {
return AWSClient(
credentialProvider: credentialProvider,
retryPolicy: retryPolicy,
middleware: middlewares,
options: options,
httpClientProvider: httpClientProvider,
httpClient: httpClient,
logger: logger
)
}
Expand Down
44 changes: 15 additions & 29 deletions Tests/SotoCoreTests/AWSClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,7 @@ class AWSClientTests: XCTestCase {
let httpClient = HTTPClient(eventLoopGroupProvider: .singleton)
defer { XCTAssertNoThrow(try httpClient.syncShutdown()) }

let client = createAWSClient(httpClientProvider: .shared(httpClient))
try await client.shutdown()
}

func testShutdownWithEventLoop() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }

let client = createAWSClient(httpClientProvider: .createNewWithEventLoopGroup(eventLoopGroup))
let client = createAWSClient(httpClient: httpClient)
try await client.shutdown()
}

Expand All @@ -71,7 +63,7 @@ class AWSClientTests: XCTestCase {
)
let client = createAWSClient(
credentialProvider: .static(accessKeyId: "foo", secretAccessKey: "bar"),
httpClientProvider: .shared(httpClient)
httpClient: httpClient
)
defer {
XCTAssertNoThrow(try client.syncShutdown())
Expand Down Expand Up @@ -174,13 +166,10 @@ class AWSClientTests: XCTestCase {
let i: Int64
}
do {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let awsServer = AWSTestServer(serviceProtocol: .json)
let config = createServiceConfig(serviceProtocol: .json(version: "1.1"), endpoint: awsServer.address)
let client = createAWSClient(
credentialProvider: .empty,
httpClientProvider: .createNewWithEventLoopGroup(eventLoopGroup)
credentialProvider: .empty
)
defer {
XCTAssertNoThrow(try client.syncShutdown())
Expand Down Expand Up @@ -218,13 +207,10 @@ class AWSClientTests: XCTestCase {
let data: AWSBase64Data
}
do {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let awsServer = AWSTestServer(serviceProtocol: .json)
let config = createServiceConfig(serviceProtocol: .json(version: "1.1"), endpoint: awsServer.address)
let client = createAWSClient(
credentialProvider: .empty,
httpClientProvider: .createNewWithEventLoopGroup(eventLoopGroup)
credentialProvider: .empty
)
defer {
XCTAssertNoThrow(try client.syncShutdown())
Expand Down Expand Up @@ -288,7 +274,7 @@ class AWSClientTests: XCTestCase {
let awsServer = AWSTestServer(serviceProtocol: .json)
let httpClient = HTTPClient(eventLoopGroupProvider: .singleton)
let config = createServiceConfig(endpoint: awsServer.address)
let client = createAWSClient(credentialProvider: .empty, httpClientProvider: .shared(httpClient))
let client = createAWSClient(credentialProvider: .empty, httpClient: httpClient)
defer {
XCTAssertNoThrow(try awsServer.stop())
XCTAssertNoThrow(try client.syncShutdown())
Expand All @@ -304,7 +290,7 @@ class AWSClientTests: XCTestCase {
let awsServer = AWSTestServer(serviceProtocol: .json)
let httpClient = HTTPClient(eventLoopGroupProvider: .singleton)
let config = createServiceConfig(service: "s3", endpoint: awsServer.address)
let client = createAWSClient(credentialProvider: .static(accessKeyId: "foo", secretAccessKey: "bar"), httpClientProvider: .shared(httpClient))
let client = createAWSClient(credentialProvider: .static(accessKeyId: "foo", secretAccessKey: "bar"), httpClient: httpClient)
defer {
XCTAssertNoThrow(try client.syncShutdown())
XCTAssertNoThrow(try awsServer.stop())
Expand Down Expand Up @@ -336,7 +322,7 @@ class AWSClientTests: XCTestCase {
let awsServer = AWSTestServer(serviceProtocol: .json)
let httpClient = HTTPClient(eventLoopGroupProvider: .singleton)
let config = createServiceConfig(endpoint: awsServer.address)
let client = createAWSClient(credentialProvider: .empty, httpClientProvider: .shared(httpClient))
let client = createAWSClient(credentialProvider: .empty, httpClient: httpClient)
defer {
// ignore error
try? awsServer.stop()
Expand Down Expand Up @@ -384,7 +370,7 @@ class AWSClientTests: XCTestCase {
let awsServer = AWSTestServer(serviceProtocol: .json)
let httpClient = HTTPClient(eventLoopGroupProvider: .singleton)
let config = createServiceConfig(endpoint: awsServer.address)
let client = createAWSClient(credentialProvider: .empty, httpClientProvider: .shared(httpClient))
let client = createAWSClient(credentialProvider: .empty, httpClient: httpClient)
defer {
XCTAssertNoThrow(try awsServer.stop())
XCTAssertNoThrow(try client.syncShutdown())
Expand Down Expand Up @@ -423,7 +409,7 @@ class AWSClientTests: XCTestCase {
let httpClientConfig = AsyncHTTPClient.HTTPClient.Configuration(redirectConfiguration: .init(.disallow))
let httpClient = AsyncHTTPClient.HTTPClient(eventLoopGroupProvider: .singleton, configuration: httpClientConfig)
let config = createServiceConfig(serviceProtocol: .json(version: "1.1"), endpoint: awsServer.address)
let client = createAWSClient(credentialProvider: .empty, httpClientProvider: .shared(httpClient))
let client = createAWSClient(credentialProvider: .empty, httpClient: httpClient)
defer {
XCTAssertNoThrow(try awsServer.stop())
XCTAssertNoThrow(try client.syncShutdown())
Expand All @@ -450,7 +436,7 @@ class AWSClientTests: XCTestCase {
let httpClient = AsyncHTTPClient.HTTPClient(eventLoopGroupProvider: .singleton)
let awsServer = AWSTestServer(serviceProtocol: .json)
let config = createServiceConfig(serviceProtocol: .json(version: "1.1"), endpoint: awsServer.address)
let client = createAWSClient(credentialProvider: .empty, retryPolicy: .exponential(base: .milliseconds(200)), httpClientProvider: .shared(httpClient))
let client = createAWSClient(credentialProvider: .empty, retryPolicy: .exponential(base: .milliseconds(200)), httpClient: httpClient)
defer {
XCTAssertNoThrow(try awsServer.stop())
XCTAssertNoThrow(try client.syncShutdown())
Expand Down Expand Up @@ -489,7 +475,7 @@ class AWSClientTests: XCTestCase {
let httpClient = AsyncHTTPClient.HTTPClient(eventLoopGroupProvider: .singleton)
let awsServer = AWSTestServer(serviceProtocol: .json)
let config = createServiceConfig(serviceProtocol: .json(version: "1.1"), endpoint: awsServer.address)
let client = createAWSClient(credentialProvider: .empty, retryPolicy: .jitter(), httpClientProvider: .shared(httpClient))
let client = createAWSClient(credentialProvider: .empty, retryPolicy: .jitter(), httpClient: httpClient)
defer {
XCTAssertNoThrow(try awsServer.stop())
XCTAssertNoThrow(try client.syncShutdown())
Expand Down Expand Up @@ -547,7 +533,7 @@ class AWSClientTests: XCTestCase {
let httpClient = AsyncHTTPClient.HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup))
defer { XCTAssertNoThrow(try httpClient.syncShutdown()) }
let config = createServiceConfig(serviceProtocol: .json(version: "1.1"), endpoint: serverAddress)
let client = createAWSClient(credentialProvider: .empty, retryPolicy: .init(retryPolicy: retryPolicy), httpClientProvider: .shared(httpClient))
let client = createAWSClient(credentialProvider: .empty, retryPolicy: .init(retryPolicy: retryPolicy), httpClient: httpClient)
defer { XCTAssertNoThrow(try client.syncShutdown()) }
async let responseTask: Void = client.execute(
operation: "test",
Expand Down Expand Up @@ -583,7 +569,7 @@ class AWSClientTests: XCTestCase {
let httpClient = AsyncHTTPClient.HTTPClient(eventLoopGroupProvider: .singleton)
let awsServer = AWSTestServer(serviceProtocol: .json)
let config = createServiceConfig(serviceProtocol: .json(version: "1.1"), endpoint: awsServer.address)
let client = createAWSClient(credentialProvider: .empty, retryPolicy: .jitter(), httpClientProvider: .shared(httpClient))
let client = createAWSClient(credentialProvider: .empty, retryPolicy: .jitter(), httpClient: httpClient)
defer {
XCTAssertNoThrow(try awsServer.stop())
XCTAssertNoThrow(try client.syncShutdown())
Expand Down Expand Up @@ -617,7 +603,7 @@ class AWSClientTests: XCTestCase {
let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup))
let awsServer = AWSTestServer(serviceProtocol: .json)
let config = createServiceConfig(endpoint: awsServer.address)
let client = createAWSClient(credentialProvider: .empty, httpClientProvider: .shared(httpClient))
let client = createAWSClient(credentialProvider: .empty, httpClient: httpClient)
defer {
XCTAssertNoThrow(try client.syncShutdown())
XCTAssertNoThrow(try httpClient.syncShutdown())
Expand Down Expand Up @@ -665,7 +651,7 @@ class AWSClientTests: XCTestCase {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 5)
let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup))
let config = createServiceConfig(endpoint: awsServer.address)
let client = createAWSClient(credentialProvider: .empty, httpClientProvider: .shared(httpClient))
let client = createAWSClient(credentialProvider: .empty, httpClient: httpClient)
defer {
XCTAssertNoThrow(try client.syncShutdown())
XCTAssertNoThrow(try httpClient.syncShutdown())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class ConfigFileCredentialProviderTests: XCTestCase {
context: context,
endpoint: testServer.address
)
}, httpClientProvider: .shared(httpClient))
}, httpClient: httpClient)
defer { XCTAssertNoThrow(try client.syncShutdown()) }

// Retrieve credentials
Expand Down
Loading

0 comments on commit b665919

Please sign in to comment.