diff --git a/sql/lantern.sql b/sql/lantern.sql index 6bc39f26a..324d46049 100644 --- a/sql/lantern.sql +++ b/sql/lantern.sql @@ -139,13 +139,7 @@ BEGIN CREATE FUNCTION cos_dist(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME', 'vector_cos_dist' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - CREATE FUNCTION hamming_dist(vector, vector) RETURNS float8 - AS 'MODULE_PATHNAME', 'vector_hamming_dist' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - - CREATE OPERATOR <+> ( - LEFTARG = vector, RIGHTARG = vector, PROCEDURE = hamming_dist, - COMMUTATOR = '<+>' - ); + -- pgvecor's vector type requires floats and we cannot define hamming distance for floats CREATE OPERATOR CLASS dist_vec_l2sq_ops DEFAULT FOR TYPE vector USING lantern_hnsw AS @@ -159,12 +153,6 @@ BEGIN OPERATOR 2 <=> (vector, vector) FOR ORDER BY float_ops, FUNCTION 2 cos_dist(vector, vector); - CREATE OPERATOR CLASS dist_vec_hamming_ops - FOR TYPE vector USING lantern_hnsw AS - OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, - FUNCTION 1 hamming_dist(vector, vector), - OPERATOR 2 <+> (vector, vector) FOR ORDER BY float_ops, - FUNCTION 2 hamming_dist(vector, vector); END IF; diff --git a/sql/updates/0.0.9--0.0.10.sql b/sql/updates/0.0.9--0.0.10.sql new file mode 100644 index 000000000..721bd3214 --- /dev/null +++ b/sql/updates/0.0.9--0.0.10.sql @@ -0,0 +1,5 @@ +-- these go for good. + +DROP OPERATOR CLASS IF EXISTS dist_vec_hamming_ops USING hnsw CASCADE; +DROP FUNCTION IF EXISTS cos_dist(vector, vector); +DROP OPERATOR <+>(vector, vector) CASCADE \ No newline at end of file diff --git a/src/hnsw.c b/src/hnsw.c index d46085ef0..083929f69 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -295,9 +295,12 @@ static float4 array_dist(ArrayType *a, ArrayType *b, usearch_metric_kind_t metri } float4 result; - bool is_int_array = (metric_kind == usearch_metric_hamming_k); - if(is_int_array) { + if(metric_kind == usearch_metric_hamming_k) { + // when computing hamming distance, array element type must be an integer type + if(ARR_ELEMTYPE(a) != INT4OID || ARR_ELEMTYPE(b) != INT4OID) { + elog(ERROR, "expected integer array but got array with element type %d", ARR_ELEMTYPE(a)); + } int32 *ax_int = (int32 *)ARR_DATA_PTR(a); int32 *bx_int = (int32 *)ARR_DATA_PTR(b); @@ -305,10 +308,9 @@ static float4 array_dist(ArrayType *a, ArrayType *b, usearch_metric_kind_t metri // the hamming distance in usearch actually ignores the scalar type // and it will get casted appropriately in usearch even with this scalar type result = usearch_dist(ax_int, bx_int, metric_kind, a_dim, usearch_scalar_f32_k); - } else { - float4 *ax = (float4 *)ARR_DATA_PTR(a); - float4 *bx = (float4 *)ARR_DATA_PTR(b); + float4 *ax = ToFloat4Array(a); + float4 *bx = ToFloat4Array(b); result = usearch_dist(ax, bx, metric_kind, a_dim, usearch_scalar_f32_k); } diff --git a/src/hnsw.h b/src/hnsw.h index d3b5edc49..e6d9b58f2 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -38,7 +38,7 @@ PGDLLEXPORT Datum vector_hamming_dist(PG_FUNCTION_ARGS); HnswColumnType GetColumnTypeFromOid(Oid oid); HnswColumnType GetIndexColumnType(Relation index); -void *DatumGetSizedArray(Datum datum, HnswColumnType type, int dimensions); +void* DatumGetSizedArray(Datum datum, HnswColumnType type, int dimensions); #define LDB_UNUSED(x) (void)(x) diff --git a/src/hnsw/utils.c b/src/hnsw/utils.c index 00be528d3..e310e40b9 100644 --- a/src/hnsw/utils.c +++ b/src/hnsw/utils.c @@ -3,6 +3,7 @@ #include "utils.h" #include +#include #include #include #include @@ -81,3 +82,24 @@ void CheckMem(int limit, Relation index, usearch_index_t uidx, uint32 n_nodes, c elog(WARNING, "%s", msg); } } + +// if the element type of the passed array is already float4, this function just returns that pointer +// otherwise, it allocates a new array, casts all elements to float4 and returns the resulting array +float4 *ToFloat4Array(ArrayType *arr) +{ + Oid element_type = ARR_ELEMTYPE(arr); + if(element_type == FLOAT4OID) { + return (float4 *)ARR_DATA_PTR(arr); + } else if(element_type == INT4OID) { + int arr_dim = ArrayGetNItems(ARR_NDIM(arr), ARR_DIMS(arr)); + + float4 *result = palloc(arr_dim * sizeof(int32)); + int32 *typed_src = (int32 *)ARR_DATA_PTR(arr); + for(int i = 0; i < arr_dim; i++) { + result[ i ] = typed_src[ i ]; + } + return result; + } else { + elog(ERROR, "unsupported element type: %d", element_type); + } +} diff --git a/src/hnsw/utils.h b/src/hnsw/utils.h index 106d296ea..90a57a880 100644 --- a/src/hnsw/utils.h +++ b/src/hnsw/utils.h @@ -1,6 +1,7 @@ #ifndef LDB_HNSW_UTILS_H #define LDB_HNSW_UTILS_H #include +#include #include "options.h" #include "usearch.h" @@ -9,6 +10,7 @@ void CheckMem(int limit, Relation index, usearch_index_t uidx, uint32 void LogUsearchOptions(usearch_init_options_t *opts); void PopulateUsearchOpts(Relation index, usearch_init_options_t *opts); usearch_label_t GetUsearchLabel(ItemPointer itemPtr); +float4 *ToFloat4Array(ArrayType *arr); static inline void ldb_invariant(bool condition, const char *msg, ...) { diff --git a/test/expected/hnsw_insert.out b/test/expected/hnsw_insert.out index 4ceb8896c..52fe23510 100644 --- a/test/expected/hnsw_insert.out +++ b/test/expected/hnsw_insert.out @@ -9,7 +9,11 @@ set work_mem = '64kB'; set client_min_messages = 'ERROR'; CREATE TABLE small_world ( id SERIAL PRIMARY KEY, - v REAL[2] + v REAL[2] -- this demonstates that postgres actually does not enforce real[] length as we actually insert vectors of length 3 +); +CREATE TABLE small_world_int ( + id SERIAL PRIMARY KEY, + v INTEGER[] ); CREATE INDEX ON small_world USING hnsw (v) WITH (dim=3); INFO: done init usearch index @@ -28,6 +32,9 @@ INSERT INTO small_world (v) VALUES ('{0,0,1}'), ('{0,1,0}'); INSERT INTO small_world (v) VALUES (NULL); -- Attempt to insert a row with an incorrect vector length \set ON_ERROR_STOP off +-- Cannot create an hnsw index with implicit typecasts (trying to cast integer[] to real[], in this case) +CREATE INDEX ON small_world_int USING hnsw (v dist_l2sq_ops) WITH (dim=3); +ERROR: operator class "dist_l2sq_ops" does not accept data type integer[] INSERT INTO small_world (v) VALUES ('{1,1,1,1}'); ERROR: Wrong number of dimensions: 4 instead of 3 expected \set ON_ERROR_STOP on diff --git a/test/expected/hnsw_operators.out b/test/expected/hnsw_operators.out index b8971acad..7b0265935 100644 --- a/test/expected/hnsw_operators.out +++ b/test/expected/hnsw_operators.out @@ -6,6 +6,7 @@ INFO: done init usearch index INFO: inserted 2 elements INFO: done saving 2 vectors -- should rewrite operator +SET lantern.pgvector_compat=FALSE; SELECT * FROM op_test ORDER BY v <-> ARRAY[1,1,1]; v --------- @@ -27,6 +28,63 @@ ERROR: Operator <-> is invalid outside of ORDER BY context SET lantern.pgvector_compat=TRUE; SET enable_seqscan=OFF; \set ON_ERROR_STOP on +-- one-off vector distance calculations should work with relevant operator +-- with integer arrays: +SELECT ARRAY[0,0,0] <-> ARRAY[2,3,-4]; + ?column? +---------- + 29 +(1 row) + +-- with float arrays: +SELECT ARRAY[0,0,0] <-> ARRAY[2,3,-4]::real[]; + ?column? +---------- + 29 +(1 row) + +SELECT ARRAY[0,0,0]::real[] <-> ARRAY[2,3,-4]::real[]; + ?column? +---------- + 29 +(1 row) + +SELECT '{1,0,1}' <-> '{0,1,0}'::integer[]; + ?column? +---------- + 3 +(1 row) + +SELECT '{1,0,1}' <=> '{0,1,0}'::integer[]; + ?column? +---------- + 1 +(1 row) + +SELECT ROUND(num::NUMERIC, 2) FROM (SELECT '{1,1,1}' <=> '{0,1,0}'::INTEGER[] AS num) _sub; + round +------- + 0.42 +(1 row) + +SELECT ARRAY[.1,0,0] <=> ARRAY[0,.5,0]; + ?column? +---------- + 1 +(1 row) + +SELECT cos_dist(ARRAY[.1,0,0]::real[], ARRAY[0,.5,0]::real[]); + cos_dist +---------- + 1 +(1 row) + +SELECT ARRAY[1,0,0] <+> ARRAY[0,1,0]; + ?column? +---------- + 2 +(1 row) + -- NOW THIS IS TRIGGERING INDEX SCAN AS WELL -- BECAUSE WE ARE REGISTERING <-> FOR ALL OPERATOR CLASSES -- IDEALLY THIS SHOULD NOT TRIGGER INDEX SCAN WHEN lantern.pgvector_compat=TRUE diff --git a/test/expected/hnsw_vector.out b/test/expected/hnsw_vector.out index 5d2013dc6..91aaf6ce8 100644 --- a/test/expected/hnsw_vector.out +++ b/test/expected/hnsw_vector.out @@ -321,30 +321,3 @@ FROM small_world ORDER BY v <=> '[0,1,0]'::VECTOR LIMIT 7; Order By: (v <=> '[0,1,0]'::vector) (3 rows) --- hamming index -CREATE INDEX hamming_idx ON small_world USING lantern_hnsw (v dist_vec_hamming_ops); -INFO: done init usearch index -INFO: inserted 8 elements -INFO: done saving 8 vectors -SELECT ROUND((v <+> '[0,1,0]'::VECTOR)::numeric, 2) as dist -FROM small_world ORDER BY v <+> '[0,1,0]'::VECTOR LIMIT 7; - dist -------- - 0.00 - 7.00 - 7.00 - 7.00 - 14.00 - 14.00 - 14.00 -(7 rows) - -EXPLAIN (COSTS FALSE) SELECT ROUND((v <+> '[0,1,0]'::VECTOR)::numeric, 2) as dist -FROM small_world ORDER BY v <+> '[0,1,0]'::VECTOR LIMIT 7; - QUERY PLAN ---------------------------------------------------- - Limit - -> Index Scan using hamming_idx on small_world - Order By: (v <+> '[0,1,0]'::vector) -(3 rows) - diff --git a/test/sql/hnsw_insert.sql b/test/sql/hnsw_insert.sql index 96f931a88..7a46e24ca 100644 --- a/test/sql/hnsw_insert.sql +++ b/test/sql/hnsw_insert.sql @@ -10,8 +10,14 @@ set client_min_messages = 'ERROR'; CREATE TABLE small_world ( id SERIAL PRIMARY KEY, - v REAL[2] + v REAL[2] -- this demonstates that postgres actually does not enforce real[] length as we actually insert vectors of length 3 ); + +CREATE TABLE small_world_int ( + id SERIAL PRIMARY KEY, + v INTEGER[] +); + CREATE INDEX ON small_world USING hnsw (v) WITH (dim=3); SELECT _lantern_internal.validate_index('small_world_v_idx', false); @@ -21,6 +27,8 @@ INSERT INTO small_world (v) VALUES (NULL); -- Attempt to insert a row with an incorrect vector length \set ON_ERROR_STOP off +-- Cannot create an hnsw index with implicit typecasts (trying to cast integer[] to real[], in this case) +CREATE INDEX ON small_world_int USING hnsw (v dist_l2sq_ops) WITH (dim=3); INSERT INTO small_world (v) VALUES ('{1,1,1,1}'); \set ON_ERROR_STOP on diff --git a/test/sql/hnsw_operators.sql b/test/sql/hnsw_operators.sql index ba6514e1e..9002c7ed5 100644 --- a/test/sql/hnsw_operators.sql +++ b/test/sql/hnsw_operators.sql @@ -3,6 +3,7 @@ CREATE TABLE op_test (v REAL[]); INSERT INTO op_test (v) VALUES (ARRAY[0,0,0]), (ARRAY[1,1,1]); CREATE INDEX cos_idx ON op_test USING hnsw(v dist_cos_ops); -- should rewrite operator +SET lantern.pgvector_compat=FALSE; SELECT * FROM op_test ORDER BY v <-> ARRAY[1,1,1]; -- should throw error @@ -20,6 +21,19 @@ SET lantern.pgvector_compat=TRUE; SET enable_seqscan=OFF; \set ON_ERROR_STOP on +-- one-off vector distance calculations should work with relevant operator +-- with integer arrays: +SELECT ARRAY[0,0,0] <-> ARRAY[2,3,-4]; +-- with float arrays: +SELECT ARRAY[0,0,0] <-> ARRAY[2,3,-4]::real[]; +SELECT ARRAY[0,0,0]::real[] <-> ARRAY[2,3,-4]::real[]; +SELECT '{1,0,1}' <-> '{0,1,0}'::integer[]; +SELECT '{1,0,1}' <=> '{0,1,0}'::integer[]; +SELECT ROUND(num::NUMERIC, 2) FROM (SELECT '{1,1,1}' <=> '{0,1,0}'::INTEGER[] AS num) _sub; +SELECT ARRAY[.1,0,0] <=> ARRAY[0,.5,0]; +SELECT cos_dist(ARRAY[.1,0,0]::real[], ARRAY[0,.5,0]::real[]); +SELECT ARRAY[1,0,0] <+> ARRAY[0,1,0]; + -- NOW THIS IS TRIGGERING INDEX SCAN AS WELL -- BECAUSE WE ARE REGISTERING <-> FOR ALL OPERATOR CLASSES -- IDEALLY THIS SHOULD NOT TRIGGER INDEX SCAN WHEN lantern.pgvector_compat=TRUE diff --git a/test/sql/hnsw_vector.sql b/test/sql/hnsw_vector.sql index 8d000a8a4..8c7b25726 100644 --- a/test/sql/hnsw_vector.sql +++ b/test/sql/hnsw_vector.sql @@ -143,13 +143,4 @@ SELECT ROUND(cos_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist FROM small_world ORDER BY v <=> '[0,1,0]'::VECTOR LIMIT 7; EXPLAIN (COSTS FALSE) SELECT ROUND(cos_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist -FROM small_world ORDER BY v <=> '[0,1,0]'::VECTOR LIMIT 7; - --- hamming index -CREATE INDEX hamming_idx ON small_world USING lantern_hnsw (v dist_vec_hamming_ops); - -SELECT ROUND((v <+> '[0,1,0]'::VECTOR)::numeric, 2) as dist -FROM small_world ORDER BY v <+> '[0,1,0]'::VECTOR LIMIT 7; - -EXPLAIN (COSTS FALSE) SELECT ROUND((v <+> '[0,1,0]'::VECTOR)::numeric, 2) as dist -FROM small_world ORDER BY v <+> '[0,1,0]'::VECTOR LIMIT 7; +FROM small_world ORDER BY v <=> '[0,1,0]'::VECTOR LIMIT 7; \ No newline at end of file