-
Notifications
You must be signed in to change notification settings - Fork 577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ability to handle context cancellations for TCP protocol #1389
Changes from 12 commits
36c73bf
73195bc
35f7a54
239f6e6
13043be
19d1a3b
f45be27
8e55989
d6e9e26
14ebd32
7591db5
410ed1b
b10e3fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ _testmain.go | |
|
||
coverage.txt | ||
.idea/** | ||
.vscode/** | ||
dev/* | ||
.run/** | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ import ( | |
"log" | ||
"net" | ||
"os" | ||
"sync" | ||
"syscall" | ||
"time" | ||
|
||
|
@@ -42,6 +43,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er | |
conn net.Conn | ||
debugf = func(format string, v ...any) {} | ||
) | ||
|
||
switch { | ||
case opt.DialContext != nil: | ||
conn, err = opt.DialContext(ctx, addr) | ||
|
@@ -53,9 +55,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er | |
conn, err = net.DialTimeout("tcp", addr, opt.DialTimeout) | ||
} | ||
} | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if opt.Debug { | ||
if opt.Debugf != nil { | ||
debugf = func(format string, v ...any) { | ||
|
@@ -68,6 +72,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er | |
debugf = log.New(os.Stdout, fmt.Sprintf("[clickhouse][conn=%d][%s]", num, conn.RemoteAddr()), 0).Printf | ||
} | ||
} | ||
|
||
compression := CompressionNone | ||
if opt.Compression != nil { | ||
switch opt.Compression.Method { | ||
|
@@ -96,9 +101,11 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er | |
maxCompressionBuffer: opt.MaxCompressionBuffer, | ||
} | ||
) | ||
|
||
if err := connect.handshake(opt.Auth.Database, opt.Auth.Username, opt.Auth.Password); err != nil { | ||
return nil, err | ||
} | ||
|
||
if connect.revision >= proto.DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM { | ||
if err := connect.sendAddendum(); err != nil { | ||
return nil, err | ||
|
@@ -109,6 +116,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er | |
if num == 1 && !resources.ClientMeta.IsSupportedClickHouseVersion(connect.server.Version) { | ||
debugf("[handshake] WARNING: version %v of ClickHouse is not supported by this client - client supports %v", connect.server.Version, resources.ClientMeta.SupportedVersions()) | ||
} | ||
|
||
return connect, nil | ||
} | ||
|
||
|
@@ -131,6 +139,8 @@ type connect struct { | |
readTimeout time.Duration | ||
blockBufferSize uint8 | ||
maxCompressionBuffer int | ||
mutex sync.Mutex | ||
mutexClose sync.Mutex | ||
} | ||
|
||
func (c *connect) settings(querySettings Settings) []proto.Setting { | ||
|
@@ -153,15 +163,16 @@ func (c *connect) settings(querySettings Settings) []proto.Setting { | |
for k, v := range c.opt.Settings { | ||
settings = append(settings, settingToProtoSetting(k, v)) | ||
} | ||
|
||
for k, v := range querySettings { | ||
settings = append(settings, settingToProtoSetting(k, v)) | ||
} | ||
|
||
return settings | ||
} | ||
|
||
func (c *connect) isBad() bool { | ||
switch { | ||
case c.closed: | ||
if c.isClosed() { | ||
return true | ||
} | ||
|
||
|
@@ -172,19 +183,43 @@ func (c *connect) isBad() bool { | |
if err := c.connCheck(); err != nil { | ||
return true | ||
} | ||
|
||
return false | ||
} | ||
|
||
func (c *connect) isClosed() bool { | ||
c.mutexClose.Lock() | ||
defer c.mutexClose.Unlock() | ||
|
||
return c.closed | ||
} | ||
|
||
func (c *connect) setClosed() { | ||
c.mutexClose.Lock() | ||
defer c.mutexClose.Unlock() | ||
|
||
c.closed = true | ||
} | ||
|
||
func (c *connect) close() error { | ||
c.mutexClose.Lock() | ||
if c.closed { | ||
c.mutexClose.Unlock() | ||
return nil | ||
} | ||
c.closed = true | ||
c.buffer = nil | ||
c.reader = nil | ||
c.mutexClose.Unlock() | ||
|
||
if err := c.conn.Close(); err != nil { | ||
return err | ||
} | ||
|
||
c.buffer = nil | ||
|
||
c.mutex.Lock() | ||
c.reader = nil | ||
c.mutex.Unlock() | ||
|
||
return nil | ||
} | ||
|
||
|
@@ -193,6 +228,7 @@ func (c *connect) progress() (*Progress, error) { | |
if err := progress.Decode(c.reader, c.revision); err != nil { | ||
return nil, err | ||
} | ||
|
||
c.debugf("[progress] %s", &progress) | ||
return &progress, nil | ||
} | ||
|
@@ -202,6 +238,7 @@ func (c *connect) exception() error { | |
if err := e.Decode(c.reader); err != nil { | ||
return err | ||
} | ||
|
||
c.debugf("[exception] %s", e.Error()) | ||
return &e | ||
} | ||
|
@@ -218,6 +255,18 @@ func (c *connect) compressBuffer(start int) error { | |
} | ||
|
||
func (c *connect) sendData(block *proto.Block, name string) error { | ||
if c.isClosed() { | ||
err := errors.New("attempted sending on closed connection") | ||
c.debugf("[send data] err: %v", err) | ||
return err | ||
} | ||
|
||
if c.buffer == nil { | ||
err := errors.New("attempted sending on nil buffer") | ||
c.debugf("[send data] err: %v", err) | ||
return err | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it something you found when while debugging or only added for safety? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, this is outdated code, please check the latest changes it was all found during debugging, yes. note:
|
||
|
||
c.debugf("[send data] compression=%q", c.compression) | ||
c.buffer.PutByte(proto.ClientData) | ||
c.buffer.PutString(name) | ||
|
@@ -227,6 +276,7 @@ func (c *connect) sendData(block *proto.Block, name string) error { | |
if err := block.EncodeHeader(c.buffer, c.revision); err != nil { | ||
return err | ||
} | ||
|
||
for i := range block.Columns { | ||
if err := block.EncodeColumn(c.buffer, c.revision, i); err != nil { | ||
return err | ||
|
@@ -242,33 +292,50 @@ func (c *connect) sendData(block *proto.Block, name string) error { | |
compressionOffset = 0 | ||
} | ||
} | ||
|
||
if err := c.compressBuffer(compressionOffset); err != nil { | ||
return err | ||
} | ||
|
||
if err := c.flush(); err != nil { | ||
switch { | ||
case errors.Is(err, syscall.EPIPE): | ||
c.debugf("[send data] pipe is broken, closing connection") | ||
c.closed = true | ||
c.setClosed() | ||
case errors.Is(err, io.EOF): | ||
c.debugf("[send data] unexpected EOF, closing connection") | ||
c.closed = true | ||
c.setClosed() | ||
default: | ||
c.debugf("[send data] unexpected error: %v", err) | ||
} | ||
return err | ||
} | ||
|
||
defer func() { | ||
c.buffer.Reset() | ||
}() | ||
|
||
return nil | ||
} | ||
|
||
func (c *connect) readData(ctx context.Context, packet byte, compressible bool) (*proto.Block, error) { | ||
if c.isClosed() { | ||
err := errors.New("attempted reading on closed connection") | ||
c.debugf("[read data] err: %v", err) | ||
return nil, err | ||
} | ||
|
||
if c.reader == nil { | ||
err := errors.New("attempted reading on nil reader") | ||
c.debugf("[read data] err: %v", err) | ||
return nil, err | ||
} | ||
|
||
if _, err := c.reader.Str(); err != nil { | ||
c.debugf("[read data] str error: %v", err) | ||
return nil, err | ||
} | ||
|
||
if compressible && c.compression != CompressionNone { | ||
c.reader.EnableCompression() | ||
defer c.reader.DisableCompression() | ||
|
@@ -285,6 +352,7 @@ func (c *connect) readData(ctx context.Context, packet byte, compressible bool) | |
c.debugf("[read data] decode error: %v", err) | ||
return nil, err | ||
} | ||
|
||
block.Packet = packet | ||
c.debugf("[read data] compression=%q. block: columns=%d, rows=%d", c.compression, len(block.Columns), block.Rows()) | ||
return &block, nil | ||
|
@@ -295,10 +363,12 @@ func (c *connect) flush() error { | |
// Nothing to flush. | ||
return nil | ||
} | ||
|
||
n, err := c.conn.Write(c.buffer.Buf) | ||
if err != nil { | ||
return errors.Wrap(err, "write") | ||
} | ||
|
||
if n != len(c.buffer.Buf) { | ||
return errors.New("wrote less than expected") | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you think it makes sense to name it explicit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, good idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done