Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aggregate query for datasets #193

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions src/dataregistry/query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sqlalchemy import func, Integer, Float, Numeric
from collections import namedtuple
from sqlalchemy import select
import sqlalchemy.sql.sqltypes as sqltypes
Expand Down Expand Up @@ -262,6 +263,156 @@ def _parse_selected_columns(self, column_names):

return list(tables_required), column_list, is_orderable_list

def aggregate_datasets(self, column_name=None, agg_func="count", filters=[], table_name="dataset"):
"""
Perform an aggregation (count, sum, min, max, or avg) on a specified
column in the specified table.

Parameters
----------
column_name : str or None, optional
The column to perform the aggregation on. Can be None for "count"
aggregation.
agg_func : str, optional
The aggregation function to use: "count" (default), "sum", "min",
"max", or "avg".
filters : list, optional
List of filters (WHERE clauses) to apply.
table_name : str, optional
Table to query. Default is "dataset". For "count" aggregations, can
also be "dataset_alias", "keyword", or "dataset_keyword".

Returns
-------
result : int or float
The aggregated value.
"""
allowed_agg_funcs = {"count", "sum", "min", "max", "avg"}
allowed_tables = {"dataset", "dataset_alias", "keyword", "dataset_keyword"}

if agg_func not in allowed_agg_funcs:
raise ValueError(f"agg_func must be one of {', '.join(allowed_agg_funcs)}")

if table_name not in allowed_tables:
raise ValueError(f"table_name must be one of {', '.join(allowed_tables)}")

if agg_func != "count" and table_name != "dataset":
raise ValueError(f"Can only use agg_func '{agg_func}' on 'dataset' table")

if column_name is None and agg_func != "count":
raise ValueError("column_name cannot be None for non-count aggregations")

results = []
query_mode = self.db_connection._query_mode

for table_key in self.db_connection.metadata["tables"]:
# Extract schema and table name
parts = table_key.split(".")
if len(parts) == 2:
schema, table = parts
else:
schema, table = None, parts[0] # SQLite case (no schema)

# Skip tables that don't match the requested table
if table != table_name:
continue

# Determine if this schema should be queried
if query_mode != "both" and schema and query_mode != schema.split("_")[-1]:
continue

db_table = self.db_connection.metadata["tables"].get(table_key)
if db_table is None:
continue

# Handle 'count' aggregation with None column
if agg_func == "count" and column_name is None:
aggregation = func.count()
else:
# Check if the column exists
if column_name not in db_table.c:
raise ValueError(f"Column '{column_name}' does not exist in {table_name} table")

# For non-count aggregations, verify column type is numeric
if agg_func != "count":
col_type = db_table.c[column_name].type
is_numeric = isinstance(col_type, (Integer, Float, Numeric)) or hasattr(col_type, '_type_affinity') and col_type._type_affinity in (Integer, Float, Numeric)

if not is_numeric:
raise ValueError(f"Column '{column_name}' must be numeric for '{agg_func}' aggregation")

# Set up the appropriate aggregation function
if agg_func == "count":
aggregation = func.count(db_table.c[column_name])
elif agg_func == "sum":
aggregation = func.sum(db_table.c[column_name])
elif agg_func == "min":
aggregation = func.min(db_table.c[column_name])
elif agg_func == "max":
aggregation = func.max(db_table.c[column_name])
elif agg_func == "avg":
aggregation = func.avg(db_table.c[column_name])

stmt = select(aggregation).select_from(db_table)

if filters:
for f in filters:
stmt = self._render_filter(f, stmt, schema)

with self._engine.connect() as conn:
result = conn.execute(stmt).scalar()

if result is not None:
results.append(result)

# For most aggregations, we sum across tables
if agg_func in ("count", "sum"):
return sum(results) if results else 0
# For min/max, we need to find the min/max across all tables
elif agg_func == "min":
return min(results) if results else None
elif agg_func == "max":
return max(results) if results else None
# For avg, we compute a weighted average across all tables
elif agg_func == "avg" and results:
# We need to get count for each table to compute weighted avg
counts = []
for table_key in self.db_connection.metadata["tables"]:
parts = table_key.split(".")
if len(parts) == 2:
schema, table = parts
else:
schema, table = None, parts[0]

if table != table_name:
continue

if query_mode != "both" and schema and query_mode != schema.split("_")[-1]:
continue

db_table = self.db_connection.metadata["tables"].get(table_key)
if db_table is None:
continue

stmt = select(func.count()).select_from(db_table)

if filters:
for f in filters:
stmt = self._render_filter(f, stmt, schema)

with self._engine.connect() as conn:
count = conn.execute(stmt).scalar() or 0

counts.append(count)

total_count = sum(counts)
if total_count > 0:
weighted_avg = sum(avg * count for avg, count in zip(results, counts)) / total_count
return weighted_avg
return None

return None

def _render_filter(self, f, stmt, schema):
"""
Append SQL statement with an additional WHERE clause based on a
Expand Down
165 changes: 162 additions & 3 deletions tests/end_to_end_tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def test_query_between_columns(dummy_file):
_NAME = "DESC:datasets:test_query_between_columns"
_V_STRING = "0.0.1"

e_id = _insert_execution_entry(
datareg, "test_query_between_columns", "test"
)
e_id = _insert_execution_entry(datareg, "test_query_between_columns", "test")

d_id = _insert_dataset_entry(datareg, _NAME, _V_STRING, execution_id=e_id)

Expand Down Expand Up @@ -100,6 +98,7 @@ def test_query_between_columns(dummy_file):
assert results["dataset.name"][0] == _NAME
assert results["dataset.version_string"][0] == _V_STRING


@pytest.mark.skipif(
datareg.db_connection._dialect == "sqlite", reason="wildcards break for sqlite"
)
Expand Down Expand Up @@ -138,3 +137,163 @@ def test_query_name(dummy_file, op, qstr, ans, tag):
assert len(results) > 0
for c, v in results.items():
assert len(v) == ans


def test_aggregate_datasets_count(dummy_file):
"""Test counting the number of datasets."""
tmp_src_dir, tmp_root_dir = dummy_file
datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE)

# Insert datasets
for i in range(3):
_insert_dataset_entry(datareg, f"test_aggregate_datasets_count_{i}", "1.0.0")

# Count datasets
count = datareg.Query.aggregate_datasets("dataset_id", agg_func="count")
assert count >= 3 # Ensure at least 3 were counted


def test_aggregate_datasets_count_with_none_column(dummy_file):
"""Test counting datasets with None column."""
tmp_src_dir, tmp_root_dir = dummy_file
datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE)

# Insert datasets
for i in range(3):
_insert_dataset_entry(datareg, f"test_count_none_col_{i}", "1.0.0")

# Count datasets with None column
count = datareg.Query.aggregate_datasets(column_name=None, agg_func="count")
assert count >= 3


def test_aggregate_datasets_sum(dummy_file):
"""Test summing the column values."""
tmp_src_dir, tmp_root_dir = dummy_file
datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE)

# Insert datasets
for i in range(3):
_insert_dataset_entry(datareg, f"test_aggregate_datasets_sum_{i}", "1.0.0")

sum_value = datareg.Query.aggregate_datasets("dataset_id", agg_func="sum")
assert sum_value >= 3


def test_aggregate_datasets_min(dummy_file):
"""Test finding the minimum value in a column."""
tmp_src_dir, tmp_root_dir = dummy_file
datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE)

# Insert datasets
for i in range(3):
dataset_id = f"test_aggregate_datasets_min_{i}"
_insert_dataset_entry(datareg, dataset_id, "1.0.0")

min_value = datareg.Query.aggregate_datasets("dataset_id", agg_func="min")
assert min_value >= 0


def test_aggregate_datasets_max(dummy_file):
"""Test finding the maximum value in a column."""
tmp_src_dir, tmp_root_dir = dummy_file
datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE)

# Insert datasets
for i in range(3):
dataset_id = f"test_aggregate_datasets_max_{i}"
_insert_dataset_entry(datareg, dataset_id, "1.0.0")

max_value = datareg.Query.aggregate_datasets("dataset_id", agg_func="max")
assert max_value >= 3


def test_aggregate_datasets_avg(dummy_file):
"""Test finding the average value in a column."""
tmp_src_dir, tmp_root_dir = dummy_file
datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE)

# Insert datasets
for i in range(3):
dataset_id = f"test_aggregate_datasets_avg_{i}"
_insert_dataset_entry(datareg, dataset_id, "1.0.0")

avg_value = datareg.Query.aggregate_datasets("dataset_id", agg_func="avg")
assert avg_value > 0


def test_aggregate_datasets_with_non_dataset_table(dummy_file):
"""Test counting records in non-dataset tables."""
tmp_src_dir, tmp_root_dir = dummy_file
datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE)

# Insert dataset
d_id = _insert_dataset_entry(
datareg,
"test_aggregate_datasets_with_non_dataset_table",
"0.0.1",
)

a_id = _insert_alias_entry(
datareg.Registrar, "test_aggregate_datasets_with_non_dataset_table_alias", d_id
)

count = datareg.Query.aggregate_datasets(
column_name=None,
agg_func="count",
table_name="dataset_alias",
)
assert count >= 1


def test_aggregate_datasets_with_filters(dummy_file):
"""Test aggregation with filters applied."""
tmp_src_dir, tmp_root_dir = dummy_file
datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE)

# Insert datasets with different versions
for i in range(3):
_insert_dataset_entry(
datareg, f"test_aggregate_datasets_with_filters_{i}", "12.123.111"
)

# Count with version filter
f = datareg.Query.gen_filter("dataset.version_string", "==", "12.123.111")
count = datareg.Query.aggregate_datasets(
column_name=None, agg_func="count", filters=[f]
)
assert count == 3


def test_aggregate_datasets_errors(dummy_file):
"""Test error cases for the aggregation function."""
tmp_src_dir, tmp_root_dir = dummy_file
datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE)

# Test invalid aggregation function
with pytest.raises(ValueError, match="agg_func must be one of"):
datareg.Query.aggregate_datasets("dataset_id", agg_func="invalid")

# Test invalid table name
with pytest.raises(ValueError, match="table_name must be one of"):
datareg.Query.aggregate_datasets("dataset_id", table_name="invalid")

# Test non-count aggregation on non-dataset table
with pytest.raises(ValueError, match="Can only use agg_func"):
datareg.Query.aggregate_datasets(
"id", agg_func="sum", table_name="dataset_alias"
)

# Test None column with non-count aggregation
with pytest.raises(ValueError, match="column_name cannot be None"):
datareg.Query.aggregate_datasets(None, agg_func="sum")

# Test non-existent column
with pytest.raises(ValueError, match="Column.*does not exist"):
datareg.Query.aggregate_datasets("non_existent_column", agg_func="count")

# Test non-numeric column with numeric aggregation
# This requires knowing a non-numeric column in your schema
# Assuming dataset_id is non-numeric:
with pytest.raises(ValueError, match="must be numeric"):
datareg.Query.aggregate_datasets("description", agg_func="sum")