Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: handle max allowed packet for cast_as_binary #8014

Closed
wants to merge 10 commits into from
24 changes: 22 additions & 2 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -268,7 +269,12 @@ func (c *castAsStringFunctionClass) getFunction(ctx sessionctx.Context, args []E
bf := newBaseBuiltinFunc(ctx, args)
bf.tp = c.tp
if args[0].GetType().Hybrid() || IsBinaryLiteral(args[0]) {
sig = &builtinCastStringAsStringSig{bf}
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinCastStringAsStringSig{bf, maxAllowedPacket}
sig.setPbCode(tipb.ScalarFuncSig_CastStringAsString)
return sig, nil
}
Expand All @@ -293,7 +299,12 @@ func (c *castAsStringFunctionClass) getFunction(ctx sessionctx.Context, args []E
sig = &builtinCastJSONAsStringSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastJsonAsString)
case types.ETString:
sig = &builtinCastStringAsStringSig{bf}
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinCastStringAsStringSig{bf, maxAllowedPacket}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these be extracted into a function?

sig.setPbCode(tipb.ScalarFuncSig_CastStringAsString)
default:
panic("unsupported types.EvalType in castAsStringFunctionClass")
Expand Down Expand Up @@ -1007,11 +1018,13 @@ func (b *builtinCastDecimalAsDurationSig) evalDuration(row chunk.Row) (res types

type builtinCastStringAsStringSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinCastStringAsStringSig) Clone() builtinFunc {
newSig := &builtinCastStringAsStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -1021,6 +1034,13 @@ func (b *builtinCastStringAsStringSig) evalString(row chunk.Row) (res string, is
return res, isNull, errors.Trace(err)
}
sc := b.ctx.GetSessionVars().StmtCtx

isTooLarge := uint64(len(res)) > b.maxAllowedPacket || (b.tp.Flen != types.UnspecifiedLength && uint64(b.tp.Flen) > b.maxAllowedPacket)
if types.IsBinaryStr(b.tp) && isTooLarge {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("cast_as_binary", b.maxAllowedPacket))
return "", true, nil
}

res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, sc)
return res, false, errors.Trace(err)
}
Expand Down
4 changes: 2 additions & 2 deletions expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
case 5:
sig = &builtinCastJSONAsStringSig{stringFunc}
case 6:
sig = &builtinCastStringAsStringSig{stringFunc}
sig = &builtinCastStringAsStringSig{stringFunc, 1 << 20}
}
res, isNull, err := sig.evalString(t.row.ToRow())
c.Assert(isNull, Equals, false)
Expand Down Expand Up @@ -731,7 +731,7 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) {
sig = &builtinCastDurationAsStringSig{stringFunc}
case 5:
stringFunc.tp.Charset = charset.CharsetUTF8
sig = &builtinCastStringAsStringSig{stringFunc}
sig = &builtinCastStringAsStringSig{stringFunc, 1 << 20}
}
res, isNull, err := sig.evalString(t.row.ToRow())
c.Assert(isNull, Equals, false)
Expand Down
2 changes: 0 additions & 2 deletions expression/builtin_time_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package expression

import (
"fmt"
"math"
"strings"
"time"
Expand Down Expand Up @@ -1615,7 +1614,6 @@ func (s *testEvaluatorSuite) TestUnixTimestamp(c *C) {
}

for _, test := range tests {
fmt.Printf("Begin Test %v\n", test)
expr := s.datumsToConstants([]types.Datum{test.input})
expr[0].GetType().Decimal = test.inputDecimal
resetStmtContext(s.ctx)
Expand Down
9 changes: 8 additions & 1 deletion expression/distsql_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ package expression

import (
"fmt"
"strconv"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/mock"
Expand Down Expand Up @@ -115,7 +117,12 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti
case tipb.ScalarFuncSig_CastTimeAsString:
f = &builtinCastTimeAsStringSig{base}
case tipb.ScalarFuncSig_CastStringAsString:
f = &builtinCastStringAsStringSig{base}
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}
f = &builtinCastStringAsStringSig{base, maxAllowedPacket}
case tipb.ScalarFuncSig_CastJsonAsString:
f = &builtinCastJSONAsStringSig{base}

Expand Down
3 changes: 3 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,9 @@ func (s *testIntegrationSuite) TestStringBuiltin(c *C) {
result = tk.MustQuery("select a,b,concat_ws(',',a,b) from t")
result.Check(testkit.Rows("114.57011441 38.04620115 114.57011441,38.04620115",
"-38.04620119 38.04620115 -38.04620119,38.04620115"))

result = tk.MustQuery("SELECT CAST('a' AS BINARY(67108865));")
result.Check(testkit.Rows("<nil>"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the result of show warnings here?

}

func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) {
Expand Down
1 change: 1 addition & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,7 @@ func CreateSession4Test(store kv.Storage) (Session, error) {
if err == nil {
// initialize session variables for test.
s.GetSessionVars().MaxChunkSize = 2
s.GetSessionVars().SetSystemVar(variable.MaxAllowedPacket, variable.SysVars[variable.MaxAllowedPacket].Value)
}
return s, errors.Trace(err)
}
Expand Down
1 change: 1 addition & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ func NewSessionVars() *SessionVars {
enableStreaming = "0"
}
terror.Log(vars.SetSystemVar(TiDBEnableStreaming, enableStreaming))
terror.Log(vars.SetSystemVar(MaxAllowedPacket, SysVars[MaxAllowedPacket].Value))
return vars
}

Expand Down