diff --git a/README.md b/README.md index b4789fd..2eba299 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ To use this package, include it in your `Package.swift` dependencies: .package(url: "https://github.com/GraphQLSwift/GraphQLTransportWS", from: "") ``` -Then create a class to implement the `Messenger` protocol. Here's an example using +Then create a concrete type that conforms to the `Messenger` protocol. Here's an example using [`WebSocketKit`](https://github.com/vapor/websocket-kit): ```swift @@ -31,12 +31,12 @@ import GraphQLTransportWS struct WebSocketMessenger: Messenger { let websocket: WebSocket - func send(_ message: S) where S: Collection, S.Element == Character async throws { - try await websocket.send(message) + func send(_ message: Data) async throws { + try await websocket.send(String(decoding: message, as: UTF8.self)) } func error(_ message: String, code: Int) async throws { - try await websocket.send("\(code): \(message)") + try await websocket.close(code: code) } func close() async throws { @@ -73,9 +73,9 @@ routes.webSocket( ) } ) - let incoming = AsyncStream { continuation in + let incoming = AsyncStream { continuation in websocket.onText { _, message in - continuation.yield(message) + continuation.yield(Data(message.utf8)) } } try await server.listen(to: incoming) diff --git a/Sources/GraphQLTransportWS/Client.swift b/Sources/GraphQLTransportWS/Client.swift index 27d15c6..ca15d84 100644 --- a/Sources/GraphQLTransportWS/Client.swift +++ b/Sources/GraphQLTransportWS/Client.swift @@ -39,89 +39,102 @@ public actor Client { /// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed. /// - Parameter incoming: The server message sequence that the client should react to. public func listen(to incoming: A) async throws - where A.Element == String { + where A.Element == Data { for try await message in incoming { - // Detect and ignore error responses. - if message.starts(with: "44") { - // TODO: Determine what to do with returned error messages + try await respond(to: message) + } + } + + /// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed. + /// - Parameter incoming: The server message sequence that the client should react to. + @available(*, deprecated, message: "Use `Data` sequence instead.") + public func listen(to incoming: A) async throws + where A.Element == String { + for try await stringMessage in incoming { + guard let message = stringMessage.data(using: .utf8) else { + try await self.error(.invalidEncoding()) return } + try await respond(to: message) + } + } + + private func respond(to message: Data) async throws { + let response: Response + do { + response = try decoder.decode(Response.self, from: message) + } catch { + try await self.error(.noType()) + return + } - guard let json = message.data(using: .utf8) else { - try await error(.invalidEncoding()) + switch response.type { + case .connectionAck: + guard + let connectionAckResponse = try? decoder.decode( + ConnectionAckResponse.self, + from: message + ) + else { + try await error(.invalidResponseFormat(messageType: .connectionAck)) return } - - let response: Response - do { - response = try decoder.decode(Response.self, from: json) - } catch { - try await self.error(.noType()) + try await onConnectionAck(connectionAckResponse, self) + case .next: + guard let nextResponse = try? decoder.decode(NextResponse.self, from: message) else { + try await error(.invalidResponseFormat(messageType: .next)) return } - - switch response.type { - case .connectionAck: - guard - let connectionAckResponse = try? decoder.decode( - ConnectionAckResponse.self, - from: json - ) - else { - try await error(.invalidResponseFormat(messageType: .connectionAck)) - return - } - try await onConnectionAck(connectionAckResponse, self) - case .next: - guard let nextResponse = try? decoder.decode(NextResponse.self, from: json) else { - try await error(.invalidResponseFormat(messageType: .next)) - return - } - try await onNext(nextResponse, self) - case .error: - guard let errorResponse = try? decoder.decode(ErrorResponse.self, from: json) else { - try await error(.invalidResponseFormat(messageType: .error)) - return - } - try await onError(errorResponse, self) - case .complete: - guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: json) - else { - try await error(.invalidResponseFormat(messageType: .complete)) - return - } - try await onComplete(completeResponse, self) - default: - try await error(.invalidType()) + try await onNext(nextResponse, self) + case .error: + guard let errorResponse = try? decoder.decode(ErrorResponse.self, from: message) else { + try await error(.invalidResponseFormat(messageType: .error)) + return + } + try await onError(errorResponse, self) + case .complete: + guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: message) + else { + try await error(.invalidResponseFormat(messageType: .complete)) + return } + try await onComplete(completeResponse, self) + default: + try await error(.invalidType()) } } /// Send a `connection_init` request through the messenger public func sendConnectionInit(payload: InitPayload) async throws { try await messenger.send( - ConnectionInitRequest( - payload: payload - ).toJSON(encoder) + encoder.encode( + ConnectionInitRequest( + payload: payload + ) + ) ) } /// Send a `subscribe` request through the messenger public func sendStart(payload: GraphQLRequest, id: String) async throws { try await messenger.send( - SubscribeRequest( - payload: payload, - id: id - ).toJSON(encoder) + encoder.encode( + SubscribeRequest( + payload: payload, + id: id + ) + ) ) } /// Send a `complete` request through the messenger public func sendStop(id: String) async throws { try await messenger.send( - CompleteRequest( - id: id - ).toJSON(encoder) + encoder.encode( + CompleteRequest( + id: id + ) + ) ) } diff --git a/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift b/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift index 2b200dc..d2ae94b 100644 --- a/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift +++ b/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift @@ -9,105 +9,86 @@ struct GraphQLTransportWSError: Error { self.code = code } - static func unauthorized() -> Self { + static func forbidden() -> Self { return self.init( - "Unauthorized", - code: .unauthorized + "Forbidden", + code: .forbidden ) } static func notInitialized() -> Self { return self.init( "Connection not initialized", - code: .notInitialized + code: .unauthorized ) } static func tooManyInitializations() -> Self { return self.init( "Too many initialisation requests", - code: .tooManyInitializations + code: .tooManyRequests ) } static func subscriberAlreadyExists(id: String) -> Self { return self.init( "Subscriber for \(id) already exists", - code: .subscriberAlreadyExists + code: .conflict ) } static func invalidEncoding() -> Self { return self.init( "Message was not encoded in UTF8", - code: .invalidEncoding + code: .miscellaneous ) } static func noType() -> Self { return self.init( "Message has no 'type' field", - code: .noType + code: .miscellaneous ) } static func invalidType() -> Self { return self.init( "Message 'type' value does not match supported types", - code: .invalidType + code: .miscellaneous ) } static func invalidRequestFormat(messageType: RequestMessageType) -> Self { return self.init( "Request message doesn't match '\(messageType.type.rawValue)' JSON format", - code: .invalidRequestFormat + code: .miscellaneous ) } static func invalidResponseFormat(messageType: ResponseMessageType) -> Self { return self.init( "Response message doesn't match '\(messageType.type.rawValue)' JSON format", - code: .invalidResponseFormat + code: .miscellaneous ) } static func internalAPIStreamIssue(errors: [GraphQLError]) -> Self { return self.init( "API Response did not result in a stream type, contained errors\n \(errors.map { $0.message }.joined(separator: "\n"))", - code: .internalAPIStreamIssue - ) - } - - static func graphQLError(_ error: Error) -> Self { - return self.init( - "\(error)", - code: .graphQLError + code: .internalServerError ) } } /// Error codes for miscellaneous issues -public enum ErrorCode: Int, CustomStringConvertible, Sendable { +enum ErrorCode: Int, CustomStringConvertible, Sendable { /// Miscellaneous case miscellaneous = 4400 - - // Internal errors - case graphQLError = 4401 - case internalAPIStreamIssue = 4402 - - // Message errors - case invalidEncoding = 4410 - case noType = 4411 - case invalidType = 4412 - case invalidRequestFormat = 4413 - case invalidResponseFormat = 4414 - - // Initialization errors - case unauthorized = 4430 - case notInitialized = 4431 - case tooManyInitializations = 4432 - case subscriberAlreadyExists = 4433 + case unauthorized = 4401 + case forbidden = 4403 + case conflict = 4409 + case tooManyRequests = 4429 + case internalServerError = 4500 public var description: String { return "\(rawValue)" diff --git a/Sources/GraphQLTransportWS/JsonEncodable.swift b/Sources/GraphQLTransportWS/JsonEncodable.swift deleted file mode 100644 index b54f881..0000000 --- a/Sources/GraphQLTransportWS/JsonEncodable.swift +++ /dev/null @@ -1,23 +0,0 @@ -import Foundation -import GraphQL - -/// Indicates an object that can be converted into JSON for messaging -protocol JsonEncodable: Codable {} - -extension JsonEncodable { - /// Converts the object into a JSON string - /// - Parameter encoder: JSON Encoder used to encode the object into a string - /// - Returns: The JSON string representation of the object, or an error JSON if not possible - func toJSON(_ encoder: GraphQLJSONEncoder) -> String { - let data: Data - do { - data = try encoder.encode(self) - } catch { - return EncodingErrorResponse("Unable to encode response").toJSON(encoder) - } - guard let body = String(data: data, encoding: .utf8) else { - return EncodingErrorResponse("Encoded response can't be cast to string").toJSON(encoder) - } - return body - } -} diff --git a/Sources/GraphQLTransportWS/Messenger.swift b/Sources/GraphQLTransportWS/Messenger.swift index e0ba6d9..543c6fa 100644 --- a/Sources/GraphQLTransportWS/Messenger.swift +++ b/Sources/GraphQLTransportWS/Messenger.swift @@ -4,7 +4,7 @@ import Foundation public protocol Messenger: Sendable { /// Send a message through this messenger /// - Parameter message: The message to send - func send(_ message: S) async throws where S.Element == Character + func send(_ message: Data) async throws /// Close the messenger func close() async throws diff --git a/Sources/GraphQLTransportWS/Requests.swift b/Sources/GraphQLTransportWS/Requests.swift index 5807190..5639622 100644 --- a/Sources/GraphQLTransportWS/Requests.swift +++ b/Sources/GraphQLTransportWS/Requests.swift @@ -2,12 +2,12 @@ import Foundation import GraphQL /// A general request. This object's type is used to triage to other, more specific request objects. -public struct Request: Equatable, JsonEncodable { +public struct Request: Equatable, Codable { public let type: RequestMessageType } /// A websocket `connection_init` request from the client to the server -public struct ConnectionInitRequest: Equatable, JsonEncodable { +public struct ConnectionInitRequest: Equatable, Codable { public let type: RequestMessageType = .connectionInit public let payload: InitPayload @@ -30,7 +30,7 @@ public struct ConnectionInitRequest: Equatable } /// A websocket `subscribe` request from the client to the server -public struct SubscribeRequest: Equatable, JsonEncodable { +public struct SubscribeRequest: Equatable, Codable { public let type = RequestMessageType.subscribe public let payload: GraphQLRequest public let id: String @@ -56,7 +56,7 @@ public struct SubscribeRequest: Equatable, JsonEncodable { } /// A websocket `complete` request from the client to the server -public struct CompleteRequest: Equatable, JsonEncodable { +public struct CompleteRequest: Equatable, Codable { public let type = RequestMessageType.complete public let id: String diff --git a/Sources/GraphQLTransportWS/Responses.swift b/Sources/GraphQLTransportWS/Responses.swift index 8a2bf3d..f24a59d 100644 --- a/Sources/GraphQLTransportWS/Responses.swift +++ b/Sources/GraphQLTransportWS/Responses.swift @@ -2,12 +2,12 @@ import Foundation import GraphQL /// A general response. This object's type is used to triage to other, more specific response objects. -public struct Response: Equatable, JsonEncodable { +public struct Response: Equatable, Codable { public let type: ResponseMessageType } /// A websocket `connection_ack` response from the server to the client -public struct ConnectionAckResponse: Equatable, JsonEncodable { +public struct ConnectionAckResponse: Equatable, Codable { public let type: ResponseMessageType = .connectionAck public let payload: [String: Map]? @@ -30,7 +30,7 @@ public struct ConnectionAckResponse: Equatable, JsonEncodable { } /// A websocket `next` response from the server to the client -public struct NextResponse: Equatable, JsonEncodable { +public struct NextResponse: Equatable, Codable { public let type: ResponseMessageType = .next public let payload: GraphQLResult? public let id: String @@ -56,7 +56,7 @@ public struct NextResponse: Equatable, JsonEncodable { } /// A websocket `complete` response from the server to the client -public struct CompleteResponse: Equatable, JsonEncodable { +public struct CompleteResponse: Equatable, Codable { public let type: ResponseMessageType = .complete public let id: String @@ -79,7 +79,7 @@ public struct CompleteResponse: Equatable, JsonEncodable { } /// A websocket `error` response from the server to the client -public struct ErrorResponse: Equatable, JsonEncodable { +public struct ErrorResponse: Equatable, Codable { public let type: ResponseMessageType = .error public let payload: [GraphQLError] public let id: String @@ -148,7 +148,7 @@ public struct ResponseMessageType: Equatable, Codable, Sendable { /// A websocket `error` response from the server to the client that indicates an issue with encoding /// a response JSON -struct EncodingErrorResponse: Equatable, Codable, JsonEncodable { +struct EncodingErrorResponse: Equatable, Codable { let type: ResponseMessageType let payload: [String: String] diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index 3e79e03..803f798 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -52,67 +52,75 @@ where self.onOperationError = onOperationError } + deinit { + subscriptionTasks.values.forEach { $0.cancel() } + } + /// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed. /// - Parameter incoming: The client message sequence that the server should react to. public func listen(to incoming: A) async throws - where A.Element == String { + where A.Element == Data { for try await message in incoming { - // Detect and ignore error responses. - if message.starts(with: "44") { - // TODO: Determine what to do with returned error messages - return - } + try await respond(to: message) + } + } - guard let json = message.data(using: .utf8) else { + /// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed. + /// - Parameter incoming: The client message sequence that the server should react to. + @available(*, deprecated, message: "Use `Data` sequence instead.") + public func listen(to incoming: A) async throws + where A.Element == String { + for try await stringMessage in incoming { + guard let message = stringMessage.data(using: .utf8) else { try await error(.invalidEncoding()) return } - let request: Request - do { - request = try decoder.decode(Request.self, from: json) - } catch { - try await self.error(.noType()) + try await respond(to: message) + } + } + + private func respond(to message: Data) async throws { + let request: Request + do { + request = try decoder.decode(Request.self, from: message) + } catch { + try await self.error(.noType()) + return + } + + // handle incoming message + switch request.type { + case .connectionInit: + guard + let connectionInitRequest = try? decoder.decode( + ConnectionInitRequest.self, + from: message + ) + else { + try await error(.invalidRequestFormat(messageType: .connectionInit)) return } - - // handle incoming message - switch request.type { - case .connectionInit: - guard - let connectionInitRequest = try? decoder.decode( - ConnectionInitRequest.self, - from: json - ) - else { - try await error(.invalidRequestFormat(messageType: .connectionInit)) - return - } - try await onConnectionInit(connectionInitRequest, messenger) - case .subscribe: - guard let subscribeRequest = try? decoder.decode(SubscribeRequest.self, from: json) - else { - try await error(.invalidRequestFormat(messageType: .subscribe)) - return - } - try await onSubscribe(subscribeRequest) - case .complete: - guard let completeRequest = try? decoder.decode(CompleteRequest.self, from: json) - else { - try await error(.invalidRequestFormat(messageType: .complete)) - return - } - try await onOperationComplete(completeRequest) - default: - try await error(.invalidType()) + try await onConnectionInit(connectionInitRequest, messenger) + case .subscribe: + guard let subscribeRequest = try? decoder.decode(SubscribeRequest.self, from: message) + else { + try await error(.invalidRequestFormat(messageType: .subscribe)) + return + } + try await onSubscribe(subscribeRequest) + case .complete: + guard let completeRequest = try? decoder.decode(CompleteRequest.self, from: message) + else { + try await error(.invalidRequestFormat(messageType: .complete)) + return } + try await onOperationComplete(completeRequest) + default: + try await error(.invalidType()) } } - deinit { - subscriptionTasks.values.forEach { $0.cancel() } - } - private func onConnectionInit( _ connectionInitRequest: ConnectionInitRequest, _: Messenger @@ -125,7 +133,7 @@ where do { initResult = try await onInit(connectionInitRequest.payload) } catch { - try await self.error(.unauthorized()) + try await self.error(.forbidden()) return } initialized = true @@ -198,26 +206,32 @@ where /// Send a `connection_ack` response through the messenger private func sendConnectionAck(_ payload: [String: Map]? = nil) async throws { try await messenger.send( - ConnectionAckResponse(payload: payload).toJSON(encoder) + encoder.encode( + ConnectionAckResponse(payload: payload) + ) ) } /// Send a `next` response through the messenger private func sendNext(_ payload: GraphQLResult? = nil, id: String) async throws { try await messenger.send( - NextResponse( - payload: payload, - id: id - ).toJSON(encoder) + encoder.encode( + NextResponse( + payload: payload, + id: id + ) + ) ) } /// Send a `complete` response through the messenger private func sendComplete(id: String) async throws { try await messenger.send( - CompleteResponse( - id: id - ).toJSON(encoder) + encoder.encode( + CompleteResponse( + id: id + ) + ) ) try await onOperationComplete(id) } @@ -225,10 +239,12 @@ where /// Send an `error` response through the messenger private func sendError(_ errors: [Error], id: String) async throws { try await messenger.send( - ErrorResponse( - errors, - id: id - ).toJSON(encoder) + encoder.encode( + ErrorResponse( + errors, + id: id + ) + ) ) try await onOperationError(id, errors) } diff --git a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift index 759c65b..a1c397d 100644 --- a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift +++ b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift @@ -28,21 +28,7 @@ struct GraphqlTransportWSTests { ).get() } ) - let (messageStream, messageContinuation) = AsyncThrowingStream - .makeStream() - let serverMessageStream = serverMessenger.stream.map { message in - messageContinuation.yield(message) - // Expect only one message - messageContinuation.finish() - return message - } - let client = Client( - messenger: clientMessenger, - onError: { message, _ in - messageContinuation.finish(throwing: message.payload[0]) - await clientMessenger.close() - } - ) + let client = Client(messenger: clientMessenger) let clientStream = clientMessenger.stream Task { try await server.listen(to: clientStream) @@ -59,13 +45,16 @@ struct GraphqlTransportWSTests { ), id: UUID().uuidString ) - try await client.listen(to: serverMessageStream) - let messages = try await messageStream.reduce(into: [String]()) { result, message in - result.append(message) + let error = await #expect(throws: TestMessengerError.self) { + try await client.listen(to: serverMessenger.stream) } #expect( - messages == ["\(ErrorCode.notInitialized): Connection not initialized"] + error + == TestMessengerError( + code: 4401, + message: "Connection not initialized" + ) ) } @@ -91,21 +80,7 @@ struct GraphqlTransportWSTests { ).get() } ) - let (messageStream, messageContinuation) = AsyncThrowingStream - .makeStream() - let serverMessageStream = serverMessenger.stream.map { message in - messageContinuation.yield(message) - // Expect only one message - messageContinuation.finish() - return message - } - let client = Client( - messenger: clientMessenger, - onError: { message, _ in - messageContinuation.finish(throwing: message.payload[0]) - await clientMessenger.close() - } - ) + let client = Client(messenger: clientMessenger) let clientStream = clientMessenger.stream Task { try await server.listen(to: clientStream) @@ -117,13 +92,16 @@ struct GraphqlTransportWSTests { authToken: "" ) ) - try await client.listen(to: serverMessageStream) - let messages = try await messageStream.reduce(into: [String]()) { result, message in - result.append(message) + let error = await #expect(throws: TestMessengerError.self) { + try await client.listen(to: serverMessenger.stream) } #expect( - messages == ["\(ErrorCode.unauthorized): Unauthorized"] + error + == TestMessengerError( + code: 4403, + message: "Forbidden" + ) ) } @@ -149,7 +127,7 @@ struct GraphqlTransportWSTests { ).get() } ) - let (messageStream, messageContinuation) = AsyncThrowingStream + let (messageStream, messageContinuation) = AsyncThrowingStream .makeStream() let serverMessageStream = serverMessenger.stream.map { message in messageContinuation.yield(message) @@ -191,7 +169,7 @@ struct GraphqlTransportWSTests { ) try await client.listen(to: serverMessageStream) - let messages = try await messageStream.reduce(into: [String]()) { result, message in + let messages = try await messageStream.reduce(into: [Data]()) { result, message in result.append(message) } #expect( @@ -226,7 +204,7 @@ struct GraphqlTransportWSTests { return subscription } ) - let (messageStream, messageContinuation) = AsyncThrowingStream + let (messageStream, messageContinuation) = AsyncThrowingStream .makeStream() // Used to extract the server messages let serverMessageStream = serverMessenger.stream.map { message in @@ -282,7 +260,7 @@ struct GraphqlTransportWSTests { ) try await client.listen(to: serverMessageStream) - let messages = try await messageStream.reduce(into: [String]()) { result, message in + let messages = try await messageStream.reduce(into: [Data]()) { result, message in result.append(message) } #expect( diff --git a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift index a27c1d5..15b7cfc 100644 --- a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift +++ b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift @@ -5,25 +5,29 @@ import Foundation /// Messenger for simple testing that doesn't require starting up a websocket server. actor TestMessenger: Messenger { /// An async stream of the messages sent through this messenger. - let stream: AsyncStream - private var continuation: AsyncStream.Continuation + let stream: AsyncThrowingStream + private var continuation: AsyncThrowingStream.Continuation init() { - let (stream, continuation) = AsyncStream.makeStream() + let (stream, continuation) = AsyncThrowingStream.makeStream() self.stream = stream self.continuation = continuation } - func send(_ message: S) async throws where S.Element == Character { - continuation.yield(String(message)) + func send(_ message: Data) async throws { + continuation.yield(message) } func error(_ message: String, code: Int) async throws { - continuation.yield("\(code): \(message)") - continuation.finish() + continuation.finish(throwing: TestMessengerError(code: code, message: message)) } func close() { continuation.finish() } } + +struct TestMessengerError: Error, Equatable { + let code: Int + let message: String +}