From 44d82697b0fbbca43ff9e51577efc33a9f573b1c Mon Sep 17 00:00:00 2001 From: Stuart McAlpine Date: Fri, 7 Mar 2025 15:43:15 +0100 Subject: [PATCH 1/3] Add get_all_tables function to query --- src/dataregistry/query.py | 56 +++++++++++++++++++++++++--- tests/end_to_end_tests/test_query.py | 39 +++++++++++++++++++ 2 files changed, 89 insertions(+), 6 deletions(-) diff --git a/src/dataregistry/query.py b/src/dataregistry/query.py index 865cf46..3dd6ca4 100644 --- a/src/dataregistry/query.py +++ b/src/dataregistry/query.py @@ -115,10 +115,31 @@ def __init__(self, db_connection, root_dir): self._schema = db_connection.schema self._root_dir = root_dir - def get_all_columns(self, include_schema=False): + def get_all_tables(self): + """ + Return all tables of the database. + + Returns + ------- + table_list : set + """ + + table_list = set() + + # Loop over each table + for tbl in self.db_connection.metadata["tables"]: + table_list.add(self.db_connection.metadata["tables"][tbl].name) + + return table_list + + def get_all_columns(self, table=None, include_table=True, include_schema=False): """ Return all columns of the db in . format. + You can limit to the columns of one table by passing the `table` + parameter. By default the `` is included, but this can be + removed setting `include_table=False`. + If `include_schema=True` return all columns of the db in .. format. Note this will essentially duplicate the output, as the working and production schemas have the @@ -126,6 +147,10 @@ def get_all_columns(self, include_schema=False): Parameters ---------- + table : str, optional + Limit results to a given table + include_table : bool, optional + If true, include `.` in the return string include_schema : bool, optional If True, also return the schema name in the column name @@ -137,13 +162,32 @@ def get_all_columns(self, include_schema=False): column_list = set() # Loop over each table - for table in self.db_connection.metadata["tables"]: + for tbl in self.db_connection.metadata["tables"]: + # Loop over each column - for c in self.db_connection.metadata["tables"][table].c: - if include_schema: - column_list.add(".".join((str(c.table), str(c.name)))) + for c in self.db_connection.metadata["tables"][tbl].c: + + # Pull out information + if self.db_connection.dialect == "sqlite": + schema = "" else: - column_list.add(".".join((str(c.table.name), str(c.name)))) + _schema = str(c.table).split(".")[0] + _table = str(c.table.name) + _column = c.name + + # Are we considering this table? + if table is not None and _table != table: + continue + + # Build string + mystr = [] + if include_schema: + mystr.append(_schema) + if include_table: + mystr.append(_table) + mystr.append(_column) + + column_list.add(".".join(mystr)) return column_list diff --git a/tests/end_to_end_tests/test_query.py b/tests/end_to_end_tests/test_query.py index 282fef2..e03d706 100644 --- a/tests/end_to_end_tests/test_query.py +++ b/tests/end_to_end_tests/test_query.py @@ -138,3 +138,42 @@ def test_query_name(dummy_file, op, qstr, ans, tag): assert len(results) > 0 for c, v in results.items(): assert len(v) == ans + +@pytest.mark.parametrize( + "table,include_table,include_schema", + [ + (None, True, False), + (None, False, True), + (None, False, False), + (None, True, True), + ("dataset", True, False), + ("execution", False, False), + ] +) +def test_query_get_all_columns(dummy_file,table,include_table,include_schema): + """Test the `get_all_columns()` function in `query.py`""" + + # Establish connection to database + tmp_src_dir, tmp_root_dir = dummy_file + datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE) + + cols = datareg.Query.get_all_columns(table=table, include_table=include_table, include_schema=include_schema) + + assert len(cols) > 0 + + if table is not None: + for att in cols: + if include_table: + assert table in att + +def test_query_get_all_tables(dummy_file): + """Test the `get_all_tables()` function in `query.py`""" + + # Establish connection to database + tmp_src_dir, tmp_root_dir = dummy_file + datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE) + + tables = datareg.Query.get_all_tables() + + assert len(tables) > 0 + From 7c43c84e1240bb597d1b349cf56d0ee6530ad377 Mon Sep 17 00:00:00 2001 From: Stuart McAlpine Date: Fri, 7 Mar 2025 15:47:26 +0100 Subject: [PATCH 2/3] Fix string bug --- src/dataregistry/query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dataregistry/query.py b/src/dataregistry/query.py index 3dd6ca4..fba15f9 100644 --- a/src/dataregistry/query.py +++ b/src/dataregistry/query.py @@ -169,7 +169,7 @@ def get_all_columns(self, table=None, include_table=True, include_schema=False): # Pull out information if self.db_connection.dialect == "sqlite": - schema = "" + _schema = "" else: _schema = str(c.table).split(".")[0] _table = str(c.table.name) From 61997305cc3119228b5ecd39ab73eb3953e23cbd Mon Sep 17 00:00:00 2001 From: Stuart McAlpine Date: Sat, 15 Mar 2025 00:25:15 +0100 Subject: [PATCH 3/3] Address reviewer comments --- src/dataregistry/query.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/dataregistry/query.py b/src/dataregistry/query.py index fba15f9..6370388 100644 --- a/src/dataregistry/query.py +++ b/src/dataregistry/query.py @@ -132,23 +132,27 @@ def get_all_tables(self): return table_list - def get_all_columns(self, table=None, include_table=True, include_schema=False): + def get_all_columns(self, table="dataset", include_table=True, include_schema=False): """ Return all columns of the db in . format. - You can limit to the columns of one table by passing the `table` - parameter. By default the `` is included, but this can be - removed setting `include_table=False`. + By default results are limited to the dataset table, can be changed via + the `table` parameter (`table=None` returns all tables). By default the + `` is included, but this can be removed setting + `include_table=False`. If `include_schema=True` return all columns of the db in .. format. Note this will essentially duplicate the output, as the working and production schemas have the - same layout. + same layout. Note this makes no difference for sqlite dialects (as + there are no schemas). Also, if the `DbConnection` was made directly + via a schema, not a namespace, only the connected schemas tables will + be returned. Parameters ---------- table : str, optional - Limit results to a given table + Limit results to a given table, default is dataset table include_table : bool, optional If true, include `
.` in the return string include_schema : bool, optional