Skip to content

Commit 0c18ce5

Browse files
tylerni7holiman
authored andcommitted
eth, rpc: add configurable option for wsMessageSizeLimit (ethereum#27801)
This change adds a configurable limit to websocket message. --------- Co-authored-by: Martin Holst Swende <[email protected]>
1 parent 728a990 commit 0c18ce5

5 files changed

+85
-8
lines changed

rpc/client_opt.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ type clientConfig struct {
3434
httpAuth HTTPAuth
3535

3636
// WebSocket options
37-
wsDialer *websocket.Dialer
37+
wsDialer *websocket.Dialer
38+
wsMessageSizeLimit *int64 // wsMessageSizeLimit nil = default, 0 = no limit
3839

3940
// RPC handler options
4041
idgen func() ID
@@ -66,6 +67,14 @@ func WithWebsocketDialer(dialer websocket.Dialer) ClientOption {
6667
})
6768
}
6869

70+
// WithWebsocketMessageSizeLimit configures the websocket message size limit used by the RPC
71+
// client. Passing a limit of 0 means no limit.
72+
func WithWebsocketMessageSizeLimit(messageSizeLimit int64) ClientOption {
73+
return optionFunc(func(cfg *clientConfig) {
74+
cfg.wsMessageSizeLimit = &messageSizeLimit
75+
})
76+
}
77+
6978
// WithHeader configures HTTP headers set by the RPC client. Headers set using this option
7079
// will be used for both HTTP and WebSocket connections.
7180
func WithHeader(key, value string) ClientOption {

rpc/server_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func TestServerRegisterName(t *testing.T) {
4545
t.Fatalf("Expected service calc to be registered")
4646
}
4747

48-
wantCallbacks := 13
48+
wantCallbacks := 14
4949
if len(svc.callbacks) != wantCallbacks {
5050
t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks))
5151
}

rpc/testservice_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ func (s *testService) EchoWithCtx(ctx context.Context, str string, i int, args *
9090
return echoResult{str, i, args}
9191
}
9292

93+
func (s *testService) Repeat(msg string, i int) string {
94+
return strings.Repeat(msg, i)
95+
}
96+
9397
func (s *testService) PeerInfo(ctx context.Context) PeerInfo {
9498
return PeerInfoFromContext(ctx)
9599
}

rpc/websocket.go

+9-5
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ const (
3838
wsPingInterval = 30 * time.Second
3939
wsPingWriteTimeout = 5 * time.Second
4040
wsPongTimeout = 30 * time.Second
41-
wsMessageSizeLimit = 32 * 1024 * 1024
41+
wsDefaultReadLimit = 32 * 1024 * 1024
4242
)
4343

4444
var wsBufferPool = new(sync.Pool)
@@ -60,7 +60,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
6060
log.Debug("WebSocket upgrade failed", "err", err)
6161
return
6262
}
63-
codec := newWebsocketCodec(conn, r.Host, r.Header)
63+
codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit)
6464
s.ServeCodec(codec, 0)
6565
})
6666
}
@@ -251,7 +251,11 @@ func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, er
251251
}
252252
return nil, hErr
253253
}
254-
return newWebsocketCodec(conn, dialURL, header), nil
254+
messageSizeLimit := int64(wsDefaultReadLimit)
255+
if cfg.wsMessageSizeLimit != nil && *cfg.wsMessageSizeLimit >= 0 {
256+
messageSizeLimit = *cfg.wsMessageSizeLimit
257+
}
258+
return newWebsocketCodec(conn, dialURL, header, messageSizeLimit), nil
255259
}
256260
return connect, nil
257261
}
@@ -283,8 +287,8 @@ type websocketCodec struct {
283287
pongReceived chan struct{}
284288
}
285289

286-
func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec {
287-
conn.SetReadLimit(wsMessageSizeLimit)
290+
func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header, readLimit int64) ServerCodec {
291+
conn.SetReadLimit(readLimit)
288292
encode := func(v interface{}, isErrorResponse bool) error {
289293
return conn.WriteJSON(v)
290294
}

rpc/websocket_test.go

+61-1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,66 @@ func TestWebsocketLargeCall(t *testing.T) {
113113
}
114114
}
115115

116+
// This test checks whether the wsMessageSizeLimit option is obeyed.
117+
func TestWebsocketLargeRead(t *testing.T) {
118+
t.Parallel()
119+
120+
var (
121+
srv = newTestServer()
122+
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}))
123+
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
124+
)
125+
defer srv.Stop()
126+
defer httpsrv.Close()
127+
128+
testLimit := func(limit *int64) {
129+
opts := []ClientOption{}
130+
expLimit := int64(wsDefaultReadLimit)
131+
if limit != nil && *limit >= 0 {
132+
opts = append(opts, WithWebsocketMessageSizeLimit(*limit))
133+
if *limit > 0 {
134+
expLimit = *limit // 0 means infinite
135+
}
136+
}
137+
client, err := DialOptions(context.Background(), wsURL, opts...)
138+
if err != nil {
139+
t.Fatalf("can't dial: %v", err)
140+
}
141+
defer client.Close()
142+
// Remove some bytes for json encoding overhead.
143+
underLimit := int(expLimit - 128)
144+
overLimit := expLimit + 1
145+
if expLimit == wsDefaultReadLimit {
146+
// No point trying the full 32MB in tests. Just sanity-check that
147+
// it's not obviously limited.
148+
underLimit = 1024
149+
overLimit = -1
150+
}
151+
var res string
152+
// Check under limit
153+
if err = client.Call(&res, "test_repeat", "A", underLimit); err != nil {
154+
t.Fatalf("unexpected error with limit %d: %v", expLimit, err)
155+
}
156+
if len(res) != underLimit || strings.Count(res, "A") != underLimit {
157+
t.Fatal("incorrect data")
158+
}
159+
// Check over limit
160+
if overLimit > 0 {
161+
err = client.Call(&res, "test_repeat", "A", expLimit+1)
162+
if err == nil || err != websocket.ErrReadLimit {
163+
t.Fatalf("wrong error with limit %d: %v expecting %v", expLimit, err, websocket.ErrReadLimit)
164+
}
165+
}
166+
}
167+
ptr := func(v int64) *int64 { return &v }
168+
169+
testLimit(ptr(-1)) // Should be ignored (use default)
170+
testLimit(ptr(0)) // Should be ignored (use default)
171+
testLimit(nil) // Should be ignored (use default)
172+
testLimit(ptr(200))
173+
testLimit(ptr(wsDefaultReadLimit * 2))
174+
}
175+
116176
func TestWebsocketPeerInfo(t *testing.T) {
117177
var (
118178
s = newTestServer()
@@ -206,7 +266,7 @@ func TestClientWebsocketLargeMessage(t *testing.T) {
206266
defer srv.Stop()
207267
defer httpsrv.Close()
208268

209-
respLength := wsMessageSizeLimit - 50
269+
respLength := wsDefaultReadLimit - 50
210270
srv.RegisterName("test", largeRespService{respLength})
211271

212272
c, err := DialWebsocket(context.Background(), wsURL, "")

0 commit comments

Comments
 (0)