Skip to content

Commit

Permalink
filter: Use print_err and AugurError
Browse files Browse the repository at this point in the history
This also simplifies the implementation of validate_arguments() to raise AugurErrors directly instead of returning a boolean to be translated to a proper error message by the caller.
  • Loading branch information
victorlin committed May 28, 2022
1 parent 446b39d commit 84992ba
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 63 deletions.
75 changes: 22 additions & 53 deletions augur/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
import pandas as pd
import random
import re
import sys
from tempfile import NamedTemporaryFile
from typing import Collection

from .dates import numeric_date, numeric_date_type, SUPPORTED_DATE_HELP_TEXT, is_date_ambiguous, get_numerical_dates
from .errors import AugurError
from .index import index_sequences, index_vcf
from .io import open_file, read_metadata, read_sequences, write_sequences, is_vcf as filename_is_vcf, write_vcf
from .io import open_file, print_err, read_metadata, read_sequences, write_sequences, is_vcf as filename_is_vcf, write_vcf
from .utils import read_strains

comment_char = '#'
Expand Down Expand Up @@ -48,8 +47,7 @@ def constant_factory(value):
for elems in (line.strip().split('\t') if '\t' in line else line.strip().split() for line in pfile.readlines())
})
except Exception as e:
print(f"ERROR: missing or malformed priority scores file {fname}", file=sys.stderr)
raise e
raise AugurError(f"ERROR: missing or malformed priority scores file {fname}")

# Define metadata filters.

Expand Down Expand Up @@ -634,7 +632,7 @@ def construct_filters(args, sequence_index):
is_vcf = filename_is_vcf(args.sequences)

if is_vcf: #doesn't make sense for VCF, ignore.
print("WARNING: Cannot use min_length for VCF files. Ignoring...", file=sys.stderr)
print_err("WARNING: Cannot use min_length for VCF files. Ignoring...")
else:
exclude_by.append((
filter_by_sequence_length,
Expand Down Expand Up @@ -925,8 +923,8 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
if 'year' in group_by_set or 'month' in group_by_set:
if 'date' not in metadata:
# set year/month/day = unknown
print(f"WARNING: A 'date' column could not be found to group-by year or month.", file=sys.stderr)
print(f"Filtering by group may behave differently than expected!", file=sys.stderr)
print_err(f"WARNING: A 'date' column could not be found to group-by year or month.")
print_err(f"Filtering by group may behave differently than expected!")
df_dates = pd.DataFrame({'year': 'unknown', 'month': 'unknown'}, index=metadata.index)
metadata = pd.concat([metadata, df_dates], axis=1)
else:
Expand Down Expand Up @@ -966,8 +964,8 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):

unknown_groups = group_by_set - set(metadata.columns)
if unknown_groups:
print(f"WARNING: Some of the specified group-by categories couldn't be found: {', '.join(unknown_groups)}", file=sys.stderr)
print("Filtering by group may behave differently than expected!", file=sys.stderr)
print_err(f"WARNING: Some of the specified group-by categories couldn't be found: {', '.join(unknown_groups)}")
print_err("Filtering by group may behave differently than expected!")
for group in unknown_groups:
metadata[group] = 'unknown'

Expand Down Expand Up @@ -1164,41 +1162,25 @@ def register_arguments(parser):


def validate_arguments(args):
"""Validate arguments and return a boolean representing whether all validation
rules succeeded.
"""Validate arguments.
Parameters
----------
args : argparse.Namespace
Parsed arguments from argparse
Returns
-------
bool :
Validation succeeded.
"""
# Don't allow sequence output when no sequence input is provided.
if args.output and not args.sequences:
print(
"ERROR: You need to provide sequences to output sequences.",
file=sys.stderr)
return False
raise AugurError("You need to provide sequences to output sequences.")

# Confirm that at least one output was requested.
if not any((args.output, args.output_metadata, args.output_strains)):
print(
"ERROR: You need to select at least one output.",
file=sys.stderr)
return False
raise AugurError("You need to select at least one output.")

# Don't allow filtering on sequence-based information, if no sequences or
# sequence index is provided.
if not args.sequences and not args.sequence_index and any(getattr(args, arg) for arg in SEQUENCE_ONLY_FILTERS):
print(
"ERROR: You need to provide a sequence index or sequences to filter on sequence-specific information.",
file=sys.stderr)
return False
raise AugurError("You need to provide a sequence index or sequences to filter on sequence-specific information.")

# Set flags if VCF
is_vcf = filename_is_vcf(args.sequences)
Expand All @@ -1207,29 +1189,20 @@ def validate_arguments(args):
if is_vcf:
from shutil import which
if which("vcftools") is None:
print("ERROR: 'vcftools' is not installed! This is required for VCF data. "
"Please see the augur install instructions to install it.",
file=sys.stderr)
return False
raise AugurError("'vcftools' is not installed! This is required for VCF data. "
"Please see the augur install instructions to install it.")

# If user requested grouping, confirm that other required inputs are provided, too.
if args.group_by and not any((args.sequences_per_group, args.subsample_max_sequences)):
print(
"ERROR: You must specify a number of sequences per group or maximum sequences to subsample.",
file=sys.stderr
)
return False

return True
raise AugurError("You must specify a number of sequences per group or maximum sequences to subsample.")


def run(args):
'''
filter and subsample a set of sequences into an analysis set
'''
# Validate arguments before attempting any I/O.
if not validate_arguments(args):
return 1
validate_arguments(args)

# Determine whether the sequence index exists or whether should be
# generated. We need to generate an index if the input sequences are in a
Expand All @@ -1252,10 +1225,9 @@ def run(args):
with NamedTemporaryFile(delete=False) as sequence_index_file:
sequence_index_path = sequence_index_file.name

print(
print_err(
"Note: You did not provide a sequence index, so Augur will generate one.",
"You can generate your own index ahead of time with `augur index` and pass it with `augur filter --sequence-index`.",
file=sys.stderr
"You can generate your own index ahead of time with `augur index` and pass it with `augur filter --sequence-index`."
)

if is_vcf:
Expand Down Expand Up @@ -1494,8 +1466,7 @@ def run(args):
args.probabilistic_sampling,
)
except TooManyGroupsError as error:
print(f"ERROR: {error}", file=sys.stderr)
sys.exit(1)
raise AugurError(error)

if (probabilistic_used):
print(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.")
Expand Down Expand Up @@ -1620,10 +1591,9 @@ def run(args):
# Warn the user if the expected strains from the sequence index are
# not a superset of the observed strains.
if sequence_strains is not None and observed_sequence_strains > sequence_strains:
print(
print_err(
"WARNING: The sequence index is out of sync with the provided sequences.",
"Metadata and strain output may not match sequence output.",
file=sys.stderr
"Metadata and strain output may not match sequence output."
)

# Update the set of available sequence strains.
Expand Down Expand Up @@ -1682,8 +1652,7 @@ def run(args):
print("\t%i of these were dropped because of subsampling criteria%s" % (num_excluded_subsamp, seed_txt))

if total_strains_passed == 0:
print("ERROR: All samples have been dropped! Check filter rules and metadata file format.", file=sys.stderr)
return 1
raise AugurError("All samples have been dropped! Check filter rules and metadata file format.")

print(f"{total_strains_passed} strains passed all filters")

Expand Down Expand Up @@ -1729,7 +1698,7 @@ def calculate_sequences_per_group(target_max_value, counts_per_group, allow_prob
)
except TooManyGroupsError as error:
if allow_probabilistic:
print(f"WARNING: {error}", file=sys.stderr)
print_err(f"WARNING: {error}")
sequences_per_group = _calculate_fractional_sequences_per_group(
target_max_value,
counts_per_group,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ This should fail because the requested filters rely on sequence information.
> --min-length 10000 \
> --output-strains "$TMP/filtered_strains.txt" > /dev/null
ERROR: You need to provide a sequence index or sequences to filter on sequence-specific information.
[1]
[2]
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ This should produce no results because the intersection of metadata and sequence
> --output-strains "$TMP/filtered_strains.txt" > /dev/null
Note: You did not provide a sequence index, so Augur will generate one. You can generate your own index ahead of time with `augur index` and pass it with `augur filter --sequence-index`.
ERROR: All samples have been dropped! Check filter rules and metadata file format.
[1]
[2]
$ wc -l "$TMP/filtered_strains.txt"
\s*0 .* (re)
$ rm -f "$TMP/filtered_strains.txt"
Expand All @@ -30,7 +30,7 @@ Repeat with sequence and strain outputs. We should get the same results.
> --output-sequences "$TMP/filtered.fasta" > /dev/null
Note: You did not provide a sequence index, so Augur will generate one. You can generate your own index ahead of time with `augur index` and pass it with `augur filter --sequence-index`.
ERROR: All samples have been dropped! Check filter rules and metadata file format.
[1]
[2]
$ wc -l "$TMP/filtered_strains.txt"
\s*0 .* (re)
$ grep "^>" "$TMP/filtered.fasta" | wc -l
Expand All @@ -47,7 +47,7 @@ Since we expect metadata to be filtered by presence of strains in input sequence
> --output-strains "$TMP/filtered_strains.txt" > /dev/null
Note: You did not provide a sequence index, so Augur will generate one. You can generate your own index ahead of time with `augur index` and pass it with `augur filter --sequence-index`.
ERROR: All samples have been dropped! Check filter rules and metadata file format.
[1]
[2]
$ wc -l "$TMP/filtered_strains.txt"
\s*0 .* (re)
$ rm -f "$TMP/filtered_strains.txt"
2 changes: 1 addition & 1 deletion tests/functional/filter/cram/filter-no-outputs-error.t
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ Try to filter without any outputs.
> --metadata filter/data/metadata.tsv \
> --min-length 10000 > /dev/null
ERROR: You need to select at least one output.
[1]
[2]
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ This should fail.
> --min-length 10000 \
> --output "$TMP/filtered.fasta" > /dev/null
ERROR: You need to provide sequences to output sequences.
[1]
[2]
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ This should fail, as probabilistic sampling is explicitly disabled.
> --no-probabilistic-sampling \
> --output "$TMP/filtered.fasta"
ERROR: Asked to provide at most 5 sequences, but there are 8 groups.
[1]
[2]
$ rm -f "$TMP/filtered.fasta"
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ This should fail with a helpful error message.
> --group-by year month \
> --output-strains "$TMP/filtered_strains.txt" > /dev/null
ERROR: You must specify a number of sequences per group or maximum sequences to subsample.
[1]
[2]
7 changes: 5 additions & 2 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import augur.filter
from augur.io import read_metadata
from augur.utils import AugurError

@pytest.fixture
def argparser():
Expand Down Expand Up @@ -75,9 +76,10 @@ def test_read_priority_scores_valid(self, mock_priorities_file_valid):
assert priorities["strain42"] == -np.inf, "Default priority is negative infinity for unlisted sequences"

def test_read_priority_scores_malformed(self, mock_priorities_file_malformed):
with pytest.raises(ValueError):
with pytest.raises(AugurError) as e_info:
# builtins.open is stubbed, but we need a valid file to satisfy the existence check
augur.filter.read_priority_scores("tests/builds/tb/data/lee_2015.vcf")
assert str(e_info.value) == "ERROR: missing or malformed priority scores file tests/builds/tb/data/lee_2015.vcf"

def test_read_priority_scores_valid_with_spaces_and_tabs(self, mock_priorities_file_valid_with_spaces_and_tabs):
# builtins.open is stubbed, but we need a valid file to satisfy the existence check
Expand All @@ -88,8 +90,9 @@ def test_read_priority_scores_valid_with_spaces_and_tabs(self, mock_priorities_f
assert priorities == {"strain 1": 5, "strain 2": 6, "strain 3": 8}

def test_read_priority_scores_does_not_exist(self):
with pytest.raises(FileNotFoundError):
with pytest.raises(AugurError) as e_info:
augur.filter.read_priority_scores("/does/not/exist.txt")
assert str(e_info.value) == "ERROR: missing or malformed priority scores file /does/not/exist.txt"

def test_filter_on_query_good(self, tmpdir, sequences):
"""Basic filter_on_query test"""
Expand Down

0 comments on commit 84992ba

Please sign in to comment.