From b4c4372b82d9440f4f58d372bb878b4594629fb3 Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Thu, 2 May 2024 14:49:59 -0500 Subject: [PATCH 1/2] updated tests for DB selection --- locidex/extract.py | 13 ++----------- locidex/search.py | 14 +++----------- locidex/utils.py | 19 ++++++++++++++++++- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/locidex/extract.py b/locidex/extract.py index 71fa894..c0637cd 100644 --- a/locidex/extract.py +++ b/locidex/extract.py @@ -15,7 +15,7 @@ from locidex.constants import SEARCH_RUN_DATA, FILE_TYPES, BLAST_TABLE_COLS, DBConfig, DB_EXPECTED_FILES, NT_SUB, EXTRACT_MODES, OPTION_GROUPS from locidex.version import __version__ from locidex.classes.aligner import perform_alignment, aligner -import locidex.manifest as manifest +from locidex.utils import check_db_groups def add_args(parser=None): if parser is None: @@ -263,16 +263,7 @@ def run(cmd_args=None): analysis_parameters = vars(cmd_args) - for opt in OPTION_GROUPS: - if analysis_parameters[opt] is not None: - for option in OPTION_GROUPS[opt]: - if analysis_parameters[option] is None: - raise AttributeError("Missing required parameter: {}".format(option)) - - if cmd_args.db_group is not None: - analysis_parameters["db"] = str(manifest.get_manifest_db(input_file=Path(cmd_args.db_group), name=cmd_args.db_name, version=cmd_args.db_version)) - - + analysis_parameters = check_db_groups(analysis_params=analysis_parameters, cmd_args=cmd_args) config_file = cmd_args.config diff --git a/locidex/search.py b/locidex/search.py index 06853c0..3fa045c 100644 --- a/locidex/search.py +++ b/locidex/search.py @@ -8,12 +8,11 @@ import pandas as pd -import locidex.manifest as manifest from locidex.classes.blast import blast_search, parse_blast from locidex.classes.db import search_db_conf, db_config from locidex.classes.seq_intake import seq_intake, seq_store from locidex.constants import SEARCH_RUN_DATA, FILE_TYPES, BLAST_TABLE_COLS, DB_EXPECTED_FILES, OPTION_GROUPS, DBConfig -from locidex.utils import write_seq_dict +from locidex.utils import write_seq_dict, check_db_groups from locidex.version import __version__ def add_args(parser=None): @@ -302,16 +301,9 @@ def run(cmd_args=None): parser = add_args() cmd_args = parser.parse_args() analysis_parameters = vars(cmd_args) - - for opt in OPTION_GROUPS: - if analysis_parameters[opt] is not None: - for option in OPTION_GROUPS[opt]: - if analysis_parameters[option] is None: - raise AttributeError("Missing required parameter: {}".format(option)) - - if cmd_args.db_group is not None: - analysis_parameters["db"] = str(manifest.get_manifest_db(input_file=Path(cmd_args.db_group), name=cmd_args.db_name, version=cmd_args.db_version)) + analysis_parameters = check_db_groups(analysis_params=analysis_parameters, cmd_args=cmd_args) + config_file = cmd_args.config config = {} diff --git a/locidex/utils.py b/locidex/utils.py index a8c941a..01c51e5 100644 --- a/locidex/utils.py +++ b/locidex/utils.py @@ -1,12 +1,29 @@ import hashlib import json import os +import argparse from collections import Counter +from pathlib import Path from Bio.Seq import Seq -from locidex.constants import NT_SUB, PROTEIN_ALPHA, DNA_ALPHA +from locidex.constants import NT_SUB, PROTEIN_ALPHA, DNA_ALPHA, OPTION_GROUPS +import locidex.manifest as manifest +def check_db_groups(analysis_params: dict, cmd_args: argparse.Namespace, param_db: str = "db") -> dict: + """ + Verify that a locidex database, or database group passed has all of the require parameters + """ + for opt in OPTION_GROUPS: + if analysis_params[opt] is not None: + for option in OPTION_GROUPS[opt]: + if analysis_params[option] is None: + raise AttributeError("Missing required parameter: {}".format(option)) + + if cmd_args.db_group is not None: + analysis_params[param_db] = str(manifest.get_manifest_db(input_file=Path(cmd_args.db_group), name=cmd_args.db_name, version=cmd_args.db_version)) + + return analysis_params def revcomp(s): """ From 9150ae64e716c942c98a2344db35700fa85b0437 Mon Sep 17 00:00:00 2001 From: Matthew Wells Date: Thu, 2 May 2024 14:52:17 -0500 Subject: [PATCH 2/2] forgot test file --- tests/test_utils.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..d32d3f9 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,26 @@ +""" +Test the ever growing util functions +""" + +import pytest +from locidex import utils +from locidex import manifest +from argparse import Namespace + + +def test_check_db_groups_pass(monkeypatch): + nm_group = Namespace(db_group="Db1", db_name="test_name", db_version="1.0.0") + analysis_params = {"db_group": "Db1", "db_name": "test_name", "db_version": "1.0.0"} + + def mockreturn(*args, **kwargs): + return True + monkeypatch.setattr(manifest, "get_manifest_db", mockreturn) + analysis_params = utils.check_db_groups(analysis_params, nm_group) + assert analysis_params["db"] + +def test_check_db_groups_fail(): + nm_group = Namespace(db_group="Db1", db_name="test_name", db_version="1.0.0") + analysis_params = {"db_group": "Db1", "db_name": "test_name"} + + with pytest.raises(KeyError): + analysis_params = utils.check_db_groups(analysis_params, nm_group) \ No newline at end of file