-
-
Notifications
You must be signed in to change notification settings - Fork 86
Expand file tree
/
Copy pathWebSocketClient.swift
More file actions
289 lines (260 loc) · 12.2 KB
/
WebSocketClient.swift
File metadata and controls
289 lines (260 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import Foundation
import NIOCore
import NIOPosix
import NIOConcurrencyHelpers
import NIOExtras
import NIOHTTP1
import NIOWebSocket
import NIOSSL
import NIOTransportServices
import Atomics
public final class WebSocketClient: Sendable {
public enum Error: Swift.Error, LocalizedError {
case invalidURL
case invalidResponseStatus(HTTPResponseHead)
case alreadyShutdown
public var errorDescription: String? {
return "\(self)"
}
}
public typealias EventLoopGroupProvider = NIOEventLoopGroupProvider
public struct Configuration: Sendable {
public var tlsConfiguration: TLSConfiguration?
public var maxFrameSize: Int
/// Defends against small payloads in frame aggregation.
/// See `NIOWebSocketFrameAggregator` for details.
public var minNonFinalFragmentSize: Int
/// Max number of fragments in an aggregated frame.
/// See `NIOWebSocketFrameAggregator` for details.
public var maxAccumulatedFrameCount: Int
/// Maximum frame size after aggregation.
/// See `NIOWebSocketFrameAggregator` for details.
public var maxAccumulatedFrameSize: Int
public init(
tlsConfiguration: TLSConfiguration? = nil,
maxFrameSize: Int = 1 << 14
) {
self.tlsConfiguration = tlsConfiguration
self.maxFrameSize = maxFrameSize
self.minNonFinalFragmentSize = 0
self.maxAccumulatedFrameCount = Int.max
self.maxAccumulatedFrameSize = Int.max
}
}
let eventLoopGroupProvider: EventLoopGroupProvider
let group: EventLoopGroup
let configuration: Configuration
let isShutdown = ManagedAtomic(false)
public init(eventLoopGroupProvider: EventLoopGroupProvider, configuration: Configuration = .init()) {
self.eventLoopGroupProvider = eventLoopGroupProvider
switch self.eventLoopGroupProvider {
case .shared(let group):
self.group = group
case .createNew:
self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
}
self.configuration = configuration
}
@preconcurrency
public func connect(
scheme: String,
host: String,
port: Int,
path: String = "/",
query: String? = nil,
headers: HTTPHeaders = [:],
onUpgrade: @Sendable @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
self.connect(scheme: scheme, host: host, port: port, path: path, query: query, headers: headers, proxy: nil, onUpgrade: onUpgrade)
}
/// Establish a WebSocket connection via a proxy server.
///
/// - Parameters:
/// - scheme: Scheme component of the URI for the origin server.
/// - host: Host component of the URI for the origin server.
/// - port: Port on which to connect to the origin server.
/// - path: Path component of the URI for the origin server.
/// - query: Query component of the URI for the origin server.
/// - headers: Headers to send to the origin server.
/// - proxy: Host component of the URI for the proxy server.
/// - proxyPort: Port on which to connect to the proxy server.
/// - proxyHeaders: Headers to send to the proxy server.
/// - proxyConnectDeadline: Deadline for establishing the proxy connection.
/// - onUpgrade: An escaping closure to be executed after the upgrade is completed by `NIOWebSocketClientUpgrader`.
/// - Returns: A future which completes when the connection to the origin server is established.
@preconcurrency
public func connect(
scheme: String,
host: String,
port: Int,
path: String = "/",
query: String? = nil,
headers: HTTPHeaders = [:],
proxy: String?,
proxyPort: Int? = nil,
proxyHeaders: HTTPHeaders = [:],
proxyConnectDeadline: NIODeadline = NIODeadline.distantFuture,
onUpgrade: @Sendable @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
assert(["ws", "wss"].contains(scheme))
let upgradePromise = self.group.any().makePromise(of: Void.self)
let bootstrap = WebSocketClient.makeBootstrap(on: self.group)
.channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1)
.channelInitializer { channel -> EventLoopFuture<Void> in
let uri: String
var upgradeRequestHeaders = headers
if proxy == nil {
uri = path
} else {
let relativePath = path.hasPrefix("/") ? path : "/" + path
uri = "\(scheme)://\(host):\(port)\(relativePath)"
if scheme == "ws" {
upgradeRequestHeaders.add(contentsOf: proxyHeaders)
}
}
let httpUpgradeRequestHandler = HTTPUpgradeRequestHandler(
host: host,
path: uri,
query: query,
headers: upgradeRequestHeaders,
upgradePromise: upgradePromise
)
let httpUpgradeRequestHandlerBox = NIOLoopBound(httpUpgradeRequestHandler, eventLoop: channel.eventLoop)
let websocketUpgrader = NIOWebSocketClientUpgrader(
maxFrameSize: self.configuration.maxFrameSize,
automaticErrorHandling: true,
upgradePipelineHandler: { channel, req in
return WebSocket.client(on: channel, config: .init(clientConfig: self.configuration), onUpgrade: onUpgrade)
}
)
let config: NIOHTTPClientUpgradeConfiguration = (
upgraders: [websocketUpgrader],
completionHandler: { context in
upgradePromise.succeed(())
channel.pipeline.removeHandler(httpUpgradeRequestHandlerBox.value, promise: nil)
}
)
let configBox = NIOLoopBound(config, eventLoop: channel.eventLoop)
if proxy == nil || scheme == "ws" {
if scheme == "wss" {
do {
let tlsHandler = try self.makeTLSHandler(tlsConfiguration: self.configuration.tlsConfiguration, host: host)
// The sync methods here are safe because we're on the channel event loop
// due to the promise originating on the event loop of the channel.
try channel.pipeline.syncOperations.addHandler(tlsHandler)
} catch {
return channel.pipeline.close(mode: .all)
}
}
return channel.pipeline.addHTTPClientHandlers(
leftOverBytesStrategy: .forwardBytes,
withClientUpgrade: config
).flatMap {
channel.pipeline.addHandler(httpUpgradeRequestHandlerBox.value)
}
}
// TLS + proxy
// we need to handle connecting with an additional CONNECT request
let proxyEstablishedPromise = channel.eventLoop.makePromise(of: Void.self)
let encoder = NIOLoopBound(HTTPRequestEncoder(), eventLoop: channel.eventLoop)
let decoder = NIOLoopBound(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes)), eventLoop: channel.eventLoop)
var connectHeaders = proxyHeaders
connectHeaders.add(name: "Host", value: host)
let proxyRequestHandler = NIOHTTP1ProxyConnectHandler(
targetHost: host,
targetPort: port,
headers: connectHeaders,
deadline: proxyConnectDeadline,
promise: proxyEstablishedPromise
)
// This code block adds HTTP handlers to allow the proxy request handler to function.
// They are then removed upon completion only to be re-added in `addHTTPClientHandlers`.
// This is done because the HTTP decoder is not valid after an upgrade, the CONNECT request being counted as one.
do {
try channel.pipeline.syncOperations.addHandler(encoder.value)
try channel.pipeline.syncOperations.addHandler(decoder.value)
try channel.pipeline.syncOperations.addHandler(proxyRequestHandler)
} catch {
return channel.eventLoop.makeFailedFuture(error)
}
proxyEstablishedPromise.futureResult.flatMap {
channel.pipeline.removeHandler(decoder.value)
}.flatMap {
channel.pipeline.removeHandler(encoder.value)
}.whenComplete { result in
switch result {
case .success:
do {
let tlsHandler = try self.makeTLSHandler(tlsConfiguration: self.configuration.tlsConfiguration, host: host)
// The sync methods here are safe because we're on the channel event loop
// due to the promise originating on the event loop of the channel.
try channel.pipeline.syncOperations.addHandler(tlsHandler)
try channel.pipeline.syncOperations.addHTTPClientHandlers(
leftOverBytesStrategy: .forwardBytes,
withClientUpgrade: configBox.value
)
try channel.pipeline.syncOperations.addHandler(httpUpgradeRequestHandlerBox.value)
} catch {
channel.pipeline.close(mode: .all, promise: nil)
}
case .failure:
channel.pipeline.close(mode: .all, promise: nil)
}
}
return channel.eventLoop.makeSucceededVoidFuture()
}
let connect = bootstrap.connect(host: proxy ?? host, port: proxyPort ?? port)
connect.cascadeFailure(to: upgradePromise)
return connect.flatMap { channel in
return upgradePromise.futureResult
}
}
@Sendable
private func makeTLSHandler(tlsConfiguration: TLSConfiguration?, host: String) throws -> NIOSSLClientHandler {
let context = try NIOSSLContext(
configuration: self.configuration.tlsConfiguration ?? .makeClientConfiguration()
)
let tlsHandler: NIOSSLClientHandler
do {
tlsHandler = try NIOSSLClientHandler(context: context, serverHostname: host)
} catch let error as NIOSSLExtraError where error == .cannotUseIPAddressInSNI {
tlsHandler = try NIOSSLClientHandler(context: context, serverHostname: nil)
}
return tlsHandler
}
public func syncShutdown() throws {
switch self.eventLoopGroupProvider {
case .shared:
return
case .createNew:
if self.isShutdown.compareExchange(
expected: false,
desired: true,
ordering: .relaxed
).exchanged {
try self.group.syncShutdownGracefully()
} else {
throw WebSocketClient.Error.alreadyShutdown
}
}
}
private static func makeBootstrap(on eventLoop: EventLoopGroup) -> NIOClientTCPBootstrapProtocol {
#if canImport(Network)
if let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) {
return tsBootstrap
}
#endif
if let nioBootstrap = ClientBootstrap(validatingGroup: eventLoop) {
return nioBootstrap
}
fatalError("No matching bootstrap found")
}
deinit {
switch self.eventLoopGroupProvider {
case .shared:
return
case .createNew:
assert(self.isShutdown.load(ordering: .relaxed), "WebSocketClient not shutdown before deinit.")
}
}
}