Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aggregate query for datasets #193

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/dataregistry/query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sqlalchemy import func
from collections import namedtuple
from sqlalchemy import select
import sqlalchemy.sql.sqltypes as sqltypes
Expand Down Expand Up @@ -262,6 +263,74 @@ def _parse_selected_columns(self, column_names):

return list(tables_required), column_list, is_orderable_list

def aggregate_datasets(self, column_name, agg_func="count", filters=[]):
"""
Perform an aggregation (count or sum) on a specified column in the
dataset table.

Parameters
----------
column_name : str
The column to perform the aggregation on.
agg_func : str, optional
The aggregation function to use: "count" (default) or "sum".
filters : list, optional
List of filters (WHERE clauses) to apply.

Returns
-------
result : int or float
The aggregated value.
"""
if agg_func not in {"count", "sum"}:
raise ValueError("agg_func must be either 'count' or 'sum'")

results = []
query_mode = self.db_connection._query_mode

for table_key in self.db_connection.metadata["tables"]:
# Extract schema and table name
parts = table_key.split(".")
if len(parts) == 2:
schema, table = parts
else:
schema, table = None, parts[0] # SQLite case (no schema)

# Skip non-dataset tables
if table != "dataset":
continue

# Determine if this schema should be queried
if query_mode != "both" and schema and query_mode != schema.split("_")[-1]:
continue

dataset_table = self.db_connection.metadata["tables"].get(table_key)
if dataset_table is None:
continue

# Ensure the column exists before referencing it
if column_name not in dataset_table.c:
raise ValueError(f"Column '{column_name}' does not exist in dataset table")

if agg_func == "count":
aggregation = func.count()
else:
aggregation = func.sum(dataset_table.c[column_name])

stmt = select(aggregation).select_from(dataset_table)

if filters:
for f in filters:
stmt = self._render_filter(f, stmt, schema)

with self._engine.connect() as conn:
result = conn.execute(stmt).scalar()

if result is not None:
results.append(result)

return sum(results) if agg_func == "sum" else sum(results)

def _render_filter(self, f, stmt, schema):
"""
Append SQL statement with an additional WHERE clause based on a
Expand Down
26 changes: 26 additions & 0 deletions tests/end_to_end_tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,29 @@ def test_query_name(dummy_file, op, qstr, ans, tag):
assert len(results) > 0
for c, v in results.items():
assert len(v) == ans

def test_aggregate_datasets_count(dummy_file):
"""Test counting the number of datasets."""
tmp_src_dir, tmp_root_dir = dummy_file
datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE)

# Insert datasets
for i in range(3):
_insert_dataset_entry(datareg, f"test_aggregate_datasets_count_{i}", "1.0.0")

# Count datasets
count = datareg.Query.aggregate_datasets("dataset_id", agg_func="count")
assert count >= 3 # Ensure at least 3 were counted

def test_aggregate_datasets_sum(dummy_file):
"""Test summing the nfiles column."""
tmp_src_dir, tmp_root_dir = dummy_file
datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE)

# Insert datasets with different nfiles values
for i in range(3):
_insert_dataset_entry(datareg, f"test_aggregate_datasets_sum_{i}", "1.0.0")

# Sum nfiles
sum_value = datareg.Query.aggregate_datasets("dataset_id", agg_func="sum")
assert sum_value >= 3