Skip to content

Commit

Permalink
Added nested json query support
Browse files Browse the repository at this point in the history
  • Loading branch information
akclace committed Jan 25, 2024
1 parent 034ab1d commit 1c00652
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 42 deletions.
62 changes: 50 additions & 12 deletions internal/app/store/parse_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,27 @@ func init() {
}
}

func parseQuery(query map[string]any) (string, []interface{}, error) {
// fieldMapper maps the given field name to the expression to be passed in the sql
type fieldMapper func(string) (string, error)

func sqliteFieldMapper(field string) (string, error) {
if RESERVED_FIELDS[field] {
if field == JSON_FIELD {
return "", fmt.Errorf("querying %s directly is not supported", field)
}
return field, nil
}

if strings.Contains(field, "'") {
// Protect against sql injection, even though this is the column name rather than value
return "", fmt.Errorf("field path cannot contain ': %s", field)
}

v := fmt.Sprintf("_json ->> '%s'", field)
return v, nil
}

func parseQuery(query map[string]any, mapper fieldMapper) (string, []interface{}, error) {
var conditions []string
var params []interface{}

Expand All @@ -39,7 +59,7 @@ func parseQuery(query map[string]any) (string, []interface{}, error) {

for _, key := range keys {
value := query[key]
condition, subParams, err := parseCondition(key, value)
condition, subParams, err := parseCondition(key, value, mapper)
if err != nil {
return "", nil, err
}
Expand All @@ -51,12 +71,12 @@ func parseQuery(query map[string]any) (string, []interface{}, error) {
return joinedConditions, params, nil
}

func parseCondition(field string, value any) (string, []any, error) {
func parseCondition(field string, value any, mapper fieldMapper) (string, []any, error) {
switch v := value.(type) {
case []map[string]any:
// Check if the map represents a logical operator or multiple conditions
if isLogicalOperator(field) {
return parseLogicalOperator(field, v)
return parseLogicalOperator(field, v, mapper)
}

return "", nil, fmt.Errorf("invalid condition for %s, list supported for logical operators only, got: %#v", field, value)
Expand All @@ -68,7 +88,7 @@ func parseCondition(field string, value any) (string, []any, error) {
if op != "" {
return "", nil, fmt.Errorf("operator %s supported for field conditions only: %#v", field, value)
}
return parseFieldCondition(field, v)
return parseFieldCondition(field, v, mapper)
case map[any]any:
return "", nil, fmt.Errorf("invalid query condition for %s, only map of strings supported: %#v", field, value)
case []any:
Expand All @@ -81,16 +101,24 @@ func parseCondition(field string, value any) (string, []any, error) {
return "", nil, fmt.Errorf("operator %s supported for field conditions only: %#v", field, value)
}
// Simple equality condition
return fmt.Sprintf("%s = ?", field), []any{v}, nil
mappedField := field
if mapper != nil {
var err error
mappedField, err = mapper(field)
if err != nil {
return "", nil, err
}
}
return fmt.Sprintf("%s = ?", mappedField), []any{v}, nil
}
}

func parseLogicalOperator(operator string, query []map[string]any) (string, []any, error) {
func parseLogicalOperator(operator string, query []map[string]any, mapper fieldMapper) (string, []any, error) {
var conditions []string
var params []interface{}

for _, cond := range query {
condition, subParams, err := parseQuery(cond)
condition, subParams, err := parseQuery(cond, mapper)
if err != nil {
return "", nil, err
}
Expand All @@ -108,7 +136,7 @@ func parseLogicalOperator(operator string, query []map[string]any) (string, []an
return " ( " + joinedConditions + " ) ", params, nil
}

func parseFieldCondition(field string, query map[string]any) (string, []any, error) {
func parseFieldCondition(field string, query map[string]any, mapper fieldMapper) (string, []any, error) {
var keys []string
for key := range query {
keys = append(keys, key)
Expand All @@ -128,7 +156,7 @@ func parseFieldCondition(field string, query map[string]any) (string, []any, err
case []map[string]any:
// Check if the map represents a logical operator or multiple conditions
if isLogicalOperator(key) {
subCondition, subParams, err = parseFieldLogicalOperator(field, key, v)
subCondition, subParams, err = parseFieldLogicalOperator(field, key, v, mapper)
if err != nil {
return "", nil, err
}
Expand All @@ -146,7 +174,17 @@ func parseFieldCondition(field string, query map[string]any) (string, []any, err
if op == "" {
return "", nil, fmt.Errorf("invalid query condition for %s %s, only operators supported: %#v", field, key, value)
}
subCondition = fmt.Sprintf("%s %s ?", field, op)

mappedField := field
if mapper != nil {
var err error
mappedField, err = mapper(field)
if err != nil {
return "", nil, err
}
}

subCondition = fmt.Sprintf("%s %s ?", mappedField, op)
subParams = []any{value}
}

Expand All @@ -158,7 +196,7 @@ func parseFieldCondition(field string, query map[string]any) (string, []any, err
return joinedConditions, params, nil
}

func parseFieldLogicalOperator(field string, operator string, query []map[string]any) (string, []any, error) {
func parseFieldLogicalOperator(field string, operator string, query []map[string]any, mapper fieldMapper) (string, []any, error) {
var conditions []string
var params []interface{}

Expand Down
41 changes: 39 additions & 2 deletions internal/app/store/parse_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,23 @@ import (

func ParseQueryTest(t *testing.T, query map[string]any, expectedConditions string, expectedParams []any) {
t.Helper()
conditions, params, err := parseQuery(query)
conditions, params, err := parseQuery(query, nil)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

if conditions != expectedConditions {
t.Errorf("Conditions do not match. Expected: %s, Got: %s.", expectedConditions, conditions)
}

if !slices.Equal(params, expectedParams) {
t.Errorf("Params do not match. Expected: %v, Got: %v.", expectedParams, params)
}
}

func ParseQueryMapperTest(t *testing.T, query map[string]any, expectedConditions string, expectedParams []any) {
t.Helper()
conditions, params, err := parseQuery(query, sqliteFieldMapper)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
Expand All @@ -28,7 +44,13 @@ func ParseQueryTest(t *testing.T, query map[string]any, expectedConditions strin

func ParseQueryErrorTest(t *testing.T, query map[string]any, expected string) {
t.Helper()
_, _, err := parseQuery(query)
_, _, err := parseQuery(query, nil)
testutil.AssertErrorContains(t, err, expected)
}

func ParseMappedErrorTest(t *testing.T, query map[string]any, expected string) {
t.Helper()
_, _, err := parseQuery(query, sqliteFieldMapper)
testutil.AssertErrorContains(t, err, expected)
}

Expand Down Expand Up @@ -79,3 +101,18 @@ func TestErrorQueries(t *testing.T) {
ParseQueryErrorTest(t, map[string]any{"age": map[string]any{"$or": []map[string]any{{"$gt": 1, "$lt": 10}}}}, "invalid logical condition for age $or, only one key supported: map")
ParseQueryErrorTest(t, map[string]any{"age": map[string]any{"$or": []map[string]any{{"$AA": 1}}}}, "invalid logical condition for age $AA, only operators supported: 1")
}

func TestMappedQueries(t *testing.T) {
ParseQueryMapperTest(t, nil, "", nil)
ParseQueryMapperTest(t, map[string]any{}, "", nil)
ParseQueryMapperTest(t, map[string]any{"age": 30, "city": "New York", "state": "California"}, "_json ->> 'age' = ? AND _json ->> 'city' = ? AND _json ->> 'state' = ?", []any{30, "New York", "California"})
ParseQueryMapperTest(t, map[string]any{"_id": 30, "city": "New York", "state": "California", "country": "USA"}, "_id = ? AND _json ->> 'city' = ? AND _json ->> 'country' = ? AND _json ->> 'state' = ?", []any{30, "New York", "USA", "California"})
ParseQueryMapperTest(t, map[string]any{"age": 30, "$AND": []map[string]any{{"city": "New York"}, {"$OR": []map[string]any{{"state": "California"}, {"country": "USA"}}}, {"city": "New York"}}}, " ( _json ->> 'city' = ? AND ( _json ->> 'state' = ? OR _json ->> 'country' = ? ) AND _json ->> 'city' = ? ) AND _json ->> 'age' = ?", []any{"New York", "California", "USA", "New York", 30})
ParseQueryMapperTest(t, map[string]any{"age": map[string]any{"$gt": 30, "$lt": 40}}, "_json ->> 'age' > ? AND _json ->> 'age' < ?", []any{30, 40})
ParseQueryMapperTest(t, map[string]any{"age": map[string]any{"$lte": 30}}, "_json ->> 'age' <= ?", []any{30})
}

func TestMappedError(t *testing.T) {
ParseMappedErrorTest(t, map[string]any{"_json": 30}, "querying _json directly is not supporte")
ParseMappedErrorTest(t, map[string]any{"abc'def": 30}, "field path cannot contain ': abc")
}
51 changes: 43 additions & 8 deletions internal/app/store/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (s *SqlStore) initStore() error {
return err
}

createStmt := "CREATE TABLE IF NOT EXISTS " + table + " (id INTEGER PRIMARY KEY AUTOINCREMENT, version INTEGER, created_by TEXT, updated_by TEXT, created_at INTEGER, updated_at INTEGER, data JSON)"
createStmt := "CREATE TABLE IF NOT EXISTS " + table + " (_id INTEGER PRIMARY KEY AUTOINCREMENT, _version INTEGER, _created_by TEXT, _updated_by TEXT, _created_at INTEGER, _updated_at INTEGER, _json JSON)"
_, err = s.db.Exec(createStmt)
if err != nil {
return fmt.Errorf("error creating table %s: %w", table, err)
Expand Down Expand Up @@ -147,7 +147,7 @@ func (s *SqlStore) Insert(table string, entry *Entry) (EntryId, error) {
return -1, fmt.Errorf("error marshalling data for table %s: %w", table, err)
}

createStmt := "INSERT INTO " + table + " (version, created_by, updated_by, created_at, updated_at, data) VALUES (?, ?, ?, ?, ?, ?)"
createStmt := "INSERT INTO " + table + " (_version, _created_by, _updated_by, _created_at, _updated_at, _json) VALUES (?, ?, ?, ?, ?, ?)"
result, err := s.db.Exec(createStmt, entry.Version, entry.CreatedBy, entry.UpdatedBy, entry.CreatedAt.UnixMilli(), entry.UpdatedAt.UnixMilli(), dataJson)
if err != nil {
return -1, nil
Expand All @@ -172,7 +172,7 @@ func (s *SqlStore) SelectById(table string, id EntryId) (*Entry, error) {
return nil, err
}

query := "SELECT id, version, created_by, updated_by, created_at, updated_at, data FROM " + table + " WHERE id = ?"
query := "SELECT _id, _version, _created_by, _updated_by, _created_at, _updated_at, _json FROM " + table + " WHERE _id = ?"
row := s.db.QueryRow(query, id)

entry := &Entry{}
Expand Down Expand Up @@ -223,7 +223,7 @@ func (s *SqlStore) Select(table string, filter map[string]any, sort []string, of

// TODO handle sort

filterStr, params, err := parseQuery(filter)
filterStr, params, err := parseQuery(filter, sqliteFieldMapper)
if err != nil {
return nil, err
}
Expand All @@ -233,7 +233,7 @@ func (s *SqlStore) Select(table string, filter map[string]any, sort []string, of
whereStr = " WHERE " + filterStr
}

query := "SELECT id, version, created_by, updated_by, created_at, updated_at, data FROM " + table + whereStr + limitOffsetStr
query := "SELECT _id, _version, _created_by, _updated_by, _created_at, _updated_at, _json FROM " + table + whereStr + limitOffsetStr
s.Trace().Msgf("query: %s, params: %#v", query, params)
rows, err := s.db.Query(query, params...)

Expand All @@ -244,6 +244,41 @@ func (s *SqlStore) Select(table string, filter map[string]any, sort []string, of
return NewStoreEntryIterabe(s.Logger, table, rows), nil
}

// Count returns the number of entries matching the filter
func (s *SqlStore) Count(table string, filter map[string]any) (int64, error) {
if err := s.initialize(); err != nil {
return -1, err
}

var err error
table, err = s.genTableName(table)
if err != nil {
return -1, err
}

filterStr, params, err := parseQuery(filter, sqliteFieldMapper)
if err != nil {
return -1, err
}

whereStr := ""
if filterStr != "" {
whereStr = " WHERE " + filterStr
}

query := "SELECT count(_id) FROM " + table + whereStr
s.Trace().Msgf("query: %s, params: %#v", query, params)
row := s.db.QueryRow(query, params...)

var count int64
err = row.Scan(&count)
if err != nil {
return -1, err
}

return count, nil
}

// Update an existing entry in the store
func (s *SqlStore) Update(table string, entry *Entry) (int64, error) {
if err := s.initialize(); err != nil {
Expand All @@ -264,7 +299,7 @@ func (s *SqlStore) Update(table string, entry *Entry) (int64, error) {
return 0, fmt.Errorf("error marshalling data for table %s: %w", table, err)
}

updateStmt := "UPDATE " + table + " set version = ?, updated_by = ?, updated_at = ?, data = ? where id = ? and updated_at = ?"
updateStmt := "UPDATE " + table + " set _version = ?, _updated_by = ?, _updated_at = ?, _json = ? where _id = ? and _updated_at = ?"
result, err := s.db.Exec(updateStmt, entry.Version, entry.UpdatedBy, entry.UpdatedAt.UnixMilli(), dataJson, entry.Id, origUpdateAt.UnixMilli())
if err != nil {
return 0, err
Expand Down Expand Up @@ -292,7 +327,7 @@ func (s *SqlStore) DeleteById(table string, id EntryId) (int64, error) {
return 0, err
}

deleteStmt := "DELETE from " + table + " where id = ?"
deleteStmt := "DELETE from " + table + " where _id = ?"
result, err := s.db.Exec(deleteStmt, id)
if err != nil {
return 0, err
Expand Down Expand Up @@ -321,7 +356,7 @@ func (s *SqlStore) Delete(table string, filter map[string]any) (int64, error) {
return 0, err
}

filterStr, params, err := parseQuery(filter)
filterStr, params, err := parseQuery(filter, sqliteFieldMapper)
if err != nil {
return 0, err
}
Expand Down
Loading

0 comments on commit 1c00652

Please sign in to comment.