diff --git a/internal/app/store/parse_query.go b/internal/app/store/parse_query.go index ceff2ef..39172c0 100644 --- a/internal/app/store/parse_query.go +++ b/internal/app/store/parse_query.go @@ -27,7 +27,27 @@ func init() { } } -func parseQuery(query map[string]any) (string, []interface{}, error) { +// fieldMapper maps the given field name to the expression to be passed in the sql +type fieldMapper func(string) (string, error) + +func sqliteFieldMapper(field string) (string, error) { + if RESERVED_FIELDS[field] { + if field == JSON_FIELD { + return "", fmt.Errorf("querying %s directly is not supported", field) + } + return field, nil + } + + if strings.Contains(field, "'") { + // Protect against sql injection, even though this is the column name rather than value + return "", fmt.Errorf("field path cannot contain ': %s", field) + } + + v := fmt.Sprintf("_json ->> '%s'", field) + return v, nil +} + +func parseQuery(query map[string]any, mapper fieldMapper) (string, []interface{}, error) { var conditions []string var params []interface{} @@ -39,7 +59,7 @@ func parseQuery(query map[string]any) (string, []interface{}, error) { for _, key := range keys { value := query[key] - condition, subParams, err := parseCondition(key, value) + condition, subParams, err := parseCondition(key, value, mapper) if err != nil { return "", nil, err } @@ -51,12 +71,12 @@ func parseQuery(query map[string]any) (string, []interface{}, error) { return joinedConditions, params, nil } -func parseCondition(field string, value any) (string, []any, error) { +func parseCondition(field string, value any, mapper fieldMapper) (string, []any, error) { switch v := value.(type) { case []map[string]any: // Check if the map represents a logical operator or multiple conditions if isLogicalOperator(field) { - return parseLogicalOperator(field, v) + return parseLogicalOperator(field, v, mapper) } return "", nil, fmt.Errorf("invalid condition for %s, list supported for logical operators only, got: %#v", field, value) @@ -68,7 +88,7 @@ func parseCondition(field string, value any) (string, []any, error) { if op != "" { return "", nil, fmt.Errorf("operator %s supported for field conditions only: %#v", field, value) } - return parseFieldCondition(field, v) + return parseFieldCondition(field, v, mapper) case map[any]any: return "", nil, fmt.Errorf("invalid query condition for %s, only map of strings supported: %#v", field, value) case []any: @@ -81,16 +101,24 @@ func parseCondition(field string, value any) (string, []any, error) { return "", nil, fmt.Errorf("operator %s supported for field conditions only: %#v", field, value) } // Simple equality condition - return fmt.Sprintf("%s = ?", field), []any{v}, nil + mappedField := field + if mapper != nil { + var err error + mappedField, err = mapper(field) + if err != nil { + return "", nil, err + } + } + return fmt.Sprintf("%s = ?", mappedField), []any{v}, nil } } -func parseLogicalOperator(operator string, query []map[string]any) (string, []any, error) { +func parseLogicalOperator(operator string, query []map[string]any, mapper fieldMapper) (string, []any, error) { var conditions []string var params []interface{} for _, cond := range query { - condition, subParams, err := parseQuery(cond) + condition, subParams, err := parseQuery(cond, mapper) if err != nil { return "", nil, err } @@ -108,7 +136,7 @@ func parseLogicalOperator(operator string, query []map[string]any) (string, []an return " ( " + joinedConditions + " ) ", params, nil } -func parseFieldCondition(field string, query map[string]any) (string, []any, error) { +func parseFieldCondition(field string, query map[string]any, mapper fieldMapper) (string, []any, error) { var keys []string for key := range query { keys = append(keys, key) @@ -128,7 +156,7 @@ func parseFieldCondition(field string, query map[string]any) (string, []any, err case []map[string]any: // Check if the map represents a logical operator or multiple conditions if isLogicalOperator(key) { - subCondition, subParams, err = parseFieldLogicalOperator(field, key, v) + subCondition, subParams, err = parseFieldLogicalOperator(field, key, v, mapper) if err != nil { return "", nil, err } @@ -146,7 +174,17 @@ func parseFieldCondition(field string, query map[string]any) (string, []any, err if op == "" { return "", nil, fmt.Errorf("invalid query condition for %s %s, only operators supported: %#v", field, key, value) } - subCondition = fmt.Sprintf("%s %s ?", field, op) + + mappedField := field + if mapper != nil { + var err error + mappedField, err = mapper(field) + if err != nil { + return "", nil, err + } + } + + subCondition = fmt.Sprintf("%s %s ?", mappedField, op) subParams = []any{value} } @@ -158,7 +196,7 @@ func parseFieldCondition(field string, query map[string]any) (string, []any, err return joinedConditions, params, nil } -func parseFieldLogicalOperator(field string, operator string, query []map[string]any) (string, []any, error) { +func parseFieldLogicalOperator(field string, operator string, query []map[string]any, mapper fieldMapper) (string, []any, error) { var conditions []string var params []interface{} diff --git a/internal/app/store/parse_query_test.go b/internal/app/store/parse_query_test.go index 6e6c148..9d83c1d 100644 --- a/internal/app/store/parse_query_test.go +++ b/internal/app/store/parse_query_test.go @@ -12,7 +12,23 @@ import ( func ParseQueryTest(t *testing.T, query map[string]any, expectedConditions string, expectedParams []any) { t.Helper() - conditions, params, err := parseQuery(query) + conditions, params, err := parseQuery(query, nil) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if conditions != expectedConditions { + t.Errorf("Conditions do not match. Expected: %s, Got: %s.", expectedConditions, conditions) + } + + if !slices.Equal(params, expectedParams) { + t.Errorf("Params do not match. Expected: %v, Got: %v.", expectedParams, params) + } +} + +func ParseQueryMapperTest(t *testing.T, query map[string]any, expectedConditions string, expectedParams []any) { + t.Helper() + conditions, params, err := parseQuery(query, sqliteFieldMapper) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -28,7 +44,13 @@ func ParseQueryTest(t *testing.T, query map[string]any, expectedConditions strin func ParseQueryErrorTest(t *testing.T, query map[string]any, expected string) { t.Helper() - _, _, err := parseQuery(query) + _, _, err := parseQuery(query, nil) + testutil.AssertErrorContains(t, err, expected) +} + +func ParseMappedErrorTest(t *testing.T, query map[string]any, expected string) { + t.Helper() + _, _, err := parseQuery(query, sqliteFieldMapper) testutil.AssertErrorContains(t, err, expected) } @@ -79,3 +101,18 @@ func TestErrorQueries(t *testing.T) { ParseQueryErrorTest(t, map[string]any{"age": map[string]any{"$or": []map[string]any{{"$gt": 1, "$lt": 10}}}}, "invalid logical condition for age $or, only one key supported: map") ParseQueryErrorTest(t, map[string]any{"age": map[string]any{"$or": []map[string]any{{"$AA": 1}}}}, "invalid logical condition for age $AA, only operators supported: 1") } + +func TestMappedQueries(t *testing.T) { + ParseQueryMapperTest(t, nil, "", nil) + ParseQueryMapperTest(t, map[string]any{}, "", nil) + ParseQueryMapperTest(t, map[string]any{"age": 30, "city": "New York", "state": "California"}, "_json ->> 'age' = ? AND _json ->> 'city' = ? AND _json ->> 'state' = ?", []any{30, "New York", "California"}) + ParseQueryMapperTest(t, map[string]any{"_id": 30, "city": "New York", "state": "California", "country": "USA"}, "_id = ? AND _json ->> 'city' = ? AND _json ->> 'country' = ? AND _json ->> 'state' = ?", []any{30, "New York", "USA", "California"}) + ParseQueryMapperTest(t, map[string]any{"age": 30, "$AND": []map[string]any{{"city": "New York"}, {"$OR": []map[string]any{{"state": "California"}, {"country": "USA"}}}, {"city": "New York"}}}, " ( _json ->> 'city' = ? AND ( _json ->> 'state' = ? OR _json ->> 'country' = ? ) AND _json ->> 'city' = ? ) AND _json ->> 'age' = ?", []any{"New York", "California", "USA", "New York", 30}) + ParseQueryMapperTest(t, map[string]any{"age": map[string]any{"$gt": 30, "$lt": 40}}, "_json ->> 'age' > ? AND _json ->> 'age' < ?", []any{30, 40}) + ParseQueryMapperTest(t, map[string]any{"age": map[string]any{"$lte": 30}}, "_json ->> 'age' <= ?", []any{30}) +} + +func TestMappedError(t *testing.T) { + ParseMappedErrorTest(t, map[string]any{"_json": 30}, "querying _json directly is not supporte") + ParseMappedErrorTest(t, map[string]any{"abc'def": 30}, "field path cannot contain ': abc") +} diff --git a/internal/app/store/sql_store.go b/internal/app/store/sql_store.go index 2cdfe93..d2c983e 100644 --- a/internal/app/store/sql_store.go +++ b/internal/app/store/sql_store.go @@ -116,7 +116,7 @@ func (s *SqlStore) initStore() error { return err } - createStmt := "CREATE TABLE IF NOT EXISTS " + table + " (id INTEGER PRIMARY KEY AUTOINCREMENT, version INTEGER, created_by TEXT, updated_by TEXT, created_at INTEGER, updated_at INTEGER, data JSON)" + createStmt := "CREATE TABLE IF NOT EXISTS " + table + " (_id INTEGER PRIMARY KEY AUTOINCREMENT, _version INTEGER, _created_by TEXT, _updated_by TEXT, _created_at INTEGER, _updated_at INTEGER, _json JSON)" _, err = s.db.Exec(createStmt) if err != nil { return fmt.Errorf("error creating table %s: %w", table, err) @@ -147,7 +147,7 @@ func (s *SqlStore) Insert(table string, entry *Entry) (EntryId, error) { return -1, fmt.Errorf("error marshalling data for table %s: %w", table, err) } - createStmt := "INSERT INTO " + table + " (version, created_by, updated_by, created_at, updated_at, data) VALUES (?, ?, ?, ?, ?, ?)" + createStmt := "INSERT INTO " + table + " (_version, _created_by, _updated_by, _created_at, _updated_at, _json) VALUES (?, ?, ?, ?, ?, ?)" result, err := s.db.Exec(createStmt, entry.Version, entry.CreatedBy, entry.UpdatedBy, entry.CreatedAt.UnixMilli(), entry.UpdatedAt.UnixMilli(), dataJson) if err != nil { return -1, nil @@ -172,7 +172,7 @@ func (s *SqlStore) SelectById(table string, id EntryId) (*Entry, error) { return nil, err } - query := "SELECT id, version, created_by, updated_by, created_at, updated_at, data FROM " + table + " WHERE id = ?" + query := "SELECT _id, _version, _created_by, _updated_by, _created_at, _updated_at, _json FROM " + table + " WHERE _id = ?" row := s.db.QueryRow(query, id) entry := &Entry{} @@ -223,7 +223,7 @@ func (s *SqlStore) Select(table string, filter map[string]any, sort []string, of // TODO handle sort - filterStr, params, err := parseQuery(filter) + filterStr, params, err := parseQuery(filter, sqliteFieldMapper) if err != nil { return nil, err } @@ -233,7 +233,7 @@ func (s *SqlStore) Select(table string, filter map[string]any, sort []string, of whereStr = " WHERE " + filterStr } - query := "SELECT id, version, created_by, updated_by, created_at, updated_at, data FROM " + table + whereStr + limitOffsetStr + query := "SELECT _id, _version, _created_by, _updated_by, _created_at, _updated_at, _json FROM " + table + whereStr + limitOffsetStr s.Trace().Msgf("query: %s, params: %#v", query, params) rows, err := s.db.Query(query, params...) @@ -244,6 +244,41 @@ func (s *SqlStore) Select(table string, filter map[string]any, sort []string, of return NewStoreEntryIterabe(s.Logger, table, rows), nil } +// Count returns the number of entries matching the filter +func (s *SqlStore) Count(table string, filter map[string]any) (int64, error) { + if err := s.initialize(); err != nil { + return -1, err + } + + var err error + table, err = s.genTableName(table) + if err != nil { + return -1, err + } + + filterStr, params, err := parseQuery(filter, sqliteFieldMapper) + if err != nil { + return -1, err + } + + whereStr := "" + if filterStr != "" { + whereStr = " WHERE " + filterStr + } + + query := "SELECT count(_id) FROM " + table + whereStr + s.Trace().Msgf("query: %s, params: %#v", query, params) + row := s.db.QueryRow(query, params...) + + var count int64 + err = row.Scan(&count) + if err != nil { + return -1, err + } + + return count, nil +} + // Update an existing entry in the store func (s *SqlStore) Update(table string, entry *Entry) (int64, error) { if err := s.initialize(); err != nil { @@ -264,7 +299,7 @@ func (s *SqlStore) Update(table string, entry *Entry) (int64, error) { return 0, fmt.Errorf("error marshalling data for table %s: %w", table, err) } - updateStmt := "UPDATE " + table + " set version = ?, updated_by = ?, updated_at = ?, data = ? where id = ? and updated_at = ?" + updateStmt := "UPDATE " + table + " set _version = ?, _updated_by = ?, _updated_at = ?, _json = ? where _id = ? and _updated_at = ?" result, err := s.db.Exec(updateStmt, entry.Version, entry.UpdatedBy, entry.UpdatedAt.UnixMilli(), dataJson, entry.Id, origUpdateAt.UnixMilli()) if err != nil { return 0, err @@ -292,7 +327,7 @@ func (s *SqlStore) DeleteById(table string, id EntryId) (int64, error) { return 0, err } - deleteStmt := "DELETE from " + table + " where id = ?" + deleteStmt := "DELETE from " + table + " where _id = ?" result, err := s.db.Exec(deleteStmt, id) if err != nil { return 0, err @@ -321,7 +356,7 @@ func (s *SqlStore) Delete(table string, filter map[string]any) (int64, error) { return 0, err } - filterStr, params, err := parseQuery(filter) + filterStr, params, err := parseQuery(filter, sqliteFieldMapper) if err != nil { return 0, err } diff --git a/internal/app/store/store.go b/internal/app/store/store.go index 93d4dba..e60ab88 100644 --- a/internal/app/store/store.go +++ b/internal/app/store/store.go @@ -12,6 +12,26 @@ import ( "go.starlark.net/starlark" ) +const ( + ID_FIELD = "_id" + VERSION_FIELD = "_version" + CREATED_BY_FIELD = "_created_by" + UPDATED_BY_FIELD = "_updated_by" + CREATED_AT_FIELD = "_created_at" + UPDATED_AT_FIELD = "_updated_at" + JSON_FIELD = "_json" +) + +var RESERVED_FIELDS = map[string]bool{ + ID_FIELD: true, + VERSION_FIELD: true, + CREATED_BY_FIELD: true, + UPDATED_BY_FIELD: true, + CREATED_AT_FIELD: true, + UPDATED_AT_FIELD: true, + JSON_FIELD: true, +} + type EntryId int64 type UserId string type Document map[string]any @@ -38,36 +58,36 @@ func (e *Entry) Unpack(value starlark.Value) error { entryData := make(map[string]any) for _, attr := range v.AttrNames() { switch attr { - case "_id": + case ID_FIELD: id, err := util.GetIntAttr(v, attr) if err != nil { return fmt.Errorf("error reading %s: %w", attr, err) } e.Id = EntryId(id) - case "_version": + case VERSION_FIELD: e.Version, err = util.GetIntAttr(v, attr) if err != nil { return fmt.Errorf("error reading %s: %w", attr, err) } - case "_created_by": + case CREATED_BY_FIELD: createdBy, err := util.GetStringAttr(v, attr) if err != nil { return fmt.Errorf("error reading %s: %w", attr, err) } e.CreatedBy = UserId(createdBy) - case "_updated_by": + case UPDATED_BY_FIELD: updatedBy, err := util.GetStringAttr(v, attr) if err != nil { return fmt.Errorf("error reading %s: %w", attr, err) } e.UpdatedBy = UserId(updatedBy) - case "_created_at": + case CREATED_AT_FIELD: createdAt, err := util.GetIntAttr(v, attr) if err != nil { return fmt.Errorf("error reading %s: %w", attr, err) } e.CreatedAt = time.UnixMilli(createdAt) - case "_updated_at": + case UPDATED_AT_FIELD: updatedAt, err := util.GetIntAttr(v, attr) if err != nil { return fmt.Errorf("error reading %s: %w", attr, err) @@ -101,6 +121,9 @@ type Store interface { // Select returns the entries matching the filter Select(table string, filter map[string]any, sort []string, offset, limit int64) (starlark.Iterable, error) + // Count returns the count of entries matching the filter + Count(table string, filter map[string]any) (int64, error) + // Update an existing entry in the store Update(table string, Entry *Entry) (int64, error) @@ -114,12 +137,12 @@ type Store interface { func CreateType(name string, entry *Entry) (*utils.StarlarkType, error) { data := make(map[string]starlark.Value) - data["_id"] = starlark.MakeInt(int(entry.Id)) - data["_version"] = starlark.MakeInt(int(entry.Version)) - data["_created_by"] = starlark.String(string(entry.CreatedBy)) - data["_updated_by"] = starlark.String(string(entry.UpdatedBy)) - data["_created_at"] = starlark.MakeInt(int(entry.CreatedAt.UnixMilli())) - data["_updated_at"] = starlark.MakeInt(int(entry.UpdatedAt.UnixMilli())) + data[ID_FIELD] = starlark.MakeInt(int(entry.Id)) + data[VERSION_FIELD] = starlark.MakeInt(int(entry.Version)) + data[CREATED_BY_FIELD] = starlark.String(string(entry.CreatedBy)) + data[UPDATED_BY_FIELD] = starlark.String(string(entry.UpdatedBy)) + data[CREATED_AT_FIELD] = starlark.MakeInt(int(entry.CreatedAt.UnixMilli())) + data[UPDATED_AT_FIELD] = starlark.MakeInt(int(entry.UpdatedAt.UnixMilli())) var err error for k, v := range entry.Data { diff --git a/internal/app/store/store_plugin.go b/internal/app/store/store_plugin.go index 6af293a..ca70b64 100644 --- a/internal/app/store/store_plugin.go +++ b/internal/app/store/store_plugin.go @@ -18,6 +18,7 @@ func init() { pluginFuncs := []utils.PluginFunc{ app.CreatePluginApiName(h.SelectById, app.READ, "select_by_id"), app.CreatePluginApi(h.Select, app.READ), + app.CreatePluginApi(h.Count, app.READ), app.CreatePluginApi(h.Insert, app.WRITE), app.CreatePluginApi(h.Update, app.WRITE), app.CreatePluginApiName(h.DeleteById, app.WRITE, "delete_by_id"), @@ -155,15 +156,41 @@ func (s *storePlugin) Select(thread *starlark.Thread, builtin *starlark.Builtin, if err != nil { return utils.NewErrorResponse(err), nil } + + iterator, err := s.sqlStore.Select(table, filterMap, sortList, offsetVal, limitVal) if err != nil { + return utils.NewErrorResponse(err), nil + } + return utils.NewResponse(iterator), nil +} + +func (s *storePlugin) Count(thread *starlark.Thread, builtin *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { + var table string + var filter *starlark.Dict + + if err := starlark.UnpackArgs("select", args, kwargs, "table", &table, "filter", &filter); err != nil { return nil, err } - iterator, err := s.sqlStore.Select(table, filterMap, sortList, offsetVal, limitVal) + if filter == nil { + filter = starlark.NewDict(0) + } + + filterUnmarshalled, err := utils.UnmarshalStarlark(filter) if err != nil { return utils.NewErrorResponse(err), nil } - return utils.NewResponse(iterator), nil + + filterMap, ok := filterUnmarshalled.(map[string]any) + if !ok { + return utils.NewErrorResponse(errors.New("invalid filter")), nil + } + + count, err := s.sqlStore.Count(table, filterMap) + if err != nil { + return utils.NewErrorResponse(err), nil + } + return utils.NewResponse(count), nil } func (s *storePlugin) Delete(thread *starlark.Thread, builtin *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { diff --git a/internal/app/tests/store_test.go b/internal/app/tests/store_test.go index 5b03424..f86e26c 100644 --- a/internal/app/tests/store_test.go +++ b/internal/app/tests/store_test.go @@ -27,6 +27,7 @@ permissions=[ ace.permission("store.in", "delete_by_id"), ace.permission("store.in", "select"), ace.permission("store.in", "delete"), + ace.permission("store.in", "count"), ] ) @@ -41,6 +42,11 @@ def handler(req): ret2 = store.insert(table.test1, myt) if not ret2: return {"error": ret2.error} + myt.aint=30 + myt.adict = {"a": 2} + ret3 = store.insert(table.test1, myt) + if not ret3: + return {"error": ret3.error} id = ret.value ret = store.select_by_id(table.test1, id) @@ -60,14 +66,26 @@ def handler(req): if upd_status: return {"error": "Expected duplicate update to fail"} + q1 = store.count(table.test1, {"aint": 100}) + if not q1: + return {"error": q1.error} + if q1.value != 1: + return {"error": "Expected count to be 1, got %d" % q1.value} + + q2 = store.count(table.test1, {"adict.a": 2}) + if not q2: + return {"error": q2.error} + if q2.value != 1: + return {"error": "Expected count to be 1, got %d" % q2.value} + + ret = store.select_by_id(table.test1, id) select_result = store.select(table.test1, {}) - rows = [] + all_rows = [] for row in select_result.value: - print("rrr", row) - rows.append(row) + all_rows.append(row) del_status = store.delete_by_id(table.test1, id) if not del_status: @@ -79,7 +97,7 @@ def handler(req): return {"intval": ret.value.aint, "stringval": ret.value.astring, "_id": ret.value._id, "creator": ret.value._created_by, "created_at": ret.value._created_at, - "rows": rows} + "all_rows": all_rows} `, "schema.star": ` @@ -101,6 +119,7 @@ type("test1", fields=[ {Plugin: "store.in", Method: "delete_by_id"}, {Plugin: "store.in", Method: "select"}, {Plugin: "store.in", Method: "delete"}, + {Plugin: "store.in", Method: "count"}, }, map[string]utils.PluginSettings{ "store.in": { "db_connection": "sqlite:/tmp/clace_app.db?_journal_mode=WAL", @@ -131,8 +150,8 @@ type("test1", fields=[ if id <= 0 { t.Errorf("Expected _id to be > 0, got %f", id) } - testutil.AssertEqualsInt(t, "length", 2, len(ret["rows"].([]any))) - rows := ret["rows"].([]any) + testutil.AssertEqualsInt(t, "length", 3, len(ret["all_rows"].([]any))) + rows := ret["all_rows"].([]any) if rows[0].(map[string]any)["aint"].(float64) != 100 { t.Errorf("Expected aint to be 100, got %f", rows[0].(map[string]any)["aint"].(float64)) }