From 63e4bcff806c8a87b19544238dfe84ade12c170a Mon Sep 17 00:00:00 2001 From: Stuart McAlpine Date: Fri, 14 Mar 2025 22:56:57 +0100 Subject: [PATCH] Add aggregate query for datasets --- src/dataregistry/query.py | 69 ++++++++++++++++++++++++++++ tests/end_to_end_tests/test_query.py | 26 +++++++++++ 2 files changed, 95 insertions(+) diff --git a/src/dataregistry/query.py b/src/dataregistry/query.py index 65ad0b4..d802f73 100644 --- a/src/dataregistry/query.py +++ b/src/dataregistry/query.py @@ -1,3 +1,4 @@ +from sqlalchemy import func from collections import namedtuple from sqlalchemy import select import sqlalchemy.sql.sqltypes as sqltypes @@ -262,6 +263,74 @@ def _parse_selected_columns(self, column_names): return list(tables_required), column_list, is_orderable_list + def aggregate_datasets(self, column_name, agg_func="count", filters=[]): + """ + Perform an aggregation (count or sum) on a specified column in the + dataset table. + + Parameters + ---------- + column_name : str + The column to perform the aggregation on. + agg_func : str, optional + The aggregation function to use: "count" (default) or "sum". + filters : list, optional + List of filters (WHERE clauses) to apply. + + Returns + ------- + result : int or float + The aggregated value. + """ + if agg_func not in {"count", "sum"}: + raise ValueError("agg_func must be either 'count' or 'sum'") + + results = [] + query_mode = self.db_connection._query_mode + + for table_key in self.db_connection.metadata["tables"]: + # Extract schema and table name + parts = table_key.split(".") + if len(parts) == 2: + schema, table = parts + else: + schema, table = None, parts[0] # SQLite case (no schema) + + # Skip non-dataset tables + if table != "dataset": + continue + + # Determine if this schema should be queried + if query_mode != "both" and schema and query_mode != schema.split("_")[-1]: + continue + + dataset_table = self.db_connection.metadata["tables"].get(table_key) + if dataset_table is None: + continue + + # Ensure the column exists before referencing it + if column_name not in dataset_table.c: + raise ValueError(f"Column '{column_name}' does not exist in dataset table") + + if agg_func == "count": + aggregation = func.count() + else: + aggregation = func.sum(dataset_table.c[column_name]) + + stmt = select(aggregation).select_from(dataset_table) + + if filters: + for f in filters: + stmt = self._render_filter(f, stmt, schema) + + with self._engine.connect() as conn: + result = conn.execute(stmt).scalar() + + if result is not None: + results.append(result) + + return sum(results) if agg_func == "sum" else sum(results) + def _render_filter(self, f, stmt, schema): """ Append SQL statement with an additional WHERE clause based on a diff --git a/tests/end_to_end_tests/test_query.py b/tests/end_to_end_tests/test_query.py index 282fef2..3e6270f 100644 --- a/tests/end_to_end_tests/test_query.py +++ b/tests/end_to_end_tests/test_query.py @@ -138,3 +138,29 @@ def test_query_name(dummy_file, op, qstr, ans, tag): assert len(results) > 0 for c, v in results.items(): assert len(v) == ans + +def test_aggregate_datasets_count(dummy_file): + """Test counting the number of datasets.""" + tmp_src_dir, tmp_root_dir = dummy_file + datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE) + + # Insert datasets + for i in range(3): + _insert_dataset_entry(datareg, f"test_aggregate_datasets_count_{i}", "1.0.0") + + # Count datasets + count = datareg.Query.aggregate_datasets("dataset_id", agg_func="count") + assert count >= 3 # Ensure at least 3 were counted + +def test_aggregate_datasets_sum(dummy_file): + """Test summing the nfiles column.""" + tmp_src_dir, tmp_root_dir = dummy_file + datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE) + + # Insert datasets with different nfiles values + for i in range(3): + _insert_dataset_entry(datareg, f"test_aggregate_datasets_sum_{i}", "1.0.0") + + # Sum nfiles + sum_value = datareg.Query.aggregate_datasets("dataset_id", agg_func="sum") + assert sum_value >= 3