Skip to content

Commit 4674007

Browse files
gh-79579: Improve DML query detection in sqlite3 (#93623)
The fix involves using pysqlite_check_remaining_sql(), not only to check for multiple statements, but now also to strip leading comments and whitespace from SQL statements, so we can improve DML query detection. pysqlite_check_remaining_sql() is renamed lstrip_sql(), to more accurately reflect its function, and hardened to handle more SQL comment corner cases.
1 parent e566ce5 commit 4674007

File tree

3 files changed

+103
-84
lines changed

3 files changed

+103
-84
lines changed

Lib/test/test_sqlite3/test_dbapi.py

+56-10
Original file line numberDiff line numberDiff line change
@@ -746,22 +746,44 @@ def test_execute_illegal_sql(self):
746746
with self.assertRaises(sqlite.OperationalError):
747747
self.cu.execute("select asdf")
748748

749-
def test_execute_too_much_sql(self):
750-
self.assertRaisesRegex(sqlite.ProgrammingError,
751-
"You can only execute one statement at a time",
752-
self.cu.execute, "select 5+4; select 4+5")
753-
754-
def test_execute_too_much_sql2(self):
755-
self.cu.execute("select 5+4; -- foo bar")
749+
def test_execute_multiple_statements(self):
750+
msg = "You can only execute one statement at a time"
751+
dataset = (
752+
"select 1; select 2",
753+
"select 1; // c++ comments are not allowed",
754+
"select 1; *not a comment",
755+
"select 1; -*not a comment",
756+
"select 1; /* */ a",
757+
"select 1; /**/a",
758+
"select 1; -",
759+
"select 1; /",
760+
"select 1; -\n- select 2",
761+
"""select 1;
762+
-- comment
763+
select 2
764+
""",
765+
)
766+
for query in dataset:
767+
with self.subTest(query=query):
768+
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
769+
self.cu.execute(query)
756770

757-
def test_execute_too_much_sql3(self):
758-
self.cu.execute("""
771+
def test_execute_with_appended_comments(self):
772+
dataset = (
773+
"select 1; -- foo bar",
774+
"select 1; --",
775+
"select 1; /*", # Unclosed comments ending in \0 are skipped.
776+
"""
759777
select 5+4;
760778
761779
/*
762780
foo
763781
*/
764-
""")
782+
""",
783+
)
784+
for query in dataset:
785+
with self.subTest(query=query):
786+
self.cu.execute(query)
765787

766788
def test_execute_wrong_sql_arg(self):
767789
with self.assertRaises(TypeError):
@@ -906,6 +928,30 @@ def test_rowcount_update_returning(self):
906928
self.assertEqual(self.cu.fetchone()[0], 1)
907929
self.assertEqual(self.cu.rowcount, 1)
908930

931+
def test_rowcount_prefixed_with_comment(self):
932+
# gh-79579: rowcount is updated even if query is prefixed with comments
933+
self.cu.execute("""
934+
-- foo
935+
insert into test(name) values ('foo'), ('foo')
936+
""")
937+
self.assertEqual(self.cu.rowcount, 2)
938+
self.cu.execute("""
939+
/* -- messy *r /* /* ** *- *--
940+
*/
941+
/* one more */ insert into test(name) values ('messy')
942+
""")
943+
self.assertEqual(self.cu.rowcount, 1)
944+
self.cu.execute("/* bar */ update test set name='bar' where name='foo'")
945+
self.assertEqual(self.cu.rowcount, 3)
946+
947+
def test_rowcount_vaccuum(self):
948+
data = ((1,), (2,), (3,))
949+
self.cu.executemany("insert into test(income) values(?)", data)
950+
self.assertEqual(self.cu.rowcount, 3)
951+
self.cx.commit()
952+
self.cu.execute("vacuum")
953+
self.assertEqual(self.cu.rowcount, -1)
954+
909955
def test_total_changes(self):
910956
self.cu.execute("insert into test(name) values ('foo')")
911957
self.cu.execute("insert into test(name) values ('foo')")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
:mod:`sqlite3` now correctly detects DML queries with leading comments.
2+
Patch by Erlend E. Aasland.

Modules/_sqlite/statement.c

+45-74
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,7 @@
2626
#include "util.h"
2727

2828
/* prototypes */
29-
static int pysqlite_check_remaining_sql(const char* tail);
30-
31-
typedef enum {
32-
LINECOMMENT_1,
33-
IN_LINECOMMENT,
34-
COMMENTSTART_1,
35-
IN_COMMENT,
36-
COMMENTEND_1,
37-
NORMAL
38-
} parse_remaining_sql_state;
29+
static const char *lstrip_sql(const char *sql);
3930

4031
pysqlite_Statement *
4132
pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
@@ -73,7 +64,7 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
7364
return NULL;
7465
}
7566

76-
if (pysqlite_check_remaining_sql(tail)) {
67+
if (lstrip_sql(tail) != NULL) {
7768
PyErr_SetString(connection->ProgrammingError,
7869
"You can only execute one statement at a time.");
7970
goto error;
@@ -82,20 +73,12 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
8273
/* Determine if the statement is a DML statement.
8374
SELECT is the only exception. See #9924. */
8475
int is_dml = 0;
85-
for (const char *p = sql_cstr; *p != 0; p++) {
86-
switch (*p) {
87-
case ' ':
88-
case '\r':
89-
case '\n':
90-
case '\t':
91-
continue;
92-
}
93-
76+
const char *p = lstrip_sql(sql_cstr);
77+
if (p != NULL) {
9478
is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
9579
|| (PyOS_strnicmp(p, "update", 6) == 0)
9680
|| (PyOS_strnicmp(p, "delete", 6) == 0)
9781
|| (PyOS_strnicmp(p, "replace", 7) == 0);
98-
break;
9982
}
10083

10184
pysqlite_Statement *self = PyObject_GC_New(pysqlite_Statement,
@@ -139,73 +122,61 @@ stmt_traverse(pysqlite_Statement *self, visitproc visit, void *arg)
139122
}
140123

141124
/*
142-
* Checks if there is anything left in an SQL string after SQLite compiled it.
143-
* This is used to check if somebody tried to execute more than one SQL command
144-
* with one execute()/executemany() command, which the DB-API and we don't
145-
* allow.
125+
* Strip leading whitespace and comments from incoming SQL (null terminated C
126+
* string) and return a pointer to the first non-whitespace, non-comment
127+
* character.
146128
*
147-
* Returns 1 if there is more left than should be. 0 if ok.
129+
* This is used to check if somebody tries to execute more than one SQL query
130+
* with one execute()/executemany() command, which the DB-API don't allow.
131+
*
132+
* It is also used to harden DML query detection.
148133
*/
149-
static int pysqlite_check_remaining_sql(const char* tail)
134+
static inline const char *
135+
lstrip_sql(const char *sql)
150136
{
151-
const char* pos = tail;
152-
153-
parse_remaining_sql_state state = NORMAL;
154-
155-
for (;;) {
137+
// This loop is borrowed from the SQLite source code.
138+
for (const char *pos = sql; *pos; pos++) {
156139
switch (*pos) {
157-
case 0:
158-
return 0;
159-
case '-':
160-
if (state == NORMAL) {
161-
state = LINECOMMENT_1;
162-
} else if (state == LINECOMMENT_1) {
163-
state = IN_LINECOMMENT;
164-
}
165-
break;
166140
case ' ':
167141
case '\t':
168-
break;
142+
case '\f':
169143
case '\n':
170-
case 13:
171-
if (state == IN_LINECOMMENT) {
172-
state = NORMAL;
173-
}
144+
case '\r':
145+
// Skip whitespace.
174146
break;
175-
case '/':
176-
if (state == NORMAL) {
177-
state = COMMENTSTART_1;
178-
} else if (state == COMMENTEND_1) {
179-
state = NORMAL;
180-
} else if (state == COMMENTSTART_1) {
181-
return 1;
147+
case '-':
148+
// Skip line comments.
149+
if (pos[1] == '-') {
150+
pos += 2;
151+
while (pos[0] && pos[0] != '\n') {
152+
pos++;
153+
}
154+
if (pos[0] == '\0') {
155+
return NULL;
156+
}
157+
continue;
182158
}
183-
break;
184-
case '*':
185-
if (state == NORMAL) {
186-
return 1;
187-
} else if (state == LINECOMMENT_1) {
188-
return 1;
189-
} else if (state == COMMENTSTART_1) {
190-
state = IN_COMMENT;
191-
} else if (state == IN_COMMENT) {
192-
state = COMMENTEND_1;
159+
return pos;
160+
case '/':
161+
// Skip C style comments.
162+
if (pos[1] == '*') {
163+
pos += 2;
164+
while (pos[0] && (pos[0] != '*' || pos[1] != '/')) {
165+
pos++;
166+
}
167+
if (pos[0] == '\0') {
168+
return NULL;
169+
}
170+
pos++;
171+
continue;
193172
}
194-
break;
173+
return pos;
195174
default:
196-
if (state == COMMENTEND_1) {
197-
state = IN_COMMENT;
198-
} else if (state == IN_LINECOMMENT) {
199-
} else if (state == IN_COMMENT) {
200-
} else {
201-
return 1;
202-
}
175+
return pos;
203176
}
204-
205-
pos++;
206177
}
207178

208-
return 0;
179+
return NULL;
209180
}
210181

211182
static PyType_Slot stmt_slots[] = {

0 commit comments

Comments
 (0)