@@ -4,13 +4,15 @@ package controller
4
4
5
5
import (
6
6
"fmt"
7
+ "math/rand/v2"
7
8
"os"
8
9
"os/exec"
9
10
"strconv"
10
11
"strings"
11
12
12
13
"github.com/julien040/anyquery/namespace"
13
14
"github.com/julien040/go-ternary"
15
+ pg_query "github.com/pganalyze/pg_query_go/v5"
14
16
"vitess.io/vitess/go/vt/sqlparser"
15
17
)
16
18
@@ -219,11 +221,17 @@ func middlewareMySQL(queryData *QueryData) bool {
219
221
220
222
queryType , stmt , err := namespace .GetQueryType (queryData .SQLQuery )
221
223
if err != nil {
222
- // If we wan 't parse the query, we just pass it
224
+ // If we can 't parse the query, we just pass it to the next middleware
223
225
return true
224
226
}
225
227
if queryType == sqlparser .StmtShow {
226
228
queryData .SQLQuery , queryData .Args = namespace .RewriteShowStatement (stmt .(* sqlparser.Show ))
229
+ } else if queryType == sqlparser .StmtExplain {
230
+ // We rewrite the EXPLAIN/DESCRIBE statement
231
+ if explain , ok := stmt .(* sqlparser.ExplainTab ); ok {
232
+ queryData .SQLQuery = "SELECT * FROM pragma_table_info(?);"
233
+ queryData .Args = append (queryData .Args , explain .Table .Name .String ())
234
+ }
227
235
}
228
236
229
237
return true
@@ -241,6 +249,16 @@ func middlewareQuery(queryData *QueryData) bool {
241
249
return true
242
250
}
243
251
252
+ // Run the pre-execution statements
253
+ for i , preExec := range queryData .PreExec {
254
+ _ , err := queryData .DB .Exec (preExec )
255
+ if err != nil {
256
+ queryData .Message = fmt .Sprintf ("Error in pre-execution statement %d: %s" , i , err .Error ())
257
+ queryData .StatusCode = 2
258
+ return false
259
+ }
260
+ }
261
+
244
262
// Check whether the query must be run with Query or Exec
245
263
// We need to check that because, for example, a CREATE VIRTUAL TABLE statement run with Query
246
264
// will not return an error if it fails
@@ -282,6 +300,12 @@ func middlewareQuery(queryData *QueryData) bool {
282
300
queryData .Message = fmt .Sprintf ("Query executed successfully (%d %s affected)" , rowsAffected , ternary .If (rowsAffected > 1 , "rows" , "row" ))
283
301
}
284
302
}
303
+
304
+ // Note: we can't run the post-execution statements here
305
+ // because the result is not yet processed
306
+ //
307
+ // The post-execution statements are run at the end of the pipeline
308
+ // after the output was printed
285
309
return true
286
310
}
287
311
@@ -348,3 +372,284 @@ func middlewareSlashCommand(queryData *QueryData) bool {
348
372
return false
349
373
350
374
}
375
+
376
+ type tableFunction struct {
377
+ name string
378
+ args []string
379
+ position int
380
+ alias string
381
+ }
382
+
383
+ // Extract the table functions from the query parsed by pg_query
384
+ func extractTableFunctions (fromClause []* pg_query.Node ) []tableFunction {
385
+ result := []tableFunction {}
386
+
387
+ for ithTable , item := range fromClause {
388
+ funcCall := item .GetRangeFunction ()
389
+ if funcCall == nil {
390
+ continue
391
+ }
392
+ alias := ""
393
+ if funcCall .Alias != nil {
394
+ alias = funcCall .Alias .Aliasname
395
+ }
396
+ for _ , function1 := range funcCall .Functions {
397
+ // Get the function name
398
+ nodeList , ok := function1 .Node .(* pg_query.Node_List )
399
+ if ! ok {
400
+ continue
401
+ }
402
+
403
+ for _ , item := range nodeList .List .Items {
404
+ funcCall := item .GetFuncCall ()
405
+ if funcCall != nil {
406
+ if len (funcCall .Funcname ) < 1 {
407
+ continue
408
+ }
409
+ // Get the table name
410
+ tableName := funcCall .Funcname [0 ].GetString_ ().Sval
411
+ // Get args
412
+ args := []string {}
413
+ for _ , arg := range funcCall .Args {
414
+ // e.g. "foo", "bar"
415
+ columnRef := arg .GetColumnRef ()
416
+ if columnRef != nil {
417
+ args = append (args , columnRef .Fields [0 ].GetString_ ().Sval )
418
+ }
419
+
420
+ // e.g. 1, 'a", 1.0, true
421
+ constRef := arg .GetAConst ()
422
+ if constRef != nil {
423
+ svalStr := constRef .GetSval ()
424
+ if svalStr != nil {
425
+ args = append (args , svalStr .Sval )
426
+ }
427
+ svalBool := constRef .GetBoolval ()
428
+ if svalBool != nil {
429
+ if svalBool .Boolval {
430
+ args = append (args , "true" )
431
+ } else {
432
+ args = append (args , "false" )
433
+ }
434
+ }
435
+ svalInt := constRef .GetIval ()
436
+ if svalInt != nil {
437
+ args = append (args , strconv .Itoa (int (svalInt .Ival )))
438
+ }
439
+ svalFloat := constRef .GetFval ()
440
+ if svalFloat != nil {
441
+ args = append (args , svalFloat .Fval )
442
+ }
443
+ }
444
+
445
+ // e.g. foo = bar
446
+ exprRef := arg .GetAExpr ()
447
+ if exprRef != nil {
448
+ leftSide := ""
449
+ rightSide := ""
450
+ // Get the left side of the expression
451
+ left := exprRef .GetLexpr ()
452
+ if left != nil && left .GetColumnRef () != nil && len (left .GetColumnRef ().Fields ) > 0 {
453
+ leftSide = left .GetColumnRef ().Fields [0 ].GetString_ ().Sval
454
+ } else if left != nil && left .GetAConst () != nil {
455
+ if left .GetAConst () != nil {
456
+ if left .GetAConst ().GetIval () != nil {
457
+ leftSide = strconv .Itoa (int (left .GetAConst ().GetIval ().Ival ))
458
+ } else if left .GetAConst ().GetFval () != nil {
459
+ leftSide = left .GetAConst ().GetFval ().Fval
460
+ } else if left .GetAConst ().GetSval () != nil {
461
+ leftSide = left .GetAConst ().GetSval ().Sval
462
+ } else if left .GetAConst ().GetBoolval () != nil {
463
+ if left .GetAConst ().GetBoolval ().Boolval {
464
+ leftSide = "true"
465
+ } else {
466
+ leftSide = "false"
467
+ }
468
+ } else {
469
+ leftSide = "NULL"
470
+ }
471
+ }
472
+ }
473
+
474
+ // Get the right side of the expression
475
+ right := exprRef .GetRexpr ()
476
+ if right != nil && right .GetColumnRef () != nil && len (right .GetColumnRef ().Fields ) > 0 {
477
+ rightSide = right .GetColumnRef ().Fields [0 ].GetString_ ().Sval
478
+ } else if right != nil && right .GetAConst () != nil {
479
+ if right .GetAConst () != nil {
480
+ if right .GetAConst ().GetIval () != nil {
481
+ rightSide = strconv .Itoa (int (right .GetAConst ().GetIval ().Ival ))
482
+ } else if right .GetAConst ().GetFval () != nil {
483
+ rightSide = right .GetAConst ().GetFval ().Fval
484
+ } else if right .GetAConst ().GetSval () != nil {
485
+ rightSide = right .GetAConst ().GetSval ().Sval
486
+ } else if right .GetAConst ().GetBoolval () != nil {
487
+ if right .GetAConst ().GetBoolval ().Boolval {
488
+ rightSide = "true"
489
+ } else {
490
+ rightSide = "false"
491
+ }
492
+ } else {
493
+ rightSide = "NULL"
494
+ }
495
+
496
+ }
497
+ }
498
+
499
+ args = append (args , leftSide + " = " + rightSide )
500
+
501
+ }
502
+
503
+ }
504
+
505
+ result = append (result , tableFunction {
506
+ name : tableName ,
507
+ args : args ,
508
+ position : ithTable ,
509
+ alias : alias ,
510
+ })
511
+ }
512
+ }
513
+ }
514
+ }
515
+
516
+ return result
517
+ }
518
+
519
+ const alphabet = "abcdefghijklmnopqrstuvwxyz"
520
+
521
+ func generateRandomString (size int ) string {
522
+ result := strings.Builder {}
523
+ for i := 0 ; i < size ; i ++ {
524
+ result .WriteByte (alphabet [rand .IntN (len (alphabet ))])
525
+ }
526
+ return result .String ()
527
+
528
+ }
529
+
530
+ // Prefix the query like SELECT * FROM read_json with a CREATE VIRTUAL TABLE statement
531
+ func middlewareFileQuery (queryData * QueryData ) bool {
532
+ // # The problem
533
+ //
534
+ // To explain what this middleware does, let's take an example
535
+ // SELECT * FROM read_json('file.json') WHERE name = 'John'
536
+ //
537
+ // file.json has a schema that is not known by the database
538
+ // and SQLite does not support dynamic schema
539
+ // Therefore, we need to register a virtual table that will read the file
540
+ // and then we can query it
541
+ //
542
+ // # The solution
543
+ //
544
+ // The process is quite cumbersome. Therefore, we use a parser that detects
545
+ // the table functions and replaces them with a random table name
546
+ // Before running the query, we create a virtual table with the random name
547
+ // and once the query is executed, we drop the table
548
+ // That's a workaround around the limitation of SQLite
549
+ //
550
+ // # Implementation issues
551
+ //
552
+ // The vitess's parser is not able to parse queries
553
+ // with table functions like read_json() or read_csv()
554
+ //
555
+ // At first, I wanted to modify the parser to support
556
+ // these functions, but it was too complicated
557
+ // My knowledge of YACC is extremely limited
558
+ //
559
+ // As a temporary solution, I decided to parse these queries
560
+ // with pg_query (it adds 6MB to the binary size so it's not ideal)
561
+ // As the old saying goes, there is nothing more permanent than a temporary solution
562
+ // I hope this is not the case here
563
+ //
564
+ // There is quite a lot of spaghetti code in this middleware to explore the AST
565
+ // Couldn't find a better way to do it
566
+
567
+ // Parse the query
568
+ Result , err := pg_query .Parse (queryData .SQLQuery )
569
+ if err != nil {
570
+ return true
571
+ }
572
+
573
+ if Result == nil || len (Result .Stmts ) == 0 || Result .Stmts [0 ].Stmt == nil {
574
+ return true
575
+ }
576
+
577
+ selectStmt := Result .Stmts [0 ].Stmt .GetSelectStmt ()
578
+ if selectStmt == nil {
579
+ // To handle INSERT INTO SELECT
580
+ insertStmt := Result .Stmts [0 ].Stmt .GetInsertStmt ()
581
+ if insertStmt == nil {
582
+ // To handle CREATE TABLE AS SELECT
583
+ createTableStmt := Result .Stmts [0 ].Stmt .GetCreateTableAsStmt ()
584
+ if createTableStmt == nil {
585
+ return true
586
+ } else {
587
+ selectStmt = createTableStmt .Query .GetSelectStmt ()
588
+ if selectStmt == nil {
589
+ return true
590
+ }
591
+ }
592
+ } else {
593
+ selectStmt = insertStmt .SelectStmt .GetSelectStmt ()
594
+ if selectStmt == nil {
595
+ return true
596
+ }
597
+ }
598
+
599
+ }
600
+
601
+ // Get the from clause
602
+ tableFunctions := extractTableFunctions (selectStmt .FromClause )
603
+ for _ , tableFunction := range tableFunctions {
604
+ // Check if the table function is a file module
605
+ if tableFunction .name != "read_json" && tableFunction .name != "read_csv" {
606
+ continue
607
+ }
608
+
609
+ // Replace the table function with a random one
610
+ tableName := generateRandomString (16 )
611
+ preExecBuilder := strings.Builder {}
612
+ preExecBuilder .WriteString ("CREATE VIRTUAL TABLE " )
613
+ preExecBuilder .WriteString (tableName )
614
+ preExecBuilder .WriteString (" USING " )
615
+ if tableFunction .name == "read_json" {
616
+ preExecBuilder .WriteString ("json_reader" )
617
+ } else if tableFunction .name == "read_csv" {
618
+ preExecBuilder .WriteString ("csv_reader" )
619
+ }
620
+ preExecBuilder .WriteString ("(" )
621
+ for i , arg := range tableFunction .args {
622
+ if i > 0 {
623
+ preExecBuilder .WriteString (", " )
624
+ }
625
+ preExecBuilder .WriteRune ('"' )
626
+ preExecBuilder .WriteString (arg )
627
+ preExecBuilder .WriteRune ('"' )
628
+ }
629
+ preExecBuilder .WriteString (");" )
630
+
631
+ // Add the pre-execution statement
632
+ queryData .PreExec = append (queryData .PreExec , preExecBuilder .String ())
633
+
634
+ // Add a post-execution statement to drop the table
635
+ queryData .PostExec = append (queryData .PostExec , "DROP TABLE " + tableName + ";" )
636
+
637
+ // Replace the table function with the new table name
638
+ var tempTableName * pg_query.Node
639
+ if tableFunction .alias == "" {
640
+ tempTableName = pg_query .MakeSimpleRangeVarNode (tableName , int32 (tableFunction .position ))
641
+ } else {
642
+ tempTableName = pg_query .MakeFullRangeVarNode ("" , tableName , tableFunction .alias , int32 (tableFunction .position ))
643
+ }
644
+ selectStmt .FromClause [tableFunction .position ] = tempTableName
645
+
646
+ }
647
+
648
+ newQuery , err := pg_query .Deparse (Result )
649
+ if err != nil {
650
+ return true
651
+ }
652
+ queryData .SQLQuery = newQuery
653
+
654
+ return true
655
+ }
0 commit comments