Skip to content

Commit df4b967

Browse files
committed
✨ Add middleware to allow dynamic schema
1 parent edfe9dc commit df4b967

File tree

7 files changed

+325
-6
lines changed

7 files changed

+325
-6
lines changed

controller/middleware.go

+306-1
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ package controller
44

55
import (
66
"fmt"
7+
"math/rand/v2"
78
"os"
89
"os/exec"
910
"strconv"
1011
"strings"
1112

1213
"github.com/julien040/anyquery/namespace"
1314
"github.com/julien040/go-ternary"
15+
pg_query "github.com/pganalyze/pg_query_go/v5"
1416
"vitess.io/vitess/go/vt/sqlparser"
1517
)
1618

@@ -219,11 +221,17 @@ func middlewareMySQL(queryData *QueryData) bool {
219221

220222
queryType, stmt, err := namespace.GetQueryType(queryData.SQLQuery)
221223
if err != nil {
222-
// If we wan't parse the query, we just pass it
224+
// If we can't parse the query, we just pass it to the next middleware
223225
return true
224226
}
225227
if queryType == sqlparser.StmtShow {
226228
queryData.SQLQuery, queryData.Args = namespace.RewriteShowStatement(stmt.(*sqlparser.Show))
229+
} else if queryType == sqlparser.StmtExplain {
230+
// We rewrite the EXPLAIN/DESCRIBE statement
231+
if explain, ok := stmt.(*sqlparser.ExplainTab); ok {
232+
queryData.SQLQuery = "SELECT * FROM pragma_table_info(?);"
233+
queryData.Args = append(queryData.Args, explain.Table.Name.String())
234+
}
227235
}
228236

229237
return true
@@ -241,6 +249,16 @@ func middlewareQuery(queryData *QueryData) bool {
241249
return true
242250
}
243251

252+
// Run the pre-execution statements
253+
for i, preExec := range queryData.PreExec {
254+
_, err := queryData.DB.Exec(preExec)
255+
if err != nil {
256+
queryData.Message = fmt.Sprintf("Error in pre-execution statement %d: %s", i, err.Error())
257+
queryData.StatusCode = 2
258+
return false
259+
}
260+
}
261+
244262
// Check whether the query must be run with Query or Exec
245263
// We need to check that because, for example, a CREATE VIRTUAL TABLE statement run with Query
246264
// will not return an error if it fails
@@ -282,6 +300,12 @@ func middlewareQuery(queryData *QueryData) bool {
282300
queryData.Message = fmt.Sprintf("Query executed successfully (%d %s affected)", rowsAffected, ternary.If(rowsAffected > 1, "rows", "row"))
283301
}
284302
}
303+
304+
// Note: we can't run the post-execution statements here
305+
// because the result is not yet processed
306+
//
307+
// The post-execution statements are run at the end of the pipeline
308+
// after the output was printed
285309
return true
286310
}
287311

@@ -348,3 +372,284 @@ func middlewareSlashCommand(queryData *QueryData) bool {
348372
return false
349373

350374
}
375+
376+
type tableFunction struct {
377+
name string
378+
args []string
379+
position int
380+
alias string
381+
}
382+
383+
// Extract the table functions from the query parsed by pg_query
384+
func extractTableFunctions(fromClause []*pg_query.Node) []tableFunction {
385+
result := []tableFunction{}
386+
387+
for ithTable, item := range fromClause {
388+
funcCall := item.GetRangeFunction()
389+
if funcCall == nil {
390+
continue
391+
}
392+
alias := ""
393+
if funcCall.Alias != nil {
394+
alias = funcCall.Alias.Aliasname
395+
}
396+
for _, function1 := range funcCall.Functions {
397+
// Get the function name
398+
nodeList, ok := function1.Node.(*pg_query.Node_List)
399+
if !ok {
400+
continue
401+
}
402+
403+
for _, item := range nodeList.List.Items {
404+
funcCall := item.GetFuncCall()
405+
if funcCall != nil {
406+
if len(funcCall.Funcname) < 1 {
407+
continue
408+
}
409+
// Get the table name
410+
tableName := funcCall.Funcname[0].GetString_().Sval
411+
// Get args
412+
args := []string{}
413+
for _, arg := range funcCall.Args {
414+
// e.g. "foo", "bar"
415+
columnRef := arg.GetColumnRef()
416+
if columnRef != nil {
417+
args = append(args, columnRef.Fields[0].GetString_().Sval)
418+
}
419+
420+
// e.g. 1, 'a", 1.0, true
421+
constRef := arg.GetAConst()
422+
if constRef != nil {
423+
svalStr := constRef.GetSval()
424+
if svalStr != nil {
425+
args = append(args, svalStr.Sval)
426+
}
427+
svalBool := constRef.GetBoolval()
428+
if svalBool != nil {
429+
if svalBool.Boolval {
430+
args = append(args, "true")
431+
} else {
432+
args = append(args, "false")
433+
}
434+
}
435+
svalInt := constRef.GetIval()
436+
if svalInt != nil {
437+
args = append(args, strconv.Itoa(int(svalInt.Ival)))
438+
}
439+
svalFloat := constRef.GetFval()
440+
if svalFloat != nil {
441+
args = append(args, svalFloat.Fval)
442+
}
443+
}
444+
445+
// e.g. foo = bar
446+
exprRef := arg.GetAExpr()
447+
if exprRef != nil {
448+
leftSide := ""
449+
rightSide := ""
450+
// Get the left side of the expression
451+
left := exprRef.GetLexpr()
452+
if left != nil && left.GetColumnRef() != nil && len(left.GetColumnRef().Fields) > 0 {
453+
leftSide = left.GetColumnRef().Fields[0].GetString_().Sval
454+
} else if left != nil && left.GetAConst() != nil {
455+
if left.GetAConst() != nil {
456+
if left.GetAConst().GetIval() != nil {
457+
leftSide = strconv.Itoa(int(left.GetAConst().GetIval().Ival))
458+
} else if left.GetAConst().GetFval() != nil {
459+
leftSide = left.GetAConst().GetFval().Fval
460+
} else if left.GetAConst().GetSval() != nil {
461+
leftSide = left.GetAConst().GetSval().Sval
462+
} else if left.GetAConst().GetBoolval() != nil {
463+
if left.GetAConst().GetBoolval().Boolval {
464+
leftSide = "true"
465+
} else {
466+
leftSide = "false"
467+
}
468+
} else {
469+
leftSide = "NULL"
470+
}
471+
}
472+
}
473+
474+
// Get the right side of the expression
475+
right := exprRef.GetRexpr()
476+
if right != nil && right.GetColumnRef() != nil && len(right.GetColumnRef().Fields) > 0 {
477+
rightSide = right.GetColumnRef().Fields[0].GetString_().Sval
478+
} else if right != nil && right.GetAConst() != nil {
479+
if right.GetAConst() != nil {
480+
if right.GetAConst().GetIval() != nil {
481+
rightSide = strconv.Itoa(int(right.GetAConst().GetIval().Ival))
482+
} else if right.GetAConst().GetFval() != nil {
483+
rightSide = right.GetAConst().GetFval().Fval
484+
} else if right.GetAConst().GetSval() != nil {
485+
rightSide = right.GetAConst().GetSval().Sval
486+
} else if right.GetAConst().GetBoolval() != nil {
487+
if right.GetAConst().GetBoolval().Boolval {
488+
rightSide = "true"
489+
} else {
490+
rightSide = "false"
491+
}
492+
} else {
493+
rightSide = "NULL"
494+
}
495+
496+
}
497+
}
498+
499+
args = append(args, leftSide+" = "+rightSide)
500+
501+
}
502+
503+
}
504+
505+
result = append(result, tableFunction{
506+
name: tableName,
507+
args: args,
508+
position: ithTable,
509+
alias: alias,
510+
})
511+
}
512+
}
513+
}
514+
}
515+
516+
return result
517+
}
518+
519+
const alphabet = "abcdefghijklmnopqrstuvwxyz"
520+
521+
func generateRandomString(size int) string {
522+
result := strings.Builder{}
523+
for i := 0; i < size; i++ {
524+
result.WriteByte(alphabet[rand.IntN(len(alphabet))])
525+
}
526+
return result.String()
527+
528+
}
529+
530+
// Prefix the query like SELECT * FROM read_json with a CREATE VIRTUAL TABLE statement
531+
func middlewareFileQuery(queryData *QueryData) bool {
532+
// # The problem
533+
//
534+
// To explain what this middleware does, let's take an example
535+
// SELECT * FROM read_json('file.json') WHERE name = 'John'
536+
//
537+
// file.json has a schema that is not known by the database
538+
// and SQLite does not support dynamic schema
539+
// Therefore, we need to register a virtual table that will read the file
540+
// and then we can query it
541+
//
542+
// # The solution
543+
//
544+
// The process is quite cumbersome. Therefore, we use a parser that detects
545+
// the table functions and replaces them with a random table name
546+
// Before running the query, we create a virtual table with the random name
547+
// and once the query is executed, we drop the table
548+
// That's a workaround around the limitation of SQLite
549+
//
550+
// # Implementation issues
551+
//
552+
// The vitess's parser is not able to parse queries
553+
// with table functions like read_json() or read_csv()
554+
//
555+
// At first, I wanted to modify the parser to support
556+
// these functions, but it was too complicated
557+
// My knowledge of YACC is extremely limited
558+
//
559+
// As a temporary solution, I decided to parse these queries
560+
// with pg_query (it adds 6MB to the binary size so it's not ideal)
561+
// As the old saying goes, there is nothing more permanent than a temporary solution
562+
// I hope this is not the case here
563+
//
564+
// There is quite a lot of spaghetti code in this middleware to explore the AST
565+
// Couldn't find a better way to do it
566+
567+
// Parse the query
568+
Result, err := pg_query.Parse(queryData.SQLQuery)
569+
if err != nil {
570+
return true
571+
}
572+
573+
if Result == nil || len(Result.Stmts) == 0 || Result.Stmts[0].Stmt == nil {
574+
return true
575+
}
576+
577+
selectStmt := Result.Stmts[0].Stmt.GetSelectStmt()
578+
if selectStmt == nil {
579+
// To handle INSERT INTO SELECT
580+
insertStmt := Result.Stmts[0].Stmt.GetInsertStmt()
581+
if insertStmt == nil {
582+
// To handle CREATE TABLE AS SELECT
583+
createTableStmt := Result.Stmts[0].Stmt.GetCreateTableAsStmt()
584+
if createTableStmt == nil {
585+
return true
586+
} else {
587+
selectStmt = createTableStmt.Query.GetSelectStmt()
588+
if selectStmt == nil {
589+
return true
590+
}
591+
}
592+
} else {
593+
selectStmt = insertStmt.SelectStmt.GetSelectStmt()
594+
if selectStmt == nil {
595+
return true
596+
}
597+
}
598+
599+
}
600+
601+
// Get the from clause
602+
tableFunctions := extractTableFunctions(selectStmt.FromClause)
603+
for _, tableFunction := range tableFunctions {
604+
// Check if the table function is a file module
605+
if tableFunction.name != "read_json" && tableFunction.name != "read_csv" {
606+
continue
607+
}
608+
609+
// Replace the table function with a random one
610+
tableName := generateRandomString(16)
611+
preExecBuilder := strings.Builder{}
612+
preExecBuilder.WriteString("CREATE VIRTUAL TABLE ")
613+
preExecBuilder.WriteString(tableName)
614+
preExecBuilder.WriteString(" USING ")
615+
if tableFunction.name == "read_json" {
616+
preExecBuilder.WriteString("json_reader")
617+
} else if tableFunction.name == "read_csv" {
618+
preExecBuilder.WriteString("csv_reader")
619+
}
620+
preExecBuilder.WriteString("(")
621+
for i, arg := range tableFunction.args {
622+
if i > 0 {
623+
preExecBuilder.WriteString(", ")
624+
}
625+
preExecBuilder.WriteRune('"')
626+
preExecBuilder.WriteString(arg)
627+
preExecBuilder.WriteRune('"')
628+
}
629+
preExecBuilder.WriteString(");")
630+
631+
// Add the pre-execution statement
632+
queryData.PreExec = append(queryData.PreExec, preExecBuilder.String())
633+
634+
// Add a post-execution statement to drop the table
635+
queryData.PostExec = append(queryData.PostExec, "DROP TABLE "+tableName+";")
636+
637+
// Replace the table function with the new table name
638+
var tempTableName *pg_query.Node
639+
if tableFunction.alias == "" {
640+
tempTableName = pg_query.MakeSimpleRangeVarNode(tableName, int32(tableFunction.position))
641+
} else {
642+
tempTableName = pg_query.MakeFullRangeVarNode("", tableName, tableFunction.alias, int32(tableFunction.position))
643+
}
644+
selectStmt.FromClause[tableFunction.position] = tempTableName
645+
646+
}
647+
648+
newQuery, err := pg_query.Deparse(Result)
649+
if err != nil {
650+
return true
651+
}
652+
queryData.SQLQuery = newQuery
653+
654+
return true
655+
}

controller/query.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func Query(cmd *cobra.Command, args []string) error {
130130
// Create the shell
131131
shell := shell{
132132
DB: db,
133-
Middlewares: []middleware{middlewareSlashCommand, middlewareDotCommand, middlewareMySQL, middlewareQuery},
133+
Middlewares: []middleware{middlewareSlashCommand, middlewareDotCommand, middlewareMySQL, middlewareFileQuery, middlewareQuery},
134134
Config: middlewareConfiguration{
135135
"dot-command": true,
136136
"mysql": true,

controller/shell.go

+9-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ type QueryData struct {
2626

2727
// The query to exec before/after the query
2828
// Useful to create temporary tables
29-
preExec, postExec string
29+
PreExec, PostExec []string
3030

3131
// The query to run
3232
SQLQuery string
@@ -302,6 +302,14 @@ func (p *shell) Run(rawQuery string) bool {
302302

303303
}
304304

305+
// Run all the post exec queries
306+
for _, postExec := range queryData.PostExec {
307+
_, err := queryData.DB.Exec(postExec)
308+
if err != nil {
309+
fmt.Fprintf(tempOutput, "Error running post exec query: %s\n", err.Error())
310+
}
311+
}
312+
305313
// We print a newline to separate the queries
306314
// unless it's the last query
307315
if i != len(queries)-1 {

go.mod

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ require (
2323
github.com/julien040/go-ternary v0.0.0-20230119180150-f0435f66948e
2424
github.com/mattn/go-sqlite3 v1.14.22
2525
github.com/olekukonko/tablewriter v0.0.5
26+
github.com/pganalyze/pg_query_go/v5 v5.1.0
2627
github.com/santhosh-tekuri/jsonschema/v5 v5.3.1
2728
github.com/spf13/cobra v1.8.0
2829
github.com/spf13/pflag v1.0.5

0 commit comments

Comments
 (0)