Skip to content

Commit

Permalink
interim commit
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Aug 9, 2024
1 parent 78dc873 commit 11429a6
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 34 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,5 @@ require (
gopkg.in/yaml.v3 v3.0.1 // indirect
lukechampine.com/blake3 v1.3.0 // indirect
)

replace github.com/pion/webrtc/v3 => ../webrtc
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,6 @@ github.com/pion/transport/v3 v3.0.6/go.mod h1:HvJr2N/JwNJAfipsRleqwFoR3t/pWyHeZU
github.com/pion/turn/v2 v2.1.3/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY=
github.com/pion/turn/v2 v2.1.6 h1:Xr2niVsiPTB0FPtt+yAWKFUkU1eotQbGgpTIld4x1Gc=
github.com/pion/turn/v2 v2.1.6/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY=
github.com/pion/webrtc/v3 v3.2.50 h1:C/rwL2mBfCxHv6tlLzDAO3krJpQXfVx8A8WHnGJ2j34=
github.com/pion/webrtc/v3 v3.2.50/go.mod h1:dytYYoSBy7ZUWhJMbndx9UckgYvzNAfL7xgVnrIKxqo=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
Expand Down
28 changes: 28 additions & 0 deletions p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -770,3 +770,31 @@ func TestConnDroppedWhenBlocked(t *testing.T) {
})
}
}

// TestConnClosedWhenRemoteCloses tests that a connection is closed locally when it's closed by remote
func TestConnClosedWhenRemoteCloses(t *testing.T) {
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
server := tc.HostGenerator(t, TransportTestCaseOpts{})
client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
defer server.Close()
defer client.Close()

client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := client.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: server.Addrs()})
require.NoError(t, err)

require.Eventually(t, func() bool {
return server.Network().Connectedness(client.ID()) != network.NotConnected
}, 5*time.Second, 50*time.Millisecond)
for _, c := range client.Network().ConnsToPeer(server.ID()) {
c.Close()
}
require.Eventually(t, func() bool {
return server.Network().Connectedness(client.ID()) == network.NotConnected
}, 5*time.Second, 50*time.Millisecond)
})
}
}
62 changes: 36 additions & 26 deletions p2p/transport/webrtc/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (

ma "github.com/multiformats/go-multiaddr"
"github.com/pion/datachannel"
"github.com/pion/sctp"
"github.com/pion/webrtc/v3"
)

Expand All @@ -32,6 +31,8 @@ func (errConnectionTimeout) Error() string { return "connection timeout" }
func (errConnectionTimeout) Timeout() bool { return true }
func (errConnectionTimeout) Temporary() bool { return false }

var errConnClosed = errors.New("connection closed")

type dataChannel struct {
stream datachannel.ReadWriteCloser
channel *webrtc.DataChannel
Expand All @@ -56,7 +57,8 @@ type connection struct {
streams map[uint16]*stream
nextStreamID atomic.Int32

acceptQueue chan dataChannel
acceptQueue chan dataChannel
peerConnectionClosedCh chan struct{}

ctx context.Context
cancel context.CancelFunc
Expand All @@ -75,6 +77,7 @@ func newConnection(
remoteKey ic.PubKey,
remoteMultiaddr ma.Multiaddr,
incomingDataChannels chan dataChannel,
peerConnectionClosedCh chan struct{},
) (*connection, error) {
ctx, cancel := context.WithCancel(context.Background())
c := &connection{
Expand Down Expand Up @@ -103,6 +106,18 @@ func newConnection(
}

pc.OnConnectionStateChange(c.onConnectionStateChange)
pc.SCTP().OnClose(func(err error) {
if err != nil {
c.closeWithError(fmt.Errorf("%w: %w", errConnClosed, err))
}
c.closeWithError(errConnClosed)
})
select {
case <-peerConnectionClosedCh:
c.Close()
return nil, errConnClosed
default:
}
return c, nil
}

Expand All @@ -113,27 +128,29 @@ func (c *connection) ConnState() network.ConnectionState {

// Close closes the underlying peerconnection.
func (c *connection) Close() error {
c.closeOnce.Do(func() { c.closeWithError(errors.New("connection closed")) })
c.closeWithError(errConnClosed)
return nil
}

// closeWithError is used to Close the connection when the underlying DTLS connection fails
func (c *connection) closeWithError(err error) {
c.closeErr = err
// cancel must be called after closeErr is set. This ensures interested goroutines waiting on
// ctx.Done can read closeErr without holding the conn lock.
c.cancel()
// closing peerconnection will close the datachannels associated with the streams
c.pc.Close()

c.m.Lock()
streams := c.streams
c.streams = nil
c.m.Unlock()
for _, s := range streams {
s.closeForShutdown(err)
}
c.scope.Done()
c.closeOnce.Do(func() {
c.closeErr = err
// cancel must be called after closeErr is set. This ensures interested goroutines waiting on
// ctx.Done can read closeErr without holding the conn lock.
c.cancel()
// closing peerconnection will close the datachannels associated with the streams
c.pc.Close()

c.m.Lock()
streams := c.streams
c.streams = nil
c.m.Unlock()
for _, s := range streams {
s.closeForShutdown(err)
}
c.scope.Done()
})
}

func (c *connection) IsClosed() bool {
Expand All @@ -152,11 +169,6 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error
streamID := uint16(id)
dc, err := c.pc.CreateDataChannel("", &webrtc.DataChannelInit{ID: &streamID})
if err != nil {
if errors.Is(err, sctp.ErrStreamClosed) {
c.closeOnce.Do(func() {
c.closeWithError(errors.New("connection closed"))
})
}
return nil, err
}
rwc, err := c.detachChannel(ctx, dc)
Expand Down Expand Up @@ -215,9 +227,7 @@ func (c *connection) removeStream(id uint16) {

func (c *connection) onConnectionStateChange(state webrtc.PeerConnectionState) {
if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed {
c.closeOnce.Do(func() {
c.closeWithError(errConnectionTimeout{})
})
c.closeWithError(errConnectionTimeout{})
}
}

Expand Down
1 change: 1 addition & 0 deletions p2p/transport/webrtc/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ func (l *listener) setupConnection(
remotePubKey,
remoteMultiaddr,
w.IncomingDataChannels,
w.PeerConnectionClosedCh,
)
if err != nil {
return nil, err
Expand Down
24 changes: 18 additions & 6 deletions p2p/transport/webrtc/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
remotePubKey,
remoteMultiaddrWithoutCerthash,
w.IncomingDataChannels,
w.PeerConnectionClosedCh,
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -572,9 +573,10 @@ func detachHandshakeDataChannel(ctx context.Context, dc *webrtc.DataChannel) (da
// a small window of time where datachannels created by the peer may not surface to us and cause a
// memory leak.
type webRTCConnection struct {
PeerConnection *webrtc.PeerConnection
HandshakeDataChannel *webrtc.DataChannel
IncomingDataChannels chan dataChannel
PeerConnection *webrtc.PeerConnection
HandshakeDataChannel *webrtc.DataChannel
IncomingDataChannels chan dataChannel
PeerConnectionClosedCh chan struct{}
}

func newWebRTCConnection(settings webrtc.SettingEngine, config webrtc.Configuration) (webRTCConnection, error) {
Expand Down Expand Up @@ -613,10 +615,20 @@ func newWebRTCConnection(settings webrtc.SettingEngine, config webrtc.Configurat
}
})
})

connectionClosedCh := make(chan struct{}, 1)
pc.SCTP().OnClose(func(err error) {
// We only need one message. Closing a connection is a problem as pion might invoke the callback more than once.
select {
case connectionClosedCh <- struct{}{}:
default:
}
})
return webRTCConnection{
PeerConnection: pc,
HandshakeDataChannel: handshakeDataChannel,
IncomingDataChannels: incomingDataChannels,
PeerConnection: pc,
HandshakeDataChannel: handshakeDataChannel,
IncomingDataChannels: incomingDataChannels,
PeerConnectionClosedCh: connectionClosedCh,
}, nil
}

Expand Down

0 comments on commit 11429a6

Please sign in to comment.