Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: graphql_transport_ws protocol should send 'complete' to end subscription #2320

Merged
merged 3 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Sources/ApolloTestSupport/MockWebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ public class MockWebSocket: WebSocketClient {
public var delegate: WebSocketClientDelegate? = nil
public var isConnected: Bool = false

public required init(request: URLRequest) {
public required init(request: URLRequest, protocol: WebSocket.WSProtocol) {
self.request = request

self.request.setValue(`protocol`.description, forHTTPHeaderField: WebSocket.Constants.headerWSProtocolName)
}

open func reportDidConnect() {
Expand Down
32 changes: 27 additions & 5 deletions Sources/ApolloWebSocket/WebSocketTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,14 @@ public class WebSocketTransport {
autoPersistQuery: false)
let identifier = operationMessageIdCreator.requestId()

var type: OperationMessage.Types = .start
if case WebSocket.WSProtocol.graphql_transport_ws.description = websocket.request.value(forHTTPHeaderField: WebSocket.Constants.headerWSProtocolName) {
type = .subscribe
let messageType: OperationMessage.Types
switch websocket.request.wsProtocol {
case .graphql_ws: messageType = .start
case .graphql_transport_ws: messageType = .subscribe
default: return nil
}

guard let message = OperationMessage(payload: body, id: identifier, type: type).rawMessage else {
guard let message = OperationMessage(payload: body, id: identifier, type: messageType).rawMessage else {
return nil
}

Expand All @@ -302,7 +304,13 @@ public class WebSocketTransport {
}

public func unsubscribe(_ subscriptionId: String) {
let str = OperationMessage(id: subscriptionId, type: .stop).rawMessage
let messageType: OperationMessage.Types
switch websocket.request.wsProtocol {
case .graphql_transport_ws: messageType = .complete
default: messageType = .stop
}

let str = OperationMessage(id: subscriptionId, type: messageType).rawMessage

processingQueue.async {
if let str = str {
Expand Down Expand Up @@ -359,6 +367,20 @@ public class WebSocketTransport {
}
}

extension URLRequest {
fileprivate var wsProtocol: WebSocket.WSProtocol? {
guard let header = value(forHTTPHeaderField: WebSocket.Constants.headerWSProtocolName) else {
return nil
}

switch header {
case WebSocket.WSProtocol.graphql_transport_ws.description: return .graphql_transport_ws
case WebSocket.WSProtocol.graphql_ws.description: return .graphql_ws
default: return nil
}
}
}

// MARK: - NetworkTransport conformance

extension WebSocketTransport: NetworkTransport {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ class StarWarsSubscriptionTests: XCTestCase {
func testConcurrentConnectAndCloseConnection() {
let webSocketTransport = WebSocketTransport(
websocket: MockWebSocket(
request: URLRequest(url: TestServerURL.starWarsWebSocket.url)
request: URLRequest(url: TestServerURL.starWarsWebSocket.url),
protocol: .graphql_ws
),
store: ApolloStore()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class GraphqlTransportWsProtocolTests: WSProtocolTestsBase {
return request
}

private func buildWebSocket() {
buildWebSocket(protocol: .graphql_transport_ws)
}

// MARK: Initializer Tests

func test__designatedInitializer__shouldSetRequestProtocolHeader() {
Expand Down Expand Up @@ -123,7 +127,7 @@ class GraphqlTransportWsProtocolTests: WSProtocolTestsBase {
}
}

func test__messaging__givenSubscriptionCancel_shouldSendStop() {
func test__messaging__givenSubscriptionCancel_shouldSendComplete() {
// given
buildWebSocket()
buildClient()
Expand All @@ -136,7 +140,7 @@ class GraphqlTransportWsProtocolTests: WSProtocolTestsBase {
waitUntil { done in
self.mockWebSocketDelegate.didReceiveMessage = { message in
// then
let expected = OperationMessage(id: "1", type: .stop).rawMessage!
let expected = OperationMessage(id: "1", type: .complete).rawMessage!
if message == expected {
done()
}
Expand Down
4 changes: 4 additions & 0 deletions Tests/ApolloTests/WebSocket/GraphqlWsProtocolTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class GraphqlWsProtocolTests: WSProtocolTestsBase {
return request
}

private func buildWebSocket() {
buildWebSocket(protocol: .graphql_ws)
}

// MARK: Initializer Tests

func test__designatedInitializer__shouldSetRequestProtocolHeader() {
Expand Down
4 changes: 2 additions & 2 deletions Tests/ApolloTests/WebSocket/WSProtocolTestsBase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class WSProtocolTestsBase: XCTestCase {
fatalError("Subclasses must override this property!")
}

func buildWebSocket() {
func buildWebSocket(protocol: WebSocket.WSProtocol) {
mockWebSocketDelegate = MockWebSocketDelegate()
mockWebSocket = MockWebSocket(request: urlRequest)
mockWebSocket = MockWebSocket(request: urlRequest, protocol: `protocol`)
websocketTransport = WebSocketTransport(websocket: mockWebSocket, store: store)
}

Expand Down
10 changes: 8 additions & 2 deletions Tests/ApolloTests/WebSocket/WebSocketTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ class WebSocketTests: XCTestCase {
super.setUp()

let store = ApolloStore()
let websocket = MockWebSocket(request:URLRequest(url: TestURL.mockServer.url))
let websocket = MockWebSocket(
request:URLRequest(url: TestURL.mockServer.url),
protocol: .graphql_ws
)
networkTransport = WebSocketTransport(websocket: websocket, store: store)
client = ApolloClient(networkTransport: networkTransport!, store: store)
}
Expand Down Expand Up @@ -133,7 +136,10 @@ class WebSocketTests: XCTestCase {
let expectation = self.expectation(description: "Single Subscription with Custom Operation Message Id Creator")

let store = ApolloStore()
let websocket = MockWebSocket(request:URLRequest(url: TestURL.mockServer.url))
let websocket = MockWebSocket(
request:URLRequest(url: TestURL.mockServer.url),
protocol: .graphql_ws
)
networkTransport = WebSocketTransport(websocket: websocket, store: store, operationMessageIdCreator: CustomOperationMessageIdCreator())
client = ApolloClient(networkTransport: networkTransport!, store: store)

Expand Down
22 changes: 14 additions & 8 deletions Tests/ApolloTests/WebSocket/WebSocketTransportTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ class WebSocketTransportTests: XCTestCase {
var request = URLRequest(url: TestURL.mockServer.url)
request.addValue("OldToken", forHTTPHeaderField: "Authorization")

self.webSocketTransport = WebSocketTransport(websocket: MockWebSocket(request: request),
store: ApolloStore())
self.webSocketTransport = WebSocketTransport(
websocket: MockWebSocket(request: request, protocol: .graphql_ws),
store: ApolloStore()
)

self.webSocketTransport.updateHeaderValues(["Authorization": "UpdatedToken"])

Expand All @@ -28,9 +30,11 @@ class WebSocketTransportTests: XCTestCase {
func testUpdateConnectingPayload() {
let request = URLRequest(url: TestURL.mockServer.url)

self.webSocketTransport = WebSocketTransport(websocket: MockWebSocket(request: request),
store: ApolloStore(),
connectingPayload: ["Authorization": "OldToken"])
self.webSocketTransport = WebSocketTransport(
websocket: MockWebSocket(request: request, protocol: .graphql_ws),
store: ApolloStore(),
connectingPayload: ["Authorization": "OldToken"]
)

let mockWebSocketDelegate = MockWebSocketDelegate()

Expand Down Expand Up @@ -59,9 +63,11 @@ class WebSocketTransportTests: XCTestCase {
func testCloseConnectionAndInit() {
let request = URLRequest(url: TestURL.mockServer.url)

self.webSocketTransport = WebSocketTransport(websocket: MockWebSocket(request: request),
store: ApolloStore(),
connectingPayload: ["Authorization": "OldToken"])
self.webSocketTransport = WebSocketTransport(
websocket: MockWebSocket(request: request, protocol: .graphql_ws),
store: ApolloStore(),
connectingPayload: ["Authorization": "OldToken"]
)
self.webSocketTransport.closeConnection()
self.webSocketTransport.updateConnectingPayload(["Authorization": "UpdatedToken"])
self.webSocketTransport.initServer()
Expand Down