Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_all_tables() #185

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 55 additions & 7 deletions src/dataregistry/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,46 @@ def __init__(self, db_connection, root_dir):
self._schema = db_connection.schema
self._root_dir = root_dir

def get_all_columns(self, include_schema=False):
def get_all_tables(self):
"""
Return all tables of the database.

Returns
-------
table_list : set
"""

table_list = set()

# Loop over each table
for tbl in self.db_connection.metadata["tables"]:
table_list.add(self.db_connection.metadata["tables"][tbl].name)

return table_list

def get_all_columns(self, table="dataset", include_table=True, include_schema=False):
"""
Return all columns of the db in <table_name>.<column_name> format.

By default results are limited to the dataset table, can be changed via
the `table` parameter (`table=None` returns all tables). By default the
`<table_name>` is included, but this can be removed setting
`include_table=False`.

If `include_schema=True` return all columns of the db in
<schema>.<table_name>.<column_name> format. Note this will essentially
duplicate the output, as the working and production schemas have the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite accurate. It assumes 1) the connection was made using a namespace and 2) dialect is not SQLite.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to make this more clear in the doc string

same layout.
same layout. Note this makes no difference for sqlite dialects (as
there are no schemas). Also, if the `DbConnection` was made directly
via a schema, not a namespace, only the connected schemas tables will
be returned.

Parameters
----------
table : str, optional
Limit results to a given table, default is dataset table
include_table : bool, optional
If true, include `<table>.` in the return string
include_schema : bool, optional
If True, also return the schema name in the column name

Expand All @@ -137,13 +166,32 @@ def get_all_columns(self, include_schema=False):
column_list = set()

# Loop over each table
for table in self.db_connection.metadata["tables"]:
for tbl in self.db_connection.metadata["tables"]:

# Loop over each column
for c in self.db_connection.metadata["tables"][table].c:
if include_schema:
column_list.add(".".join((str(c.table), str(c.name))))
for c in self.db_connection.metadata["tables"][tbl].c:

# Pull out information
if self.db_connection.dialect == "sqlite":
_schema = ""
else:
column_list.add(".".join((str(c.table.name), str(c.name))))
_schema = str(c.table).split(".")[0]
_table = str(c.table.name)
_column = c.name

# Are we considering this table?
if table is not None and _table != table:
continue

# Build string
mystr = []
if include_schema:
mystr.append(_schema)
if include_table:
mystr.append(_table)
mystr.append(_column)

column_list.add(".".join(mystr))

return column_list

Expand Down
39 changes: 39 additions & 0 deletions tests/end_to_end_tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,42 @@ def test_query_name(dummy_file, op, qstr, ans, tag):
assert len(results) > 0
for c, v in results.items():
assert len(v) == ans

@pytest.mark.parametrize(
"table,include_table,include_schema",
[
(None, True, False),
(None, False, True),
(None, False, False),
(None, True, True),
("dataset", True, False),
("execution", False, False),
]
)
def test_query_get_all_columns(dummy_file,table,include_table,include_schema):
"""Test the `get_all_columns()` function in `query.py`"""

# Establish connection to database
tmp_src_dir, tmp_root_dir = dummy_file
datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE)

cols = datareg.Query.get_all_columns(table=table, include_table=include_table, include_schema=include_schema)

assert len(cols) > 0

if table is not None:
for att in cols:
if include_table:
assert table in att

def test_query_get_all_tables(dummy_file):
"""Test the `get_all_tables()` function in `query.py`"""

# Establish connection to database
tmp_src_dir, tmp_root_dir = dummy_file
datareg = DataRegistry(root_dir=str(tmp_root_dir), namespace=DEFAULT_NAMESPACE)

tables = datareg.Query.get_all_tables()

assert len(tables) > 0