diff --git a/go.mod b/go.mod index 3aa884062e..08591aa838 100644 --- a/go.mod +++ b/go.mod @@ -130,3 +130,5 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect lukechampine.com/blake3 v1.3.0 // indirect ) + +replace github.com/pion/webrtc/v3 => ../webrtc diff --git a/go.sum b/go.sum index dc2487ca9e..41cd14aeef 100644 --- a/go.sum +++ b/go.sum @@ -316,8 +316,6 @@ github.com/pion/transport/v3 v3.0.6/go.mod h1:HvJr2N/JwNJAfipsRleqwFoR3t/pWyHeZU github.com/pion/turn/v2 v2.1.3/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY= github.com/pion/turn/v2 v2.1.6 h1:Xr2niVsiPTB0FPtt+yAWKFUkU1eotQbGgpTIld4x1Gc= github.com/pion/turn/v2 v2.1.6/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY= -github.com/pion/webrtc/v3 v3.2.50 h1:C/rwL2mBfCxHv6tlLzDAO3krJpQXfVx8A8WHnGJ2j34= -github.com/pion/webrtc/v3 v3.2.50/go.mod h1:dytYYoSBy7ZUWhJMbndx9UckgYvzNAfL7xgVnrIKxqo= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 297c239793..7cfab5f3ca 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -770,3 +770,31 @@ func TestConnDroppedWhenBlocked(t *testing.T) { }) } } + +// TestConnClosedWhenRemoteCloses tests that a connection is closed locally when it's closed by remote +func TestConnClosedWhenRemoteCloses(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { + server := tc.HostGenerator(t, TransportTestCaseOpts{}) + client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + defer server.Close() + defer client.Close() + + client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := client.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: server.Addrs()}) + require.NoError(t, err) + + require.Eventually(t, func() bool { + return server.Network().Connectedness(client.ID()) != network.NotConnected + }, 5*time.Second, 50*time.Millisecond) + for _, c := range client.Network().ConnsToPeer(server.ID()) { + c.Close() + } + require.Eventually(t, func() bool { + return server.Network().Connectedness(client.ID()) == network.NotConnected + }, 5*time.Second, 50*time.Millisecond) + }) + } +} diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index e73a236a30..f40e4460ae 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -16,7 +16,6 @@ import ( ma "github.com/multiformats/go-multiaddr" "github.com/pion/datachannel" - "github.com/pion/sctp" "github.com/pion/webrtc/v3" ) @@ -32,6 +31,8 @@ func (errConnectionTimeout) Error() string { return "connection timeout" } func (errConnectionTimeout) Timeout() bool { return true } func (errConnectionTimeout) Temporary() bool { return false } +var errConnClosed = errors.New("connection closed") + type dataChannel struct { stream datachannel.ReadWriteCloser channel *webrtc.DataChannel @@ -56,7 +57,8 @@ type connection struct { streams map[uint16]*stream nextStreamID atomic.Int32 - acceptQueue chan dataChannel + acceptQueue chan dataChannel + peerConnectionClosedCh chan struct{} ctx context.Context cancel context.CancelFunc @@ -75,6 +77,7 @@ func newConnection( remoteKey ic.PubKey, remoteMultiaddr ma.Multiaddr, incomingDataChannels chan dataChannel, + peerConnectionClosedCh chan struct{}, ) (*connection, error) { ctx, cancel := context.WithCancel(context.Background()) c := &connection{ @@ -103,6 +106,18 @@ func newConnection( } pc.OnConnectionStateChange(c.onConnectionStateChange) + pc.SCTP().OnClose(func(err error) { + if err != nil { + c.closeWithError(fmt.Errorf("%w: %w", errConnClosed, err)) + } + c.closeWithError(errConnClosed) + }) + select { + case <-peerConnectionClosedCh: + c.Close() + return nil, errConnClosed + default: + } return c, nil } @@ -113,27 +128,29 @@ func (c *connection) ConnState() network.ConnectionState { // Close closes the underlying peerconnection. func (c *connection) Close() error { - c.closeOnce.Do(func() { c.closeWithError(errors.New("connection closed")) }) + c.closeWithError(errConnClosed) return nil } // closeWithError is used to Close the connection when the underlying DTLS connection fails func (c *connection) closeWithError(err error) { - c.closeErr = err - // cancel must be called after closeErr is set. This ensures interested goroutines waiting on - // ctx.Done can read closeErr without holding the conn lock. - c.cancel() - // closing peerconnection will close the datachannels associated with the streams - c.pc.Close() - - c.m.Lock() - streams := c.streams - c.streams = nil - c.m.Unlock() - for _, s := range streams { - s.closeForShutdown(err) - } - c.scope.Done() + c.closeOnce.Do(func() { + c.closeErr = err + // cancel must be called after closeErr is set. This ensures interested goroutines waiting on + // ctx.Done can read closeErr without holding the conn lock. + c.cancel() + // closing peerconnection will close the datachannels associated with the streams + c.pc.Close() + + c.m.Lock() + streams := c.streams + c.streams = nil + c.m.Unlock() + for _, s := range streams { + s.closeForShutdown(err) + } + c.scope.Done() + }) } func (c *connection) IsClosed() bool { @@ -152,11 +169,6 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error streamID := uint16(id) dc, err := c.pc.CreateDataChannel("", &webrtc.DataChannelInit{ID: &streamID}) if err != nil { - if errors.Is(err, sctp.ErrStreamClosed) { - c.closeOnce.Do(func() { - c.closeWithError(errors.New("connection closed")) - }) - } return nil, err } rwc, err := c.detachChannel(ctx, dc) @@ -215,9 +227,7 @@ func (c *connection) removeStream(id uint16) { func (c *connection) onConnectionStateChange(state webrtc.PeerConnectionState) { if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { - c.closeOnce.Do(func() { - c.closeWithError(errConnectionTimeout{}) - }) + c.closeWithError(errConnectionTimeout{}) } } diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index 3f465b34fc..96174f3457 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -276,6 +276,7 @@ func (l *listener) setupConnection( remotePubKey, remoteMultiaddr, w.IncomingDataChannels, + w.PeerConnectionClosedCh, ) if err != nil { return nil, err diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index 8aaa93e43a..3f23be5b80 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -399,6 +399,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement remotePubKey, remoteMultiaddrWithoutCerthash, w.IncomingDataChannels, + w.PeerConnectionClosedCh, ) if err != nil { return nil, err @@ -572,9 +573,10 @@ func detachHandshakeDataChannel(ctx context.Context, dc *webrtc.DataChannel) (da // a small window of time where datachannels created by the peer may not surface to us and cause a // memory leak. type webRTCConnection struct { - PeerConnection *webrtc.PeerConnection - HandshakeDataChannel *webrtc.DataChannel - IncomingDataChannels chan dataChannel + PeerConnection *webrtc.PeerConnection + HandshakeDataChannel *webrtc.DataChannel + IncomingDataChannels chan dataChannel + PeerConnectionClosedCh chan struct{} } func newWebRTCConnection(settings webrtc.SettingEngine, config webrtc.Configuration) (webRTCConnection, error) { @@ -613,10 +615,20 @@ func newWebRTCConnection(settings webrtc.SettingEngine, config webrtc.Configurat } }) }) + + connectionClosedCh := make(chan struct{}, 1) + pc.SCTP().OnClose(func(err error) { + // We only need one message. Closing a connection is a problem as pion might invoke the callback more than once. + select { + case connectionClosedCh <- struct{}{}: + default: + } + }) return webRTCConnection{ - PeerConnection: pc, - HandshakeDataChannel: handshakeDataChannel, - IncomingDataChannels: incomingDataChannels, + PeerConnection: pc, + HandshakeDataChannel: handshakeDataChannel, + IncomingDataChannels: incomingDataChannels, + PeerConnectionClosedCh: connectionClosedCh, }, nil }