@@ -20,6 +20,7 @@ import (
20
20
"errors"
21
21
"math/big"
22
22
"sync"
23
+ "time"
23
24
24
25
"github.com/ethereum/go-ethereum/common"
25
26
"github.com/ethereum/go-ethereum/eth/downloader"
38
39
// to the peer set, but one with the same id already exists.
39
40
errPeerAlreadyRegistered = errors .New ("peer already registered" )
40
41
42
+ // errPeerWaitTimeout is returned if a peer waits extension for too long
43
+ errPeerWaitTimeout = errors .New ("peer wait timeout" )
44
+
41
45
// errPeerNotRegistered is returned if a peer is attempted to be removed from
42
46
// a peer set, but no peer with the given id exists.
43
47
errPeerNotRegistered = errors .New ("peer not registered" )
51
55
errDiffWithoutEth = errors .New ("peer connected on diff without compatible eth support" )
52
56
)
53
57
58
+ const (
59
+ // extensionWaitTimeout is the maximum allowed time for the extension wait to
60
+ // complete before dropping the connection.= as malicious.
61
+ extensionWaitTimeout = 5 * time .Second
62
+ )
63
+
54
64
// peerSet represents the collection of active peers currently participating in
55
65
// the `eth` protocol, with or without the `snap` extension.
56
66
type peerSet struct {
@@ -169,7 +179,18 @@ func (ps *peerSet) waitSnapExtension(peer *eth.Peer) (*snap.Peer, error) {
169
179
ps .snapWait [id ] = wait
170
180
ps .lock .Unlock ()
171
181
172
- return <- wait , nil
182
+ select {
183
+ case peer := <- wait :
184
+ return peer , nil
185
+
186
+ case <- time .After (extensionWaitTimeout ):
187
+ ps .lock .Lock ()
188
+ if _ , ok := ps .snapWait [id ]; ok {
189
+ delete (ps .snapWait , id )
190
+ }
191
+ ps .lock .Unlock ()
192
+ return nil , errPeerWaitTimeout
193
+ }
173
194
}
174
195
175
196
// waitDiffExtension blocks until all satellite protocols are connected and tracked
@@ -203,7 +224,18 @@ func (ps *peerSet) waitDiffExtension(peer *eth.Peer) (*diff.Peer, error) {
203
224
ps .diffWait [id ] = wait
204
225
ps .lock .Unlock ()
205
226
206
- return <- wait , nil
227
+ select {
228
+ case peer := <- wait :
229
+ return peer , nil
230
+
231
+ case <- time .After (extensionWaitTimeout ):
232
+ ps .lock .Lock ()
233
+ if _ , ok := ps .diffWait [id ]; ok {
234
+ delete (ps .diffWait , id )
235
+ }
236
+ ps .lock .Unlock ()
237
+ return nil , errPeerWaitTimeout
238
+ }
207
239
}
208
240
209
241
func (ps * peerSet ) GetDiffPeer (pid string ) downloader.IDiffPeer {
0 commit comments